import json
from pathlib import Path
from typing import List, Dict, Any
import argparse
import chromadb
from sentence_transformers import SentenceTransformer
# -------- 설정 --------
INPUT_JSON = "data/emb_chunks.json" # { "fileA.json": [ {chunk, content, vector}, ... ], ... }
PERSIST_DIR = "chroma_store"
COLLECTION = "docs"
DISTANCE = "cosine" # cosine / l2 / ip
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # 질의 임베딩용 (문서벡터는 JSON 그대로 사용)
# -------- 데이터 로드 --------
def load_preembedded(path: str):
"""
입력 JSON을 읽어 각 청크를 (id, document, metadata, embedding)로 평탄화.
id: "{filename}::chunk::{chunk}"
"""
raw = json.loads(Path(path).read_text(encoding="utf-8"))
ids, docs, metas, embs = [], [], [], []
for fname, chunks in raw.items():
for ch in chunks:
cid = ch.get("chunk")
text = ch.get("content") or ""
vec = ch.get("vector") or []
if not vec:
continue
ids.append(f"{fname}::chunk::{cid}")
docs.append(text)
metas.append({"filename": fname, "chunk": cid})
embs.append(vec)
return ids, docs, metas, embs
# -------- 인덱싱/적재 --------
def index_to_chroma(ids: List[str], docs: List[str], metas: List[Dict[str,Any]], embs: List[List[float]]):
client = chromadb.PersistentClient(path=PERSIST_DIR)
col = client.get_or_create_collection(
name=COLLECTION,
metadata={"hnsw:space": DISTANCE}, # 코사인/유클리드/내적
)
# 필요 시 기존 동일 id 삭제 가능: col.delete(ids=ids)
col.add(ids=ids, documents=docs, metadatas=metas, embeddings=embs)
return col
# -------- Retrieval --------
def retrieve(col, query_text: str, top_k: int = 4):
# 문서벡터는 이미 DB에 있음 → 쿼리만 임베딩
model = SentenceTransformer(EMBED_MODEL)
qvec = model.encode([query_text])[0].tolist()
res = col.query(query_embeddings=[qvec], n_results=top_k)
# 출력 정리
hits = []
for i in range(len(res["ids"][0])):
hits.append({
"id": res["ids"][0][i],
"text": res["documents"][0][i],
"metadata": res["metadatas"][0][i],
"distance": res["distances"][0][i] if "distances" in res else None,
})
return hits
# -------- CLI --------
if __name__ == "__main__":
ap = argparse.ArgumentParser(description="Pre-embedded JSON → Chroma index & retrieval (no normalization).")
ap.add_argument("--json", default=INPUT_JSON, help="pre-embedded JSON path")
ap.add_argument("--query", default=None, help="run retrieval with this text query")
ap.add_argument("--k", type=int, default=4, help="top-k")
args = ap.parse_args()
# 1) 인덱싱/적재
ids, docs, metas, embs = load_preembedded(args.json)
if not ids:
raise SystemExit("입력 JSON에 적재할 벡터가 없습니다.")
col = index_to_chroma(ids, docs, metas, embs)
print(f"[OK] Indexed {len(ids)} chunks into Chroma @ {PERSIST_DIR}/{COLLECTION}")
# 2) Retrieval (옵션)
if args.query:
hits = retrieve(col, args.query, args.k)
print(f"\n=== TOP-{args.k} for: {args.query} ===")
for i, h in enumerate(hits, 1):
meta = h["metadata"] or {}
preview = (h["text"][:200] + "...") if len(h["text"]) > 200 else h["text"]
print(f"[{i}] {h['id']} ({meta.get('filename')} #chunk={meta.get('chunk')}) dist={h['distance']}")
print(preview.replace("\n", " "))
카테고리 없음
chroma vectordb
728x90
반응형
댓글