feat: enhance UI, refactor services, improve templates and tests
- Replace image_extractor with pdf_image_extractor service - Enhance pi_client with expanded API capabilities - Improve summarizer service with additional features - Update admin routes with more endpoints - Add login page template - Enhance detail page with comprehensive layout - Improve search and trends pages - Update base template with additional elements - Refactor tests for better coverage - Add validate_summary script - Update project configuration and dependencies
This commit is contained in:
+50
-39
@@ -87,7 +87,8 @@ def client(db_engine, db_session):
|
||||
# ── 样例数据 ────────────────────────────────────────────────────────────
|
||||
|
||||
SAMPLE_ARXIV_ID = "2401.12345"
|
||||
ADMIN_TOKEN = "test-admin-token-12345"
|
||||
_TEST_ADMIN_USERNAME = "admin"
|
||||
_TEST_ADMIN_PASSWORD = "test-password-12345"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -138,46 +139,56 @@ def sample_paper(db_session):
|
||||
def sample_summary_dict() -> dict:
|
||||
"""完整合法的 summary dict。"""
|
||||
return {
|
||||
"arxiv_id": "2401.12345",
|
||||
"title_zh": "测试论文中文标题",
|
||||
"one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。",
|
||||
"tags": ["自然语言处理", "大语言模型", "Transformer"],
|
||||
"difficulty": "中级",
|
||||
"prerequisites": {
|
||||
"concepts": ["Transformer", "注意力机制"],
|
||||
"level": "中级",
|
||||
"concepts": [
|
||||
{
|
||||
"term": "Transformer",
|
||||
"explanation": "一种基于自注意力机制的序列到序列模型架构,广泛用于NLP任务。",
|
||||
"why_matters": "本文方法基于 Transformer 架构进行改进。",
|
||||
},
|
||||
{
|
||||
"term": "注意力机制",
|
||||
"explanation": "允许模型在处理序列时动态关注不同位置的信息的机制。",
|
||||
"why_matters": "理解注意力机制是理解本文方法的基础。",
|
||||
},
|
||||
],
|
||||
},
|
||||
"motivation": {
|
||||
"problem": "现有模型在长文本理解上存在不足。",
|
||||
"goal": "提出一种新的注意力机制来提升长文本建模能力。",
|
||||
"gap": "当前方法计算复杂度过高。",
|
||||
"problem": "现有模型在长文本理解上存在不足,主要体现在注意力计算复杂度随序列长度二次增长,导致实际应用中无法处理超长文本输入。",
|
||||
"goal": "提出一种新的稀疏注意力机制来有效提升长文本建模能力,在保持模型整体性能的同时大幅降低计算开销和显存占用。",
|
||||
"gap": "当前方法计算复杂度过高,已有的稀疏注意力方案在保留全局信息方面存在明显不足,导致长距离依赖建模效果不佳。",
|
||||
},
|
||||
"method": {
|
||||
"overview": "提出了一种高效的稀疏注意力机制。",
|
||||
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度。",
|
||||
"steps": [
|
||||
"分析现有注意力机制的瓶颈",
|
||||
"设计稀疏注意力模式",
|
||||
"在多个基准上验证效果",
|
||||
],
|
||||
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模。",
|
||||
"overview": "提出了一种高效的稀疏注意力机制,通过局部-全局混合的注意力模式,在降低计算复杂度的同时保留了关键的全局信息流动。",
|
||||
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度,局部窗口捕获短距离依赖,全局采样点维护长距离信息传递。",
|
||||
"steps": "首先分析现有注意力机制的计算瓶颈,发现全连接注意力中大部分注意力权重接近于零。然后设计了一种混合稀疏注意力模式,包含局部滑动窗口和全局随机采样两条路径。最后在多个长文本基准数据集上进行了全面的实验验证。",
|
||||
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模,通过可学习的采样策略动态调整全局注意力点的位置,而非固定模式。",
|
||||
},
|
||||
"results": {
|
||||
"main_findings": [
|
||||
"在长文本基准上取得了 SOTA 结果",
|
||||
"推理速度提升了 2 倍",
|
||||
],
|
||||
"main_findings": "在长文本基准 LongBench 上取得了 SOTA 结果,平均得分提升 3.2 个百分点。推理速度相比全注意力提升了 2 倍,显存占用降低 60%。在 32k 序列长度下仍保持与全注意力相当的生成质量。",
|
||||
"benchmarks": [
|
||||
{"dataset": "LongBench", "score": 85.3},
|
||||
],
|
||||
"limitations": [
|
||||
"在超长文本(>100k tokens)上效果有所下降",
|
||||
{"task": "长文本摘要", "metric": "ROUGE-L", "this_work": "42.1", "baseline": "38.9", "improvement": "+3.2"},
|
||||
],
|
||||
"limitations": "在超长文本(>100k tokens)上效果有所下降,主要原因是全局采样点数量不足以覆盖所有关键信息。此外,在小规模数据集上的优势不如大规模数据集明显。",
|
||||
},
|
||||
"improvements": {
|
||||
"weaknesses": ["仅验证了英文数据"],
|
||||
"future_work": ["扩展到多语言场景"],
|
||||
"reproducibility": "代码已开源,模型权重可下载。",
|
||||
"weaknesses": "仅验证了英文数据,未在中文等多语言场景下测试。全局采样策略在极端长度的文本上可能需要增加采样点数量,增加了工程复杂度。",
|
||||
"future_work": "扩展到多语言场景,研究自适应采样策略,使模型能根据输入内容动态调整全局注意力点的分配。同时探索与 Flash Attention 等底层优化的兼容性。",
|
||||
"reproducibility": "代码已在 GitHub 开源,提供了完整的训练脚本和预训练模型权重。实验使用了公开数据集,硬件需求为 8×A100 GPU。",
|
||||
},
|
||||
"figures": [
|
||||
{
|
||||
"id": "Figure 1",
|
||||
"caption": "稀疏注意力机制的整体架构图",
|
||||
"description": "展示了局部窗口注意力和全局采样注意力的组合方式,以及信息如何在两种路径间流动。",
|
||||
"reason": "帮助理解本文方法的核心设计思想,直观展示了局部-全局混合模式的工作原理。",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -200,21 +211,21 @@ def mock_pi_output(sample_summary_json) -> str:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_token():
|
||||
"""返回测试用的 ADMIN_TOKEN(需要配合 monkeypatch 使用)。"""
|
||||
return ADMIN_TOKEN
|
||||
def auth_client(client, monkeypatch):
|
||||
"""已登录的 TestClient(session cookie 自动携带)。"""
|
||||
from app.config import settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_headers(admin_token):
|
||||
"""带 Bearer token 的请求头。"""
|
||||
return {"Authorization": f"Bearer {admin_token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wrong_admin_headers():
|
||||
"""错误的 Authorization 请求头。"""
|
||||
return {"Authorization": "Bearer wrong-token"}
|
||||
monkeypatch.setattr(settings, "ADMIN_USERNAME", _TEST_ADMIN_USERNAME)
|
||||
monkeypatch.setattr(settings, "ADMIN_PASSWORD", _TEST_ADMIN_PASSWORD)
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
# 登录获取 session cookie
|
||||
resp = client.post(
|
||||
"/admin/login",
|
||||
data={"username": _TEST_ADMIN_USERNAME, "password": _TEST_ADMIN_PASSWORD},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert resp.status_code == 303
|
||||
return client
|
||||
|
||||
|
||||
# ── 多样例数据 ────────────────────────────────────────────────────────────
|
||||
|
||||
+94
-100
@@ -16,19 +16,6 @@ from app.models import (
|
||||
)
|
||||
|
||||
|
||||
# ── Fixtures ────────────────────────────────────────────────────────────
|
||||
|
||||
ADMIN_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_client(client, monkeypatch):
|
||||
"""带 admin token monkeypatch 的 TestClient。"""
|
||||
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
return client
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Admin Routes — 鉴权测试
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -37,80 +24,92 @@ def auth_client(client, monkeypatch):
|
||||
class TestAdminAuth:
|
||||
"""管理接口鉴权测试。"""
|
||||
|
||||
def test_no_token_returns_403(self, auth_client):
|
||||
"""无 token 时请求管理接口应返回 403。"""
|
||||
resp = auth_client.post("/admin/crawl")
|
||||
assert resp.status_code in (403, 401)
|
||||
def test_unauthenticated_redirects_to_login(self, auth_client):
|
||||
"""未登录时请求管理接口应重定向到登录页。"""
|
||||
# 用未登录的 client(auth_client 已登录,这里直接用 client)
|
||||
pass # 见下方 test_no_session_returns_303
|
||||
|
||||
def test_wrong_token_returns_401(self, auth_client, wrong_admin_headers):
|
||||
"""错误 token 应返回 401。"""
|
||||
resp = auth_client.post("/admin/crawl", headers=wrong_admin_headers)
|
||||
assert resp.status_code == 401
|
||||
def test_no_session_returns_303(self, client, monkeypatch):
|
||||
"""无 session 时请求管理接口应返回 303 重定向。"""
|
||||
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
|
||||
resp = client.post("/admin/crawl", follow_redirects=False)
|
||||
assert resp.status_code == 303
|
||||
assert "/admin/login" in resp.headers.get("location", "")
|
||||
|
||||
def test_correct_token_accepted(self, auth_client, admin_headers):
|
||||
"""正确 token 应被接受(crawl 可能会失败但不是 401)。"""
|
||||
def test_wrong_password_shows_error(self, client, monkeypatch):
|
||||
"""错误密码应返回登录页并显示错误。"""
|
||||
monkeypatch.setattr(settings, "ADMIN_USERNAME", "admin")
|
||||
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "correct-pass")
|
||||
resp = client.post(
|
||||
"/admin/login",
|
||||
data={"username": "admin", "password": "wrong-pass"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "错误" in resp.text or "error" in resp.text.lower()
|
||||
|
||||
def test_correct_login_redirects_to_logs(self, client, monkeypatch):
|
||||
"""正确登录应重定向到 /admin/logs。"""
|
||||
monkeypatch.setattr(settings, "ADMIN_USERNAME", "admin")
|
||||
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "test-pass")
|
||||
resp = client.post(
|
||||
"/admin/login",
|
||||
data={"username": "admin", "password": "test-pass"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert resp.status_code == 303
|
||||
assert "/admin/logs" in resp.headers.get("location", "")
|
||||
|
||||
def test_logout_clears_session(self, auth_client, monkeypatch):
|
||||
"""退出登录后应清除 session。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = auth_client.post("/admin/logout", follow_redirects=False)
|
||||
assert resp.status_code == 303
|
||||
# 退出后访问管理页应被重定向
|
||||
resp = auth_client.get("/admin/logs", follow_redirects=False)
|
||||
assert resp.status_code == 303
|
||||
|
||||
def test_correct_session_accepted(self, auth_client):
|
||||
"""已登录 session 应被接受(crawl 可能会失败但不是 303)。"""
|
||||
with patch(
|
||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
|
||||
resp = auth_client.post("/admin/crawl", headers=admin_headers)
|
||||
assert resp.status_code != 401
|
||||
resp = auth_client.post("/admin/crawl")
|
||||
assert resp.status_code != 303
|
||||
|
||||
# ── summarize route auth ────────────────────────────────────────
|
||||
|
||||
def test_no_token_returns_401_for_summarize(self, client):
|
||||
"""无 Bearer token 返回 401。"""
|
||||
resp = client.post("/admin/summarize")
|
||||
assert resp.status_code in (401, 403)
|
||||
def test_no_session_returns_303_for_summarize(self, client, monkeypatch):
|
||||
"""无 session 返回 303。"""
|
||||
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
|
||||
resp = client.post("/admin/summarize", follow_redirects=False)
|
||||
assert resp.status_code == 303
|
||||
|
||||
def test_wrong_token_returns_401_for_summarize(self, client):
|
||||
resp = client.post(
|
||||
"/admin/summarize",
|
||||
headers={"Authorization": "Bearer wrong-token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
def test_correct_session_batch_summarize(self, auth_client):
|
||||
"""已登录调用 batch summarize,mock 掉服务层。"""
|
||||
with patch(
|
||||
"app.routes.admin.summarize_batch", new_callable=AsyncMock
|
||||
) as mock:
|
||||
mock.return_value = {
|
||||
"status": "success",
|
||||
"done": 0,
|
||||
"failed": 0,
|
||||
"total": 0,
|
||||
}
|
||||
resp = auth_client.post("/admin/summarize")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "success"
|
||||
|
||||
def test_correct_token_batch_summarize(self, client, admin_headers):
|
||||
"""正确 token 调用 batch summarize,mock 掉服务层。"""
|
||||
import app.config as config_mod
|
||||
|
||||
original = config_mod.settings.ADMIN_TOKEN
|
||||
config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN
|
||||
try:
|
||||
with patch(
|
||||
"app.routes.admin.summarize_batch", new_callable=AsyncMock
|
||||
) as mock:
|
||||
mock.return_value = {
|
||||
"status": "success",
|
||||
"done": 0,
|
||||
"failed": 0,
|
||||
"total": 0,
|
||||
}
|
||||
resp = client.post("/admin/summarize", headers=admin_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "success"
|
||||
finally:
|
||||
config_mod.settings.ADMIN_TOKEN = original
|
||||
|
||||
def test_single_paper_not_found(self, client, admin_headers):
|
||||
def test_single_paper_not_found(self, auth_client):
|
||||
"""单篇总结不存在的论文返回 404。"""
|
||||
import app.config as config_mod
|
||||
|
||||
original = config_mod.settings.ADMIN_TOKEN
|
||||
config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN
|
||||
try:
|
||||
with patch(
|
||||
"app.routes.admin.summarize_single",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
|
||||
):
|
||||
resp = client.post(
|
||||
"/admin/summarize/nonexistent.99999",
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
finally:
|
||||
config_mod.settings.ADMIN_TOKEN = original
|
||||
with patch(
|
||||
"app.routes.admin.summarize_single",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
|
||||
):
|
||||
resp = auth_client.post("/admin/summarize/nonexistent.99999")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -121,27 +120,25 @@ class TestAdminAuth:
|
||||
class TestAdminCrawl:
|
||||
"""POST /admin/crawl 测试。"""
|
||||
|
||||
def test_crawl_default_today(self, auth_client, admin_headers):
|
||||
def test_crawl_default_today(self, auth_client):
|
||||
"""不指定日期时默认抓取今天。"""
|
||||
with patch(
|
||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
|
||||
resp = auth_client.post("/admin/crawl", headers=admin_headers)
|
||||
resp = auth_client.post("/admin/crawl")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "success"
|
||||
mock_crawl.assert_called_once()
|
||||
|
||||
def test_crawl_specific_date(self, auth_client, admin_headers):
|
||||
def test_crawl_specific_date(self, auth_client):
|
||||
"""指定日期抓取。"""
|
||||
with patch(
|
||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
|
||||
resp = auth_client.post(
|
||||
"/admin/crawl?date=2024-01-15", headers=admin_headers
|
||||
)
|
||||
resp = auth_client.post("/admin/crawl?date=2024-01-15")
|
||||
assert resp.status_code == 200
|
||||
mock_crawl.assert_called_once()
|
||||
call_args = mock_crawl.call_args
|
||||
@@ -156,21 +153,21 @@ class TestAdminCrawl:
|
||||
class TestAdminCleanup:
|
||||
"""POST /admin/cleanup 测试。"""
|
||||
|
||||
def test_cleanup_returns_stats(self, auth_client, admin_headers):
|
||||
def test_cleanup_returns_stats(self, auth_client):
|
||||
"""清理应返回统计信息。"""
|
||||
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
|
||||
mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []}
|
||||
resp = auth_client.post("/admin/cleanup", headers=admin_headers)
|
||||
resp = auth_client.post("/admin/cleanup")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["scanned"] == 3
|
||||
assert data["removed"] == 1
|
||||
|
||||
def test_cleanup_writes_log(self, auth_client, admin_headers, db_session):
|
||||
def test_cleanup_writes_log(self, auth_client, db_session):
|
||||
"""清理应写入 crawl_logs。"""
|
||||
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
|
||||
mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []}
|
||||
auth_client.post("/admin/cleanup", headers=admin_headers)
|
||||
auth_client.post("/admin/cleanup")
|
||||
|
||||
logs = (
|
||||
db_session.execute(select(CrawlLog).where(CrawlLog.task == "cleanup"))
|
||||
@@ -189,7 +186,7 @@ class TestAdminCleanup:
|
||||
class TestAdminDelete:
|
||||
"""POST /admin/delete 测试。"""
|
||||
|
||||
def test_delete_requires_confirm(self, auth_client, admin_headers):
|
||||
def test_delete_requires_confirm(self, auth_client):
|
||||
"""confirm 不是 'DELETE' 时应返回 422。"""
|
||||
resp = auth_client.post(
|
||||
"/admin/delete",
|
||||
@@ -199,12 +196,11 @@ class TestAdminDelete:
|
||||
"include_notes": True,
|
||||
"confirm": "WRONG",
|
||||
},
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_delete_with_confirm(
|
||||
self, auth_client, admin_headers, db_session, sample_papers_range
|
||||
self, auth_client, db_session, sample_papers_range
|
||||
):
|
||||
"""confirm='DELETE' 时应执行删除。"""
|
||||
resp = auth_client.post(
|
||||
@@ -215,13 +211,12 @@ class TestAdminDelete:
|
||||
"include_notes": True,
|
||||
"confirm": "DELETE",
|
||||
},
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["deleted"] == 3
|
||||
|
||||
def test_delete_invalid_date_range(self, auth_client, admin_headers):
|
||||
def test_delete_invalid_date_range(self, auth_client):
|
||||
"""date_start > date_end 应返回 400。"""
|
||||
resp = auth_client.post(
|
||||
"/admin/delete",
|
||||
@@ -230,11 +225,10 @@ class TestAdminDelete:
|
||||
"date_end": "2024-01-10",
|
||||
"confirm": "DELETE",
|
||||
},
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_delete_without_confirm_field(self, auth_client, admin_headers):
|
||||
def test_delete_without_confirm_field(self, auth_client):
|
||||
"""缺少 confirm 字段应返回 422。"""
|
||||
resp = auth_client.post(
|
||||
"/admin/delete",
|
||||
@@ -242,7 +236,6 @@ class TestAdminDelete:
|
||||
"date_start": "2024-01-10",
|
||||
"date_end": "2024-01-12",
|
||||
},
|
||||
headers=admin_headers,
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
@@ -255,19 +248,20 @@ class TestAdminDelete:
|
||||
class TestAdminLogs:
|
||||
"""GET /admin/logs 测试。"""
|
||||
|
||||
def test_logs_returns_page(self, auth_client, admin_headers):
|
||||
def test_logs_returns_page(self, auth_client):
|
||||
"""应返回管理日志页面。"""
|
||||
resp = auth_client.get("/admin/logs", headers=admin_headers)
|
||||
resp = auth_client.get("/admin/logs")
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers.get("content-type", "")
|
||||
|
||||
def test_logs_requires_auth(self, auth_client):
|
||||
def test_logs_requires_auth(self, client, monkeypatch):
|
||||
"""日志页面需要鉴权。"""
|
||||
resp = auth_client.get("/admin/logs")
|
||||
assert resp.status_code in (403, 401)
|
||||
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
|
||||
resp = client.get("/admin/logs", follow_redirects=False)
|
||||
assert resp.status_code == 303
|
||||
|
||||
def test_logs_contains_data(
|
||||
self, auth_client, admin_headers, db_session, sample_papers_range
|
||||
self, auth_client, db_session, sample_papers_range
|
||||
):
|
||||
"""日志页面应包含日志数据。"""
|
||||
# 先创建一条日志
|
||||
@@ -282,7 +276,7 @@ class TestAdminLogs:
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
resp = auth_client.get("/admin/logs", headers=admin_headers)
|
||||
resp = auth_client.get("/admin/logs")
|
||||
assert resp.status_code == 200
|
||||
assert "crawl" in resp.text.lower() or "日志" in resp.text
|
||||
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""LaTeX 图片提取测试 — 从 .tex 源码中提取图片文件。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Image Extraction
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestImageExtraction:
|
||||
"""LaTeX 图片提取测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
|
||||
"""源码目录不存在时返回 0。"""
|
||||
monkeypatch.setattr(
|
||||
"app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x
|
||||
)
|
||||
from app.services.image_extractor import extract_images_from_source
|
||||
|
||||
result = await extract_images_from_source("2401.99999")
|
||||
assert result == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
|
||||
"""从 .tex 文件中提取图片。"""
|
||||
from app.services.image_extractor import extract_images_from_source
|
||||
|
||||
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
|
||||
tmp_source.mkdir(parents=True)
|
||||
|
||||
images_dir = tmp_source / "figs"
|
||||
images_dir.mkdir()
|
||||
(images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n")
|
||||
(images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0")
|
||||
|
||||
# 创建 .tex 文件
|
||||
tex_content = r"""
|
||||
\documentclass{article}
|
||||
\begin{document}
|
||||
\begin{figure}
|
||||
\includegraphics[width=0.8\textwidth]{figs/figure1.png}
|
||||
\includegraphics{figs/figure2.jpg}
|
||||
\includegraphics[angle=90]{figs/nonexistent.pdf}
|
||||
\end{figure}
|
||||
\end{document}
|
||||
"""
|
||||
(tmp_source / "main.tex").write_text(tex_content)
|
||||
|
||||
papers_dir = tmp_path / "papers" / "2401.00001"
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x
|
||||
)
|
||||
|
||||
# Mock download_source_zip to avoid real network call (source dir already exists)
|
||||
async def _noop_download(*args, **kwargs):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.download_source_zip", _noop_download
|
||||
)
|
||||
|
||||
result = await extract_images_from_source("2401.00001")
|
||||
|
||||
assert result == 2
|
||||
dest_images = papers_dir / "images"
|
||||
assert dest_images.exists()
|
||||
assert (dest_images / "figure1.png").exists()
|
||||
assert (dest_images / "figure2.jpg").exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
|
||||
""".tex 文件无图片时返回 0。"""
|
||||
from app.services.image_extractor import extract_images_from_source
|
||||
|
||||
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
|
||||
tmp_source.mkdir(parents=True)
|
||||
(tmp_source / "main.tex").write_text(
|
||||
r"\documentclass{article}\begin{document}Hello\end{document}"
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x
|
||||
)
|
||||
|
||||
# Mock download_source_zip to avoid real network call
|
||||
async def _noop_download(*args, **kwargs):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.download_source_zip", _noop_download
|
||||
)
|
||||
|
||||
result = await extract_images_from_source("2401.00002")
|
||||
assert result == 0
|
||||
@@ -64,10 +64,9 @@ class TestSummarySchema:
|
||||
SummarySchema.model_validate(sample_summary_dict)
|
||||
|
||||
def test_extra_fields_ignored(self, sample_summary_dict):
|
||||
sample_summary_dict["figures"] = ["fig1.png"]
|
||||
sample_summary_dict["takeaway"] = "important paper"
|
||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||
assert not hasattr(schema, "figures")
|
||||
assert not hasattr(schema, "takeaway")
|
||||
assert schema.title_zh # 正常解析
|
||||
|
||||
def test_flatten_for_db(self, sample_summary_dict):
|
||||
@@ -80,7 +79,7 @@ class TestSummarySchema:
|
||||
assert "updated_at" in flat
|
||||
# JSON 字段可解析
|
||||
assert isinstance(json.loads(flat["prerequisites_json"]), dict)
|
||||
assert isinstance(json.loads(flat["method_steps_json"]), list)
|
||||
assert isinstance(flat["figures_json"], str) # figures 序列化为 JSON
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -99,7 +98,7 @@ class TestQualityAssessment:
|
||||
sample_summary_dict["motivation"]["goal"] = ""
|
||||
sample_summary_dict["motivation"]["gap"] = ""
|
||||
sample_summary_dict["method"]["overview"] = ""
|
||||
sample_summary_dict["results"]["main_findings"] = []
|
||||
sample_summary_dict["results"]["main_findings"] = ""
|
||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||
assert assess_quality(schema) == "degraded"
|
||||
|
||||
|
||||
+18
-26
@@ -182,7 +182,7 @@ class TestSummarizeOneFlow:
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_pi_output,
|
||||
return_value=(mock_pi_output, "test-session-id"),
|
||||
),
|
||||
):
|
||||
result = await summarize_one(db_session, sample_paper)
|
||||
@@ -246,27 +246,28 @@ class TestSummarizeOneFlow:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
|
||||
"""pi 输出无 JSON → json_not_found。"""
|
||||
"""pi 输出无 JSON → 验证循环重试 4 次后 ValueError (unknown)。"""
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value="No JSON in this output at all.",
|
||||
return_value=("No JSON in this output at all.", "test-session-id"),
|
||||
),
|
||||
):
|
||||
result = await summarize_one(db_session, sample_paper)
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert result["error_type"] == "json_not_found"
|
||||
assert result["error_type"] == "unknown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_missing_and_retry(
|
||||
async def test_validation_fails_and_retries(
|
||||
self, db_session, sample_paper, _patch_paths
|
||||
):
|
||||
"""必填字段缺失 → field_missing → retry → permanent_failure。"""
|
||||
"""验证失败(字段不符合要求)→ 重试多次后失败。"""
|
||||
bad_json = json.dumps(
|
||||
{
|
||||
"arxiv_id": sample_paper.arxiv_id,
|
||||
"title_zh": "", # 空的必填字段
|
||||
"one_line": "valid line",
|
||||
"tags": ["tag1"],
|
||||
@@ -282,23 +283,14 @@ class TestSummarizeOneFlow:
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=bad_output,
|
||||
return_value=(bad_output, "test-session-id"),
|
||||
),
|
||||
):
|
||||
# 第一次失败 → pending (retry)
|
||||
result1 = await summarize_one(db_session, sample_paper)
|
||||
assert result1["status"] == "failed"
|
||||
assert result1["error_type"] == "field_missing"
|
||||
assert result1["retry_count"] == 1
|
||||
|
||||
# 第二次失败 → permanent_failure (SUMMARY_MAX_RETRIES=1, 所以 2 次 > 1+1)
|
||||
db_session.refresh(sample_paper)
|
||||
result2 = await summarize_one(db_session, sample_paper)
|
||||
assert result2["status"] == "failed"
|
||||
assert result2["retry_count"] == 2
|
||||
|
||||
db_session.refresh(sample_paper)
|
||||
assert sample_paper.summary_status.status == "permanent_failure"
|
||||
# _validate_summary 先拦截,4 轮都失败后 ValueError → unknown
|
||||
result = await summarize_one(db_session, sample_paper)
|
||||
assert result["status"] == "failed"
|
||||
assert result["error_type"] == "unknown"
|
||||
assert result["retry_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_output_saved_on_failure(
|
||||
@@ -310,7 +302,7 @@ class TestSummarizeOneFlow:
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value="Some output without JSON",
|
||||
return_value=("Some output without JSON", "test-session-id"),
|
||||
),
|
||||
):
|
||||
await summarize_one(db_session, sample_paper)
|
||||
@@ -329,7 +321,7 @@ class TestSummarizeOneFlow:
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_pi_output,
|
||||
return_value=(mock_pi_output, "test-session-id"),
|
||||
),
|
||||
):
|
||||
await summarize_one(db_session, sample_paper)
|
||||
@@ -417,7 +409,7 @@ class TestBatchSummarize:
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_pi_output,
|
||||
return_value=(mock_pi_output, "test-session-id"),
|
||||
),
|
||||
):
|
||||
result = await summarize_batch(db_session, _session_factory=_TestSession)
|
||||
@@ -464,7 +456,7 @@ class TestBatchSummarize:
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise PiTimeoutError("timeout")
|
||||
return mock_pi_output
|
||||
return mock_pi_output, "test-session-id"
|
||||
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
@@ -506,7 +498,7 @@ class TestBatchSummarize:
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_pi_output,
|
||||
return_value=(mock_pi_output, "test-session-id"),
|
||||
),
|
||||
):
|
||||
await summarize_batch(db_session, _session_factory=_TestSession)
|
||||
|
||||
Reference in New Issue
Block a user