fix: validate summarizer inputs and add missing torch dependency

This commit is contained in:
cesnimda
2026-03-22 14:00:41 +01:00
parent a974e80ca4
commit 10e10bb6a7
2 changed files with 14 additions and 8 deletions
+13 -8
View File
@@ -1,5 +1,5 @@
from fastapi import FastAPI from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from cachetools import TTLCache from cachetools import TTLCache
import hashlib import hashlib
@@ -9,7 +9,10 @@ import torch
app = FastAPI(title="Local Summarizer") app = FastAPI(title="Local Summarizer")
MODEL_NAME = "sshleifer/distilbart-cnn-12-6" MODEL_NAME = "sshleifer/distilbart-cnn-12-6"
MAX_INPUT_CHARS = 20000
# The local summarizer is intentionally simple, but we still validate request sizes
# so accidental giant pastes do not cause avoidable latency or memory spikes.
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.eval() model.eval()
@@ -19,9 +22,9 @@ cache = TTLCache(maxsize=1024, ttl=60 * 60) # 1 hour cache
class SummarizeRequest(BaseModel): class SummarizeRequest(BaseModel):
text: str text: str = Field(min_length=1, max_length=MAX_INPUT_CHARS)
max_length: int = 160 max_length: int = Field(default=160, ge=24, le=256)
min_length: int = 45 min_length: int = Field(default=45, ge=8, le=180)
def _key(text: str, max_length: int, min_length: int) -> str: def _key(text: str, max_length: int, min_length: int) -> str:
@@ -64,8 +67,6 @@ _TECH = [
"rest", "rest",
] ]
_SOFT = [ _SOFT = [
"communication", "communication",
"collaboration", "collaboration",
@@ -80,6 +81,7 @@ _SOFT = [
"detail oriented", "detail oriented",
] ]
def _strip_html(text: str) -> str: def _strip_html(text: str) -> str:
# Good enough for job descriptions pasted from the web. # Good enough for job descriptions pasted from the web.
text = re.sub(r"<\s*br\s*/?>", "\n", text, flags=re.IGNORECASE) text = re.sub(r"<\s*br\s*/?>", "\n", text, flags=re.IGNORECASE)
@@ -114,7 +116,7 @@ def _role_focused_excerpt(text: str) -> dict:
} }
def match_heading(s: str): def match_heading(s: str):
sl = s.lower().strip(":- ") sl = s.lower().strip(":-\x7f ")
for k, words in headings.items(): for k, words in headings.items():
for w in words: for w in words:
if sl == w or sl.startswith(w + " "): if sl == w or sl.startswith(w + " "):
@@ -201,6 +203,9 @@ def _model_summarize(text: str, max_length: int, min_length: int) -> str:
@app.post("/summarize") @app.post("/summarize")
async def summarize(req: SummarizeRequest): async def summarize(req: SummarizeRequest):
if req.min_length >= req.max_length:
raise HTTPException(status_code=400, detail="min_length must be smaller than max_length.")
key = _key(req.text, req.max_length, req.min_length) key = _key(req.text, req.max_length, req.min_length)
if key in cache: if key in cache:
return {"summary": cache[key], "cached": True} return {"summary": cache[key], "cached": True}
+1
View File
@@ -3,3 +3,4 @@ uvicorn[standard]==0.34.0
transformers==4.48.3 transformers==4.48.3
cachetools==5.5.2 cachetools==5.5.2
pydantic==2.10.6 pydantic==2.10.6
torch==2.6.0