diff --git a/tools/summarizer/app.py b/tools/summarizer/app.py index 1fe51f6..a572a07 100644 --- a/tools/summarizer/app.py +++ b/tools/summarizer/app.py @@ -1,5 +1,5 @@ -from fastapi import FastAPI -from pydantic import BaseModel +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from cachetools import TTLCache import hashlib @@ -9,7 +9,10 @@ import torch app = FastAPI(title="Local Summarizer") MODEL_NAME = "sshleifer/distilbart-cnn-12-6" +MAX_INPUT_CHARS = 20000 +# The local summarizer is intentionally simple, but we still validate request sizes +# so accidental giant pastes do not cause avoidable latency or memory spikes. tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) model.eval() @@ -19,9 +22,9 @@ cache = TTLCache(maxsize=1024, ttl=60 * 60) # 1 hour cache class SummarizeRequest(BaseModel): - text: str - max_length: int = 160 - min_length: int = 45 + text: str = Field(min_length=1, max_length=MAX_INPUT_CHARS) + max_length: int = Field(default=160, ge=24, le=256) + min_length: int = Field(default=45, ge=8, le=180) def _key(text: str, max_length: int, min_length: int) -> str: @@ -64,8 +67,6 @@ _TECH = [ "rest", ] - - _SOFT = [ "communication", "collaboration", @@ -80,6 +81,7 @@ _SOFT = [ "detail oriented", ] + def _strip_html(text: str) -> str: # Good enough for job descriptions pasted from the web. text = re.sub(r"<\s*br\s*/?>", "\n", text, flags=re.IGNORECASE) @@ -114,7 +116,7 @@ def _role_focused_excerpt(text: str) -> dict: } def match_heading(s: str): - sl = s.lower().strip(":- ") + sl = s.lower().strip(":-\x7f ") for k, words in headings.items(): for w in words: if sl == w or sl.startswith(w + " "): @@ -201,6 +203,9 @@ def _model_summarize(text: str, max_length: int, min_length: int) -> str: @app.post("/summarize") async def summarize(req: SummarizeRequest): + if req.min_length >= req.max_length: + raise HTTPException(status_code=400, detail="min_length must be smaller than max_length.") + key = _key(req.text, req.max_length, req.min_length) if key in cache: return {"summary": cache[key], "cached": True} diff --git a/tools/summarizer/requirements.txt b/tools/summarizer/requirements.txt index be9b26e..f11de3e 100644 --- a/tools/summarizer/requirements.txt +++ b/tools/summarizer/requirements.txt @@ -3,3 +3,4 @@ uvicorn[standard]==0.34.0 transformers==4.48.3 cachetools==5.5.2 pydantic==2.10.6 +torch==2.6.0