Files

368 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试 fixtures — 内存 SQLite、TestClient、样例数据。"""
from __future__ import annotations
import json
from datetime import date
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database import get_db
from app.main import create_app
from app.database import init_db
from app.models import (
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
)
from app.utils import utc_now
# ── ChromaDB 隔离(autouse,所有测试)──────────────────────────────────
@pytest.fixture(autouse=True)
def _isolate_chroma(monkeypatch, tmp_path):
"""所有测试把 ChromaDB 隔离到临时目录 + 重置单例,绝不污染 data/chroma。
与内存 DB 隔离同理:summarize 后处理经真实 _maybe_index_chroma → index_paper
写入,不隔离会把测试夹具(2401.*)泄漏到生产 data/chroma,污染语义搜索。
每个测试前重置 _chroma 单例,确保 CHROMA_DIR 指向本次 tmp。
"""
import app.services.embedder as emb
from app.config import settings
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
emb._chroma.reset()
yield
emb._chroma.reset()
# ── 内存数据库 ──────────────────────────────────────────────────────────
@pytest.fixture
def db_engine():
"""创建内存 SQLite 引擎 + FTS5。"""
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
@event.listens_for(engine, "connect")
def _pragma(dbapi_connection, _record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
init_db(engine)
return engine
@pytest.fixture
def db_session(db_engine):
"""提供事务隔离的数据库 session。"""
Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False)
session = Session()
try:
yield session
finally:
session.close()
@pytest.fixture
def client(db_engine, db_session):
"""FastAPI TestClientoverride get_db。"""
app = create_app()
def _override_get_db():
yield db_session
app.dependency_overrides[get_db] = _override_get_db
with TestClient(app, raise_server_exceptions=False) as c:
yield c
app.dependency_overrides.clear()
# ── 样例数据 ────────────────────────────────────────────────────────────
SAMPLE_ARXIV_ID = "2401.12345"
_TEST_ADMIN_USERNAME = "admin"
_TEST_ADMIN_PASSWORD = "test-password-12345"
@pytest.fixture
def sample_paper(db_session):
"""插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。"""
now = utc_now()
paper = Paper(
arxiv_id=SAMPLE_ARXIV_ID,
title_en="Test Paper Title",
abstract="This is a test abstract for the paper.",
published_at=date(2024, 1, 15),
paper_date=date(2024, 1, 15),
crawled_at=now,
upvotes=42,
hf_url=f"https://huggingface.co/papers/{SAMPLE_ARXIV_ID}",
arxiv_url=f"https://arxiv.org/abs/{SAMPLE_ARXIV_ID}",
pdf_url=f"https://arxiv.org/pdf/{SAMPLE_ARXIV_ID}.pdf",
)
db_session.add(paper)
db_session.flush()
db_session.add(PaperAuthor(paper_id=paper.id, name="Alice Smith", position=0))
db_session.add(PaperAuthor(paper_id=paper.id, name="Bob Jones", position=1))
db_session.add(PaperTag(paper_id=paper.id, tag="NLP", source="hf"))
db_session.add(PaperTag(paper_id=paper.id, tag="LLM", source="hf"))
db_session.add(SummaryStatus(paper_id=paper.id, status="pending"))
# FTS5 初始行(与 crawler 一致)
db_session.execute(
__import__("sqlalchemy").text(
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
"VALUES (:id, :title, :abstract, :authors, :tags)"
),
{
"id": paper.id,
"title": paper.title_en,
"abstract": paper.abstract or "",
"authors": "Alice Smith, Bob Jones",
"tags": "NLP, LLM",
},
)
db_session.commit()
return paper
@pytest.fixture
def sample_summary_dict() -> dict:
"""完整合法的 summary dict。"""
return {
"arxiv_id": "2401.12345",
"title_zh": "测试论文中文标题",
"one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。",
"tags": ["自然语言处理", "大语言模型", "Transformer"],
"difficulty": "中级",
"prerequisites": {
"concepts": [
{
"term": "Transformer",
"explanation": "一种基于自注意力机制的序列到序列模型架构,广泛用于NLP任务。",
"why_matters": "本文方法基于 Transformer 架构进行改进。",
},
{
"term": "注意力机制",
"explanation": "允许模型在处理序列时动态关注不同位置的信息的机制。",
"why_matters": "理解注意力机制是理解本文方法的基础。",
},
],
},
"motivation": {
"problem": "现有模型在长文本理解上存在不足,主要体现在注意力计算复杂度随序列长度二次增长,导致实际应用中无法处理超长文本输入。",
"goal": "提出一种新的稀疏注意力机制来有效提升长文本建模能力,在保持模型整体性能的同时大幅降低计算开销和显存占用。",
"gap": "当前方法计算复杂度过高,已有的稀疏注意力方案在保留全局信息方面存在明显不足,导致长距离依赖建模效果不佳。",
},
"method": {
"overview": "提出了一种高效的稀疏注意力机制,通过局部-全局混合的注意力模式,在降低计算复杂度的同时保留了关键的全局信息流动。",
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度,局部窗口捕获短距离依赖,全局采样点维护长距离信息传递。",
"steps": "首先分析现有注意力机制的计算瓶颈,发现全连接注意力中大部分注意力权重接近于零。然后设计了一种混合稀疏注意力模式,包含局部滑动窗口和全局随机采样两条路径。最后在多个长文本基准数据集上进行了全面的实验验证。",
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模,通过可学习的采样策略动态调整全局注意力点的位置,而非固定模式。",
},
"results": {
"main_findings": "在长文本基准 LongBench 上取得了 SOTA 结果,平均得分提升 3.2 个百分点。推理速度相比全注意力提升了 2 倍,显存占用降低 60%。在 32k 序列长度下仍保持与全注意力相当的生成质量。",
"benchmarks": [
{
"task": "长文本摘要",
"metric": "ROUGE-L",
"this_work": "42.1",
"baseline": "38.9",
"improvement": "+3.2",
},
],
"limitations": "在超长文本(>100k tokens)上效果有所下降,主要原因是全局采样点数量不足以覆盖所有关键信息。此外,在小规模数据集上的优势不如大规模数据集明显。",
},
"improvements": {
"weaknesses": "仅验证了英文数据,未在中文等多语言场景下测试。全局采样策略在极端长度的文本上可能需要增加采样点数量,增加了工程复杂度。",
"future_work": "扩展到多语言场景,研究自适应采样策略,使模型能根据输入内容动态调整全局注意力点的分配。同时探索与 Flash Attention 等底层优化的兼容性。",
"reproducibility": "代码已在 GitHub 开源,提供了完整的训练脚本和预训练模型权重。实验使用了公开数据集,硬件需求为 8×A100 GPU。",
},
"figures": [
{
"id": "Figure 1",
"caption": "稀疏注意力机制的整体架构图",
"description": "展示了局部窗口注意力和全局采样注意力的组合方式,以及信息如何在两种路径间流动。",
"reason": "帮助理解本文方法的核心设计思想,直观展示了局部-全局混合模式的工作原理。",
},
],
}
@pytest.fixture
def sample_summary_json(sample_summary_dict) -> str:
"""合法 summary 的 JSON 字符串。"""
return json.dumps(sample_summary_dict, ensure_ascii=False, indent=2)
@pytest.fixture
def mock_pi_output(sample_summary_json) -> str:
"""模拟 pi CLI 的完整输出(包含 JSON)。"""
return f"""以下是论文的深度解读:
```json
{sample_summary_json}
```
希望这个总结对你有帮助!"""
@pytest.fixture
def auth_client(client, monkeypatch):
"""已登录的 TestClientsession cookie 自动携带)。"""
from app.config import settings
monkeypatch.setattr(settings, "ADMIN_USERNAME", _TEST_ADMIN_USERNAME)
monkeypatch.setattr(settings, "ADMIN_PASSWORD", _TEST_ADMIN_PASSWORD)
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
# 登录获取 session cookie
resp = client.post(
"/admin/login",
data={"username": _TEST_ADMIN_USERNAME, "password": _TEST_ADMIN_PASSWORD},
follow_redirects=False,
)
assert resp.status_code == 303
return client
# ── 多样例数据 ────────────────────────────────────────────────────────────
@pytest.fixture
def sample_papers_range(db_session):
"""插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。"""
now = utc_now()
papers = []
for i, (arxiv_id, paper_date_str) in enumerate(
[
("2401.10001", "2024-01-10"),
("2401.10002", "2024-01-11"),
("2401.10003", "2024-01-12"),
("2401.10004", "2024-01-13"),
("2401.10005", "2024-01-14"),
]
):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i + 1}",
abstract=f"Abstract for paper {i + 1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i + 1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i + 1}", source="hf"))
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
# FTS5
db_session.execute(
__import__("sqlalchemy").text(
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
"VALUES (:id, :title, :abstract, :authors, :tags)"
),
{
"id": p.id,
"title": p.title_en,
"abstract": p.abstract,
"authors": f"Author {i + 1}",
"tags": f"Tag{i + 1}",
},
)
papers.append(p)
db_session.commit()
return papers
@pytest.fixture
def sample_papers_with_summary(db_session):
"""插入 5 篇带总结的论文(用于 search / pages / trends 测试)。"""
now = utc_now()
papers = []
for i, (arxiv_id, paper_date_str) in enumerate(
[
("2401.20001", "2024-01-10"),
("2401.20002", "2024-01-11"),
("2401.20003", "2024-01-12"),
("2401.20004", "2024-01-13"),
("2401.20005", "2024-01-14"),
]
):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i + 1}",
title_zh=f"测试论文 {i + 1}",
abstract=f"Abstract for paper {i + 1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10 + 5,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i + 1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf"))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i + 1}", source="hf"))
db_session.add(
SummaryStatus(
paper_id=p.id,
status="done" if i < 4 else "pending",
quality="normal",
)
)
# 添加总结(前 4 篇)
if i < 4:
summary = PaperSummary(
paper_id=p.id,
one_line=f"这是论文{i + 1}的一句话摘要",
difficulty="中级",
motivation_problem=f"论文{i + 1}的研究问题",
motivation_goal=f"论文{i + 1}的研究目标",
method_key_idea=f"论文{i + 1}的关键思路",
method_overview=f"论文{i + 1}的方法概述",
updated_at=now,
full_json=json.dumps({"title_zh": f"测试论文 {i + 1}"}),
)
db_session.add(summary)
# FTS5
db_session.execute(
__import__("sqlalchemy").text(
"INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) "
"VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)"
),
{
"id": p.id,
"title_en": p.title_en,
"title_zh": p.title_zh or "",
"abstract": p.abstract or "",
"authors": f"Author {i + 1}",
"tags": f"NLP, Tag{i + 1}",
},
)
papers.append(p)
db_session.commit()
return papers