fix: validate summarizer inputs and add missing torch dependency
This commit is contained in:
+13
-8
@@ -1,5 +1,5 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||||
from cachetools import TTLCache
|
from cachetools import TTLCache
|
||||||
import hashlib
|
import hashlib
|
||||||
@@ -9,7 +9,10 @@ import torch
|
|||||||
app = FastAPI(title="Local Summarizer")
|
app = FastAPI(title="Local Summarizer")
|
||||||
|
|
||||||
MODEL_NAME = "sshleifer/distilbart-cnn-12-6"
|
MODEL_NAME = "sshleifer/distilbart-cnn-12-6"
|
||||||
|
MAX_INPUT_CHARS = 20000
|
||||||
|
|
||||||
|
# The local summarizer is intentionally simple, but we still validate request sizes
|
||||||
|
# so accidental giant pastes do not cause avoidable latency or memory spikes.
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -19,9 +22,9 @@ cache = TTLCache(maxsize=1024, ttl=60 * 60) # 1 hour cache
|
|||||||
|
|
||||||
|
|
||||||
class SummarizeRequest(BaseModel):
|
class SummarizeRequest(BaseModel):
|
||||||
text: str
|
text: str = Field(min_length=1, max_length=MAX_INPUT_CHARS)
|
||||||
max_length: int = 160
|
max_length: int = Field(default=160, ge=24, le=256)
|
||||||
min_length: int = 45
|
min_length: int = Field(default=45, ge=8, le=180)
|
||||||
|
|
||||||
|
|
||||||
def _key(text: str, max_length: int, min_length: int) -> str:
|
def _key(text: str, max_length: int, min_length: int) -> str:
|
||||||
@@ -64,8 +67,6 @@ _TECH = [
|
|||||||
"rest",
|
"rest",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_SOFT = [
|
_SOFT = [
|
||||||
"communication",
|
"communication",
|
||||||
"collaboration",
|
"collaboration",
|
||||||
@@ -80,6 +81,7 @@ _SOFT = [
|
|||||||
"detail oriented",
|
"detail oriented",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _strip_html(text: str) -> str:
|
def _strip_html(text: str) -> str:
|
||||||
# Good enough for job descriptions pasted from the web.
|
# Good enough for job descriptions pasted from the web.
|
||||||
text = re.sub(r"<\s*br\s*/?>", "\n", text, flags=re.IGNORECASE)
|
text = re.sub(r"<\s*br\s*/?>", "\n", text, flags=re.IGNORECASE)
|
||||||
@@ -114,7 +116,7 @@ def _role_focused_excerpt(text: str) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def match_heading(s: str):
|
def match_heading(s: str):
|
||||||
sl = s.lower().strip(":- ")
|
sl = s.lower().strip(":-\x7f ")
|
||||||
for k, words in headings.items():
|
for k, words in headings.items():
|
||||||
for w in words:
|
for w in words:
|
||||||
if sl == w or sl.startswith(w + " "):
|
if sl == w or sl.startswith(w + " "):
|
||||||
@@ -201,6 +203,9 @@ def _model_summarize(text: str, max_length: int, min_length: int) -> str:
|
|||||||
|
|
||||||
@app.post("/summarize")
|
@app.post("/summarize")
|
||||||
async def summarize(req: SummarizeRequest):
|
async def summarize(req: SummarizeRequest):
|
||||||
|
if req.min_length >= req.max_length:
|
||||||
|
raise HTTPException(status_code=400, detail="min_length must be smaller than max_length.")
|
||||||
|
|
||||||
key = _key(req.text, req.max_length, req.min_length)
|
key = _key(req.text, req.max_length, req.min_length)
|
||||||
if key in cache:
|
if key in cache:
|
||||||
return {"summary": cache[key], "cached": True}
|
return {"summary": cache[key], "cached": True}
|
||||||
|
|||||||
@@ -3,3 +3,4 @@ uvicorn[standard]==0.34.0
|
|||||||
transformers==4.48.3
|
transformers==4.48.3
|
||||||
cachetools==5.5.2
|
cachetools==5.5.2
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
|
torch==2.6.0
|
||||||
|
|||||||
Reference in New Issue
Block a user