Python+FAISS:五分钟打造一个RAG系统

发布于 2025-9-17 00:02
浏览
0收藏

我遇到个麻烦:手头有几十(好吧,实际上是几百)个 PDF 文件——研究论文、API 文档、白皮书——散落在各个文件夹里。搜索慢得要死,浏览更烦。所以我搞了个 PDF 问答引擎,能把文件吃进去、分块、嵌入、用 FAISS 索引、找最佳段落,还能给个简洁的回答(而且有不用 API 的备选方案)。这篇文章把所有东西都给你——端到端的代码,用大白话解释清楚。

你能得到啥

• 本地 PDF 加载(不用云)

• 更聪明的分块(保留上下文)

• 用 sentence-transformers 做 embeddings

• 用 FAISS 做向量搜索(cosine)+ SQLite 存 metadata

• Retriever → Answerer,支持可选 LLM(有 extractive summary 备选)

• 用 Gradio 做个简洁的单页 app

关键词(方便别人找到这篇):AI document search, PDF search, vector database, FAISS, embeddings, Sentence Transformers, RAG, Gradio, OpenAI(可选)。

项目结构(直接复制成文件)

pdfqa/
  settings.py
  loader.py
  chunker.py
  embedder.py
  store.py
  searcher.py
  answerer.py
  app.py
  build_index.py
  requirements.txt

0) 环境要求

# requirements.txt
pdfplumber>=0.11.0
sentence-transformers>=3.0.1
faiss-cpu>=1.8.0
numpy>=1.26.4
scikit-learn>=1.5.1
tqdm>=4.66.4
gradio>=4.40.0
python-dotenv>=1.0.1
nltk>=3.9.1
rank-bm25>=0.2.2
openai>=1.40.0     # 可选;不用 API key 也能跑

安装并准备 NLTK(只需一次):

python -c "import nltk; nltk.download('punkt_tab')" || python -c "import nltk; nltk.download('punkt')"

1) 设置

# settings.py
from pathlib import Path
from dataclasses import dataclass

@dataclass(frozen=True)
class Config:
    PDF_DIR: Path = Path("./pdfs")                 # 放你的 PDF 文件
    DB_PATH: Path = Path("./chunks.sqlite")        # SQLite 存 chunk metadata
    INDEX_PATH: Path = Path("./index.faiss")       # FAISS 索引文件
    MODEL_NAME: str = "sentence-transformers/all-MiniLM-L12-v2"
    CHUNK_SIZE: int = 1000                         # 每块目标字符数
    CHUNK_OVERLAP: int = 200
    TOP_K: int = 4                                 # 检索的段落数
    MAX_ANSWER_TOKENS: int = 500                   # 用于 LLM
CFG = Config()

2) 加载 PDF(本地,超快)

# loader.py
import pdfplumber
from pathlib import Path
from typing importList, Dict
from settings import CFG

defload_pdfs(pdf_dir: Path = CFG.PDF_DIR) -> List[Dict]:
    pdf_dir.mkdir(parents=True, exist_ok=True)
    docs = []
    for pdf_path insorted(pdf_dir.glob("*.pdf")):
        text_parts = []
        with pdfplumber.open(pdf_path) as pdf:
            for page in pdf.pages:
                # 比简单 get_text 更稳;可根据需要调整
                text_parts.append(page.extract_text() or"")
        text = "\n".join(text_parts).strip()
        if text:
            docs.append({"filename": pdf_path.name, "text": text})
            print(f"✅ 加载 {pdf_path.name} ({len(text)} 字符)")
        else:
            print(f"⚠️ 空的或无法提取:{pdf_path.name}")
    return docs

if __name__ == "__main__":
    load_pdfs()

3) 分块,保留上下文(段落感知)

# chunker.py
from typing importList, Dict
from settings import CFG

def_paragraphs(txt: str) -> List[str]:
    # 按空行分割;保持结构轻量
    blocks = [b.strip() for b in txt.split("\n\n") if b.strip()]
    return blocks or [txt]

defchunk_document(doc: Dict, size: int = CFG.CHUNK_SIZE, overlap: int = CFG.CHUNK_OVERLAP) -> List[Dict]:
    paras = _paragraphs(doc["text"])
    chunks = []
    buf, start_char = [], 0
    cur_len = 0

    for p in paras:
        if cur_len + len(p) + 1 <= size:
            buf.append(p)
            cur_len += len(p) + 1
            continue
        # 清空缓冲
        block = "\n\n".join(buf).strip()
        if block:
            chunks.append({
                "filename": doc["filename"],
                "start": start_char,
                "end": start_char + len(block),
                "text": block
            })
        # 从尾部创建重叠部分
        tail = block[-overlap:] if overlap > 0andlen(block) > overlap else""
        buf = [tail, p] if tail else [p]
        start_char += max(0, len(block) - overlap)
        cur_len = len("\n\n".join(buf))

    # 最后一块
    block = "\n\n".join(buf).strip()
    if block:
        chunks.append({
            "filename": doc["filename"],
            "start": start_char,
            "end": start_char + len(block),
            "text": block
        })
    return chunks

