feat: add admin routes, summarizer service, and CLI summarize command

- Add /admin routes for manual trigger and status inspection
- Add summarizer service with batch/single summary support
- Add summarize CLI command (single arxiv_id or batch pending)
- Register admin router in main app
- Add tests for summarizer
This commit is contained in:
2026-06-05 22:29:33 +08:00
parent d69df2be10
commit 29e6797c12
7 changed files with 1874 additions and 0 deletions
+40
View File
@@ -49,6 +49,46 @@ def crawl(
db.close()
@cli_app.command()
def summarize(
arxiv_id: str = typer.Argument(
None,
help="指定论文 arXiv ID;留空则批量处理所有 pending",
),
):
"""手动触发 AI 总结。"""
from app.config import settings
from app.database import SessionLocal, engine
from app.models import init_db as _init
from app.services.summarizer import summarize_batch, summarize_single
import os
os.makedirs(settings.db_path.parent, exist_ok=True)
_init(engine)
db = SessionLocal()
try:
if arxiv_id:
typer.echo(f"🤖 开始总结 {arxiv_id} ...")
result = asyncio.run(summarize_single(db, arxiv_id))
else:
typer.echo("🤖 开始批量总结 pending 论文 ...")
result = asyncio.run(summarize_batch(db))
if result.get("status") in ("success", "done"):
typer.echo(f"✅ 总结完成:{result}")
elif result.get("status") == "conflict":
typer.echo("⚠️ 已有批量总结任务在运行中", err=True)
raise typer.Exit(code=1)
elif result.get("status") == "not_found":
typer.echo(f"❌ 论文未找到:{arxiv_id}", err=True)
raise typer.Exit(code=1)
else:
typer.echo(f"⚠️ 总结结果:{result}", err=True)
finally:
db.close()
@cli_app.command()
def init_db():
"""初始化数据库表。"""
+2
View File
@@ -9,6 +9,7 @@ from fastapi.staticfiles import StaticFiles
from app.config import settings
from app.database import engine
from app.models import init_db
from app.routes.admin import router as admin_router
from app.routes.pages import router as pages_router
logging.basicConfig(
@@ -41,6 +42,7 @@ def create_app() -> FastAPI:
# 路由
app.include_router(pages_router)
app.include_router(admin_router)
return app
+48
View File
@@ -0,0 +1,48 @@
"""管理接口 — AI 总结触发,需要 ADMIN_TOKEN 鉴权。"""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from app.config import settings
from app.database import get_db
from app.services.summarizer import summarize_batch, summarize_single
router = APIRouter(prefix="/admin", tags=["admin"])
security = HTTPBearer()
async def verify_admin(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> str:
"""验证 ADMIN_TOKEN。"""
if credentials.credentials != settings.ADMIN_TOKEN:
raise HTTPException(status_code=401, detail="Invalid admin token")
return credentials.credentials
@router.post("/summarize")
async def admin_summarize_batch(
_admin: str = Depends(verify_admin),
db: Session = Depends(get_db),
):
"""批量总结所有 pending 论文。"""
result = await summarize_batch(db)
if result.get("status") == "conflict":
raise HTTPException(status_code=409, detail=result.get("error", "batch already running"))
return result
@router.post("/summarize/{arxiv_id}")
async def admin_summarize_single(
arxiv_id: str,
_admin: str = Depends(verify_admin),
db: Session = Depends(get_db),
):
"""总结或重跑单篇论文。"""
result = await summarize_single(db, arxiv_id, force=True)
if result.get("status") == "not_found":
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
return result
+168
View File
@@ -0,0 +1,168 @@
"""AI 总结 schema — Pydantic 校验模型、质量评估、DB 展平。"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from pydantic import BaseModel, Field, ValidationError, field_validator
# ── 子模型 ──────────────────────────────────────────────────────────────
class PrerequisitesSchema(BaseModel):
concepts: list[str] = Field(default_factory=list)
level: str = ""
class MotivationSchema(BaseModel):
problem: str
goal: str = ""
gap: str = ""
@field_validator("problem")
@classmethod
def non_empty_problem(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("motivation.problem cannot be empty")
return v.strip()
class MethodSchema(BaseModel):
overview: str = ""
key_idea: str
steps: list[str] = Field(default_factory=list)
novelty: str = ""
@field_validator("key_idea")
@classmethod
def non_empty_key_idea(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("method.key_idea cannot be empty")
return v.strip()
class ResultsSchema(BaseModel):
main_findings: list[str] = Field(default_factory=list)
benchmarks: list[dict] = Field(default_factory=list)
limitations: list[str] = Field(default_factory=list)
class ImprovementsSchema(BaseModel):
weaknesses: list[str] = Field(default_factory=list)
future_work: list[str] = Field(default_factory=list)
reproducibility: str = ""
# ── 顶层 schema ─────────────────────────────────────────────────────────
class SummarySchema(BaseModel):
model_config = {"extra": "ignore"}
title_zh: str
one_line: str
tags: list[str]
difficulty: str = ""
paper_date: str | None = None
prerequisites: PrerequisitesSchema = Field(default_factory=PrerequisitesSchema)
motivation: MotivationSchema
method: MethodSchema
results: ResultsSchema = Field(default_factory=ResultsSchema)
improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema)
@field_validator("title_zh", "one_line")
@classmethod
def non_empty_text(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("field cannot be empty")
return v.strip()
@field_validator("tags")
@classmethod
def non_empty_tags(cls, v: list[str]) -> list[str]:
tags = [tag.strip() for tag in v if tag and tag.strip()]
if not tags:
raise ValueError("tags cannot be empty")
return tags
# ── 质量评估 ────────────────────────────────────────────────────────────
# 必填字段:title_zh, one_line, tags, motivation.problem, method.key_idea
# — 缺失时 Pydantic 校验就会报错,不会走到 assess_quality
# 重要字段:motivation.goal, motivation.gap, method.overview, results.main_findings
# — 缺失可入库,标记 degraded
_OPTIONAL_BUT_IMPORTANT_FIELDS = [
"motivation.goal",
"motivation.gap",
"method.overview",
"results.main_findings",
]
def assess_quality(schema: SummarySchema) -> str:
"""评估总结质量:normal / degraded / low。"""
# low:内容空洞的启发式判断
if len(schema.one_line.strip()) < 10 or len(schema.method.key_idea.strip()) < 10:
return "low"
# 检查重要字段是否缺失
missing_important = 0
if not schema.motivation.goal.strip():
missing_important += 1
if not schema.motivation.gap.strip():
missing_important += 1
if not schema.method.overview.strip():
missing_important += 1
if not schema.results.main_findings:
missing_important += 1
if missing_important == 0:
return "normal"
return "degraded"
# ── DB 展平 ─────────────────────────────────────────────────────────────
def flatten_for_db(schema: SummarySchema) -> dict:
"""将 SummarySchema 展平为 paper_summaries 表的列值 dict。"""
return {
"one_line": schema.one_line,
"difficulty": schema.difficulty,
"prerequisites_json": json.dumps(schema.prerequisites.model_dump(), ensure_ascii=False),
"motivation_problem": schema.motivation.problem,
"motivation_goal": schema.motivation.goal,
"motivation_gap": schema.motivation.gap,
"method_overview": schema.method.overview,
"method_key_idea": schema.method.key_idea,
"method_steps_json": json.dumps(schema.method.steps, ensure_ascii=False),
"method_novelty": schema.method.novelty,
"results_main_json": json.dumps(schema.results.main_findings, ensure_ascii=False),
"results_benchmarks_json": json.dumps(schema.results.benchmarks, ensure_ascii=False),
"limitations_json": json.dumps(schema.results.limitations, ensure_ascii=False),
"weaknesses_json": json.dumps(schema.improvements.weaknesses, ensure_ascii=False),
"future_work_json": json.dumps(schema.improvements.future_work, ensure_ascii=False),
"reproducibility": schema.improvements.reproducibility,
"full_json": schema.model_dump_json(ensure_ascii=False),
"updated_at": datetime.now(timezone.utc),
}
# ── 错误分类 ────────────────────────────────────────────────────────────
_REQUIRED_FIELDS = {"title_zh", "one_line", "tags", "problem", "key_idea"}
def classify_validation_error(exc: ValidationError) -> str:
"""区分 field_missing(必填缺失)和 schema_error(类型不合法等)。"""
for err in exc.errors():
field_name = err["loc"][-1] if err["loc"] else ""
if field_name in _REQUIRED_FIELDS and err["type"] in (
"missing",
"value_error",
):
return "field_missing"
return "schema_error"
+682
View File
@@ -0,0 +1,682 @@
"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结。"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import shutil
from datetime import datetime, timezone
from pathlib import Path
import httpx
from pydantic import ValidationError
from sqlalchemy import select, text
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import SessionLocal
from app.models import (
CrawlLog,
Paper,
PaperSummary,
PaperTag,
SummaryStatus,
TaskLock,
)
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
logger = logging.getLogger(__name__)
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PdfDownloadError(Exception):
pass
class PiTimeoutError(Exception):
pass
class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
class JsonNotFoundError(Exception):
pass
# ── 路径工具 ────────────────────────────────────────────────────────────
_DATA_DIR = Path("data")
_PAPERS_DIR = _DATA_DIR / "papers"
_TMP_DIR = _DATA_DIR / "tmp"
def _paper_dir(arxiv_id: str) -> Path:
return _PAPERS_DIR / arxiv_id
def _tmp_dir(arxiv_id: str) -> Path:
return _TMP_DIR / arxiv_id
# ── PDF 下载 ────────────────────────────────────────────────────────────
async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
if not pdf_url:
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
tmp = _tmp_dir(arxiv_id)
tmp.mkdir(parents=True, exist_ok=True)
dest = tmp / "paper.pdf"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(pdf_url)
resp.raise_for_status()
dest.write_bytes(resp.content)
except Exception as exc:
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
return dest
# ── meta.json ───────────────────────────────────────────────────────────
def _write_meta_json(paper: Paper) -> Path:
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
d = _paper_dir(paper.arxiv_id)
d.mkdir(parents=True, exist_ok=True)
meta_path = d / "meta.json"
authors = [a.name for a in paper.authors]
tags = [t.tag for t in paper.tags]
meta = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"abstract": paper.abstract or "",
"published_at": paper.published_at.isoformat() if paper.published_at else None,
"authors": authors,
"tags": tags,
"upvotes": paper.upvotes,
}
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
return meta_path
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def _call_pi(meta_path: Path, pdf_path: Path) -> str:
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
cmd = [
settings.PI_BIN,
"-p",
"--no-tools",
"--skill",
settings.SUMMARY_SKILL,
"请深度解读以下论文,并按指定 JSON schema 输出:",
f"@{meta_path}",
f"@{pdf_path}",
]
logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4]))
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PiTimeoutError(
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
)
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace")
def paper_id_from_path(meta_path: Path) -> str:
"""从 meta.json 路径反推 arxiv_id。"""
return meta_path.parent.name
# ── JSON 提取 ──────────────────────────────────────────────────────────
def _extract_json(raw_output: str) -> dict:
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
# 策略 1:整体直接解析
stripped = raw_output.strip()
try:
result = json.loads(stripped)
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
pass
# 策略 2:提取 ```json ... ``` 代码块
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
for match in fence_pattern.finditer(raw_output):
try:
result = json.loads(match.group(1).strip())
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
continue
# 策略 3:匹配包含 title_zh 的最大 {...} 块
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
# 先尝试一层嵌套;如果没命中再用更宽松的策略
for match in brace_pattern.finditer(raw_output):
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
continue
# 更宽松:找到最大的 { ... } 平衡块
best = None
best_len = 0
for i, ch in enumerate(raw_output):
if ch != "{":
continue
depth = 0
for j in range(i, len(raw_output)):
if raw_output[j] == "{":
depth += 1
elif raw_output[j] == "}":
depth -= 1
if depth == 0:
candidate = raw_output[i : j + 1]
if len(candidate) > best_len:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
best = parsed
best_len = len(candidate)
except json.JSONDecodeError:
pass
break
if best is not None:
return best
raise JsonNotFoundError("no JSON object found in pi output")
# ── 错误分类 ────────────────────────────────────────────────────────────
def _classify_error(exc: Exception) -> str:
"""将异常映射到 error_type 枚举值。"""
if isinstance(exc, PdfDownloadError):
return "pdf_download_failed"
if isinstance(exc, PiTimeoutError):
return "timeout"
if isinstance(exc, PiProcessError):
return "process_error"
if isinstance(exc, JsonNotFoundError):
return "json_not_found"
if isinstance(exc, json.JSONDecodeError):
return "json_invalid"
if isinstance(exc, ValidationError):
return classify_validation_error(exc)
return "unknown"
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
def _build_fts_summary_text(schema: SummarySchema) -> str:
"""拼接用于 FTS5 索引的总结文本。"""
parts = [
schema.one_line or "",
schema.motivation.problem or "",
schema.motivation.goal or "",
schema.method_overview if hasattr(schema, "method_overview") else "",
schema.method.overview or "",
schema.method.key_idea or "",
" ".join(schema.results.main_findings or []),
]
return " ".join(p for p in parts if p)
# ── DB 更新 ─────────────────────────────────────────────────────────────
def _update_summary_in_db(
db: Session,
paper: Paper,
schema: SummarySchema,
quality: str,
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
now = datetime.now(timezone.utc)
# 1. paper_summariesupsert
existing = db.get(PaperSummary, paper.id)
flat = flatten_for_db(schema)
if existing:
for k, v in flat.items():
setattr(existing, k, v)
else:
db.add(PaperSummary(paper_id=paper.id, **flat))
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
paper_dir = _paper_dir(paper.arxiv_id)
paper.summary_path = str(paper_dir / "summary.json")
paper.raw_output_path = str(paper_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
for tag_name in schema.tags:
if tag_name not in existing_tag_names:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
existing_tag_names.add(tag_name)
# 4. FTS5 更新
summary_text = _build_fts_summary_text(schema)
db.execute(
text(
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
"WHERE rowid=:paper_id"
),
{
"title_zh": schema.title_zh,
"summary_text": summary_text,
"paper_id": paper.id,
},
)
db.commit()
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
# ── 文件操作 ────────────────────────────────────────────────────────────
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
"""保存 summary.json 和 raw_output.txt。"""
d = _paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
"""仅保存 raw_output.txt(失败时)。"""
d = _paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
def _cleanup_tmp(arxiv_id: str) -> None:
"""清理 data/tmp/{arxiv_id}/ 目录。"""
tmp = _tmp_dir(arxiv_id)
if tmp.exists():
try:
shutil.rmtree(tmp)
logger.debug("Cleaned tmp: %s", arxiv_id)
except Exception:
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
# ── 单篇总结 ────────────────────────────────────────────────────────────
async def summarize_one(
db: Session,
paper: Paper,
semaphore: asyncio.Semaphore | None = None,
*,
force: bool = False,
) -> dict:
"""总结单篇论文的完整流程。"""
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
if not paper.summary_status:
db.add(SummaryStatus(paper_id=paper.id, status="pending"))
db.commit()
db.refresh(paper)
status = paper.summary_status
# 跳过已完成的(除非 force
if status.status == "done" and not force:
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
# 跳过 permanent_failure(除非 force
if status.status == "permanent_failure" and not force:
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "permanent_failure"}
if semaphore:
await semaphore.acquire()
try:
return await _do_summarize_one(db, paper)
finally:
if semaphore:
semaphore.release()
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
arxiv_id = paper.arxiv_id
status = paper.summary_status
now = datetime.now(timezone.utc)
# 状态 → processing
status.status = "processing"
status.started_at = now
db.commit()
raw_output = ""
try:
# 写 meta.json
meta_path = _write_meta_json(paper)
# 下载 PDF
await _download_pdf(arxiv_id, paper.pdf_url)
# 调用 pi
raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf")
# 提取 JSON
json_data = _extract_json(raw_output)
# Pydantic 校验
schema = SummarySchema.model_validate(json_data)
# 质量评估
quality = assess_quality(schema)
# 保存文件
_save_files(arxiv_id, schema, raw_output)
# 更新 DB
_update_summary_in_db(db, paper, schema, quality, raw_output)
# 状态 → done
status.status = "done"
status.quality = quality
status.completed_at = datetime.now(timezone.utc)
status.raw_output_saved = True
db.commit()
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
except Exception as exc:
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
arxiv_id,
error_type,
str(exc)[:200],
)
# 保存 raw_output(如果有)
if raw_output:
_save_raw_output_only(arxiv_id, raw_output)
status.raw_output_saved = True
# 重试逻辑
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = str(exc)[:2000]
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = "permanent_failure"
else:
status.status = "pending"
status.completed_at = datetime.now(timezone.utc)
db.commit()
return {
"arxiv_id": arxiv_id,
"status": "failed",
"error_type": error_type,
"error": str(exc)[:200],
"retry_count": status.retry_count,
}
finally:
_cleanup_tmp(arxiv_id)
# ── 单篇入口 ────────────────────────────────────────────────────────────
async def summarize_single(
db: Session,
arxiv_id: str,
*,
force: bool = True,
_session_factory=None,
) -> dict:
"""单篇总结入口(供 admin 路由和 CLI 调用)。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
paper = (
db.query(Paper)
.filter(Paper.arxiv_id == arxiv_id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
if not paper:
return {"status": "not_found", "arxiv_id": arxiv_id}
make_session = _session_factory or SessionLocal
# 每篇用独立 session 避免并发问题
paper_db = make_session()
try:
paper_in_new_session = (
paper_db.query(Paper)
.filter(Paper.arxiv_id == arxiv_id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
result = await summarize_one(paper_db, paper_in_new_session, force=force)
finally:
paper_db.close()
return result
# ── 批量总结 ────────────────────────────────────────────────────────────
async def summarize_batch(
db: Session,
arxiv_ids: list[str] | None = None,
*,
_session_factory=None,
) -> dict:
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
now = datetime.now(timezone.utc)
# TaskLock 防重入
lock = TaskLock(
task="summarize",
lock_key="batch",
status="running",
owner="summarize_batch",
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except Exception:
db.rollback()
logger.warning("Summarize batch already running (lock conflict)")
return {"status": "conflict", "error": "summarize batch already running"}
# CrawlLog
log_entry = CrawlLog(
task="summarize",
status="running",
started_at=now,
)
db.add(log_entry)
db.commit()
try:
# 查询待总结论文
query = db.query(Paper).options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
if arxiv_ids:
query = query.filter(Paper.arxiv_id.in_(arxiv_ids))
else:
# 只处理 pending 或 failed(可重试的)
query = query.join(SummaryStatus).filter(
SummaryStatus.status.in_(["pending", "failed"])
)
papers = query.all()
total = len(papers)
logger.info("Summarize batch: %d papers to process", total)
if total == 0:
log_entry.status = "success"
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = datetime.now(timezone.utc)
_release_lock(db, lock)
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
# 并发控制
semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY)
make_session = _session_factory or SessionLocal
async def _process_paper(paper: Paper) -> dict:
paper_db = make_session()
try:
p = (
paper_db.query(Paper)
.filter(Paper.id == paper.id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
return await summarize_one(paper_db, p, semaphore)
finally:
paper_db.close()
results = await asyncio.gather(
*[_process_paper(p) for p in papers],
return_exceptions=True,
)
# 统计结果
done = 0
failed = 0
skipped = 0
for r in results:
if isinstance(r, Exception):
logger.error("Unexpected error in batch: %s", r)
failed += 1
elif isinstance(r, dict):
if r.get("status") == "done":
done += 1
elif r.get("status") == "skipped":
skipped += 1
else:
failed += 1
log_entry.status = "success" if failed == 0 else "failed"
log_entry.papers_found = total
log_entry.papers_new = done
log_entry.completed_at = datetime.now(timezone.utc)
db.commit()
logger.info(
"Summarize batch done: total=%d done=%d failed=%d skipped=%d",
total, done, failed, skipped,
)
return {
"status": "success" if failed == 0 else "partial",
"total": total,
"done": done,
"failed": failed,
"skipped": skipped,
}
except Exception as exc:
logger.exception("Summarize batch failed")
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.completed_at = datetime.now(timezone.utc)
db.commit()
return {"status": "failed", "error": str(exc)}
finally:
_release_lock(db, lock)
def _release_lock(db: Session, lock: TaskLock) -> None:
"""释放 TaskLock。"""
try:
lock.status = "finished"
lock.released_at = datetime.now(timezone.utc)
db.commit()
except Exception:
db.rollback()
logger.warning("Failed to release summarize lock", exc_info=True)