본문 바로가기
카테고리 없음

chroma vectordb

by 후이 (hui) 2025. 10. 15.

import json
from pathlib import Path
from typing import List, Dict, Any
import argparse

import chromadb
from sentence_transformers import SentenceTransformer

# -------- 설정 --------
INPUT_JSON = "data/emb_chunks.json"   # { "fileA.json": [ {chunk, content, vector}, ... ], ... }
PERSIST_DIR = "chroma_store"
COLLECTION = "docs"
DISTANCE = "cosine"                   # cosine / l2 / ip
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"  # 질의 임베딩용 (문서벡터는 JSON 그대로 사용)

# -------- 데이터 로드 --------
def load_preembedded(path: str):
    """
    입력 JSON을 읽어 각 청크를 (id, document, metadata, embedding)로 평탄화.
    id: "{filename}::chunk::{chunk}"
    """
    raw = json.loads(Path(path).read_text(encoding="utf-8"))
    ids, docs, metas, embs = [], [], [], []
    for fname, chunks in raw.items():
        for ch in chunks:
            cid = ch.get("chunk")
            text = ch.get("content") or ""
            vec = ch.get("vector") or []
            if not vec:
                continue
            ids.append(f"{fname}::chunk::{cid}")
            docs.append(text)
            metas.append({"filename": fname, "chunk": cid})
            embs.append(vec)
    return ids, docs, metas, embs

# -------- 인덱싱/적재 --------
def index_to_chroma(ids: List[str], docs: List[str], metas: List[Dict[str,Any]], embs: List[List[float]]):
    client = chromadb.PersistentClient(path=PERSIST_DIR)
    col = client.get_or_create_collection(
        name=COLLECTION,
        metadata={"hnsw:space": DISTANCE},  # 코사인/유클리드/내적
    )
    # 필요 시 기존 동일 id 삭제 가능: col.delete(ids=ids)
    col.add(ids=ids, documents=docs, metadatas=metas, embeddings=embs)
    return col

# -------- Retrieval --------
def retrieve(col, query_text: str, top_k: int = 4):
    # 문서벡터는 이미 DB에 있음 → 쿼리만 임베딩
    model = SentenceTransformer(EMBED_MODEL)
    qvec = model.encode([query_text])[0].tolist()
    res = col.query(query_embeddings=[qvec], n_results=top_k)
    # 출력 정리
    hits = []
    for i in range(len(res["ids"][0])):
        hits.append({
            "id": res["ids"][0][i],
            "text": res["documents"][0][i],
            "metadata": res["metadatas"][0][i],
            "distance": res["distances"][0][i] if "distances" in res else None,
        })
    return hits

# -------- CLI --------
if __name__ == "__main__":
    ap = argparse.ArgumentParser(description="Pre-embedded JSON → Chroma index & retrieval (no normalization).")
    ap.add_argument("--json", default=INPUT_JSON, help="pre-embedded JSON path")
    ap.add_argument("--query", default=None, help="run retrieval with this text query")
    ap.add_argument("--k", type=int, default=4, help="top-k")
    args = ap.parse_args()

    # 1) 인덱싱/적재
    ids, docs, metas, embs = load_preembedded(args.json)
    if not ids:
        raise SystemExit("입력 JSON에 적재할 벡터가 없습니다.")
    col = index_to_chroma(ids, docs, metas, embs)
    print(f"[OK] Indexed {len(ids)} chunks into Chroma @ {PERSIST_DIR}/{COLLECTION}")

    # 2) Retrieval (옵션)
    if args.query:
        hits = retrieve(col, args.query, args.k)
        print(f"\n=== TOP-{args.k} for: {args.query} ===")
        for i, h in enumerate(hits, 1):
            meta = h["metadata"] or {}
            preview = (h["text"][:200] + "...") if len(h["text"]) > 200 else h["text"]
            print(f"[{i}] {h['id']}  ({meta.get('filename')} #chunk={meta.get('chunk')})  dist={h['distance']}")
            print(preview.replace("\n", " "))

728x90
반응형

댓글