fix: validate summarizer inputs and add missing torch dependency
This commit is contained in:
+13
-8
@@ -1,5 +1,5 @@
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
from cachetools import TTLCache
|
||||
import hashlib
|
||||
@@ -9,7 +9,10 @@ import torch
|
||||
app = FastAPI(title="Local Summarizer")
|
||||
|
||||
MODEL_NAME = "sshleifer/distilbart-cnn-12-6"
|
||||
MAX_INPUT_CHARS = 20000
|
||||
|
||||
# The local summarizer is intentionally simple, but we still validate request sizes
|
||||
# so accidental giant pastes do not cause avoidable latency or memory spikes.
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
||||
model.eval()
|
||||
@@ -19,9 +22,9 @@ cache = TTLCache(maxsize=1024, ttl=60 * 60) # 1 hour cache
|
||||
|
||||
|
||||
class SummarizeRequest(BaseModel):
|
||||
text: str
|
||||
max_length: int = 160
|
||||
min_length: int = 45
|
||||
text: str = Field(min_length=1, max_length=MAX_INPUT_CHARS)
|
||||
max_length: int = Field(default=160, ge=24, le=256)
|
||||
min_length: int = Field(default=45, ge=8, le=180)
|
||||
|
||||
|
||||
def _key(text: str, max_length: int, min_length: int) -> str:
|
||||
@@ -64,8 +67,6 @@ _TECH = [
|
||||
"rest",
|
||||
]
|
||||
|
||||
|
||||
|
||||
_SOFT = [
|
||||
"communication",
|
||||
"collaboration",
|
||||
@@ -80,6 +81,7 @@ _SOFT = [
|
||||
"detail oriented",
|
||||
]
|
||||
|
||||
|
||||
def _strip_html(text: str) -> str:
|
||||
# Good enough for job descriptions pasted from the web.
|
||||
text = re.sub(r"<\s*br\s*/?>", "\n", text, flags=re.IGNORECASE)
|
||||
@@ -114,7 +116,7 @@ def _role_focused_excerpt(text: str) -> dict:
|
||||
}
|
||||
|
||||
def match_heading(s: str):
|
||||
sl = s.lower().strip(":- ")
|
||||
sl = s.lower().strip(":-\x7f ")
|
||||
for k, words in headings.items():
|
||||
for w in words:
|
||||
if sl == w or sl.startswith(w + " "):
|
||||
@@ -201,6 +203,9 @@ def _model_summarize(text: str, max_length: int, min_length: int) -> str:
|
||||
|
||||
@app.post("/summarize")
|
||||
async def summarize(req: SummarizeRequest):
|
||||
if req.min_length >= req.max_length:
|
||||
raise HTTPException(status_code=400, detail="min_length must be smaller than max_length.")
|
||||
|
||||
key = _key(req.text, req.max_length, req.min_length)
|
||||
if key in cache:
|
||||
return {"summary": cache[key], "cached": True}
|
||||
|
||||
@@ -3,3 +3,4 @@ uvicorn[standard]==0.34.0
|
||||
transformers==4.48.3
|
||||
cachetools==5.5.2
|
||||
pydantic==2.10.6
|
||||
torch==2.6.0
|
||||
|
||||
Reference in New Issue
Block a user