diff --git a/scripts/populate_db.py b/scripts/populate_db.py index a84be5c..ac12640 100644 --- a/scripts/populate_db.py +++ b/scripts/populate_db.py @@ -1,5 +1,4 @@ import requests -import ast import hashlib from datasets import load_dataset from typing import Dict, List @@ -26,18 +25,20 @@ def upsert_point(payload: Dict) -> None: def load_and_format_dataset() -> List[Dict]: - # Load dataset from HuggingFace - dataset = load_dataset("grafanalabs/promql-test-data") - data = dataset["test"] - data = [{**row, "embeddings": ast.literal_eval(row["embeddings"])} for row in data] - return data + # Load dataset with embeddings from HuggingFace + dataset = load_dataset( + "grafanalabs/promql-templates", + data_files="promql-templates-bge-small-en-embeddings.parquet", + split="train", + ) + return dataset def generate_payload(data: List[Dict]) -> List[Dict]: return [ { - "id": hashlib.sha256(row["promql"].encode("utf-8")).hexdigest(), - "embedding": row["embeddings"], + "id": hashlib.md5(row["promql"].encode("utf-8")).hexdigest(), + "embedding": row["embedding"], "metadata": { "promql": row["promql"], "description": row["description"], @@ -49,14 +50,17 @@ def generate_payload(data: List[Dict]) -> List[Dict]: if __name__ == "__main__": + print("Loading dataset from HuggingFace...") data = load_and_format_dataset() # Create vector collection - create_vector_collection(COLLECTION, len(data[0]["embeddings"])) + create_vector_collection(COLLECTION, len(data[0]["embedding"])) # Generate payloads payloads = generate_payload(data) # Upsert each payload point + print("Upserting points...") for payload in tqdm(payloads): upsert_point(payload) + print("Done!")