feat: enhance UI, refactor services, improve templates and tests

- Replace image_extractor with pdf_image_extractor service
- Enhance pi_client with expanded API capabilities
- Improve summarizer service with additional features
- Update admin routes with more endpoints
- Add login page template
- Enhance detail page with comprehensive layout
- Improve search and trends pages
- Update base template with additional elements
- Refactor tests for better coverage
- Add validate_summary script
- Update project configuration and dependencies
This commit is contained in:
2026-06-07 19:38:58 +08:00
parent 4a72c35452
commit 0d293422ac
32 changed files with 2003 additions and 586 deletions
+50 -39
View File
@@ -87,7 +87,8 @@ def client(db_engine, db_session):
# ── 样例数据 ────────────────────────────────────────────────────────────
SAMPLE_ARXIV_ID = "2401.12345"
ADMIN_TOKEN = "test-admin-token-12345"
_TEST_ADMIN_USERNAME = "admin"
_TEST_ADMIN_PASSWORD = "test-password-12345"
@pytest.fixture
@@ -138,46 +139,56 @@ def sample_paper(db_session):
def sample_summary_dict() -> dict:
"""完整合法的 summary dict。"""
return {
"arxiv_id": "2401.12345",
"title_zh": "测试论文中文标题",
"one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。",
"tags": ["自然语言处理", "大语言模型", "Transformer"],
"difficulty": "中级",
"prerequisites": {
"concepts": ["Transformer", "注意力机制"],
"level": "中级",
"concepts": [
{
"term": "Transformer",
"explanation": "一种基于自注意力机制的序列到序列模型架构,广泛用于NLP任务。",
"why_matters": "本文方法基于 Transformer 架构进行改进。",
},
{
"term": "注意力机制",
"explanation": "允许模型在处理序列时动态关注不同位置的信息的机制。",
"why_matters": "理解注意力机制是理解本文方法的基础。",
},
],
},
"motivation": {
"problem": "现有模型在长文本理解上存在不足。",
"goal": "提出一种新的注意力机制来提升长文本建模能力。",
"gap": "当前方法计算复杂度过高。",
"problem": "现有模型在长文本理解上存在不足,主要体现在注意力计算复杂度随序列长度二次增长,导致实际应用中无法处理超长文本输入",
"goal": "提出一种新的稀疏注意力机制来有效提升长文本建模能力,在保持模型整体性能的同时大幅降低计算开销和显存占用",
"gap": "当前方法计算复杂度过高,已有的稀疏注意力方案在保留全局信息方面存在明显不足,导致长距离依赖建模效果不佳",
},
"method": {
"overview": "提出了一种高效的稀疏注意力机制。",
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度。",
"steps": [
"分析现有注意力机制的瓶颈",
"设计稀疏注意力模式",
"在多个基准上验证效果",
],
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模。",
"overview": "提出了一种高效的稀疏注意力机制,通过局部-全局混合的注意力模式,在降低计算复杂度的同时保留了关键的全局信息流动",
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度,局部窗口捕获短距离依赖,全局采样点维护长距离信息传递",
"steps": "首先分析现有注意力机制的计算瓶颈,发现全连接注意力中大部分注意力权重接近于零。然后设计了一种混合稀疏注意力模式,包含局部滑动窗口和全局随机采样两条路径。最后在多个长文本基准数据集上进行了全面的实验验证。",
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模,通过可学习的采样策略动态调整全局注意力点的位置,而非固定模式。",
},
"results": {
"main_findings": [
"在长文本基准上取得了 SOTA 结果",
"推理速度提升了 2 倍",
],
"main_findings": "在长文本基准 LongBench 上取得了 SOTA 结果,平均得分提升 3.2 个百分点。推理速度相比全注意力提升了 2 倍,显存占用降低 60%。在 32k 序列长度下仍保持与全注意力相当的生成质量。",
"benchmarks": [
{"dataset": "LongBench", "score": 85.3},
],
"limitations": [
"在超长文本(>100k tokens)上效果有所下降",
{"task": "长文本摘要", "metric": "ROUGE-L", "this_work": "42.1", "baseline": "38.9", "improvement": "+3.2"},
],
"limitations": "在超长文本(>100k tokens)上效果有所下降,主要原因是全局采样点数量不足以覆盖所有关键信息。此外,在小规模数据集上的优势不如大规模数据集明显。",
},
"improvements": {
"weaknesses": ["仅验证了英文数据"],
"future_work": ["扩展到多语言场景"],
"reproducibility": "代码已开源,模型权重可下载",
"weaknesses": "仅验证了英文数据,未在中文等多语言场景下测试。全局采样策略在极端长度的文本上可能需要增加采样点数量,增加了工程复杂度。",
"future_work": "扩展到多语言场景,研究自适应采样策略,使模型能根据输入内容动态调整全局注意力点的分配。同时探索与 Flash Attention 等底层优化的兼容性。",
"reproducibility": "代码已在 GitHub 开源,提供了完整的训练脚本和预训练模型权重。实验使用了公开数据集,硬件需求为 8×A100 GPU",
},
"figures": [
{
"id": "Figure 1",
"caption": "稀疏注意力机制的整体架构图",
"description": "展示了局部窗口注意力和全局采样注意力的组合方式,以及信息如何在两种路径间流动。",
"reason": "帮助理解本文方法的核心设计思想,直观展示了局部-全局混合模式的工作原理。",
},
],
}
@@ -200,21 +211,21 @@ def mock_pi_output(sample_summary_json) -> str:
@pytest.fixture
def admin_token():
"""返回测试用的 ADMIN_TOKEN(需要配合 monkeypatch 使用)。"""
return ADMIN_TOKEN
def auth_client(client, monkeypatch):
"""已登录的 TestClientsession cookie 自动携带)。"""
from app.config import settings
@pytest.fixture
def admin_headers(admin_token):
"""带 Bearer token 的请求头。"""
return {"Authorization": f"Bearer {admin_token}"}
@pytest.fixture
def wrong_admin_headers():
"""错误的 Authorization 请求头。"""
return {"Authorization": "Bearer wrong-token"}
monkeypatch.setattr(settings, "ADMIN_USERNAME", _TEST_ADMIN_USERNAME)
monkeypatch.setattr(settings, "ADMIN_PASSWORD", _TEST_ADMIN_PASSWORD)
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
# 登录获取 session cookie
resp = client.post(
"/admin/login",
data={"username": _TEST_ADMIN_USERNAME, "password": _TEST_ADMIN_PASSWORD},
follow_redirects=False,
)
assert resp.status_code == 303
return client
# ── 多样例数据 ────────────────────────────────────────────────────────────
+94 -100
View File
@@ -16,19 +16,6 @@ from app.models import (
)
# ── Fixtures ────────────────────────────────────────────────────────────
ADMIN_TOKEN = "test-admin-token-12345"
@pytest.fixture
def auth_client(client, monkeypatch):
"""带 admin token monkeypatch 的 TestClient。"""
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
return client
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — 鉴权测试
# ═══════════════════════════════════════════════════════════════════════
@@ -37,80 +24,92 @@ def auth_client(client, monkeypatch):
class TestAdminAuth:
"""管理接口鉴权测试。"""
def test_no_token_returns_403(self, auth_client):
"""无 token 时请求管理接口应返回 403"""
resp = auth_client.post("/admin/crawl")
assert resp.status_code in (403, 401)
def test_unauthenticated_redirects_to_login(self, auth_client):
"""未登录时请求管理接口应重定向到登录页"""
# 用未登录的 clientauth_client 已登录,这里直接用 client)
pass # 见下方 test_no_session_returns_303
def test_wrong_token_returns_401(self, auth_client, wrong_admin_headers):
"""错误 token 应返回 401"""
resp = auth_client.post("/admin/crawl", headers=wrong_admin_headers)
assert resp.status_code == 401
def test_no_session_returns_303(self, client, monkeypatch):
"""无 session 时请求管理接口应返回 303 重定向"""
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
resp = client.post("/admin/crawl", follow_redirects=False)
assert resp.status_code == 303
assert "/admin/login" in resp.headers.get("location", "")
def test_correct_token_accepted(self, auth_client, admin_headers):
"""正确 token 应被接受(crawl 可能会失败但不是 401)"""
def test_wrong_password_shows_error(self, client, monkeypatch):
"""错误密码应返回登录页并显示错误"""
monkeypatch.setattr(settings, "ADMIN_USERNAME", "admin")
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "correct-pass")
resp = client.post(
"/admin/login",
data={"username": "admin", "password": "wrong-pass"},
follow_redirects=False,
)
assert resp.status_code == 200
assert "错误" in resp.text or "error" in resp.text.lower()
def test_correct_login_redirects_to_logs(self, client, monkeypatch):
"""正确登录应重定向到 /admin/logs。"""
monkeypatch.setattr(settings, "ADMIN_USERNAME", "admin")
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "test-pass")
resp = client.post(
"/admin/login",
data={"username": "admin", "password": "test-pass"},
follow_redirects=False,
)
assert resp.status_code == 303
assert "/admin/logs" in resp.headers.get("location", "")
def test_logout_clears_session(self, auth_client, monkeypatch):
"""退出登录后应清除 session。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.post("/admin/logout", follow_redirects=False)
assert resp.status_code == 303
# 退出后访问管理页应被重定向
resp = auth_client.get("/admin/logs", follow_redirects=False)
assert resp.status_code == 303
def test_correct_session_accepted(self, auth_client):
"""已登录 session 应被接受(crawl 可能会失败但不是 303)。"""
with patch(
"app.routes.admin.crawl_daily", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
resp = auth_client.post("/admin/crawl", headers=admin_headers)
assert resp.status_code != 401
resp = auth_client.post("/admin/crawl")
assert resp.status_code != 303
# ── summarize route auth ────────────────────────────────────────
def test_no_token_returns_401_for_summarize(self, client):
"""Bearer token 返回 401"""
resp = client.post("/admin/summarize")
assert resp.status_code in (401, 403)
def test_no_session_returns_303_for_summarize(self, client, monkeypatch):
"""session 返回 303"""
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
resp = client.post("/admin/summarize", follow_redirects=False)
assert resp.status_code == 303
def test_wrong_token_returns_401_for_summarize(self, client):
resp = client.post(
"/admin/summarize",
headers={"Authorization": "Bearer wrong-token"},
)
assert resp.status_code == 401
def test_correct_session_batch_summarize(self, auth_client):
"""已登录调用 batch summarizemock 掉服务层。"""
with patch(
"app.routes.admin.summarize_batch", new_callable=AsyncMock
) as mock:
mock.return_value = {
"status": "success",
"done": 0,
"failed": 0,
"total": 0,
}
resp = auth_client.post("/admin/summarize")
assert resp.status_code == 200
assert resp.json()["status"] == "success"
def test_correct_token_batch_summarize(self, client, admin_headers):
"""正确 token 调用 batch summarizemock 掉服务层。"""
import app.config as config_mod
original = config_mod.settings.ADMIN_TOKEN
config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN
try:
with patch(
"app.routes.admin.summarize_batch", new_callable=AsyncMock
) as mock:
mock.return_value = {
"status": "success",
"done": 0,
"failed": 0,
"total": 0,
}
resp = client.post("/admin/summarize", headers=admin_headers)
assert resp.status_code == 200
assert resp.json()["status"] == "success"
finally:
config_mod.settings.ADMIN_TOKEN = original
def test_single_paper_not_found(self, client, admin_headers):
def test_single_paper_not_found(self, auth_client):
"""单篇总结不存在的论文返回 404。"""
import app.config as config_mod
original = config_mod.settings.ADMIN_TOKEN
config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN
try:
with patch(
"app.routes.admin.summarize_single",
new_callable=AsyncMock,
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
):
resp = client.post(
"/admin/summarize/nonexistent.99999",
headers=admin_headers,
)
assert resp.status_code == 404
finally:
config_mod.settings.ADMIN_TOKEN = original
with patch(
"app.routes.admin.summarize_single",
new_callable=AsyncMock,
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
):
resp = auth_client.post("/admin/summarize/nonexistent.99999")
assert resp.status_code == 404
# ═══════════════════════════════════════════════════════════════════════
@@ -121,27 +120,25 @@ class TestAdminAuth:
class TestAdminCrawl:
"""POST /admin/crawl 测试。"""
def test_crawl_default_today(self, auth_client, admin_headers):
def test_crawl_default_today(self, auth_client):
"""不指定日期时默认抓取今天。"""
with patch(
"app.routes.admin.crawl_daily", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
resp = auth_client.post("/admin/crawl", headers=admin_headers)
resp = auth_client.post("/admin/crawl")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "success"
mock_crawl.assert_called_once()
def test_crawl_specific_date(self, auth_client, admin_headers):
def test_crawl_specific_date(self, auth_client):
"""指定日期抓取。"""
with patch(
"app.routes.admin.crawl_daily", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
resp = auth_client.post(
"/admin/crawl?date=2024-01-15", headers=admin_headers
)
resp = auth_client.post("/admin/crawl?date=2024-01-15")
assert resp.status_code == 200
mock_crawl.assert_called_once()
call_args = mock_crawl.call_args
@@ -156,21 +153,21 @@ class TestAdminCrawl:
class TestAdminCleanup:
"""POST /admin/cleanup 测试。"""
def test_cleanup_returns_stats(self, auth_client, admin_headers):
def test_cleanup_returns_stats(self, auth_client):
"""清理应返回统计信息。"""
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []}
resp = auth_client.post("/admin/cleanup", headers=admin_headers)
resp = auth_client.post("/admin/cleanup")
assert resp.status_code == 200
data = resp.json()
assert data["scanned"] == 3
assert data["removed"] == 1
def test_cleanup_writes_log(self, auth_client, admin_headers, db_session):
def test_cleanup_writes_log(self, auth_client, db_session):
"""清理应写入 crawl_logs。"""
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []}
auth_client.post("/admin/cleanup", headers=admin_headers)
auth_client.post("/admin/cleanup")
logs = (
db_session.execute(select(CrawlLog).where(CrawlLog.task == "cleanup"))
@@ -189,7 +186,7 @@ class TestAdminCleanup:
class TestAdminDelete:
"""POST /admin/delete 测试。"""
def test_delete_requires_confirm(self, auth_client, admin_headers):
def test_delete_requires_confirm(self, auth_client):
"""confirm 不是 'DELETE' 时应返回 422。"""
resp = auth_client.post(
"/admin/delete",
@@ -199,12 +196,11 @@ class TestAdminDelete:
"include_notes": True,
"confirm": "WRONG",
},
headers=admin_headers,
)
assert resp.status_code == 422
def test_delete_with_confirm(
self, auth_client, admin_headers, db_session, sample_papers_range
self, auth_client, db_session, sample_papers_range
):
"""confirm='DELETE' 时应执行删除。"""
resp = auth_client.post(
@@ -215,13 +211,12 @@ class TestAdminDelete:
"include_notes": True,
"confirm": "DELETE",
},
headers=admin_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["deleted"] == 3
def test_delete_invalid_date_range(self, auth_client, admin_headers):
def test_delete_invalid_date_range(self, auth_client):
"""date_start > date_end 应返回 400。"""
resp = auth_client.post(
"/admin/delete",
@@ -230,11 +225,10 @@ class TestAdminDelete:
"date_end": "2024-01-10",
"confirm": "DELETE",
},
headers=admin_headers,
)
assert resp.status_code == 400
def test_delete_without_confirm_field(self, auth_client, admin_headers):
def test_delete_without_confirm_field(self, auth_client):
"""缺少 confirm 字段应返回 422。"""
resp = auth_client.post(
"/admin/delete",
@@ -242,7 +236,6 @@ class TestAdminDelete:
"date_start": "2024-01-10",
"date_end": "2024-01-12",
},
headers=admin_headers,
)
assert resp.status_code == 422
@@ -255,19 +248,20 @@ class TestAdminDelete:
class TestAdminLogs:
"""GET /admin/logs 测试。"""
def test_logs_returns_page(self, auth_client, admin_headers):
def test_logs_returns_page(self, auth_client):
"""应返回管理日志页面。"""
resp = auth_client.get("/admin/logs", headers=admin_headers)
resp = auth_client.get("/admin/logs")
assert resp.status_code == 200
assert "text/html" in resp.headers.get("content-type", "")
def test_logs_requires_auth(self, auth_client):
def test_logs_requires_auth(self, client, monkeypatch):
"""日志页面需要鉴权。"""
resp = auth_client.get("/admin/logs")
assert resp.status_code in (403, 401)
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
resp = client.get("/admin/logs", follow_redirects=False)
assert resp.status_code == 303
def test_logs_contains_data(
self, auth_client, admin_headers, db_session, sample_papers_range
self, auth_client, db_session, sample_papers_range
):
"""日志页面应包含日志数据。"""
# 先创建一条日志
@@ -282,7 +276,7 @@ class TestAdminLogs:
)
db_session.commit()
resp = auth_client.get("/admin/logs", headers=admin_headers)
resp = auth_client.get("/admin/logs")
assert resp.status_code == 200
assert "crawl" in resp.text.lower() or "日志" in resp.text
-107
View File
@@ -1,107 +0,0 @@
"""LaTeX 图片提取测试 — 从 .tex 源码中提取图片文件。"""
from __future__ import annotations
import pytest
# ═══════════════════════════════════════════════════════════════════════
# Image Extraction
# ═══════════════════════════════════════════════════════════════════════
class TestImageExtraction:
"""LaTeX 图片提取测试。"""
@pytest.mark.asyncio
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
"""源码目录不存在时返回 0。"""
monkeypatch.setattr(
"app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x
)
monkeypatch.setattr(
"app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x
)
from app.services.image_extractor import extract_images_from_source
result = await extract_images_from_source("2401.99999")
assert result == 0
@pytest.mark.asyncio
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
"""从 .tex 文件中提取图片。"""
from app.services.image_extractor import extract_images_from_source
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
tmp_source.mkdir(parents=True)
images_dir = tmp_source / "figs"
images_dir.mkdir()
(images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n")
(images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0")
# 创建 .tex 文件
tex_content = r"""
\documentclass{article}
\begin{document}
\begin{figure}
\includegraphics[width=0.8\textwidth]{figs/figure1.png}
\includegraphics{figs/figure2.jpg}
\includegraphics[angle=90]{figs/nonexistent.pdf}
\end{figure}
\end{document}
"""
(tmp_source / "main.tex").write_text(tex_content)
papers_dir = tmp_path / "papers" / "2401.00001"
monkeypatch.setattr(
"app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x
)
monkeypatch.setattr(
"app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x
)
# Mock download_source_zip to avoid real network call (source dir already exists)
async def _noop_download(*args, **kwargs):
pass
monkeypatch.setattr(
"app.services.image_extractor.download_source_zip", _noop_download
)
result = await extract_images_from_source("2401.00001")
assert result == 2
dest_images = papers_dir / "images"
assert dest_images.exists()
assert (dest_images / "figure1.png").exists()
assert (dest_images / "figure2.jpg").exists()
@pytest.mark.asyncio
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
""".tex 文件无图片时返回 0。"""
from app.services.image_extractor import extract_images_from_source
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
tmp_source.mkdir(parents=True)
(tmp_source / "main.tex").write_text(
r"\documentclass{article}\begin{document}Hello\end{document}"
)
monkeypatch.setattr(
"app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x
)
monkeypatch.setattr(
"app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x
)
# Mock download_source_zip to avoid real network call
async def _noop_download(*args, **kwargs):
pass
monkeypatch.setattr(
"app.services.image_extractor.download_source_zip", _noop_download
)
result = await extract_images_from_source("2401.00002")
assert result == 0
+3 -4
View File
@@ -64,10 +64,9 @@ class TestSummarySchema:
SummarySchema.model_validate(sample_summary_dict)
def test_extra_fields_ignored(self, sample_summary_dict):
sample_summary_dict["figures"] = ["fig1.png"]
sample_summary_dict["takeaway"] = "important paper"
schema = SummarySchema.model_validate(sample_summary_dict)
assert not hasattr(schema, "figures")
assert not hasattr(schema, "takeaway")
assert schema.title_zh # 正常解析
def test_flatten_for_db(self, sample_summary_dict):
@@ -80,7 +79,7 @@ class TestSummarySchema:
assert "updated_at" in flat
# JSON 字段可解析
assert isinstance(json.loads(flat["prerequisites_json"]), dict)
assert isinstance(json.loads(flat["method_steps_json"]), list)
assert isinstance(flat["figures_json"], str) # figures 序列化为 JSON
# ═══════════════════════════════════════════════════════════════════════
@@ -99,7 +98,7 @@ class TestQualityAssessment:
sample_summary_dict["motivation"]["goal"] = ""
sample_summary_dict["motivation"]["gap"] = ""
sample_summary_dict["method"]["overview"] = ""
sample_summary_dict["results"]["main_findings"] = []
sample_summary_dict["results"]["main_findings"] = ""
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "degraded"
+18 -26
View File
@@ -182,7 +182,7 @@ class TestSummarizeOneFlow:
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value=mock_pi_output,
return_value=(mock_pi_output, "test-session-id"),
),
):
result = await summarize_one(db_session, sample_paper)
@@ -246,27 +246,28 @@ class TestSummarizeOneFlow:
@pytest.mark.asyncio
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
"""pi 输出无 JSON → json_not_found"""
"""pi 输出无 JSON → 验证循环重试 4 次后 ValueError (unknown)"""
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value="No JSON in this output at all.",
return_value=("No JSON in this output at all.", "test-session-id"),
),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "failed"
assert result["error_type"] == "json_not_found"
assert result["error_type"] == "unknown"
@pytest.mark.asyncio
async def test_field_missing_and_retry(
async def test_validation_fails_and_retries(
self, db_session, sample_paper, _patch_paths
):
"""必填字段缺失 → field_missing → retry → permanent_failure"""
"""验证失败(字段不符合要求)→ 重试多次后失败"""
bad_json = json.dumps(
{
"arxiv_id": sample_paper.arxiv_id,
"title_zh": "", # 空的必填字段
"one_line": "valid line",
"tags": ["tag1"],
@@ -282,23 +283,14 @@ class TestSummarizeOneFlow:
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value=bad_output,
return_value=(bad_output, "test-session-id"),
),
):
# 第一次失败 → pending (retry)
result1 = await summarize_one(db_session, sample_paper)
assert result1["status"] == "failed"
assert result1["error_type"] == "field_missing"
assert result1["retry_count"] == 1
# 第二次失败 → permanent_failure (SUMMARY_MAX_RETRIES=1, 所以 2 次 > 1+1)
db_session.refresh(sample_paper)
result2 = await summarize_one(db_session, sample_paper)
assert result2["status"] == "failed"
assert result2["retry_count"] == 2
db_session.refresh(sample_paper)
assert sample_paper.summary_status.status == "permanent_failure"
# _validate_summary 先拦截,4 轮都失败后 ValueError → unknown
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "failed"
assert result["error_type"] == "unknown"
assert result["retry_count"] == 1
@pytest.mark.asyncio
async def test_raw_output_saved_on_failure(
@@ -310,7 +302,7 @@ class TestSummarizeOneFlow:
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value="Some output without JSON",
return_value=("Some output without JSON", "test-session-id"),
),
):
await summarize_one(db_session, sample_paper)
@@ -329,7 +321,7 @@ class TestSummarizeOneFlow:
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value=mock_pi_output,
return_value=(mock_pi_output, "test-session-id"),
),
):
await summarize_one(db_session, sample_paper)
@@ -417,7 +409,7 @@ class TestBatchSummarize:
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value=mock_pi_output,
return_value=(mock_pi_output, "test-session-id"),
),
):
result = await summarize_batch(db_session, _session_factory=_TestSession)
@@ -464,7 +456,7 @@ class TestBatchSummarize:
call_count += 1
if call_count == 1:
raise PiTimeoutError("timeout")
return mock_pi_output
return mock_pi_output, "test-session-id"
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
@@ -506,7 +498,7 @@ class TestBatchSummarize:
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value=mock_pi_output,
return_value=(mock_pi_output, "test-session-id"),
),
):
await summarize_batch(db_session, _session_factory=_TestSession)