"""测试 fixtures — 内存 SQLite、TestClient、样例数据。""" from __future__ import annotations import json from datetime import date, datetime, timezone from pathlib import Path from unittest.mock import AsyncMock import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine, event from sqlalchemy.orm import DeclarativeBase, sessionmaker from sqlalchemy.pool import StaticPool from app.database import get_db from app.main import create_app from app.database import init_db from app.models import ( Paper, PaperAuthor, PaperSummary, PaperTag, SummaryStatus, ) # ── 内存数据库 ────────────────────────────────────────────────────────── class _TestBase(DeclarativeBase): pass # 复用 app.models 的 Base metadata from app.database import Base as _AppBase # noqa: E402 _TestBase.metadata = _AppBase.metadata @pytest.fixture def db_engine(): """创建内存 SQLite 引擎 + FTS5。""" engine = create_engine( "sqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) @event.listens_for(engine, "connect") def _pragma(dbapi_connection, _record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() init_db(engine) return engine @pytest.fixture def db_session(db_engine): """提供事务隔离的数据库 session。""" Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False) session = Session() try: yield session finally: session.close() @pytest.fixture def client(db_engine, db_session): """FastAPI TestClient,override get_db。""" app = create_app() def _override_get_db(): yield db_session app.dependency_overrides[get_db] = _override_get_db with TestClient(app, raise_server_exceptions=False) as c: yield c app.dependency_overrides.clear() # ── 样例数据 ──────────────────────────────────────────────────────────── SAMPLE_ARXIV_ID = "2401.12345" ADMIN_TOKEN = "test-admin-token-12345" @pytest.fixture def sample_paper(db_session): """插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。""" now = datetime.now(timezone.utc) paper = Paper( arxiv_id=SAMPLE_ARXIV_ID, title_en="Test Paper Title", abstract="This is a test abstract for the paper.", published_at=date(2024, 1, 15), paper_date=date(2024, 1, 15), crawled_at=now, upvotes=42, hf_url=f"https://huggingface.co/papers/{SAMPLE_ARXIV_ID}", arxiv_url=f"https://arxiv.org/abs/{SAMPLE_ARXIV_ID}", pdf_url=f"https://arxiv.org/pdf/{SAMPLE_ARXIV_ID}.pdf", ) db_session.add(paper) db_session.flush() db_session.add(PaperAuthor(paper_id=paper.id, name="Alice Smith", position=0)) db_session.add(PaperAuthor(paper_id=paper.id, name="Bob Jones", position=1)) db_session.add(PaperTag(paper_id=paper.id, tag="NLP", source="hf")) db_session.add(PaperTag(paper_id=paper.id, tag="LLM", source="hf")) db_session.add(SummaryStatus(paper_id=paper.id, status="pending")) # FTS5 初始行(与 crawler 一致) db_session.execute( __import__("sqlalchemy").text( "INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) " "VALUES (:id, :title, :abstract, :authors, :tags)" ), { "id": paper.id, "title": paper.title_en, "abstract": paper.abstract or "", "authors": "Alice Smith, Bob Jones", "tags": "NLP, LLM", }, ) db_session.commit() return paper @pytest.fixture def sample_summary_dict() -> dict: """完整合法的 summary dict。""" return { "title_zh": "测试论文中文标题", "one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。", "tags": ["自然语言处理", "大语言模型", "Transformer"], "difficulty": "中级", "prerequisites": { "concepts": ["Transformer", "注意力机制"], "level": "中级", }, "motivation": { "problem": "现有模型在长文本理解上存在不足。", "goal": "提出一种新的注意力机制来提升长文本建模能力。", "gap": "当前方法计算复杂度过高。", }, "method": { "overview": "提出了一种高效的稀疏注意力机制。", "key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度。", "steps": [ "分析现有注意力机制的瓶颈", "设计稀疏注意力模式", "在多个基准上验证效果", ], "novelty": "首次将局部-全局注意力模式结合应用于长文本建模。", }, "results": { "main_findings": [ "在长文本基准上取得了 SOTA 结果", "推理速度提升了 2 倍", ], "benchmarks": [ {"dataset": "LongBench", "score": 85.3}, ], "limitations": [ "在超长文本(>100k tokens)上效果有所下降", ], }, "improvements": { "weaknesses": ["仅验证了英文数据"], "future_work": ["扩展到多语言场景"], "reproducibility": "代码已开源,模型权重可下载。", }, } @pytest.fixture def sample_summary_json(sample_summary_dict) -> str: """合法 summary 的 JSON 字符串。""" return json.dumps(sample_summary_dict, ensure_ascii=False, indent=2) @pytest.fixture def mock_pi_output(sample_summary_json) -> str: """模拟 pi CLI 的完整输出(包含 JSON)。""" return f"""以下是论文的深度解读: ```json {sample_summary_json} ``` 希望这个总结对你有帮助!""" @pytest.fixture def admin_token(): """返回测试用的 ADMIN_TOKEN(需要配合 monkeypatch 使用)。""" return ADMIN_TOKEN @pytest.fixture def admin_headers(admin_token): """带 Bearer token 的请求头。""" return {"Authorization": f"Bearer {admin_token}"} @pytest.fixture def wrong_admin_headers(): """错误的 Authorization 请求头。""" return {"Authorization": "Bearer wrong-token"} # ── 多样例数据 ──────────────────────────────────────────────────────────── @pytest.fixture def sample_papers_range(db_session): """插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。""" now = datetime.now(timezone.utc) papers = [] for i, (arxiv_id, paper_date_str) in enumerate([ ("2401.10001", "2024-01-10"), ("2401.10002", "2024-01-11"), ("2401.10003", "2024-01-12"), ("2401.10004", "2024-01-13"), ("2401.10005", "2024-01-14"), ]): paper_date = date.fromisoformat(paper_date_str) p = Paper( arxiv_id=arxiv_id, title_en=f"Test Paper {i+1}", abstract=f"Abstract for paper {i+1}.", paper_date=paper_date, crawled_at=now, upvotes=i * 10, ) db_session.add(p) db_session.flush() db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0)) db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf")) db_session.add(SummaryStatus(paper_id=p.id, status="pending")) # FTS5 db_session.execute( __import__("sqlalchemy").text( "INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) " "VALUES (:id, :title, :abstract, :authors, :tags)" ), {"id": p.id, "title": p.title_en, "abstract": p.abstract, "authors": f"Author {i+1}", "tags": f"Tag{i+1}"}, ) papers.append(p) db_session.commit() return papers @pytest.fixture def sample_papers_with_summary(db_session): """插入 5 篇带总结的论文(用于 search / pages / trends 测试)。""" now = datetime.now(timezone.utc) papers = [] for i, (arxiv_id, paper_date_str) in enumerate([ ("2401.20001", "2024-01-10"), ("2401.20002", "2024-01-11"), ("2401.20003", "2024-01-12"), ("2401.20004", "2024-01-13"), ("2401.20005", "2024-01-14"), ]): paper_date = date.fromisoformat(paper_date_str) p = Paper( arxiv_id=arxiv_id, title_en=f"Test Paper {i+1}", title_zh=f"测试论文 {i+1}", abstract=f"Abstract for paper {i+1}.", paper_date=paper_date, crawled_at=now, upvotes=i * 10 + 5, ) db_session.add(p) db_session.flush() db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0)) db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf")) db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf")) db_session.add(SummaryStatus( paper_id=p.id, status="done" if i < 4 else "pending", quality="normal", )) # 添加总结(前 4 篇) if i < 4: summary = PaperSummary( paper_id=p.id, one_line=f"这是论文{i+1}的一句话摘要", difficulty="中级", motivation_problem=f"论文{i+1}的研究问题", motivation_goal=f"论文{i+1}的研究目标", method_key_idea=f"论文{i+1}的关键思路", method_overview=f"论文{i+1}的方法概述", updated_at=now, full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}), ) db_session.add(summary) # FTS5 db_session.execute( __import__("sqlalchemy").text( "INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) " "VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)" ), { "id": p.id, "title_en": p.title_en, "title_zh": p.title_zh or "", "abstract": p.abstract or "", "authors": f"Author {i+1}", "tags": f"NLP, Tag{i+1}", }, ) papers.append(p) db_session.commit() return papers