Files
daily-paper/tests/conftest.py
T
Rain-Bus 85c4cfb9e8 refactor: restructure services and add image/pdf extraction utilities
- Add image_extractor, pdf_downloader, pi_client, trends services
- Add shared utils module
- Refactor summarizer, embedder, routes for cleaner separation
- Update tests to match new service structure
2026-06-06 00:00:55 +08:00

212 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试 fixtures — 内存 SQLite、TestClient、样例数据。"""
from __future__ import annotations
import json
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, event
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from sqlalchemy.pool import StaticPool
from app.database import get_db
from app.main import create_app
from app.database import init_db
from app.models import (
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
)
# ── 内存数据库 ──────────────────────────────────────────────────────────
class _TestBase(DeclarativeBase):
pass
# 复用 app.models 的 Base metadata
from app.database import Base as _AppBase # noqa: E402
_TestBase.metadata = _AppBase.metadata
@pytest.fixture
def db_engine():
"""创建内存 SQLite 引擎 + FTS5。"""
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
@event.listens_for(engine, "connect")
def _pragma(dbapi_connection, _record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
init_db(engine)
return engine
@pytest.fixture
def db_session(db_engine):
"""提供事务隔离的数据库 session。"""
Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False)
session = Session()
try:
yield session
finally:
session.close()
@pytest.fixture
def client(db_engine, db_session):
"""FastAPI TestClientoverride get_db。"""
app = create_app()
def _override_get_db():
yield db_session
app.dependency_overrides[get_db] = _override_get_db
with TestClient(app, raise_server_exceptions=False) as c:
yield c
app.dependency_overrides.clear()
# ── 样例数据 ────────────────────────────────────────────────────────────
SAMPLE_ARXIV_ID = "2401.12345"
ADMIN_TOKEN = "test-admin-token-12345"
@pytest.fixture
def sample_paper(db_session):
"""插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。"""
now = datetime.now(timezone.utc)
paper = Paper(
arxiv_id=SAMPLE_ARXIV_ID,
title_en="Test Paper Title",
abstract="This is a test abstract for the paper.",
published_at=date(2024, 1, 15),
paper_date=date(2024, 1, 15),
crawled_at=now,
upvotes=42,
hf_url=f"https://huggingface.co/papers/{SAMPLE_ARXIV_ID}",
arxiv_url=f"https://arxiv.org/abs/{SAMPLE_ARXIV_ID}",
pdf_url=f"https://arxiv.org/pdf/{SAMPLE_ARXIV_ID}.pdf",
)
db_session.add(paper)
db_session.flush()
db_session.add(PaperAuthor(paper_id=paper.id, name="Alice Smith", position=0))
db_session.add(PaperAuthor(paper_id=paper.id, name="Bob Jones", position=1))
db_session.add(PaperTag(paper_id=paper.id, tag="NLP", source="hf"))
db_session.add(PaperTag(paper_id=paper.id, tag="LLM", source="hf"))
db_session.add(SummaryStatus(paper_id=paper.id, status="pending"))
# FTS5 初始行(与 crawler 一致)
db_session.execute(
__import__("sqlalchemy").text(
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
"VALUES (:id, :title, :abstract, :authors, :tags)"
),
{
"id": paper.id,
"title": paper.title_en,
"abstract": paper.abstract or "",
"authors": "Alice Smith, Bob Jones",
"tags": "NLP, LLM",
},
)
db_session.commit()
return paper
@pytest.fixture
def sample_summary_dict() -> dict:
"""完整合法的 summary dict。"""
return {
"title_zh": "测试论文中文标题",
"one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。",
"tags": ["自然语言处理", "大语言模型", "Transformer"],
"difficulty": "中级",
"prerequisites": {
"concepts": ["Transformer", "注意力机制"],
"level": "中级",
},
"motivation": {
"problem": "现有模型在长文本理解上存在不足。",
"goal": "提出一种新的注意力机制来提升长文本建模能力。",
"gap": "当前方法计算复杂度过高。",
},
"method": {
"overview": "提出了一种高效的稀疏注意力机制。",
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度。",
"steps": [
"分析现有注意力机制的瓶颈",
"设计稀疏注意力模式",
"在多个基准上验证效果",
],
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模。",
},
"results": {
"main_findings": [
"在长文本基准上取得了 SOTA 结果",
"推理速度提升了 2 倍",
],
"benchmarks": [
{"dataset": "LongBench", "score": 85.3},
],
"limitations": [
"在超长文本(>100k tokens)上效果有所下降",
],
},
"improvements": {
"weaknesses": ["仅验证了英文数据"],
"future_work": ["扩展到多语言场景"],
"reproducibility": "代码已开源,模型权重可下载。",
},
}
@pytest.fixture
def sample_summary_json(sample_summary_dict) -> str:
"""合法 summary 的 JSON 字符串。"""
return json.dumps(sample_summary_dict, ensure_ascii=False, indent=2)
@pytest.fixture
def mock_pi_output(sample_summary_json) -> str:
"""模拟 pi CLI 的完整输出(包含 JSON)。"""
return f"""以下是论文的深度解读:
```json
{sample_summary_json}
```
希望这个总结对你有帮助!"""
@pytest.fixture
def admin_token():
"""返回测试用的 ADMIN_TOKEN(需要配合 monkeypatch 使用)。"""
return ADMIN_TOKEN
@pytest.fixture
def admin_headers(admin_token):
"""带 Bearer token 的请求头。"""
return {"Authorization": f"Bearer {admin_token}"}