defchunk_all(docs: List[Dict]) -> List[Dict]:
    out = []
    for d in docs:
        out.extend(chunk_document(d))
    print(f"🔹 创建了 {len(out)} 个 chunk")
    return out

if __name__ == "__main__":
    from loader import load_pdfs
    all_chunks = chunk_all(load_pdfs())

4) 嵌入(支持 cosine)

# embedder.py
import numpy as np
from typing importList, Dict
from sentence_transformers import SentenceTransformer
from settings import CFG
from tqdm import tqdm
from sklearn.preprocessing import normalize

_model = None

defget_model() -> SentenceTransformer:
    global _model
    if _model isNone:
        _model = SentenceTransformer(CFG.MODEL_NAME)
    return _model

defembed_texts(chunks: List[Dict]) -> np.ndarray:
    model = get_model()
    texts = [c["text"] for c in chunks]
    # encode → L2 归一化,确保 Inner Product == Cosine similarity
    vecs = model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
    return normalize(vecs)  # 对 IndexFlatIP 很重要

5) 存储向量 + metadata(FAISS + SQLite)

# store.py
import sqlite3
import faiss
import numpy as np
from typing importList, Dict
from settings import CFG

definit_db():
    con = sqlite3.connect(CFG.DB_PATH)
    cur = con.cursor()
    cur.execute("""
      CREATE TABLE IF NOT EXISTS chunks (
        id INTEGER PRIMARY KEY,
        filename TEXT,
        start INTEGER,
        end INTEGER,
        text TEXT
      )
    """)
    con.commit()
    con.close()

defsave_chunks(chunks: List[Dict]):
    con = sqlite3.connect(CFG.DB_PATH)
    cur = con.cursor()
    cur.execute("DELETE FROM chunks")
    cur.executemany(
        "INSERT INTO chunks (filename, start, end, text) VALUES (?,?,?,?)",
        [(c["filename"], c["start"], c["end"], c["text"]) for c in chunks]
    )
    con.commit()
    con.close()

defbuild_faiss_index(vecs: np.ndarray):
    dim = vecs.shape[1]
    index = faiss.IndexFlatIP(dim)  # cosine(因为我们归一化了向量)
    index.add(vecs.astype(np.float32))
    faiss.write_index(index, str(CFG.INDEX_PATH))
    print(f"📦 FAISS 索引保存到 {CFG.INDEX_PATH}")

defread_faiss_index() -> faiss.Index:
    return faiss.read_index(str(CFG.INDEX_PATH))

defget_chunk_by_ids(ids: List[int]) -> List[Dict]:
    con = sqlite3.connect(CFG.DB_PATH)
    cur = con.cursor()
    rows = []
    for i in ids:
        cur.execute("SELECT id, filename, start, end, text FROM chunks WHERE id=?", (i+1,))
        r = cur.fetchone()
        if r:
            rows.append({
                "id": r[0]-1, "filename": r[1], "start": r[2], "end": r[3], "text": r[4]
            })
    con.close()
    return rows

注意:SQLite 的 row ID 从 1 开始;FAISS 向量从 0 开始索引。我们按插入顺序存储 → 查询时用 (faiss_id + 1)。

6) 搜索(嵌入查询 → 找最近邻)

# searcher.py
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import normalize
from settings import CFG
from store import read_faiss_index, get_chunk_by_ids

_qmodel = None
def_qembed(q: str) -> np.ndarray:
    global _qmodel
    if _qmodel isNone:
        _qmodel = SentenceTransformer(CFG.MODEL_NAME)
    qv = _qmodel.encode([q], convert_to_numpy=True)
    return normalize(qv)  # cosine,和 corpus 一致

defsearch(query: str, k: int = CFG.TOP_K):
    index: faiss.Index = read_faiss_index()
    qv = _qembed(query)
    D, I = index.search(qv.astype(np.float32), k)
    ids = I[0].tolist()
    return get_chunk_by_ids(ids)

7) 回答生成(可选 LLM + extractive 备选)

# answerer.py
import os, re
from typing importList, Dict
from rank_bm25 import BM25Okapi

SYSTEM_PROMPT = (
"你只能从提供的上下文回答。\n"
"引用段落用 [1], [2], ...,按片段顺序。\n"
"如果信息不足,简短说明。\n"
)

