212 lines
6.4 KiB
Python
212 lines
6.4 KiB
Python
"""测试 fixtures — 内存 SQLite、TestClient、样例数据。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from datetime import date, datetime, timezone
|
||
from pathlib import Path
|
||
from unittest.mock import AsyncMock
|
||
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
from sqlalchemy import create_engine, event
|
||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
from app.database import get_db
|
||
from app.main import create_app
|
||
from app.models import (
|
||
Paper,
|
||
PaperAuthor,
|
||
PaperSummary,
|
||
PaperTag,
|
||
SummaryStatus,
|
||
init_db,
|
||
)
|
||
|
||
|
||
# ── 内存数据库 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class _TestBase(DeclarativeBase):
|
||
pass
|
||
|
||
|
||
# 复用 app.models 的 Base metadata
|
||
from app.database import Base as _AppBase # noqa: E402
|
||
|
||
_TestBase.metadata = _AppBase.metadata
|
||
|
||
|
||
@pytest.fixture
|
||
def db_engine():
|
||
"""创建内存 SQLite 引擎 + FTS5。"""
|
||
engine = create_engine(
|
||
"sqlite:///:memory:",
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
|
||
@event.listens_for(engine, "connect")
|
||
def _pragma(dbapi_connection, _record):
|
||
cursor = dbapi_connection.cursor()
|
||
cursor.execute("PRAGMA foreign_keys=ON")
|
||
cursor.close()
|
||
|
||
init_db(engine)
|
||
return engine
|
||
|
||
|
||
@pytest.fixture
|
||
def db_session(db_engine):
|
||
"""提供事务隔离的数据库 session。"""
|
||
Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False)
|
||
session = Session()
|
||
try:
|
||
yield session
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
@pytest.fixture
|
||
def client(db_engine, db_session):
|
||
"""FastAPI TestClient,override get_db。"""
|
||
app = create_app()
|
||
|
||
def _override_get_db():
|
||
yield db_session
|
||
|
||
app.dependency_overrides[get_db] = _override_get_db
|
||
|
||
with TestClient(app, raise_server_exceptions=False) as c:
|
||
yield c
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
|
||
# ── 样例数据 ────────────────────────────────────────────────────────────
|
||
|
||
SAMPLE_ARXIV_ID = "2401.12345"
|
||
ADMIN_TOKEN = "test-admin-token-12345"
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_paper(db_session):
|
||
"""插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。"""
|
||
now = datetime.now(timezone.utc)
|
||
paper = Paper(
|
||
arxiv_id=SAMPLE_ARXIV_ID,
|
||
title_en="Test Paper Title",
|
||
abstract="This is a test abstract for the paper.",
|
||
published_at=date(2024, 1, 15),
|
||
paper_date=date(2024, 1, 15),
|
||
crawled_at=now,
|
||
upvotes=42,
|
||
hf_url=f"https://huggingface.co/papers/{SAMPLE_ARXIV_ID}",
|
||
arxiv_url=f"https://arxiv.org/abs/{SAMPLE_ARXIV_ID}",
|
||
pdf_url=f"https://arxiv.org/pdf/{SAMPLE_ARXIV_ID}.pdf",
|
||
)
|
||
db_session.add(paper)
|
||
db_session.flush()
|
||
|
||
db_session.add(PaperAuthor(paper_id=paper.id, name="Alice Smith", position=0))
|
||
db_session.add(PaperAuthor(paper_id=paper.id, name="Bob Jones", position=1))
|
||
db_session.add(PaperTag(paper_id=paper.id, tag="NLP", source="hf"))
|
||
db_session.add(PaperTag(paper_id=paper.id, tag="LLM", source="hf"))
|
||
|
||
db_session.add(SummaryStatus(paper_id=paper.id, status="pending"))
|
||
|
||
# FTS5 初始行(与 crawler 一致)
|
||
db_session.execute(
|
||
__import__("sqlalchemy").text(
|
||
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
|
||
"VALUES (:id, :title, :abstract, :authors, :tags)"
|
||
),
|
||
{
|
||
"id": paper.id,
|
||
"title": paper.title_en,
|
||
"abstract": paper.abstract or "",
|
||
"authors": "Alice Smith, Bob Jones",
|
||
"tags": "NLP, LLM",
|
||
},
|
||
)
|
||
db_session.commit()
|
||
return paper
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_summary_dict() -> dict:
|
||
"""完整合法的 summary dict。"""
|
||
return {
|
||
"title_zh": "测试论文中文标题",
|
||
"one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。",
|
||
"tags": ["自然语言处理", "大语言模型", "Transformer"],
|
||
"difficulty": "中级",
|
||
"prerequisites": {
|
||
"concepts": ["Transformer", "注意力机制"],
|
||
"level": "中级",
|
||
},
|
||
"motivation": {
|
||
"problem": "现有模型在长文本理解上存在不足。",
|
||
"goal": "提出一种新的注意力机制来提升长文本建模能力。",
|
||
"gap": "当前方法计算复杂度过高。",
|
||
},
|
||
"method": {
|
||
"overview": "提出了一种高效的稀疏注意力机制。",
|
||
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度。",
|
||
"steps": [
|
||
"分析现有注意力机制的瓶颈",
|
||
"设计稀疏注意力模式",
|
||
"在多个基准上验证效果",
|
||
],
|
||
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模。",
|
||
},
|
||
"results": {
|
||
"main_findings": [
|
||
"在长文本基准上取得了 SOTA 结果",
|
||
"推理速度提升了 2 倍",
|
||
],
|
||
"benchmarks": [
|
||
{"dataset": "LongBench", "score": 85.3},
|
||
],
|
||
"limitations": [
|
||
"在超长文本(>100k tokens)上效果有所下降",
|
||
],
|
||
},
|
||
"improvements": {
|
||
"weaknesses": ["仅验证了英文数据"],
|
||
"future_work": ["扩展到多语言场景"],
|
||
"reproducibility": "代码已开源,模型权重可下载。",
|
||
},
|
||
}
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_summary_json(sample_summary_dict) -> str:
|
||
"""合法 summary 的 JSON 字符串。"""
|
||
return json.dumps(sample_summary_dict, ensure_ascii=False, indent=2)
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_pi_output(sample_summary_json) -> str:
|
||
"""模拟 pi CLI 的完整输出(包含 JSON)。"""
|
||
return f"""以下是论文的深度解读:
|
||
|
||
```json
|
||
{sample_summary_json}
|
||
```
|
||
|
||
希望这个总结对你有帮助!"""
|
||
|
||
|
||
@pytest.fixture
|
||
def admin_token():
|
||
"""返回测试用的 ADMIN_TOKEN(需要配合 monkeypatch 使用)。"""
|
||
return ADMIN_TOKEN
|
||
|
||
|
||
@pytest.fixture
|
||
def admin_headers(admin_token):
|
||
"""带 Bearer token 的请求头。"""
|
||
return {"Authorization": f"Bearer {admin_token}"}
|