828 lines
27 KiB
Python
828 lines
27 KiB
Python
"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
import shutil
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
from pydantic import ValidationError
|
||
from sqlalchemy import select, text
|
||
from sqlalchemy.orm import Session, joinedload
|
||
|
||
from app.config import settings
|
||
from app.database import SessionLocal
|
||
from app.models import (
|
||
CrawlLog,
|
||
Paper,
|
||
PaperSummary,
|
||
PaperTag,
|
||
SummaryStatus,
|
||
TaskLock,
|
||
)
|
||
from app.services.schemas import (
|
||
SummarySchema,
|
||
assess_quality,
|
||
classify_validation_error,
|
||
flatten_for_db,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class PdfDownloadError(Exception):
|
||
pass
|
||
|
||
|
||
class PiTimeoutError(Exception):
|
||
pass
|
||
|
||
|
||
class PiProcessError(Exception):
|
||
def __init__(self, returncode: int, stderr: str):
|
||
self.returncode = returncode
|
||
self.stderr = stderr
|
||
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
|
||
|
||
|
||
class JsonNotFoundError(Exception):
|
||
pass
|
||
|
||
|
||
# ── 路径工具 ────────────────────────────────────────────────────────────
|
||
|
||
_DATA_DIR = Path("data")
|
||
_PAPERS_DIR = _DATA_DIR / "papers"
|
||
_TMP_DIR = _DATA_DIR / "tmp"
|
||
|
||
|
||
def _paper_dir(arxiv_id: str) -> Path:
|
||
return _PAPERS_DIR / arxiv_id
|
||
|
||
|
||
def _tmp_dir(arxiv_id: str) -> Path:
|
||
return _TMP_DIR / arxiv_id
|
||
|
||
|
||
# ── PDF 下载 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
|
||
if not pdf_url:
|
||
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
|
||
|
||
tmp = _tmp_dir(arxiv_id)
|
||
tmp.mkdir(parents=True, exist_ok=True)
|
||
dest = tmp / "paper.pdf"
|
||
|
||
transport = None
|
||
if settings.http_proxy:
|
||
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
||
|
||
try:
|
||
async with httpx.AsyncClient(
|
||
timeout=settings.HTTP_TIMEOUT_SECONDS,
|
||
headers={"User-Agent": settings.HTTP_USER_AGENT},
|
||
transport=transport,
|
||
follow_redirects=True,
|
||
) as client:
|
||
resp = await client.get(pdf_url)
|
||
resp.raise_for_status()
|
||
dest.write_bytes(resp.content)
|
||
except Exception as exc:
|
||
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
|
||
|
||
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
|
||
return dest
|
||
|
||
|
||
# ── meta.json ───────────────────────────────────────────────────────────
|
||
|
||
|
||
def _write_meta_json(paper: Paper) -> Path:
|
||
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
|
||
d = _paper_dir(paper.arxiv_id)
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
meta_path = d / "meta.json"
|
||
|
||
authors = [a.name for a in paper.authors]
|
||
tags = [t.tag for t in paper.tags]
|
||
meta = {
|
||
"arxiv_id": paper.arxiv_id,
|
||
"title_en": paper.title_en,
|
||
"abstract": paper.abstract or "",
|
||
"published_at": paper.published_at.isoformat() if paper.published_at else None,
|
||
"authors": authors,
|
||
"tags": tags,
|
||
"upvotes": paper.upvotes,
|
||
}
|
||
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
return meta_path
|
||
|
||
|
||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||
|
||
|
||
async def _call_pi(meta_path: Path, pdf_path: Path) -> str:
|
||
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
|
||
cmd = [
|
||
settings.PI_BIN,
|
||
"-p",
|
||
"--no-tools",
|
||
"--skill",
|
||
settings.SUMMARY_SKILL,
|
||
"请深度解读以下论文,并按指定 JSON schema 输出:",
|
||
f"@{meta_path}",
|
||
f"@{pdf_path}",
|
||
]
|
||
logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4]))
|
||
|
||
proc = await asyncio.create_subprocess_exec(
|
||
*cmd,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
try:
|
||
stdout, stderr = await asyncio.wait_for(
|
||
proc.communicate(),
|
||
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
proc.kill()
|
||
await proc.wait()
|
||
raise PiTimeoutError(
|
||
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
|
||
)
|
||
|
||
if proc.returncode != 0:
|
||
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
||
|
||
return stdout.decode("utf-8", errors="replace")
|
||
|
||
|
||
def paper_id_from_path(meta_path: Path) -> str:
|
||
"""从 meta.json 路径反推 arxiv_id。"""
|
||
return meta_path.parent.name
|
||
|
||
|
||
# ── JSON 提取 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
def _extract_json(raw_output: str) -> dict:
|
||
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
|
||
# 策略 1:整体直接解析
|
||
stripped = raw_output.strip()
|
||
try:
|
||
result = json.loads(stripped)
|
||
if isinstance(result, dict) and "title_zh" in result:
|
||
return result
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# 策略 2:提取 ```json ... ``` 代码块
|
||
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
|
||
for match in fence_pattern.finditer(raw_output):
|
||
try:
|
||
result = json.loads(match.group(1).strip())
|
||
if isinstance(result, dict) and "title_zh" in result:
|
||
return result
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# 策略 3:匹配包含 title_zh 的最大 {...} 块
|
||
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
|
||
# 先尝试一层嵌套;如果没命中再用更宽松的策略
|
||
for match in brace_pattern.finditer(raw_output):
|
||
try:
|
||
return json.loads(match.group(0))
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# 更宽松:找到最大的 { ... } 平衡块
|
||
best = None
|
||
best_len = 0
|
||
for i, ch in enumerate(raw_output):
|
||
if ch != "{":
|
||
continue
|
||
depth = 0
|
||
for j in range(i, len(raw_output)):
|
||
if raw_output[j] == "{":
|
||
depth += 1
|
||
elif raw_output[j] == "}":
|
||
depth -= 1
|
||
if depth == 0:
|
||
candidate = raw_output[i : j + 1]
|
||
if len(candidate) > best_len:
|
||
try:
|
||
parsed = json.loads(candidate)
|
||
if isinstance(parsed, dict):
|
||
best = parsed
|
||
best_len = len(candidate)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
break
|
||
|
||
if best is not None:
|
||
return best
|
||
|
||
raise JsonNotFoundError("no JSON object found in pi output")
|
||
|
||
|
||
# ── 错误分类 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
def _classify_error(exc: Exception) -> str:
|
||
"""将异常映射到 error_type 枚举值。"""
|
||
if isinstance(exc, PdfDownloadError):
|
||
return "pdf_download_failed"
|
||
if isinstance(exc, PiTimeoutError):
|
||
return "timeout"
|
||
if isinstance(exc, PiProcessError):
|
||
return "process_error"
|
||
if isinstance(exc, JsonNotFoundError):
|
||
return "json_not_found"
|
||
if isinstance(exc, json.JSONDecodeError):
|
||
return "json_invalid"
|
||
if isinstance(exc, ValidationError):
|
||
return classify_validation_error(exc)
|
||
return "unknown"
|
||
|
||
|
||
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
|
||
|
||
|
||
def _build_fts_summary_text(schema: SummarySchema) -> str:
|
||
"""拼接用于 FTS5 索引的总结文本。"""
|
||
parts = [
|
||
schema.one_line or "",
|
||
schema.motivation.problem or "",
|
||
schema.motivation.goal or "",
|
||
schema.method_overview if hasattr(schema, "method_overview") else "",
|
||
schema.method.overview or "",
|
||
schema.method.key_idea or "",
|
||
" ".join(schema.results.main_findings or []),
|
||
]
|
||
return " ".join(p for p in parts if p)
|
||
|
||
|
||
# ── DB 更新 ─────────────────────────────────────────────────────────────
|
||
|
||
|
||
def _update_summary_in_db(
|
||
db: Session,
|
||
paper: Paper,
|
||
schema: SummarySchema,
|
||
quality: str,
|
||
raw_output: str,
|
||
) -> None:
|
||
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
|
||
now = datetime.now(timezone.utc)
|
||
|
||
# 1. paper_summaries:upsert
|
||
existing = db.get(PaperSummary, paper.id)
|
||
flat = flatten_for_db(schema)
|
||
if existing:
|
||
for k, v in flat.items():
|
||
setattr(existing, k, v)
|
||
else:
|
||
db.add(PaperSummary(paper_id=paper.id, **flat))
|
||
|
||
# 2. papers 表
|
||
paper.title_zh = schema.title_zh
|
||
paper.summary_quality = quality
|
||
paper_dir = _paper_dir(paper.arxiv_id)
|
||
paper.summary_path = str(paper_dir / "summary.json")
|
||
paper.raw_output_path = str(paper_dir / "raw_output.txt")
|
||
|
||
# 3. AI 标签
|
||
existing_tag_names = {t.tag for t in paper.tags}
|
||
for tag_name in schema.tags:
|
||
if tag_name not in existing_tag_names:
|
||
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
|
||
existing_tag_names.add(tag_name)
|
||
|
||
# 4. FTS5 更新
|
||
summary_text = _build_fts_summary_text(schema)
|
||
db.execute(
|
||
text(
|
||
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
|
||
"WHERE rowid=:paper_id"
|
||
),
|
||
{
|
||
"title_zh": schema.title_zh,
|
||
"summary_text": summary_text,
|
||
"paper_id": paper.id,
|
||
},
|
||
)
|
||
|
||
db.commit()
|
||
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
|
||
|
||
|
||
# ── 文件操作 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
|
||
"""保存 summary.json 和 raw_output.txt。"""
|
||
d = _paper_dir(arxiv_id)
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
(d / "summary.json").write_text(
|
||
schema.model_dump_json(ensure_ascii=False, indent=2),
|
||
encoding="utf-8",
|
||
)
|
||
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||
|
||
|
||
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
|
||
"""仅保存 raw_output.txt(失败时)。"""
|
||
d = _paper_dir(arxiv_id)
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||
|
||
|
||
def _cleanup_tmp(arxiv_id: str) -> None:
|
||
"""清理 data/tmp/{arxiv_id}/ 目录。"""
|
||
tmp = _tmp_dir(arxiv_id)
|
||
if tmp.exists():
|
||
try:
|
||
shutil.rmtree(tmp)
|
||
logger.debug("Cleaned tmp: %s", arxiv_id)
|
||
except Exception:
|
||
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
|
||
|
||
|
||
# ── LaTeX 图片提取(Phase 5)───────────────────────────────────────────
|
||
|
||
_INCLUDEGRAPHICS_RE = re.compile(
|
||
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
|
||
)
|
||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
|
||
|
||
|
||
async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int:
|
||
"""从 LaTeX 源码中提取图片文件。
|
||
|
||
流程:
|
||
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
|
||
2. 扫描 .tex 文件中的 \\includegraphics
|
||
3. 复制图片到 data/papers/{arxiv_id}/images/
|
||
4. 清理源码临时文件
|
||
|
||
Returns:
|
||
提取的图片数量
|
||
"""
|
||
tmp_source = _tmp_dir(arxiv_id) / "source"
|
||
images_dest = _paper_dir(arxiv_id) / "images"
|
||
|
||
try:
|
||
# 下载源码 zip(如果还没下载)
|
||
if not tmp_source.exists():
|
||
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
|
||
await _download_source_zip(arxiv_id, source_url, tmp_source)
|
||
|
||
if not tmp_source.exists():
|
||
return 0
|
||
|
||
# 扫描 .tex 文件,收集图片路径
|
||
image_paths: set[str] = set()
|
||
for tex_file in tmp_source.rglob("*.tex"):
|
||
try:
|
||
content = tex_file.read_text(encoding="utf-8", errors="replace")
|
||
for match in _INCLUDEGRAPHICS_RE.finditer(content):
|
||
img_path = match.group(1).strip()
|
||
image_paths.add(img_path)
|
||
except Exception:
|
||
continue
|
||
|
||
if not image_paths:
|
||
return 0
|
||
|
||
# 查找并复制图片
|
||
images_dest.mkdir(parents=True, exist_ok=True)
|
||
copied = 0
|
||
for img_rel in image_paths:
|
||
# 尝试在源码目录中找到文件
|
||
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
|
||
candidate = tmp_source / (img_rel + ext)
|
||
if candidate.is_file():
|
||
dest_name = candidate.name
|
||
# 避免文件名冲突
|
||
dest = images_dest / dest_name
|
||
if dest.exists():
|
||
stem = dest.stem
|
||
suffix = dest.suffix
|
||
dest = images_dest / f"{stem}_{copied}{suffix}"
|
||
shutil.copy2(candidate, dest)
|
||
copied += 1
|
||
break
|
||
|
||
if copied > 0:
|
||
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
|
||
return copied
|
||
|
||
except Exception:
|
||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||
return 0
|
||
|
||
|
||
async def _download_source_zip(
|
||
arxiv_id: str, source_url: str, dest_dir: Path
|
||
) -> None:
|
||
"""下载 arXiv 源码并解压。"""
|
||
import zipfile
|
||
|
||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||
zip_path = _tmp_dir(arxiv_id) / "source.zip"
|
||
|
||
transport = None
|
||
if settings.http_proxy:
|
||
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
||
|
||
try:
|
||
async with httpx.AsyncClient(
|
||
timeout=settings.HTTP_TIMEOUT_SECONDS,
|
||
headers={"User-Agent": settings.HTTP_USER_AGENT},
|
||
transport=transport,
|
||
follow_redirects=True,
|
||
) as client:
|
||
resp = await client.get(source_url)
|
||
resp.raise_for_status()
|
||
zip_path.write_bytes(resp.content)
|
||
except Exception as exc:
|
||
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
|
||
return
|
||
|
||
try:
|
||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||
zf.extractall(dest_dir)
|
||
logger.debug("Extracted source for %s", arxiv_id)
|
||
except zipfile.BadZipFile:
|
||
# 可能是 tar.gz
|
||
import tarfile
|
||
try:
|
||
with tarfile.open(zip_path, "r:*") as tf:
|
||
tf.extractall(dest_dir)
|
||
logger.debug("Extracted source (tar) for %s", arxiv_id)
|
||
except Exception:
|
||
logger.warning("Cannot extract source for %s", arxiv_id)
|
||
except Exception:
|
||
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
|
||
finally:
|
||
if zip_path.exists():
|
||
zip_path.unlink()
|
||
|
||
|
||
# ── 单篇总结 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
async def summarize_one(
|
||
db: Session,
|
||
paper: Paper,
|
||
semaphore: asyncio.Semaphore | None = None,
|
||
*,
|
||
force: bool = False,
|
||
) -> dict:
|
||
"""总结单篇论文的完整流程。"""
|
||
arxiv_id = paper.arxiv_id
|
||
|
||
# 获取或创建 summary_status
|
||
if not paper.summary_status:
|
||
db.add(SummaryStatus(paper_id=paper.id, status="pending"))
|
||
db.commit()
|
||
db.refresh(paper)
|
||
|
||
status = paper.summary_status
|
||
|
||
# 跳过已完成的(除非 force)
|
||
if status.status == "done" and not force:
|
||
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
|
||
|
||
# 跳过 permanent_failure(除非 force)
|
||
if status.status == "permanent_failure" and not force:
|
||
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "permanent_failure"}
|
||
|
||
if semaphore:
|
||
await semaphore.acquire()
|
||
try:
|
||
return await _do_summarize_one(db, paper)
|
||
finally:
|
||
if semaphore:
|
||
semaphore.release()
|
||
|
||
|
||
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
||
arxiv_id = paper.arxiv_id
|
||
status = paper.summary_status
|
||
now = datetime.now(timezone.utc)
|
||
|
||
# 状态 → processing
|
||
status.status = "processing"
|
||
status.started_at = now
|
||
db.commit()
|
||
|
||
raw_output = ""
|
||
try:
|
||
# 写 meta.json
|
||
meta_path = _write_meta_json(paper)
|
||
|
||
# 下载 PDF
|
||
await _download_pdf(arxiv_id, paper.pdf_url)
|
||
|
||
# 调用 pi
|
||
raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf")
|
||
|
||
# 提取 JSON
|
||
json_data = _extract_json(raw_output)
|
||
|
||
# Pydantic 校验
|
||
schema = SummarySchema.model_validate(json_data)
|
||
|
||
# 质量评估
|
||
quality = assess_quality(schema)
|
||
|
||
# 保存文件
|
||
_save_files(arxiv_id, schema, raw_output)
|
||
|
||
# 更新 DB
|
||
_update_summary_in_db(db, paper, schema, quality, raw_output)
|
||
|
||
# 状态 → done
|
||
status.status = "done"
|
||
status.quality = quality
|
||
status.completed_at = datetime.now(timezone.utc)
|
||
status.raw_output_saved = True
|
||
db.commit()
|
||
|
||
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
|
||
try:
|
||
await _extract_images_from_source(arxiv_id)
|
||
except Exception:
|
||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||
|
||
# Phase 5: 同步写入语义索引(失败仅 log)
|
||
try:
|
||
from app.services.embedder import index_paper
|
||
|
||
texts_dict = {
|
||
"arxiv_id": arxiv_id,
|
||
"title_zh": schema.title_zh or "",
|
||
"title_en": paper.title_en or "",
|
||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||
"one_line": schema.one_line or "",
|
||
"motivation_problem": schema.motivation_problem or "",
|
||
"method_key_idea": schema.method_key_idea or "",
|
||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||
}
|
||
index_paper(arxiv_id, texts_dict)
|
||
except Exception:
|
||
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
|
||
|
||
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
|
||
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
||
|
||
except Exception as exc:
|
||
error_type = _classify_error(exc)
|
||
logger.error(
|
||
"Summarize failed: %s error_type=%s %s",
|
||
arxiv_id,
|
||
error_type,
|
||
str(exc)[:200],
|
||
)
|
||
|
||
# 保存 raw_output(如果有)
|
||
if raw_output:
|
||
_save_raw_output_only(arxiv_id, raw_output)
|
||
status.raw_output_saved = True
|
||
|
||
# 重试逻辑
|
||
status.retry_count = (status.retry_count or 0) + 1
|
||
status.error_type = error_type
|
||
status.error = str(exc)[:2000]
|
||
|
||
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
|
||
status.status = "permanent_failure"
|
||
else:
|
||
status.status = "pending"
|
||
|
||
status.completed_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
|
||
return {
|
||
"arxiv_id": arxiv_id,
|
||
"status": "failed",
|
||
"error_type": error_type,
|
||
"error": str(exc)[:200],
|
||
"retry_count": status.retry_count,
|
||
}
|
||
|
||
finally:
|
||
_cleanup_tmp(arxiv_id)
|
||
|
||
|
||
# ── 单篇入口 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
async def summarize_single(
|
||
db: Session,
|
||
arxiv_id: str,
|
||
*,
|
||
force: bool = True,
|
||
_session_factory=None,
|
||
) -> dict:
|
||
"""单篇总结入口(供 admin 路由和 CLI 调用)。
|
||
|
||
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||
"""
|
||
paper = (
|
||
db.query(Paper)
|
||
.filter(Paper.arxiv_id == arxiv_id)
|
||
.options(
|
||
joinedload(Paper.authors),
|
||
joinedload(Paper.tags),
|
||
joinedload(Paper.summary_status),
|
||
)
|
||
.first()
|
||
)
|
||
if not paper:
|
||
return {"status": "not_found", "arxiv_id": arxiv_id}
|
||
|
||
make_session = _session_factory or SessionLocal
|
||
|
||
# 每篇用独立 session 避免并发问题
|
||
paper_db = make_session()
|
||
try:
|
||
paper_in_new_session = (
|
||
paper_db.query(Paper)
|
||
.filter(Paper.arxiv_id == arxiv_id)
|
||
.options(
|
||
joinedload(Paper.authors),
|
||
joinedload(Paper.tags),
|
||
joinedload(Paper.summary_status),
|
||
)
|
||
.first()
|
||
)
|
||
result = await summarize_one(paper_db, paper_in_new_session, force=force)
|
||
finally:
|
||
paper_db.close()
|
||
|
||
return result
|
||
|
||
|
||
# ── 批量总结 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
async def summarize_batch(
|
||
db: Session,
|
||
arxiv_ids: list[str] | None = None,
|
||
*,
|
||
_session_factory=None,
|
||
) -> dict:
|
||
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
|
||
|
||
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||
"""
|
||
now = datetime.now(timezone.utc)
|
||
|
||
# TaskLock 防重入
|
||
lock = TaskLock(
|
||
task="summarize",
|
||
lock_key="batch",
|
||
status="running",
|
||
owner="summarize_batch",
|
||
acquired_at=now,
|
||
)
|
||
try:
|
||
db.add(lock)
|
||
db.commit()
|
||
except Exception:
|
||
db.rollback()
|
||
logger.warning("Summarize batch already running (lock conflict)")
|
||
return {"status": "conflict", "error": "summarize batch already running"}
|
||
|
||
# CrawlLog
|
||
log_entry = CrawlLog(
|
||
task="summarize",
|
||
status="running",
|
||
started_at=now,
|
||
)
|
||
db.add(log_entry)
|
||
db.commit()
|
||
|
||
try:
|
||
# 查询待总结论文
|
||
query = db.query(Paper).options(
|
||
joinedload(Paper.authors),
|
||
joinedload(Paper.tags),
|
||
joinedload(Paper.summary_status),
|
||
)
|
||
if arxiv_ids:
|
||
query = query.filter(Paper.arxiv_id.in_(arxiv_ids))
|
||
else:
|
||
# 只处理 pending 或 failed(可重试的)
|
||
query = query.join(SummaryStatus).filter(
|
||
SummaryStatus.status.in_(["pending", "failed"])
|
||
)
|
||
|
||
papers = query.all()
|
||
total = len(papers)
|
||
logger.info("Summarize batch: %d papers to process", total)
|
||
|
||
if total == 0:
|
||
log_entry.status = "success"
|
||
log_entry.papers_found = 0
|
||
log_entry.papers_new = 0
|
||
log_entry.completed_at = datetime.now(timezone.utc)
|
||
_release_lock(db, lock)
|
||
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
|
||
|
||
# 并发控制
|
||
semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY)
|
||
make_session = _session_factory or SessionLocal
|
||
|
||
async def _process_paper(paper: Paper) -> dict:
|
||
paper_db = make_session()
|
||
try:
|
||
p = (
|
||
paper_db.query(Paper)
|
||
.filter(Paper.id == paper.id)
|
||
.options(
|
||
joinedload(Paper.authors),
|
||
joinedload(Paper.tags),
|
||
joinedload(Paper.summary_status),
|
||
)
|
||
.first()
|
||
)
|
||
return await summarize_one(paper_db, p, semaphore)
|
||
finally:
|
||
paper_db.close()
|
||
|
||
results = await asyncio.gather(
|
||
*[_process_paper(p) for p in papers],
|
||
return_exceptions=True,
|
||
)
|
||
|
||
# 统计结果
|
||
done = 0
|
||
failed = 0
|
||
skipped = 0
|
||
for r in results:
|
||
if isinstance(r, Exception):
|
||
logger.error("Unexpected error in batch: %s", r)
|
||
failed += 1
|
||
elif isinstance(r, dict):
|
||
if r.get("status") == "done":
|
||
done += 1
|
||
elif r.get("status") == "skipped":
|
||
skipped += 1
|
||
else:
|
||
failed += 1
|
||
|
||
log_entry.status = "success" if failed == 0 else "failed"
|
||
log_entry.papers_found = total
|
||
log_entry.papers_new = done
|
||
log_entry.completed_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
|
||
logger.info(
|
||
"Summarize batch done: total=%d done=%d failed=%d skipped=%d",
|
||
total, done, failed, skipped,
|
||
)
|
||
return {
|
||
"status": "success" if failed == 0 else "partial",
|
||
"total": total,
|
||
"done": done,
|
||
"failed": failed,
|
||
"skipped": skipped,
|
||
}
|
||
|
||
except Exception as exc:
|
||
logger.exception("Summarize batch failed")
|
||
log_entry.status = "failed"
|
||
log_entry.error = str(exc)[:2000]
|
||
log_entry.completed_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
return {"status": "failed", "error": str(exc)}
|
||
|
||
finally:
|
||
_release_lock(db, lock)
|
||
|
||
|
||
def _release_lock(db: Session, lock: TaskLock) -> None:
|
||
"""释放 TaskLock。"""
|
||
try:
|
||
lock.status = "finished"
|
||
lock.released_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
except Exception:
|
||
db.rollback()
|
||
logger.warning("Failed to release summarize lock", exc_info=True)
|