Files
daily-paper/app/cli.py
T

164 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""CLI 工具 — 手动抓取论文。"""
import asyncio
import logging
import typer
from dotenv import load_dotenv
# 在导入 app 模块前加载 .env
load_dotenv()
cli_app = typer.Typer(help="HF Daily Papers 管理 CLI")
@cli_app.command()
def crawl(
date_str: str = typer.Argument(
None,
help="抓取日期 (YYYY-MM-DD),留空则自动探测",
),
top_n: int = typer.Option(None, "--top", "-n", help="取前 N 篇"),
force: bool = typer.Option(False, "--force", "-f", help="强制重抓(即使已抓取过)"),
):
"""手动抓取指定日期的 HuggingFace Daily Papers。"""
from app.config import settings
from app.database import SessionLocal, engine
from app.database import init_db as _init
from app.models import Paper
from app.services.crawler import crawl_daily
from app.utils import today_str, yesterday_str
from sqlalchemy import func, select
target = date_str or today_str()
# 确保数据库和表存在
import os
os.makedirs(settings.db_path.parent, exist_ok=True)
_init(engine)
db = SessionLocal()
try:
# 检查是否已抓取过(非 force 模式)
if not force and not date_str:
existing = db.scalar(select(func.count(Paper.id)).where(Paper.paper_date == target)) or 0
if existing > 0:
typer.echo(f"⏭️ {target} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)")
return
typer.echo(f"📡 开始抓取 {target} ...")
result = asyncio.run(crawl_daily(db, target, top_n))
# 未指定日期且今天失败或无数据时,自动回退到昨天
need_fallback = not date_str and (
result["status"] == "failed" or result["found"] == 0
)
if need_fallback:
fallback = yesterday_str()
existing = db.scalar(select(func.count(Paper.id)).where(Paper.paper_date == fallback)) or 0
if existing > 0:
typer.echo(
f"⏭️ {fallback} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)"
)
else:
typer.echo(f"🔄 {target} 无数据,尝试 {fallback} ...")
target = fallback
result = asyncio.run(crawl_daily(db, target, top_n))
if result["status"] == "success":
typer.echo(
f"✅ 抓取完成:发现 {result['found']} 篇,新增 {result['new']}"
)
else:
typer.echo(f"❌ 抓取失败:{result['error']}", err=True)
raise typer.Exit(code=1)
finally:
db.close()
@cli_app.command()
def summarize(
arxiv_id: str = typer.Argument(
None,
help="指定论文 arXiv ID;留空则批量处理所有 pending",
),
pdf_mode: str = typer.Option(
"auto",
"--pdf-mode",
help="PDF 传递方式:auto(自动选择)| inject(全量注入)| searchpi 自主搜索)",
),
backend: str = typer.Option(
None,
"--backend",
help="总结后端:pi | claude(留空则使用 .env 配置)",
),
):
"""手动触发 AI 总结。"""
from app.config import settings
from app.database import SessionLocal, engine
from app.database import init_db as _init
from app.services.summarizer import summarize_batch, summarize_single
import os
if pdf_mode not in ("auto", "inject", "search"):
typer.echo(f"❌ 无效的 pdf_mode: {pdf_mode},只支持 auto / inject / search", err=True)
raise typer.Exit(code=1)
if backend:
if backend not in ("pi", "claude"):
typer.echo(f"❌ 无效的 backend: {backend},只支持 pi / claude", err=True)
raise typer.Exit(code=1)
settings.SUMMARY_BACKEND = backend
os.makedirs(settings.db_path.parent, exist_ok=True)
_init(engine)
# 配置 logging 输出到终端
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-5s %(name)s | %(message)s",
datefmt="%H:%M:%S",
)
db = SessionLocal()
try:
if arxiv_id:
typer.echo(f"🤖 开始总结 {arxiv_id} (mode={pdf_mode}) ...")
result = asyncio.run(summarize_single(db, arxiv_id, pdf_mode=pdf_mode))
else:
typer.echo(f"🤖 开始批量总结 pending 论文 (mode={pdf_mode}) ...")
result = asyncio.run(summarize_batch(db, pdf_mode=pdf_mode))
if result.get("status") in ("success", "done"):
typer.echo(f"✅ 总结完成:{result}")
elif result.get("status") == "conflict":
typer.echo("⚠️ 已有批量总结任务在运行中", err=True)
raise typer.Exit(code=1)
elif result.get("status") == "not_found":
typer.echo(f"❌ 论文未找到:{arxiv_id}", err=True)
raise typer.Exit(code=1)
else:
typer.echo(f"⚠️ 总结结果:{result}", err=True)
finally:
db.close()
@cli_app.command()
def init_db():
"""初始化数据库表。"""
from app.config import settings
from app.database import engine
from app.database import init_db as _init
import os
os.makedirs(settings.db_path.parent, exist_ok=True)
_init(engine)
typer.echo(f"✅ 数据库已初始化:{settings.db_path}")
if __name__ == "__main__":
cli_app()