Skip to content

Commit

Permalink
Merge pull request #56 from grafana/update-populate-db-script
Browse files Browse the repository at this point in the history
update dataset
  • Loading branch information
ioanarm authored Oct 23, 2023
2 parents 79e8e00 + d02fd4d commit 46e8d9e
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions scripts/populate_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import requests
import ast
import hashlib
from datasets import load_dataset
from typing import Dict, List
Expand All @@ -26,18 +25,20 @@ def upsert_point(payload: Dict) -> None:


def load_and_format_dataset() -> List[Dict]:
# Load dataset from HuggingFace
dataset = load_dataset("grafanalabs/promql-test-data")
data = dataset["test"]
data = [{**row, "embeddings": ast.literal_eval(row["embeddings"])} for row in data]
return data
# Load dataset with embeddings from HuggingFace
dataset = load_dataset(
"grafanalabs/promql-templates",
data_files="promql-templates-bge-small-en-embeddings.parquet",
split="train",
)
return dataset


def generate_payload(data: List[Dict]) -> List[Dict]:
return [
{
"id": hashlib.sha256(row["promql"].encode("utf-8")).hexdigest(),
"embedding": row["embeddings"],
"id": hashlib.md5(row["promql"].encode("utf-8")).hexdigest(),
"embedding": row["embedding"],
"metadata": {
"promql": row["promql"],
"description": row["description"],
Expand All @@ -49,14 +50,17 @@ def generate_payload(data: List[Dict]) -> List[Dict]:


if __name__ == "__main__":
print("Loading dataset from HuggingFace...")
data = load_and_format_dataset()

# Create vector collection
create_vector_collection(COLLECTION, len(data[0]["embeddings"]))
create_vector_collection(COLLECTION, len(data[0]["embedding"]))

# Generate payloads
payloads = generate_payload(data)

# Upsert each payload point
print("Upserting points...")
for payload in tqdm(payloads):
upsert_point(payload)
print("Done!")

0 comments on commit 46e8d9e

Please sign in to comment.