feat: overhaul UI styling, improve templates, enhance services and tests
This commit is contained in:
+49
-38
@@ -225,26 +225,28 @@ def sample_papers_range(db_session):
|
||||
"""插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
papers = []
|
||||
for i, (arxiv_id, paper_date_str) in enumerate([
|
||||
("2401.10001", "2024-01-10"),
|
||||
("2401.10002", "2024-01-11"),
|
||||
("2401.10003", "2024-01-12"),
|
||||
("2401.10004", "2024-01-13"),
|
||||
("2401.10005", "2024-01-14"),
|
||||
]):
|
||||
for i, (arxiv_id, paper_date_str) in enumerate(
|
||||
[
|
||||
("2401.10001", "2024-01-10"),
|
||||
("2401.10002", "2024-01-11"),
|
||||
("2401.10003", "2024-01-12"),
|
||||
("2401.10004", "2024-01-13"),
|
||||
("2401.10005", "2024-01-14"),
|
||||
]
|
||||
):
|
||||
paper_date = date.fromisoformat(paper_date_str)
|
||||
p = Paper(
|
||||
arxiv_id=arxiv_id,
|
||||
title_en=f"Test Paper {i+1}",
|
||||
abstract=f"Abstract for paper {i+1}.",
|
||||
title_en=f"Test Paper {i + 1}",
|
||||
abstract=f"Abstract for paper {i + 1}.",
|
||||
paper_date=paper_date,
|
||||
crawled_at=now,
|
||||
upvotes=i * 10,
|
||||
)
|
||||
db_session.add(p)
|
||||
db_session.flush()
|
||||
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
|
||||
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
|
||||
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i + 1}", position=0))
|
||||
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i + 1}", source="hf"))
|
||||
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||||
# FTS5
|
||||
db_session.execute(
|
||||
@@ -252,8 +254,13 @@ def sample_papers_range(db_session):
|
||||
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
|
||||
"VALUES (:id, :title, :abstract, :authors, :tags)"
|
||||
),
|
||||
{"id": p.id, "title": p.title_en, "abstract": p.abstract,
|
||||
"authors": f"Author {i+1}", "tags": f"Tag{i+1}"},
|
||||
{
|
||||
"id": p.id,
|
||||
"title": p.title_en,
|
||||
"abstract": p.abstract,
|
||||
"authors": f"Author {i + 1}",
|
||||
"tags": f"Tag{i + 1}",
|
||||
},
|
||||
)
|
||||
papers.append(p)
|
||||
db_session.commit()
|
||||
@@ -265,19 +272,21 @@ def sample_papers_with_summary(db_session):
|
||||
"""插入 5 篇带总结的论文(用于 search / pages / trends 测试)。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
papers = []
|
||||
for i, (arxiv_id, paper_date_str) in enumerate([
|
||||
("2401.20001", "2024-01-10"),
|
||||
("2401.20002", "2024-01-11"),
|
||||
("2401.20003", "2024-01-12"),
|
||||
("2401.20004", "2024-01-13"),
|
||||
("2401.20005", "2024-01-14"),
|
||||
]):
|
||||
for i, (arxiv_id, paper_date_str) in enumerate(
|
||||
[
|
||||
("2401.20001", "2024-01-10"),
|
||||
("2401.20002", "2024-01-11"),
|
||||
("2401.20003", "2024-01-12"),
|
||||
("2401.20004", "2024-01-13"),
|
||||
("2401.20005", "2024-01-14"),
|
||||
]
|
||||
):
|
||||
paper_date = date.fromisoformat(paper_date_str)
|
||||
p = Paper(
|
||||
arxiv_id=arxiv_id,
|
||||
title_en=f"Test Paper {i+1}",
|
||||
title_zh=f"测试论文 {i+1}",
|
||||
abstract=f"Abstract for paper {i+1}.",
|
||||
title_en=f"Test Paper {i + 1}",
|
||||
title_zh=f"测试论文 {i + 1}",
|
||||
abstract=f"Abstract for paper {i + 1}.",
|
||||
paper_date=paper_date,
|
||||
crawled_at=now,
|
||||
upvotes=i * 10 + 5,
|
||||
@@ -285,28 +294,30 @@ def sample_papers_with_summary(db_session):
|
||||
db_session.add(p)
|
||||
db_session.flush()
|
||||
|
||||
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
|
||||
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i + 1}", position=0))
|
||||
db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf"))
|
||||
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
|
||||
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i + 1}", source="hf"))
|
||||
|
||||
db_session.add(SummaryStatus(
|
||||
paper_id=p.id,
|
||||
status="done" if i < 4 else "pending",
|
||||
quality="normal",
|
||||
))
|
||||
db_session.add(
|
||||
SummaryStatus(
|
||||
paper_id=p.id,
|
||||
status="done" if i < 4 else "pending",
|
||||
quality="normal",
|
||||
)
|
||||
)
|
||||
|
||||
# 添加总结(前 4 篇)
|
||||
if i < 4:
|
||||
summary = PaperSummary(
|
||||
paper_id=p.id,
|
||||
one_line=f"这是论文{i+1}的一句话摘要",
|
||||
one_line=f"这是论文{i + 1}的一句话摘要",
|
||||
difficulty="中级",
|
||||
motivation_problem=f"论文{i+1}的研究问题",
|
||||
motivation_goal=f"论文{i+1}的研究目标",
|
||||
method_key_idea=f"论文{i+1}的关键思路",
|
||||
method_overview=f"论文{i+1}的方法概述",
|
||||
motivation_problem=f"论文{i + 1}的研究问题",
|
||||
motivation_goal=f"论文{i + 1}的研究目标",
|
||||
method_key_idea=f"论文{i + 1}的关键思路",
|
||||
method_overview=f"论文{i + 1}的方法概述",
|
||||
updated_at=now,
|
||||
full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}),
|
||||
full_json=json.dumps({"title_zh": f"测试论文 {i + 1}"}),
|
||||
)
|
||||
db_session.add(summary)
|
||||
|
||||
@@ -321,8 +332,8 @@ def sample_papers_with_summary(db_session):
|
||||
"title_en": p.title_en,
|
||||
"title_zh": p.title_zh or "",
|
||||
"abstract": p.abstract or "",
|
||||
"authors": f"Author {i+1}",
|
||||
"tags": f"NLP, Tag{i+1}",
|
||||
"authors": f"Author {i + 1}",
|
||||
"tags": f"NLP, Tag{i + 1}",
|
||||
},
|
||||
)
|
||||
papers.append(p)
|
||||
|
||||
+67
-23
@@ -49,7 +49,9 @@ class TestAdminAuth:
|
||||
|
||||
def test_correct_token_accepted(self, auth_client, admin_headers):
|
||||
"""正确 token 应被接受(crawl 可能会失败但不是 401)。"""
|
||||
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
|
||||
with patch(
|
||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
|
||||
resp = auth_client.post("/admin/crawl", headers=admin_headers)
|
||||
assert resp.status_code != 401
|
||||
@@ -75,8 +77,15 @@ class TestAdminAuth:
|
||||
original = config_mod.settings.ADMIN_TOKEN
|
||||
config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN
|
||||
try:
|
||||
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0}
|
||||
with patch(
|
||||
"app.routes.admin.summarize_batch", new_callable=AsyncMock
|
||||
) as mock:
|
||||
mock.return_value = {
|
||||
"status": "success",
|
||||
"done": 0,
|
||||
"failed": 0,
|
||||
"total": 0,
|
||||
}
|
||||
resp = client.post("/admin/summarize", headers=admin_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "success"
|
||||
@@ -114,7 +123,9 @@ class TestAdminCrawl:
|
||||
|
||||
def test_crawl_default_today(self, auth_client, admin_headers):
|
||||
"""不指定日期时默认抓取今天。"""
|
||||
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
|
||||
with patch(
|
||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
|
||||
resp = auth_client.post("/admin/crawl", headers=admin_headers)
|
||||
assert resp.status_code == 200
|
||||
@@ -124,9 +135,13 @@ class TestAdminCrawl:
|
||||
|
||||
def test_crawl_specific_date(self, auth_client, admin_headers):
|
||||
"""指定日期抓取。"""
|
||||
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
|
||||
with patch(
|
||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
|
||||
resp = auth_client.post("/admin/crawl?date=2024-01-15", headers=admin_headers)
|
||||
resp = auth_client.post(
|
||||
"/admin/crawl?date=2024-01-15", headers=admin_headers
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
mock_crawl.assert_called_once()
|
||||
call_args = mock_crawl.call_args
|
||||
@@ -157,9 +172,11 @@ class TestAdminCleanup:
|
||||
mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []}
|
||||
auth_client.post("/admin/cleanup", headers=admin_headers)
|
||||
|
||||
logs = db_session.execute(
|
||||
select(CrawlLog).where(CrawlLog.task == "cleanup")
|
||||
).scalars().all()
|
||||
logs = (
|
||||
db_session.execute(select(CrawlLog).where(CrawlLog.task == "cleanup"))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(logs) >= 1
|
||||
assert logs[-1].status == "success"
|
||||
|
||||
@@ -186,7 +203,9 @@ class TestAdminDelete:
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_delete_with_confirm(self, auth_client, admin_headers, db_session, sample_papers_range):
|
||||
def test_delete_with_confirm(
|
||||
self, auth_client, admin_headers, db_session, sample_papers_range
|
||||
):
|
||||
"""confirm='DELETE' 时应执行删除。"""
|
||||
resp = auth_client.post(
|
||||
"/admin/delete",
|
||||
@@ -247,13 +266,20 @@ class TestAdminLogs:
|
||||
resp = auth_client.get("/admin/logs")
|
||||
assert resp.status_code in (403, 401)
|
||||
|
||||
def test_logs_contains_data(self, auth_client, admin_headers, db_session, sample_papers_range):
|
||||
def test_logs_contains_data(
|
||||
self, auth_client, admin_headers, db_session, sample_papers_range
|
||||
):
|
||||
"""日志页面应包含日志数据。"""
|
||||
# 先创建一条日志
|
||||
now = datetime.now(timezone.utc)
|
||||
db_session.add(CrawlLog(
|
||||
task="crawl", status="success", started_at=now, completed_at=now,
|
||||
))
|
||||
db_session.add(
|
||||
CrawlLog(
|
||||
task="crawl",
|
||||
status="success",
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
resp = auth_client.get("/admin/logs", headers=admin_headers)
|
||||
@@ -273,9 +299,11 @@ class TestScheduler:
|
||||
"""SCHEDULER_ENABLED=false 时不应启动调度器。"""
|
||||
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", False)
|
||||
import app.services.scheduler as sched_mod
|
||||
|
||||
sched_mod._scheduler = None
|
||||
|
||||
from app.services.scheduler import start_scheduler
|
||||
|
||||
result = start_scheduler()
|
||||
assert result is None
|
||||
|
||||
@@ -285,9 +313,11 @@ class TestScheduler:
|
||||
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
|
||||
monkeypatch.setattr(settings, "APP_WORKERS", 1)
|
||||
import app.services.scheduler as sched_mod
|
||||
|
||||
sched_mod._scheduler = None
|
||||
|
||||
from app.services.scheduler import start_scheduler, stop_scheduler
|
||||
|
||||
scheduler = start_scheduler()
|
||||
assert scheduler is not None
|
||||
|
||||
@@ -305,9 +335,11 @@ class TestScheduler:
|
||||
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
|
||||
monkeypatch.setattr(settings, "APP_WORKERS", 4)
|
||||
import app.services.scheduler as sched_mod
|
||||
|
||||
sched_mod._scheduler = None
|
||||
|
||||
from app.services.scheduler import start_scheduler, stop_scheduler
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
scheduler = start_scheduler()
|
||||
|
||||
@@ -356,15 +388,21 @@ class TestTaskLocks:
|
||||
"""同一 task + lock_key 只能有一个 running 锁。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
lock1 = TaskLock(
|
||||
task="crawl", lock_key="2024-01-15",
|
||||
status="running", owner="test1", acquired_at=now,
|
||||
task="crawl",
|
||||
lock_key="2024-01-15",
|
||||
status="running",
|
||||
owner="test1",
|
||||
acquired_at=now,
|
||||
)
|
||||
db_session.add(lock1)
|
||||
db_session.commit()
|
||||
|
||||
lock2 = TaskLock(
|
||||
task="crawl", lock_key="2024-01-15",
|
||||
status="running", owner="test2", acquired_at=now,
|
||||
task="crawl",
|
||||
lock_key="2024-01-15",
|
||||
status="running",
|
||||
owner="test2",
|
||||
acquired_at=now,
|
||||
)
|
||||
db_session.add(lock2)
|
||||
with pytest.raises(Exception):
|
||||
@@ -375,16 +413,22 @@ class TestTaskLocks:
|
||||
"""已释放的锁允许新的 running 锁。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
lock1 = TaskLock(
|
||||
task="crawl", lock_key="2024-01-16",
|
||||
status="finished", owner="test1",
|
||||
acquired_at=now, released_at=now,
|
||||
task="crawl",
|
||||
lock_key="2024-01-16",
|
||||
status="finished",
|
||||
owner="test1",
|
||||
acquired_at=now,
|
||||
released_at=now,
|
||||
)
|
||||
db_session.add(lock1)
|
||||
db_session.commit()
|
||||
|
||||
lock2 = TaskLock(
|
||||
task="crawl", lock_key="2024-01-16",
|
||||
status="running", owner="test2", acquired_at=now,
|
||||
task="crawl",
|
||||
lock_key="2024-01-16",
|
||||
status="running",
|
||||
owner="test2",
|
||||
acquired_at=now,
|
||||
)
|
||||
db_session.add(lock2)
|
||||
db_session.commit() # 应成功
|
||||
|
||||
+49
-22
@@ -29,13 +29,17 @@ def sample_paper_with_user_data(db_session, sample_papers_range):
|
||||
paper = sample_papers_range[0]
|
||||
now = datetime.now(timezone.utc)
|
||||
db_session.add(UserBookmark(paper_id=paper.id, created_at=now))
|
||||
db_session.add(UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now))
|
||||
db_session.add(UserNote(
|
||||
paper_id=paper.id,
|
||||
content="My notes on this paper",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
))
|
||||
db_session.add(
|
||||
UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now)
|
||||
)
|
||||
db_session.add(
|
||||
UserNote(
|
||||
paper_id=paper.id,
|
||||
content="My notes on this paper",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
return paper
|
||||
|
||||
@@ -64,6 +68,7 @@ class TestCleanupTmp:
|
||||
|
||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
|
||||
result = cleanup_tmp()
|
||||
|
||||
assert result["scanned"] == 1
|
||||
@@ -81,6 +86,7 @@ class TestCleanupTmp:
|
||||
|
||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
|
||||
result = cleanup_tmp()
|
||||
|
||||
assert result["scanned"] == 1
|
||||
@@ -91,6 +97,7 @@ class TestCleanupTmp:
|
||||
"""data/tmp/ 不存在时安全返回。"""
|
||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
|
||||
result = cleanup_tmp()
|
||||
assert result["scanned"] == 0
|
||||
assert result["removed"] == 0
|
||||
@@ -110,6 +117,7 @@ class TestCleanupTmp:
|
||||
|
||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
|
||||
result = cleanup_tmp()
|
||||
|
||||
assert result["scanned"] == 2
|
||||
@@ -178,14 +186,18 @@ class TestDeletePapersByDateRange:
|
||||
date(2024, 1, 14),
|
||||
)
|
||||
|
||||
logs = db_session.execute(
|
||||
select(CrawlLog).where(CrawlLog.task == "delete")
|
||||
).scalars().all()
|
||||
logs = (
|
||||
db_session.execute(select(CrawlLog).where(CrawlLog.task == "delete"))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(logs) == 1
|
||||
assert logs[0].status == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_cascade_user_data(self, db_session, sample_paper_with_user_data):
|
||||
async def test_delete_cascade_user_data(
|
||||
self, db_session, sample_paper_with_user_data
|
||||
):
|
||||
"""删除论文时应 cascade 删除关联的用户数据。"""
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
@@ -200,15 +212,24 @@ class TestDeletePapersByDateRange:
|
||||
assert result["deleted"] == 1
|
||||
|
||||
# 确认用户数据被 cascade 删除
|
||||
assert db_session.execute(
|
||||
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
||||
).scalar_one_or_none() is None
|
||||
assert db_session.execute(
|
||||
select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id)
|
||||
).scalar_one_or_none() is None
|
||||
assert db_session.execute(
|
||||
select(UserNote).where(UserNote.paper_id == paper.id)
|
||||
).scalar_one_or_none() is None
|
||||
assert (
|
||||
db_session.execute(
|
||||
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
||||
).scalar_one_or_none()
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
db_session.execute(
|
||||
select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id)
|
||||
).scalar_one_or_none()
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
db_session.execute(
|
||||
select(UserNote).where(UserNote.paper_id == paper.id)
|
||||
).scalar_one_or_none()
|
||||
is None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_removes_fts(self, db_session, sample_papers_range):
|
||||
@@ -229,7 +250,9 @@ class TestDeletePapersByDateRange:
|
||||
assert rows == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_removes_local_files(self, db_session, sample_papers_range, tmp_path, monkeypatch):
|
||||
async def test_delete_removes_local_files(
|
||||
self, db_session, sample_papers_range, tmp_path, monkeypatch
|
||||
):
|
||||
"""删除论文时应删除本地文件目录。"""
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
@@ -263,13 +286,17 @@ class TestDeletePapersByDateRange:
|
||||
assert result["status"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleaner_works_without_chroma(self, db_session, sample_papers_with_summary, monkeypatch):
|
||||
async def test_cleaner_works_without_chroma(
|
||||
self, db_session, sample_papers_with_summary, monkeypatch
|
||||
):
|
||||
"""CHROMA 关闭时删除论文正常工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
result = await delete_papers_by_date_range(
|
||||
db_session,
|
||||
date(2024, 1, 10),
|
||||
|
||||
+13
-1
@@ -21,6 +21,7 @@ class TestEmbedderInit:
|
||||
"""CHROMA_ENABLED=false 时不初始化。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
emb.init_chroma()
|
||||
assert emb._chroma._client is None
|
||||
@@ -31,6 +32,7 @@ class TestEmbedderInit:
|
||||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
emb.init_chroma()
|
||||
|
||||
@@ -44,6 +46,7 @@ class TestEmbedderInit:
|
||||
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
assert emb.get_collection() is None
|
||||
|
||||
@@ -60,6 +63,7 @@ class TestEmbedderIndexing:
|
||||
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
assert emb.index_paper("test-id") is False
|
||||
|
||||
@@ -71,6 +75,7 @@ class TestEmbedderIndexing:
|
||||
monkeypatch.setattr(settings, "EMBED_MODEL", "")
|
||||
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
emb.init_chroma()
|
||||
|
||||
@@ -83,6 +88,7 @@ class TestEmbedderIndexing:
|
||||
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
result = emb.index_batch(["a", "b"])
|
||||
assert result["success"] == 0
|
||||
@@ -92,6 +98,7 @@ class TestEmbedderIndexing:
|
||||
"""空列表时返回 0。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
import app.services.embedder as emb
|
||||
|
||||
result = emb.index_batch([])
|
||||
assert result["total"] == 0
|
||||
|
||||
@@ -99,6 +106,7 @@ class TestEmbedderIndexing:
|
||||
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
assert emb.delete_paper("test-id") is False
|
||||
|
||||
@@ -106,6 +114,7 @@ class TestEmbedderIndexing:
|
||||
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
assert emb.search_similar("test query") == []
|
||||
|
||||
@@ -123,6 +132,7 @@ class TestEmbeddingApi:
|
||||
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
|
||||
monkeypatch.setattr(settings, "EMBED_MODEL", "")
|
||||
import app.services.embedder as emb
|
||||
|
||||
assert emb._get_embedding("test") is None
|
||||
|
||||
def test_dimension_mismatch_returns_none(self, monkeypatch):
|
||||
@@ -158,6 +168,8 @@ class TestEmbeddingApi:
|
||||
with patch("httpx.Client") as mock_client:
|
||||
mock_client.return_value.__enter__ = MagicMock()
|
||||
mock_client.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
|
||||
mock_client.return_value.__enter__.return_value.post.side_effect = (
|
||||
Exception("timeout")
|
||||
)
|
||||
result = emb._get_embedding("test")
|
||||
assert result is None
|
||||
|
||||
@@ -16,9 +16,14 @@ class TestImageExtraction:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
|
||||
"""源码目录不存在时返回 0。"""
|
||||
monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
||||
monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x)
|
||||
monkeypatch.setattr(
|
||||
"app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x
|
||||
)
|
||||
from app.services.image_extractor import extract_images_from_source
|
||||
|
||||
result = await extract_images_from_source("2401.99999")
|
||||
assert result == 0
|
||||
|
||||
@@ -49,14 +54,20 @@ class TestImageExtraction:
|
||||
(tmp_source / "main.tex").write_text(tex_content)
|
||||
|
||||
papers_dir = tmp_path / "papers" / "2401.00001"
|
||||
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
||||
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x
|
||||
)
|
||||
|
||||
# Mock download_source_zip to avoid real network call (source dir already exists)
|
||||
async def _noop_download(*args, **kwargs):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.download_source_zip", _noop_download
|
||||
)
|
||||
|
||||
result = await extract_images_from_source("2401.00001")
|
||||
|
||||
@@ -73,16 +84,24 @@ class TestImageExtraction:
|
||||
|
||||
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
|
||||
tmp_source.mkdir(parents=True)
|
||||
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
|
||||
(tmp_source / "main.tex").write_text(
|
||||
r"\documentclass{article}\begin{document}Hello\end{document}"
|
||||
)
|
||||
|
||||
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
||||
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x
|
||||
)
|
||||
|
||||
# Mock download_source_zip to avoid real network call
|
||||
async def _noop_download(*args, **kwargs):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
|
||||
monkeypatch.setattr(
|
||||
"app.services.image_extractor.download_source_zip", _noop_download
|
||||
)
|
||||
|
||||
result = await extract_images_from_source("2401.00002")
|
||||
assert result == 0
|
||||
|
||||
+15
-5
@@ -162,7 +162,9 @@ class TestComparePage:
|
||||
resp = client.get("/compare?ids=nonexistent.99999")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_compare_page_shows_no_summary_placeholder(self, client, sample_papers_with_summary):
|
||||
def test_compare_page_shows_no_summary_placeholder(
|
||||
self, client, sample_papers_with_summary
|
||||
):
|
||||
"""无总结的论文显示占位文本。"""
|
||||
# 2401.20005 没有 summary(status=pending)
|
||||
resp = client.get("/compare?ids=2401.20005")
|
||||
@@ -198,26 +200,34 @@ class TestNavBar:
|
||||
class TestGracefulDegradation:
|
||||
"""CHROMA_ENABLED=false 时优雅降级测试。"""
|
||||
|
||||
def test_search_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
|
||||
def test_search_works_without_chroma(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA 关闭时 FTS5 搜索正常工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/search?q=Test")
|
||||
assert resp.status_code == 200
|
||||
assert "Test Paper" in resp.text or "测试论文" in resp.text
|
||||
|
||||
def test_detail_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
|
||||
def test_detail_works_without_chroma(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA 关闭时详情页正常工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/paper/2401.20001")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_trends_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
|
||||
def test_trends_works_without_chroma(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA 关闭时趋势看板正常工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/trends")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_compare_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
|
||||
def test_compare_works_without_chroma(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA 关闭时对比页正常工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/compare?ids=2401.20001,2401.20002")
|
||||
|
||||
+32
-5
@@ -18,46 +18,54 @@ class TestSearchService:
|
||||
|
||||
def test_search_by_title(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test Paper")
|
||||
assert result["total"] == 1
|
||||
assert result["results"][0].arxiv_id == "2401.12345"
|
||||
|
||||
def test_search_by_abstract(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="test abstract")
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_search_by_author(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Alice")
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_search_by_tag_in_fts(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# FTS5 索引中包含 tags 列,可以搜到
|
||||
result = search_papers(db_session, query="NLP")
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_search_no_results(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="quantum entanglement")
|
||||
assert result["total"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
def test_search_empty_query_returns_empty(self, db_session):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="")
|
||||
assert result["total"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
def test_search_special_characters_sanitized(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# 特殊字符被清除后,剩下 "Test" 仍然能搜到
|
||||
result = search_papers(db_session, query='Test "Paper" {test}')
|
||||
assert result["total"] >= 1
|
||||
|
||||
def test_search_with_tag_filter(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# 关键词 + 标签筛选
|
||||
result = search_papers(db_session, query="Paper", tag="NLP")
|
||||
assert result["total"] == 1
|
||||
@@ -67,6 +75,7 @@ class TestSearchService:
|
||||
|
||||
def test_search_tag_only_no_query(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# 只有标签,无关键词
|
||||
result = search_papers(db_session, tag="NLP")
|
||||
assert result["total"] == 1
|
||||
@@ -74,12 +83,14 @@ class TestSearchService:
|
||||
|
||||
def test_search_pagination(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test", page=2, page_size=10)
|
||||
assert result["page"] == 2
|
||||
assert result["total_pages"] == 1 # 只有 1 条结果,1 页
|
||||
|
||||
def test_search_returns_snippets(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="test abstract")
|
||||
assert result["total"] == 1
|
||||
paper_id = result["results"][0].id
|
||||
@@ -89,6 +100,7 @@ class TestSearchService:
|
||||
|
||||
def test_get_all_tags(self, db_session, sample_paper):
|
||||
from app.services.searcher import get_all_tags
|
||||
|
||||
tags = get_all_tags(db_session)
|
||||
assert "NLP" in tags
|
||||
assert "LLM" in tags
|
||||
@@ -105,20 +117,27 @@ class TestSearchSemanticMode:
|
||||
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
|
||||
"""默认 keyword 模式走 FTS5。"""
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test Paper", mode="keyword")
|
||||
assert result["total"] >= 1
|
||||
assert result["distances"] == {}
|
||||
|
||||
def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary):
|
||||
def test_semantic_mode_disabled_fallback(
|
||||
self, db_session, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA_ENABLED=false + semantic 模式走 FTS5。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test", mode="semantic")
|
||||
assert result["total"] >= 1
|
||||
|
||||
def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary):
|
||||
def test_search_returns_distances_dict(
|
||||
self, db_session, sample_papers_with_summary
|
||||
):
|
||||
"""搜索结果应包含 distances 字段。"""
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test Paper")
|
||||
assert "distances" in result
|
||||
assert isinstance(result["distances"], dict)
|
||||
@@ -126,6 +145,7 @@ class TestSearchSemanticMode:
|
||||
def test_empty_query_returns_empty_no_tags(self, db_session):
|
||||
"""空查询无标签时返回空。"""
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session)
|
||||
assert result["total"] == 0
|
||||
assert result["results"] == []
|
||||
@@ -133,6 +153,7 @@ class TestSearchSemanticMode:
|
||||
def test_tag_only_search(self, db_session, sample_papers_with_summary):
|
||||
"""仅标签搜索。"""
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, tag="NLP")
|
||||
assert result["total"] >= 1
|
||||
|
||||
@@ -169,7 +190,9 @@ class TestSearchRoutes:
|
||||
assert resp.status_code == 200
|
||||
assert "Test" in resp.text or "测试" in resp.text
|
||||
|
||||
def test_search_page_semantic_disabled(self, client, monkeypatch, sample_papers_with_summary):
|
||||
def test_search_page_semantic_disabled(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/search?q=Test&mode=semantic")
|
||||
@@ -221,7 +244,9 @@ class TestSearchRoutes:
|
||||
class TestSimilarAPI:
|
||||
"""相似论文 API 测试。"""
|
||||
|
||||
def test_similar_api_disabled(self, client, monkeypatch, sample_papers_with_summary):
|
||||
def test_similar_api_disabled(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA_ENABLED=false 时返回空列表。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/api/similar/2401.20001")
|
||||
@@ -236,7 +261,9 @@ class TestSimilarAPI:
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["results"] == []
|
||||
|
||||
def test_similar_api_with_top_k(self, client, monkeypatch, sample_papers_with_summary):
|
||||
def test_similar_api_with_top_k(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""top_k 参数控制返回数量。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/api/similar/2401.20001?top_k=3")
|
||||
|
||||
+62
-32
@@ -51,7 +51,9 @@ class TestDbUpdate:
|
||||
assert summary.motivation_problem == schema.motivation.problem
|
||||
assert json.loads(summary.full_json)["title_zh"] == schema.title_zh
|
||||
|
||||
def test_paper_title_zh_updated(self, db_session, sample_paper, sample_summary_dict):
|
||||
def test_paper_title_zh_updated(
|
||||
self, db_session, sample_paper, sample_summary_dict
|
||||
):
|
||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||||
|
||||
@@ -85,7 +87,9 @@ class TestDbUpdate:
|
||||
assert "自然语言处理" in tag_names
|
||||
assert "大语言模型" in tag_names
|
||||
|
||||
def test_existing_tags_not_duplicated(self, db_session, sample_paper, sample_summary_dict):
|
||||
def test_existing_tags_not_duplicated(
|
||||
self, db_session, sample_paper, sample_summary_dict
|
||||
):
|
||||
"""已存在的标签名(同 name)不会被 AI source 重复插入。"""
|
||||
# sample_paper 已有 NLP (hf)、LLM (hf)
|
||||
# 让 AI 输出包含 NLP(与 HF 重复)和 "新标签"(新的)
|
||||
@@ -157,7 +161,10 @@ class TestSummarizeOneFlow:
|
||||
def _patch_paths(self, tmp_path):
|
||||
"""将 data 目录重定向到 tmp_path。"""
|
||||
with (
|
||||
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
|
||||
patch(
|
||||
"app.services.summarizer.paper_dir",
|
||||
lambda aid: tmp_path / "papers" / aid,
|
||||
),
|
||||
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||||
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||||
@@ -172,7 +179,11 @@ class TestSummarizeOneFlow:
|
||||
"""pending → processing → done 全流程。"""
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_pi_output,
|
||||
),
|
||||
):
|
||||
result = await summarize_one(db_session, sample_paper)
|
||||
|
||||
@@ -198,9 +209,7 @@ class TestSummarizeOneFlow:
|
||||
assert fts_row[0] == "测试论文中文标题"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_download_failure(
|
||||
self, db_session, sample_paper, _patch_paths
|
||||
):
|
||||
async def test_pdf_download_failure(self, db_session, sample_paper, _patch_paths):
|
||||
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
||||
with (
|
||||
patch(
|
||||
@@ -256,13 +265,16 @@ class TestSummarizeOneFlow:
|
||||
self, db_session, sample_paper, _patch_paths
|
||||
):
|
||||
"""必填字段缺失 → field_missing → retry → permanent_failure。"""
|
||||
bad_json = json.dumps({
|
||||
"title_zh": "", # 空的必填字段
|
||||
"one_line": "valid line",
|
||||
"tags": ["tag1"],
|
||||
"motivation": {"problem": "valid problem"},
|
||||
"method": {"key_idea": "valid idea"},
|
||||
}, ensure_ascii=False)
|
||||
bad_json = json.dumps(
|
||||
{
|
||||
"title_zh": "", # 空的必填字段
|
||||
"one_line": "valid line",
|
||||
"tags": ["tag1"],
|
||||
"motivation": {"problem": "valid problem"},
|
||||
"method": {"key_idea": "valid idea"},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
bad_output = f"```json\n{bad_json}\n```"
|
||||
|
||||
with (
|
||||
@@ -314,7 +326,11 @@ class TestSummarizeOneFlow:
|
||||
"""成功后清理 tmp 目录。"""
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_pi_output,
|
||||
),
|
||||
):
|
||||
await summarize_one(db_session, sample_paper)
|
||||
|
||||
@@ -359,7 +375,10 @@ class TestBatchSummarize:
|
||||
@pytest.fixture
|
||||
def _patch_paths(self, tmp_path):
|
||||
with (
|
||||
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
|
||||
patch(
|
||||
"app.services.summarizer.paper_dir",
|
||||
lambda aid: tmp_path / "papers" / aid,
|
||||
),
|
||||
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||||
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||||
@@ -390,15 +409,18 @@ class TestBatchSummarize:
|
||||
|
||||
# 每个 worker 用独立 session(同一个内存引擎)
|
||||
from sqlalchemy.orm import sessionmaker as _sm
|
||||
|
||||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||||
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_pi_output,
|
||||
),
|
||||
):
|
||||
result = await summarize_batch(
|
||||
db_session, _session_factory=_TestSession
|
||||
)
|
||||
result = await summarize_batch(db_session, _session_factory=_TestSession)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["done"] == 3
|
||||
@@ -432,6 +454,7 @@ class TestBatchSummarize:
|
||||
db_session.commit()
|
||||
|
||||
from sqlalchemy.orm import sessionmaker as _sm
|
||||
|
||||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||||
|
||||
call_count = 0
|
||||
@@ -447,9 +470,7 @@ class TestBatchSummarize:
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi),
|
||||
):
|
||||
result = await summarize_batch(
|
||||
db_session, _session_factory=_TestSession
|
||||
)
|
||||
result = await summarize_batch(db_session, _session_factory=_TestSession)
|
||||
|
||||
assert result["done"] == 1
|
||||
assert result["failed"] == 1
|
||||
@@ -472,23 +493,32 @@ class TestBatchSummarize:
|
||||
assert result["status"] == "conflict"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_lock_released(self, db_session, db_engine, mock_pi_output, _patch_paths):
|
||||
async def test_task_lock_released(
|
||||
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||||
):
|
||||
"""完成后释放 TaskLock。"""
|
||||
from sqlalchemy.orm import sessionmaker as _sm
|
||||
|
||||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||||
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_pi_output,
|
||||
),
|
||||
):
|
||||
await summarize_batch(
|
||||
db_session, _session_factory=_TestSession
|
||||
)
|
||||
await summarize_batch(db_session, _session_factory=_TestSession)
|
||||
|
||||
locks = db_session.query(TaskLock).filter(
|
||||
TaskLock.task == "summarize",
|
||||
TaskLock.lock_key == "batch",
|
||||
).all()
|
||||
locks = (
|
||||
db_session.query(TaskLock)
|
||||
.filter(
|
||||
TaskLock.task == "summarize",
|
||||
TaskLock.lock_key == "batch",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for lock in locks:
|
||||
assert lock.status == "finished"
|
||||
assert lock.released_at is not None
|
||||
|
||||
Reference in New Issue
Block a user