Files
daily-paper/tests/conftest.py
T
Rain-Bus f7f1a4c0cb refactor: split monolithic phase tests into per-module test files
- rename test_admin_phase4.py -> test_admin.py, test_search.py -> test_searcher.py
- split test_phase5.py into test_cleaner, test_embedder, test_image_extractor, test_pages
- move schema tests from test_summarizer.py into dedicated test_schemas.py
- add sample_papers_range and sample_papers_with_summary fixtures in conftest
- update .gitignore to exclude all of data/
2026-06-06 00:34:30 +08:00

331 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试 fixtures — 内存 SQLite、TestClient、样例数据。"""
from __future__ import annotations
import json
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, event
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from sqlalchemy.pool import StaticPool
from app.database import get_db
from app.main import create_app
from app.database import init_db
from app.models import (
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
)
# ── 内存数据库 ──────────────────────────────────────────────────────────
class _TestBase(DeclarativeBase):
pass
# 复用 app.models 的 Base metadata
from app.database import Base as _AppBase # noqa: E402
_TestBase.metadata = _AppBase.metadata
@pytest.fixture
def db_engine():
"""创建内存 SQLite 引擎 + FTS5。"""
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
@event.listens_for(engine, "connect")
def _pragma(dbapi_connection, _record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
init_db(engine)
return engine
@pytest.fixture
def db_session(db_engine):
"""提供事务隔离的数据库 session。"""
Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False)
session = Session()
try:
yield session
finally:
session.close()
@pytest.fixture
def client(db_engine, db_session):
"""FastAPI TestClientoverride get_db。"""
app = create_app()
def _override_get_db():
yield db_session
app.dependency_overrides[get_db] = _override_get_db
with TestClient(app, raise_server_exceptions=False) as c:
yield c
app.dependency_overrides.clear()
# ── 样例数据 ────────────────────────────────────────────────────────────
SAMPLE_ARXIV_ID = "2401.12345"
ADMIN_TOKEN = "test-admin-token-12345"
@pytest.fixture
def sample_paper(db_session):
"""插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。"""
now = datetime.now(timezone.utc)
paper = Paper(
arxiv_id=SAMPLE_ARXIV_ID,
title_en="Test Paper Title",
abstract="This is a test abstract for the paper.",
published_at=date(2024, 1, 15),
paper_date=date(2024, 1, 15),
crawled_at=now,
upvotes=42,
hf_url=f"https://huggingface.co/papers/{SAMPLE_ARXIV_ID}",
arxiv_url=f"https://arxiv.org/abs/{SAMPLE_ARXIV_ID}",
pdf_url=f"https://arxiv.org/pdf/{SAMPLE_ARXIV_ID}.pdf",
)
db_session.add(paper)
db_session.flush()
db_session.add(PaperAuthor(paper_id=paper.id, name="Alice Smith", position=0))
db_session.add(PaperAuthor(paper_id=paper.id, name="Bob Jones", position=1))
db_session.add(PaperTag(paper_id=paper.id, tag="NLP", source="hf"))
db_session.add(PaperTag(paper_id=paper.id, tag="LLM", source="hf"))
db_session.add(SummaryStatus(paper_id=paper.id, status="pending"))
# FTS5 初始行(与 crawler 一致)
db_session.execute(
__import__("sqlalchemy").text(
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
"VALUES (:id, :title, :abstract, :authors, :tags)"
),
{
"id": paper.id,
"title": paper.title_en,
"abstract": paper.abstract or "",
"authors": "Alice Smith, Bob Jones",
"tags": "NLP, LLM",
},
)
db_session.commit()
return paper
@pytest.fixture
def sample_summary_dict() -> dict:
"""完整合法的 summary dict。"""
return {
"title_zh": "测试论文中文标题",
"one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。",
"tags": ["自然语言处理", "大语言模型", "Transformer"],
"difficulty": "中级",
"prerequisites": {
"concepts": ["Transformer", "注意力机制"],
"level": "中级",
},
"motivation": {
"problem": "现有模型在长文本理解上存在不足。",
"goal": "提出一种新的注意力机制来提升长文本建模能力。",
"gap": "当前方法计算复杂度过高。",
},
"method": {
"overview": "提出了一种高效的稀疏注意力机制。",
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度。",
"steps": [
"分析现有注意力机制的瓶颈",
"设计稀疏注意力模式",
"在多个基准上验证效果",
],
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模。",
},
"results": {
"main_findings": [
"在长文本基准上取得了 SOTA 结果",
"推理速度提升了 2 倍",
],
"benchmarks": [
{"dataset": "LongBench", "score": 85.3},
],
"limitations": [
"在超长文本(>100k tokens)上效果有所下降",
],
},
"improvements": {
"weaknesses": ["仅验证了英文数据"],
"future_work": ["扩展到多语言场景"],
"reproducibility": "代码已开源,模型权重可下载。",
},
}
@pytest.fixture
def sample_summary_json(sample_summary_dict) -> str:
"""合法 summary 的 JSON 字符串。"""
return json.dumps(sample_summary_dict, ensure_ascii=False, indent=2)
@pytest.fixture
def mock_pi_output(sample_summary_json) -> str:
"""模拟 pi CLI 的完整输出(包含 JSON)。"""
return f"""以下是论文的深度解读:
```json
{sample_summary_json}
```
希望这个总结对你有帮助!"""
@pytest.fixture
def admin_token():
"""返回测试用的 ADMIN_TOKEN(需要配合 monkeypatch 使用)。"""
return ADMIN_TOKEN
@pytest.fixture
def admin_headers(admin_token):
"""带 Bearer token 的请求头。"""
return {"Authorization": f"Bearer {admin_token}"}
@pytest.fixture
def wrong_admin_headers():
"""错误的 Authorization 请求头。"""
return {"Authorization": "Bearer wrong-token"}
# ── 多样例数据 ────────────────────────────────────────────────────────────
@pytest.fixture
def sample_papers_range(db_session):
"""插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。"""
now = datetime.now(timezone.utc)
papers = []
for i, (arxiv_id, paper_date_str) in enumerate([
("2401.10001", "2024-01-10"),
("2401.10002", "2024-01-11"),
("2401.10003", "2024-01-12"),
("2401.10004", "2024-01-13"),
("2401.10005", "2024-01-14"),
]):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i+1}",
abstract=f"Abstract for paper {i+1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
# FTS5
db_session.execute(
__import__("sqlalchemy").text(
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
"VALUES (:id, :title, :abstract, :authors, :tags)"
),
{"id": p.id, "title": p.title_en, "abstract": p.abstract,
"authors": f"Author {i+1}", "tags": f"Tag{i+1}"},
)
papers.append(p)
db_session.commit()
return papers
@pytest.fixture
def sample_papers_with_summary(db_session):
"""插入 5 篇带总结的论文(用于 search / pages / trends 测试)。"""
now = datetime.now(timezone.utc)
papers = []
for i, (arxiv_id, paper_date_str) in enumerate([
("2401.20001", "2024-01-10"),
("2401.20002", "2024-01-11"),
("2401.20003", "2024-01-12"),
("2401.20004", "2024-01-13"),
("2401.20005", "2024-01-14"),
]):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i+1}",
title_zh=f"测试论文 {i+1}",
abstract=f"Abstract for paper {i+1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10 + 5,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf"))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
db_session.add(SummaryStatus(
paper_id=p.id,
status="done" if i < 4 else "pending",
quality="normal",
))
# 添加总结(前 4 篇)
if i < 4:
summary = PaperSummary(
paper_id=p.id,
one_line=f"这是论文{i+1}的一句话摘要",
difficulty="中级",
motivation_problem=f"论文{i+1}的研究问题",
motivation_goal=f"论文{i+1}的研究目标",
method_key_idea=f"论文{i+1}的关键思路",
method_overview=f"论文{i+1}的方法概述",
updated_at=now,
full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}),
)
db_session.add(summary)
# FTS5
db_session.execute(
__import__("sqlalchemy").text(
"INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) "
"VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)"
),
{
"id": p.id,
"title_en": p.title_en,
"title_zh": p.title_zh or "",
"abstract": p.abstract or "",
"authors": f"Author {i+1}",
"tags": f"NLP, Tag{i+1}",
},
)
papers.append(p)
db_session.commit()
return papers