"""测试 fixtures — 内存 SQLite、TestClient、样例数据。""" from __future__ import annotations import json from datetime import date import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from app.database import get_db from app.main import create_app from app.database import init_db from app.models import ( Paper, PaperAuthor, PaperSummary, PaperTag, SummaryStatus, ) from app.utils import utc_now # ── 内存数据库 ────────────────────────────────────────────────────────── @pytest.fixture def db_engine(): """创建内存 SQLite 引擎 + FTS5。""" engine = create_engine( "sqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) @event.listens_for(engine, "connect") def _pragma(dbapi_connection, _record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() init_db(engine) return engine @pytest.fixture def db_session(db_engine): """提供事务隔离的数据库 session。""" Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False) session = Session() try: yield session finally: session.close() @pytest.fixture def client(db_engine, db_session): """FastAPI TestClient,override get_db。""" app = create_app() def _override_get_db(): yield db_session app.dependency_overrides[get_db] = _override_get_db with TestClient(app, raise_server_exceptions=False) as c: yield c app.dependency_overrides.clear() # ── 样例数据 ──────────────────────────────────────────────────────────── SAMPLE_ARXIV_ID = "2401.12345" _TEST_ADMIN_USERNAME = "admin" _TEST_ADMIN_PASSWORD = "test-password-12345" @pytest.fixture def sample_paper(db_session): """插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。""" now = utc_now() paper = Paper( arxiv_id=SAMPLE_ARXIV_ID, title_en="Test Paper Title", abstract="This is a test abstract for the paper.", published_at=date(2024, 1, 15), paper_date=date(2024, 1, 15), crawled_at=now, upvotes=42, hf_url=f"https://huggingface.co/papers/{SAMPLE_ARXIV_ID}", arxiv_url=f"https://arxiv.org/abs/{SAMPLE_ARXIV_ID}", pdf_url=f"https://arxiv.org/pdf/{SAMPLE_ARXIV_ID}.pdf", ) db_session.add(paper) db_session.flush() db_session.add(PaperAuthor(paper_id=paper.id, name="Alice Smith", position=0)) db_session.add(PaperAuthor(paper_id=paper.id, name="Bob Jones", position=1)) db_session.add(PaperTag(paper_id=paper.id, tag="NLP", source="hf")) db_session.add(PaperTag(paper_id=paper.id, tag="LLM", source="hf")) db_session.add(SummaryStatus(paper_id=paper.id, status="pending")) # FTS5 初始行(与 crawler 一致) db_session.execute( __import__("sqlalchemy").text( "INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) " "VALUES (:id, :title, :abstract, :authors, :tags)" ), { "id": paper.id, "title": paper.title_en, "abstract": paper.abstract or "", "authors": "Alice Smith, Bob Jones", "tags": "NLP, LLM", }, ) db_session.commit() return paper @pytest.fixture def sample_summary_dict() -> dict: """完整合法的 summary dict。""" return { "arxiv_id": "2401.12345", "title_zh": "测试论文中文标题", "one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。", "tags": ["自然语言处理", "大语言模型", "Transformer"], "difficulty": "中级", "prerequisites": { "concepts": [ { "term": "Transformer", "explanation": "一种基于自注意力机制的序列到序列模型架构,广泛用于NLP任务。", "why_matters": "本文方法基于 Transformer 架构进行改进。", }, { "term": "注意力机制", "explanation": "允许模型在处理序列时动态关注不同位置的信息的机制。", "why_matters": "理解注意力机制是理解本文方法的基础。", }, ], }, "motivation": { "problem": "现有模型在长文本理解上存在不足,主要体现在注意力计算复杂度随序列长度二次增长,导致实际应用中无法处理超长文本输入。", "goal": "提出一种新的稀疏注意力机制来有效提升长文本建模能力,在保持模型整体性能的同时大幅降低计算开销和显存占用。", "gap": "当前方法计算复杂度过高,已有的稀疏注意力方案在保留全局信息方面存在明显不足,导致长距离依赖建模效果不佳。", }, "method": { "overview": "提出了一种高效的稀疏注意力机制,通过局部-全局混合的注意力模式,在降低计算复杂度的同时保留了关键的全局信息流动。", "key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度,局部窗口捕获短距离依赖,全局采样点维护长距离信息传递。", "steps": "首先分析现有注意力机制的计算瓶颈,发现全连接注意力中大部分注意力权重接近于零。然后设计了一种混合稀疏注意力模式,包含局部滑动窗口和全局随机采样两条路径。最后在多个长文本基准数据集上进行了全面的实验验证。", "novelty": "首次将局部-全局注意力模式结合应用于长文本建模,通过可学习的采样策略动态调整全局注意力点的位置,而非固定模式。", }, "results": { "main_findings": "在长文本基准 LongBench 上取得了 SOTA 结果,平均得分提升 3.2 个百分点。推理速度相比全注意力提升了 2 倍,显存占用降低 60%。在 32k 序列长度下仍保持与全注意力相当的生成质量。", "benchmarks": [ { "task": "长文本摘要", "metric": "ROUGE-L", "this_work": "42.1", "baseline": "38.9", "improvement": "+3.2", }, ], "limitations": "在超长文本(>100k tokens)上效果有所下降,主要原因是全局采样点数量不足以覆盖所有关键信息。此外,在小规模数据集上的优势不如大规模数据集明显。", }, "improvements": { "weaknesses": "仅验证了英文数据,未在中文等多语言场景下测试。全局采样策略在极端长度的文本上可能需要增加采样点数量,增加了工程复杂度。", "future_work": "扩展到多语言场景,研究自适应采样策略,使模型能根据输入内容动态调整全局注意力点的分配。同时探索与 Flash Attention 等底层优化的兼容性。", "reproducibility": "代码已在 GitHub 开源,提供了完整的训练脚本和预训练模型权重。实验使用了公开数据集,硬件需求为 8×A100 GPU。", }, "figures": [ { "id": "Figure 1", "caption": "稀疏注意力机制的整体架构图", "description": "展示了局部窗口注意力和全局采样注意力的组合方式,以及信息如何在两种路径间流动。", "reason": "帮助理解本文方法的核心设计思想,直观展示了局部-全局混合模式的工作原理。", }, ], } @pytest.fixture def sample_summary_json(sample_summary_dict) -> str: """合法 summary 的 JSON 字符串。""" return json.dumps(sample_summary_dict, ensure_ascii=False, indent=2) @pytest.fixture def mock_pi_output(sample_summary_json) -> str: """模拟 pi CLI 的完整输出(包含 JSON)。""" return f"""以下是论文的深度解读: ```json {sample_summary_json} ``` 希望这个总结对你有帮助!""" @pytest.fixture def auth_client(client, monkeypatch): """已登录的 TestClient(session cookie 自动携带)。""" from app.config import settings monkeypatch.setattr(settings, "ADMIN_USERNAME", _TEST_ADMIN_USERNAME) monkeypatch.setattr(settings, "ADMIN_PASSWORD", _TEST_ADMIN_PASSWORD) monkeypatch.setattr(settings, "CHROMA_ENABLED", False) # 登录获取 session cookie resp = client.post( "/admin/login", data={"username": _TEST_ADMIN_USERNAME, "password": _TEST_ADMIN_PASSWORD}, follow_redirects=False, ) assert resp.status_code == 303 return client # ── 多样例数据 ──────────────────────────────────────────────────────────── @pytest.fixture def sample_papers_range(db_session): """插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。""" now = utc_now() papers = [] for i, (arxiv_id, paper_date_str) in enumerate( [ ("2401.10001", "2024-01-10"), ("2401.10002", "2024-01-11"), ("2401.10003", "2024-01-12"), ("2401.10004", "2024-01-13"), ("2401.10005", "2024-01-14"), ] ): paper_date = date.fromisoformat(paper_date_str) p = Paper( arxiv_id=arxiv_id, title_en=f"Test Paper {i + 1}", abstract=f"Abstract for paper {i + 1}.", paper_date=paper_date, crawled_at=now, upvotes=i * 10, ) db_session.add(p) db_session.flush() db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i + 1}", position=0)) db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i + 1}", source="hf")) db_session.add(SummaryStatus(paper_id=p.id, status="pending")) # FTS5 db_session.execute( __import__("sqlalchemy").text( "INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) " "VALUES (:id, :title, :abstract, :authors, :tags)" ), { "id": p.id, "title": p.title_en, "abstract": p.abstract, "authors": f"Author {i + 1}", "tags": f"Tag{i + 1}", }, ) papers.append(p) db_session.commit() return papers @pytest.fixture def sample_papers_with_summary(db_session): """插入 5 篇带总结的论文(用于 search / pages / trends 测试)。""" now = utc_now() papers = [] for i, (arxiv_id, paper_date_str) in enumerate( [ ("2401.20001", "2024-01-10"), ("2401.20002", "2024-01-11"), ("2401.20003", "2024-01-12"), ("2401.20004", "2024-01-13"), ("2401.20005", "2024-01-14"), ] ): paper_date = date.fromisoformat(paper_date_str) p = Paper( arxiv_id=arxiv_id, title_en=f"Test Paper {i + 1}", title_zh=f"测试论文 {i + 1}", abstract=f"Abstract for paper {i + 1}.", paper_date=paper_date, crawled_at=now, upvotes=i * 10 + 5, ) db_session.add(p) db_session.flush() db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i + 1}", position=0)) db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf")) db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i + 1}", source="hf")) db_session.add( SummaryStatus( paper_id=p.id, status="done" if i < 4 else "pending", quality="normal", ) ) # 添加总结(前 4 篇) if i < 4: summary = PaperSummary( paper_id=p.id, one_line=f"这是论文{i + 1}的一句话摘要", difficulty="中级", motivation_problem=f"论文{i + 1}的研究问题", motivation_goal=f"论文{i + 1}的研究目标", method_key_idea=f"论文{i + 1}的关键思路", method_overview=f"论文{i + 1}的方法概述", updated_at=now, full_json=json.dumps({"title_zh": f"测试论文 {i + 1}"}), ) db_session.add(summary) # FTS5 db_session.execute( __import__("sqlalchemy").text( "INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) " "VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)" ), { "id": p.id, "title_en": p.title_en, "title_zh": p.title_zh or "", "abstract": p.abstract or "", "authors": f"Author {i + 1}", "tags": f"NLP, Tag{i + 1}", }, ) papers.append(p) db_session.commit() return papers