fix: validate summarizer inputs and add missing torch dependency

This commit is contained in:
cesnimda
2026-03-22 14:00:41 +01:00
parent a974e80ca4
commit 10e10bb6a7
2 changed files with 14 additions and 8 deletions
+13 -8
View File
@@ -1,5 +1,5 @@
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from cachetools import TTLCache
import hashlib
@@ -9,7 +9,10 @@ import torch
app = FastAPI(title="Local Summarizer")
MODEL_NAME = "sshleifer/distilbart-cnn-12-6"
MAX_INPUT_CHARS = 20000
# The local summarizer is intentionally simple, but we still validate request sizes
# so accidental giant pastes do not cause avoidable latency or memory spikes.
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.eval()
@@ -19,9 +22,9 @@ cache = TTLCache(maxsize=1024, ttl=60 * 60) # 1 hour cache
class SummarizeRequest(BaseModel):
text: str
max_length: int = 160
min_length: int = 45
text: str = Field(min_length=1, max_length=MAX_INPUT_CHARS)
max_length: int = Field(default=160, ge=24, le=256)
min_length: int = Field(default=45, ge=8, le=180)
def _key(text: str, max_length: int, min_length: int) -> str:
@@ -64,8 +67,6 @@ _TECH = [
"rest",
]
_SOFT = [
"communication",
"collaboration",
@@ -80,6 +81,7 @@ _SOFT = [
"detail oriented",
]
def _strip_html(text: str) -> str:
# Good enough for job descriptions pasted from the web.
text = re.sub(r"<\s*br\s*/?>", "\n", text, flags=re.IGNORECASE)
@@ -114,7 +116,7 @@ def _role_focused_excerpt(text: str) -> dict:
}
def match_heading(s: str):
sl = s.lower().strip(":- ")
sl = s.lower().strip(":-\x7f ")
for k, words in headings.items():
for w in words:
if sl == w or sl.startswith(w + " "):
@@ -201,6 +203,9 @@ def _model_summarize(text: str, max_length: int, min_length: int) -> str:
@app.post("/summarize")
async def summarize(req: SummarizeRequest):
if req.min_length >= req.max_length:
raise HTTPException(status_code=400, detail="min_length must be smaller than max_length.")
key = _key(req.text, req.max_length, req.min_length)
if key in cache:
return {"summary": cache[key], "cached": True}
+1
View File
@@ -3,3 +3,4 @@ uvicorn[standard]==0.34.0
transformers==4.48.3
cachetools==5.5.2
pydantic==2.10.6
torch==2.6.0