def_try_openai(question: str, snippets: List[str]) -> str:
    try:
        from openai import OpenAI
        client = OpenAI()  # 需要环境变量 OPENAI_API_KEY
        ctx = "\n\n".join(f"[{i+1}] {s}"for i, s inenumerate(snippets))
        prompt = f"{SYSTEM_PROMPT}\nContext:\n{ctx}\n\nQuestion: {question}\nAnswer:"
        resp = client.chat.completions.create(
            model=os.getenv("LLM_MODEL", "gpt-4o-mini"),
            messages=[{"role": "user", "content": prompt}],
            temperature=0.2,
            max_tokens=500
        )
        return resp.choices[0].message.content
    except Exception:
        return""

def_extractive_fallback(question: str, snippets: List[str]) -> str:
    # 用 BM25 给片段中的句子评分,拼接成简短总结
    sents, source_ids = [], []
    for i, s inenumerate(snippets):
        for sent in re.split(r"(?<=[.!?])\s+", s):
            if sent.strip():
                sents.append(sent.strip())
                source_ids.append(i)
    tokenized = [st.lower().split() for st in sents]
    bm25 = BM25Okapi(tokenized)
    scores = bm25.get_scores(question.lower().split())
    ranked = sorted(zip(sents, source_ids, scores), key=lambda x: x[2], reverse=True)[:6]
    stitched = []
    used_sources = set()
    for sent, sid, _ in ranked:
        stitched.append(sent + f" [{sid+1}]")
        used_sources.add(sid+1)
    return" ".join(stitched) or"我没有足够的信息来回答。"

defanswer(question: str, passages: List[Dict]) -> str:
    snippets = [p["text"] for p in passages]
    ans = _try_openai(question, snippets)
    return ans if ans.strip() else _extractive_fallback(question, snippets)

8) 构建脚本(一次性索引)

# build_index.py
from loader import load_pdfs
from chunker import chunk_all
from embedder import embed_texts
from store import init_db, save_chunks, build_faiss_index

if __name__ == "__main__":
    docs = load_pdfs()
    chunks = chunk_all(docs)
    init_db()
    save_chunks(chunks)
    vecs = embed_texts(chunks)
    build_faiss_index(vecs)
    print("✅ 索引完成!可以开始提问了!")

运行:

python build_index.py

9) 简洁的 Web 应用(Gradio)

# app.py
import gradio as gr
from searcher import search
from answerer import answer

defask(query: str):
    ifnot query.strip():
        return"输入一个问题开始吧。", ""
    results = search(query, k=4)
    ctx = "\n\n---\n\n".join([r["text"] for r in results])
    ans = answer(query, results)
    cites = "\n".join(f"[{i+1}] {r['filename']} ({r['start']}–{r['end']})"for i, r inenumerate(results))
    return ans, cites

with gr.Blocks(title="PDF Q&A") as demo:
    gr.Markdown("## 📚 PDF Q&A — 从你的文档中问任何问题")
    inp = gr.Textbox(label="你的问题", placeholder="例如:解释 transformers 中的 attention")
    btn = gr.Button("搜索并回答")
    out = gr.Markdown(label="回答")
    refs = gr.Markdown(label="引用")
    btn.click(fn=ask, inputs=inp, outputs=[out, refs])

if __name__ == "__main__":
    demo.launch()

运行:

python app.py

我学到的经验(别重蹈我的覆辙)

检索质量 = 回答质量。精准的 top-k 比花哨的 prompt 更重要。

分块是个平衡游戏。段落感知的合并加上小范围重叠,既好读又保留上下文。

Cosine + 归一化很重要。对 embeddings 做 L2 归一化,用 IndexFlatIP 确保 FAISS 里的 cosine 准确。

可选 LLM,强制备选。别让工具依赖 API key。extractive 方案 + BM25 句子排序效果意外不错。

Metadata 省时间。存好 (filename, start, end),就能立刻深链或显示引用。

下一步升级

• 语义分块(支持标题、目录感知)

• Rerankers(Cohere Rerank 或 BGE cross-encoder)优化最终列表

• 对话记忆(支持后续问题)

• 用 vector DB 持久化(Weaviate, Qdrant 等)

• 服务端部署(Docker + 小型 FastAPI 包装)

小 FAQ

可以完全离线跑吗?可以。Embeddings + FAISS + extractive 备选都是本地的,LLM 是可选的。

能处理几千个 chunk 吗?可以。FAISS 在 CPU 上扩展很好。如果数据量超大,换成 IVF 或 HNSW 索引。

为啥用 Gradio,不用 Streamlit?Gradio 轻量、连接快。用你喜欢的工具——移植很简单。

几个有用的官方文档链接

• Sentence Transformers: https://www.sbert.net/

• FAISS: https://github.com/facebookresearch/faiss

• Gradio: https://www.gradio.app/

本文转载自​PyTorch研习社​,作者:AI研究生

已于2025-9-17 00:02:45修改
收藏
回复
举报
回复
相关推荐