
Python+FAISS:五分钟打造一个RAG系统
我遇到个麻烦:手头有几十(好吧,实际上是几百)个 PDF 文件——研究论文、API 文档、白皮书——散落在各个文件夹里。搜索慢得要死,浏览更烦。所以我搞了个 PDF 问答引擎,能把文件吃进去、分块、嵌入、用 FAISS 索引、找最佳段落,还能给个简洁的回答(而且有不用 API 的备选方案)。这篇文章把所有东西都给你——端到端的代码,用大白话解释清楚。
你能得到啥
• 本地 PDF 加载(不用云)
• 更聪明的分块(保留上下文)
• 用 sentence-transformers 做 embeddings
• 用 FAISS 做向量搜索(cosine)+ SQLite 存 metadata
• Retriever → Answerer,支持可选 LLM(有 extractive summary 备选)
• 用 Gradio 做个简洁的单页 app
关键词(方便别人找到这篇):AI document search, PDF search, vector database, FAISS, embeddings, Sentence Transformers, RAG, Gradio, OpenAI(可选)。
项目结构(直接复制成文件)
pdfqa/
settings.py
loader.py
chunker.py
embedder.py
store.py
searcher.py
answerer.py
app.py
build_index.py
requirements.txt
0) 环境要求
# requirements.txt
pdfplumber>=0.11.0
sentence-transformers>=3.0.1
faiss-cpu>=1.8.0
numpy>=1.26.4
scikit-learn>=1.5.1
tqdm>=4.66.4
gradio>=4.40.0
python-dotenv>=1.0.1
nltk>=3.9.1
rank-bm25>=0.2.2
openai>=1.40.0 # 可选;不用 API key 也能跑
安装并准备 NLTK(只需一次):
python -c "import nltk; nltk.download('punkt_tab')" || python -c "import nltk; nltk.download('punkt')"
1) 设置
# settings.py
from pathlib import Path
from dataclasses import dataclass
@dataclass(frozen=True)
class Config:
PDF_DIR: Path = Path("./pdfs") # 放你的 PDF 文件
DB_PATH: Path = Path("./chunks.sqlite") # SQLite 存 chunk metadata
INDEX_PATH: Path = Path("./index.faiss") # FAISS 索引文件
MODEL_NAME: str = "sentence-transformers/all-MiniLM-L12-v2"
CHUNK_SIZE: int = 1000 # 每块目标字符数
CHUNK_OVERLAP: int = 200
TOP_K: int = 4 # 检索的段落数
MAX_ANSWER_TOKENS: int = 500 # 用于 LLM
CFG = Config()
2) 加载 PDF(本地,超快)
# loader.py
import pdfplumber
from pathlib import Path
from typing importList, Dict
from settings import CFG
defload_pdfs(pdf_dir: Path = CFG.PDF_DIR) -> List[Dict]:
pdf_dir.mkdir(parents=True, exist_ok=True)
docs = []
for pdf_path insorted(pdf_dir.glob("*.pdf")):
text_parts = []
with pdfplumber.open(pdf_path) as pdf:
for page in pdf.pages:
# 比简单 get_text 更稳;可根据需要调整
text_parts.append(page.extract_text() or"")
text = "\n".join(text_parts).strip()
if text:
docs.append({"filename": pdf_path.name, "text": text})
print(f"✅ 加载 {pdf_path.name} ({len(text)} 字符)")
else:
print(f"⚠️ 空的或无法提取:{pdf_path.name}")
return docs
if __name__ == "__main__":
load_pdfs()
3) 分块,保留上下文(段落感知)
# chunker.py
from typing importList, Dict
from settings import CFG
def_paragraphs(txt: str) -> List[str]:
# 按空行分割;保持结构轻量
blocks = [b.strip() for b in txt.split("\n\n") if b.strip()]
return blocks or [txt]
defchunk_document(doc: Dict, size: int = CFG.CHUNK_SIZE, overlap: int = CFG.CHUNK_OVERLAP) -> List[Dict]:
paras = _paragraphs(doc["text"])
chunks = []
buf, start_char = [], 0
cur_len = 0
for p in paras:
if cur_len + len(p) + 1 <= size:
buf.append(p)
cur_len += len(p) + 1
continue
# 清空缓冲
block = "\n\n".join(buf).strip()
if block:
chunks.append({
"filename": doc["filename"],
"start": start_char,
"end": start_char + len(block),
"text": block
})
# 从尾部创建重叠部分
tail = block[-overlap:] if overlap > 0andlen(block) > overlap else""
buf = [tail, p] if tail else [p]
start_char += max(0, len(block) - overlap)
cur_len = len("\n\n".join(buf))
# 最后一块
block = "\n\n".join(buf).strip()
if block:
chunks.append({
"filename": doc["filename"],
"start": start_char,
"end": start_char + len(block),
"text": block
})
return chunks
defchunk_all(docs: List[Dict]) -> List[Dict]:
out = []
for d in docs:
out.extend(chunk_document(d))
print(f"🔹 创建了 {len(out)} 个 chunk")
return out
if __name__ == "__main__":
from loader import load_pdfs
all_chunks = chunk_all(load_pdfs())
4) 嵌入(支持 cosine)
# embedder.py
import numpy as np
from typing importList, Dict
from sentence_transformers import SentenceTransformer
from settings import CFG
from tqdm import tqdm
from sklearn.preprocessing import normalize
_model = None
defget_model() -> SentenceTransformer:
global _model
if _model isNone:
_model = SentenceTransformer(CFG.MODEL_NAME)
return _model
defembed_texts(chunks: List[Dict]) -> np.ndarray:
model = get_model()
texts = [c["text"] for c in chunks]
# encode → L2 归一化,确保 Inner Product == Cosine similarity
vecs = model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
return normalize(vecs) # 对 IndexFlatIP 很重要
5) 存储向量 + metadata(FAISS + SQLite)
# store.py
import sqlite3
import faiss
import numpy as np
from typing importList, Dict
from settings import CFG
definit_db():
con = sqlite3.connect(CFG.DB_PATH)
cur = con.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY,
filename TEXT,
start INTEGER,
end INTEGER,
text TEXT
)
""")
con.commit()
con.close()
defsave_chunks(chunks: List[Dict]):
con = sqlite3.connect(CFG.DB_PATH)
cur = con.cursor()
cur.execute("DELETE FROM chunks")
cur.executemany(
"INSERT INTO chunks (filename, start, end, text) VALUES (?,?,?,?)",
[(c["filename"], c["start"], c["end"], c["text"]) for c in chunks]
)
con.commit()
con.close()
defbuild_faiss_index(vecs: np.ndarray):
dim = vecs.shape[1]
index = faiss.IndexFlatIP(dim) # cosine(因为我们归一化了向量)
index.add(vecs.astype(np.float32))
faiss.write_index(index, str(CFG.INDEX_PATH))
print(f"📦 FAISS 索引保存到 {CFG.INDEX_PATH}")
defread_faiss_index() -> faiss.Index:
return faiss.read_index(str(CFG.INDEX_PATH))
defget_chunk_by_ids(ids: List[int]) -> List[Dict]:
con = sqlite3.connect(CFG.DB_PATH)
cur = con.cursor()
rows = []
for i in ids:
cur.execute("SELECT id, filename, start, end, text FROM chunks WHERE id=?", (i+1,))
r = cur.fetchone()
if r:
rows.append({
"id": r[0]-1, "filename": r[1], "start": r[2], "end": r[3], "text": r[4]
})
con.close()
return rows
注意:SQLite 的 row ID 从 1 开始;FAISS 向量从 0 开始索引。我们按插入顺序存储 → 查询时用 (faiss_id + 1)。
6) 搜索(嵌入查询 → 找最近邻)
# searcher.py
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import normalize
from settings import CFG
from store import read_faiss_index, get_chunk_by_ids
_qmodel = None
def_qembed(q: str) -> np.ndarray:
global _qmodel
if _qmodel isNone:
_qmodel = SentenceTransformer(CFG.MODEL_NAME)
qv = _qmodel.encode([q], convert_to_numpy=True)
return normalize(qv) # cosine,和 corpus 一致
defsearch(query: str, k: int = CFG.TOP_K):
index: faiss.Index = read_faiss_index()
qv = _qembed(query)
D, I = index.search(qv.astype(np.float32), k)
ids = I[0].tolist()
return get_chunk_by_ids(ids)
7) 回答生成(可选 LLM + extractive 备选)
# answerer.py
import os, re
from typing importList, Dict
from rank_bm25 import BM25Okapi
SYSTEM_PROMPT = (
"你只能从提供的上下文回答。\n"
"引用段落用 [1], [2], ...,按片段顺序。\n"
"如果信息不足,简短说明。\n"
)
def_try_openai(question: str, snippets: List[str]) -> str:
try:
from openai import OpenAI
client = OpenAI() # 需要环境变量 OPENAI_API_KEY
ctx = "\n\n".join(f"[{i+1}] {s}"for i, s inenumerate(snippets))
prompt = f"{SYSTEM_PROMPT}\nContext:\n{ctx}\n\nQuestion: {question}\nAnswer:"
resp = client.chat.completions.create(
model=os.getenv("LLM_MODEL", "gpt-4o-mini"),
messages=[{"role": "user", "content": prompt}],
temperature=0.2,
max_tokens=500
)
return resp.choices[0].message.content
except Exception:
return""
def_extractive_fallback(question: str, snippets: List[str]) -> str:
# 用 BM25 给片段中的句子评分,拼接成简短总结
sents, source_ids = [], []
for i, s inenumerate(snippets):
for sent in re.split(r"(?<=[.!?])\s+", s):
if sent.strip():
sents.append(sent.strip())
source_ids.append(i)
tokenized = [st.lower().split() for st in sents]
bm25 = BM25Okapi(tokenized)
scores = bm25.get_scores(question.lower().split())
ranked = sorted(zip(sents, source_ids, scores), key=lambda x: x[2], reverse=True)[:6]
stitched = []
used_sources = set()
for sent, sid, _ in ranked:
stitched.append(sent + f" [{sid+1}]")
used_sources.add(sid+1)
return" ".join(stitched) or"我没有足够的信息来回答。"
defanswer(question: str, passages: List[Dict]) -> str:
snippets = [p["text"] for p in passages]
ans = _try_openai(question, snippets)
return ans if ans.strip() else _extractive_fallback(question, snippets)
8) 构建脚本(一次性索引)
# build_index.py
from loader import load_pdfs
from chunker import chunk_all
from embedder import embed_texts
from store import init_db, save_chunks, build_faiss_index
if __name__ == "__main__":
docs = load_pdfs()
chunks = chunk_all(docs)
init_db()
save_chunks(chunks)
vecs = embed_texts(chunks)
build_faiss_index(vecs)
print("✅ 索引完成!可以开始提问了!")
运行:
python build_index.py
9) 简洁的 Web 应用(Gradio)
# app.py
import gradio as gr
from searcher import search
from answerer import answer
defask(query: str):
ifnot query.strip():
return"输入一个问题开始吧。", ""
results = search(query, k=4)
ctx = "\n\n---\n\n".join([r["text"] for r in results])
ans = answer(query, results)
cites = "\n".join(f"[{i+1}] {r['filename']} ({r['start']}–{r['end']})"for i, r inenumerate(results))
return ans, cites
with gr.Blocks(title="PDF Q&A") as demo:
gr.Markdown("## 📚 PDF Q&A — 从你的文档中问任何问题")
inp = gr.Textbox(label="你的问题", placeholder="例如:解释 transformers 中的 attention")
btn = gr.Button("搜索并回答")
out = gr.Markdown(label="回答")
refs = gr.Markdown(label="引用")
btn.click(fn=ask, inputs=inp, outputs=[out, refs])
if __name__ == "__main__":
demo.launch()
运行:
python app.py
我学到的经验(别重蹈我的覆辙)
•检索质量 = 回答质量。精准的 top-k 比花哨的 prompt 更重要。
•分块是个平衡游戏。段落感知的合并加上小范围重叠,既好读又保留上下文。
•Cosine + 归一化很重要。对 embeddings 做 L2 归一化,用 IndexFlatIP 确保 FAISS 里的 cosine 准确。
•可选 LLM,强制备选。别让工具依赖 API key。extractive 方案 + BM25 句子排序效果意外不错。
•Metadata 省时间。存好 (filename, start, end),就能立刻深链或显示引用。
下一步升级
• 语义分块(支持标题、目录感知)
• Rerankers(Cohere Rerank 或 BGE cross-encoder)优化最终列表
• 对话记忆(支持后续问题)
• 用 vector DB 持久化(Weaviate, Qdrant 等)
• 服务端部署(Docker + 小型 FastAPI 包装)
小 FAQ
可以完全离线跑吗?可以。Embeddings + FAISS + extractive 备选都是本地的,LLM 是可选的。
能处理几千个 chunk 吗?可以。FAISS 在 CPU 上扩展很好。如果数据量超大,换成 IVF 或 HNSW 索引。
为啥用 Gradio,不用 Streamlit?Gradio 轻量、连接快。用你喜欢的工具——移植很简单。
几个有用的官方文档链接
• Sentence Transformers: https://www.sbert.net/
• FAISS: https://github.com/facebookresearch/faiss
• Gradio: https://www.gradio.app/
本文转载自PyTorch研习社,作者:AI研究生
