0d293422ac
- Replace image_extractor with pdf_image_extractor service - Enhance pi_client with expanded API capabilities - Improve summarizer service with additional features - Update admin routes with more endpoints - Add login page template - Enhance detail page with comprehensive layout - Improve search and trends pages - Update base template with additional elements - Refactor tests for better coverage - Add validate_summary script - Update project configuration and dependencies
353 lines
14 KiB
Python
353 lines
14 KiB
Python
"""测试 fixtures — 内存 SQLite、TestClient、样例数据。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from datetime import date, datetime, timezone
|
||
from pathlib import Path
|
||
from unittest.mock import AsyncMock
|
||
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
from sqlalchemy import create_engine, event
|
||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
from app.database import get_db
|
||
from app.main import create_app
|
||
from app.database import init_db
|
||
from app.models import (
|
||
Paper,
|
||
PaperAuthor,
|
||
PaperSummary,
|
||
PaperTag,
|
||
SummaryStatus,
|
||
)
|
||
|
||
|
||
# ── 内存数据库 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class _TestBase(DeclarativeBase):
|
||
pass
|
||
|
||
|
||
# 复用 app.models 的 Base metadata
|
||
from app.database import Base as _AppBase # noqa: E402
|
||
|
||
_TestBase.metadata = _AppBase.metadata
|
||
|
||
|
||
@pytest.fixture
|
||
def db_engine():
|
||
"""创建内存 SQLite 引擎 + FTS5。"""
|
||
engine = create_engine(
|
||
"sqlite:///:memory:",
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
|
||
@event.listens_for(engine, "connect")
|
||
def _pragma(dbapi_connection, _record):
|
||
cursor = dbapi_connection.cursor()
|
||
cursor.execute("PRAGMA foreign_keys=ON")
|
||
cursor.close()
|
||
|
||
init_db(engine)
|
||
return engine
|
||
|
||
|
||
@pytest.fixture
|
||
def db_session(db_engine):
|
||
"""提供事务隔离的数据库 session。"""
|
||
Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False)
|
||
session = Session()
|
||
try:
|
||
yield session
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
@pytest.fixture
|
||
def client(db_engine, db_session):
|
||
"""FastAPI TestClient,override get_db。"""
|
||
app = create_app()
|
||
|
||
def _override_get_db():
|
||
yield db_session
|
||
|
||
app.dependency_overrides[get_db] = _override_get_db
|
||
|
||
with TestClient(app, raise_server_exceptions=False) as c:
|
||
yield c
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
|
||
# ── 样例数据 ────────────────────────────────────────────────────────────
|
||
|
||
SAMPLE_ARXIV_ID = "2401.12345"
|
||
_TEST_ADMIN_USERNAME = "admin"
|
||
_TEST_ADMIN_PASSWORD = "test-password-12345"
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_paper(db_session):
|
||
"""插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。"""
|
||
now = datetime.now(timezone.utc)
|
||
paper = Paper(
|
||
arxiv_id=SAMPLE_ARXIV_ID,
|
||
title_en="Test Paper Title",
|
||
abstract="This is a test abstract for the paper.",
|
||
published_at=date(2024, 1, 15),
|
||
paper_date=date(2024, 1, 15),
|
||
crawled_at=now,
|
||
upvotes=42,
|
||
hf_url=f"https://huggingface.co/papers/{SAMPLE_ARXIV_ID}",
|
||
arxiv_url=f"https://arxiv.org/abs/{SAMPLE_ARXIV_ID}",
|
||
pdf_url=f"https://arxiv.org/pdf/{SAMPLE_ARXIV_ID}.pdf",
|
||
)
|
||
db_session.add(paper)
|
||
db_session.flush()
|
||
|
||
db_session.add(PaperAuthor(paper_id=paper.id, name="Alice Smith", position=0))
|
||
db_session.add(PaperAuthor(paper_id=paper.id, name="Bob Jones", position=1))
|
||
db_session.add(PaperTag(paper_id=paper.id, tag="NLP", source="hf"))
|
||
db_session.add(PaperTag(paper_id=paper.id, tag="LLM", source="hf"))
|
||
|
||
db_session.add(SummaryStatus(paper_id=paper.id, status="pending"))
|
||
|
||
# FTS5 初始行(与 crawler 一致)
|
||
db_session.execute(
|
||
__import__("sqlalchemy").text(
|
||
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
|
||
"VALUES (:id, :title, :abstract, :authors, :tags)"
|
||
),
|
||
{
|
||
"id": paper.id,
|
||
"title": paper.title_en,
|
||
"abstract": paper.abstract or "",
|
||
"authors": "Alice Smith, Bob Jones",
|
||
"tags": "NLP, LLM",
|
||
},
|
||
)
|
||
db_session.commit()
|
||
return paper
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_summary_dict() -> dict:
|
||
"""完整合法的 summary dict。"""
|
||
return {
|
||
"arxiv_id": "2401.12345",
|
||
"title_zh": "测试论文中文标题",
|
||
"one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。",
|
||
"tags": ["自然语言处理", "大语言模型", "Transformer"],
|
||
"difficulty": "中级",
|
||
"prerequisites": {
|
||
"concepts": [
|
||
{
|
||
"term": "Transformer",
|
||
"explanation": "一种基于自注意力机制的序列到序列模型架构,广泛用于NLP任务。",
|
||
"why_matters": "本文方法基于 Transformer 架构进行改进。",
|
||
},
|
||
{
|
||
"term": "注意力机制",
|
||
"explanation": "允许模型在处理序列时动态关注不同位置的信息的机制。",
|
||
"why_matters": "理解注意力机制是理解本文方法的基础。",
|
||
},
|
||
],
|
||
},
|
||
"motivation": {
|
||
"problem": "现有模型在长文本理解上存在不足,主要体现在注意力计算复杂度随序列长度二次增长,导致实际应用中无法处理超长文本输入。",
|
||
"goal": "提出一种新的稀疏注意力机制来有效提升长文本建模能力,在保持模型整体性能的同时大幅降低计算开销和显存占用。",
|
||
"gap": "当前方法计算复杂度过高,已有的稀疏注意力方案在保留全局信息方面存在明显不足,导致长距离依赖建模效果不佳。",
|
||
},
|
||
"method": {
|
||
"overview": "提出了一种高效的稀疏注意力机制,通过局部-全局混合的注意力模式,在降低计算复杂度的同时保留了关键的全局信息流动。",
|
||
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度,局部窗口捕获短距离依赖,全局采样点维护长距离信息传递。",
|
||
"steps": "首先分析现有注意力机制的计算瓶颈,发现全连接注意力中大部分注意力权重接近于零。然后设计了一种混合稀疏注意力模式,包含局部滑动窗口和全局随机采样两条路径。最后在多个长文本基准数据集上进行了全面的实验验证。",
|
||
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模,通过可学习的采样策略动态调整全局注意力点的位置,而非固定模式。",
|
||
},
|
||
"results": {
|
||
"main_findings": "在长文本基准 LongBench 上取得了 SOTA 结果,平均得分提升 3.2 个百分点。推理速度相比全注意力提升了 2 倍,显存占用降低 60%。在 32k 序列长度下仍保持与全注意力相当的生成质量。",
|
||
"benchmarks": [
|
||
{"task": "长文本摘要", "metric": "ROUGE-L", "this_work": "42.1", "baseline": "38.9", "improvement": "+3.2"},
|
||
],
|
||
"limitations": "在超长文本(>100k tokens)上效果有所下降,主要原因是全局采样点数量不足以覆盖所有关键信息。此外,在小规模数据集上的优势不如大规模数据集明显。",
|
||
},
|
||
"improvements": {
|
||
"weaknesses": "仅验证了英文数据,未在中文等多语言场景下测试。全局采样策略在极端长度的文本上可能需要增加采样点数量,增加了工程复杂度。",
|
||
"future_work": "扩展到多语言场景,研究自适应采样策略,使模型能根据输入内容动态调整全局注意力点的分配。同时探索与 Flash Attention 等底层优化的兼容性。",
|
||
"reproducibility": "代码已在 GitHub 开源,提供了完整的训练脚本和预训练模型权重。实验使用了公开数据集,硬件需求为 8×A100 GPU。",
|
||
},
|
||
"figures": [
|
||
{
|
||
"id": "Figure 1",
|
||
"caption": "稀疏注意力机制的整体架构图",
|
||
"description": "展示了局部窗口注意力和全局采样注意力的组合方式,以及信息如何在两种路径间流动。",
|
||
"reason": "帮助理解本文方法的核心设计思想,直观展示了局部-全局混合模式的工作原理。",
|
||
},
|
||
],
|
||
}
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_summary_json(sample_summary_dict) -> str:
|
||
"""合法 summary 的 JSON 字符串。"""
|
||
return json.dumps(sample_summary_dict, ensure_ascii=False, indent=2)
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_pi_output(sample_summary_json) -> str:
|
||
"""模拟 pi CLI 的完整输出(包含 JSON)。"""
|
||
return f"""以下是论文的深度解读:
|
||
|
||
```json
|
||
{sample_summary_json}
|
||
```
|
||
|
||
希望这个总结对你有帮助!"""
|
||
|
||
|
||
@pytest.fixture
|
||
def auth_client(client, monkeypatch):
|
||
"""已登录的 TestClient(session cookie 自动携带)。"""
|
||
from app.config import settings
|
||
|
||
monkeypatch.setattr(settings, "ADMIN_USERNAME", _TEST_ADMIN_USERNAME)
|
||
monkeypatch.setattr(settings, "ADMIN_PASSWORD", _TEST_ADMIN_PASSWORD)
|
||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||
# 登录获取 session cookie
|
||
resp = client.post(
|
||
"/admin/login",
|
||
data={"username": _TEST_ADMIN_USERNAME, "password": _TEST_ADMIN_PASSWORD},
|
||
follow_redirects=False,
|
||
)
|
||
assert resp.status_code == 303
|
||
return client
|
||
|
||
|
||
# ── 多样例数据 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_papers_range(db_session):
|
||
"""插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。"""
|
||
now = datetime.now(timezone.utc)
|
||
papers = []
|
||
for i, (arxiv_id, paper_date_str) in enumerate(
|
||
[
|
||
("2401.10001", "2024-01-10"),
|
||
("2401.10002", "2024-01-11"),
|
||
("2401.10003", "2024-01-12"),
|
||
("2401.10004", "2024-01-13"),
|
||
("2401.10005", "2024-01-14"),
|
||
]
|
||
):
|
||
paper_date = date.fromisoformat(paper_date_str)
|
||
p = Paper(
|
||
arxiv_id=arxiv_id,
|
||
title_en=f"Test Paper {i + 1}",
|
||
abstract=f"Abstract for paper {i + 1}.",
|
||
paper_date=paper_date,
|
||
crawled_at=now,
|
||
upvotes=i * 10,
|
||
)
|
||
db_session.add(p)
|
||
db_session.flush()
|
||
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i + 1}", position=0))
|
||
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i + 1}", source="hf"))
|
||
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||
# FTS5
|
||
db_session.execute(
|
||
__import__("sqlalchemy").text(
|
||
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
|
||
"VALUES (:id, :title, :abstract, :authors, :tags)"
|
||
),
|
||
{
|
||
"id": p.id,
|
||
"title": p.title_en,
|
||
"abstract": p.abstract,
|
||
"authors": f"Author {i + 1}",
|
||
"tags": f"Tag{i + 1}",
|
||
},
|
||
)
|
||
papers.append(p)
|
||
db_session.commit()
|
||
return papers
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_papers_with_summary(db_session):
|
||
"""插入 5 篇带总结的论文(用于 search / pages / trends 测试)。"""
|
||
now = datetime.now(timezone.utc)
|
||
papers = []
|
||
for i, (arxiv_id, paper_date_str) in enumerate(
|
||
[
|
||
("2401.20001", "2024-01-10"),
|
||
("2401.20002", "2024-01-11"),
|
||
("2401.20003", "2024-01-12"),
|
||
("2401.20004", "2024-01-13"),
|
||
("2401.20005", "2024-01-14"),
|
||
]
|
||
):
|
||
paper_date = date.fromisoformat(paper_date_str)
|
||
p = Paper(
|
||
arxiv_id=arxiv_id,
|
||
title_en=f"Test Paper {i + 1}",
|
||
title_zh=f"测试论文 {i + 1}",
|
||
abstract=f"Abstract for paper {i + 1}.",
|
||
paper_date=paper_date,
|
||
crawled_at=now,
|
||
upvotes=i * 10 + 5,
|
||
)
|
||
db_session.add(p)
|
||
db_session.flush()
|
||
|
||
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i + 1}", position=0))
|
||
db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf"))
|
||
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i + 1}", source="hf"))
|
||
|
||
db_session.add(
|
||
SummaryStatus(
|
||
paper_id=p.id,
|
||
status="done" if i < 4 else "pending",
|
||
quality="normal",
|
||
)
|
||
)
|
||
|
||
# 添加总结(前 4 篇)
|
||
if i < 4:
|
||
summary = PaperSummary(
|
||
paper_id=p.id,
|
||
one_line=f"这是论文{i + 1}的一句话摘要",
|
||
difficulty="中级",
|
||
motivation_problem=f"论文{i + 1}的研究问题",
|
||
motivation_goal=f"论文{i + 1}的研究目标",
|
||
method_key_idea=f"论文{i + 1}的关键思路",
|
||
method_overview=f"论文{i + 1}的方法概述",
|
||
updated_at=now,
|
||
full_json=json.dumps({"title_zh": f"测试论文 {i + 1}"}),
|
||
)
|
||
db_session.add(summary)
|
||
|
||
# FTS5
|
||
db_session.execute(
|
||
__import__("sqlalchemy").text(
|
||
"INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) "
|
||
"VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)"
|
||
),
|
||
{
|
||
"id": p.id,
|
||
"title_en": p.title_en,
|
||
"title_zh": p.title_zh or "",
|
||
"abstract": p.abstract or "",
|
||
"authors": f"Author {i + 1}",
|
||
"tags": f"NLP, Tag{i + 1}",
|
||
},
|
||
)
|
||
papers.append(p)
|
||
db_session.commit()
|
||
return papers
|