Skip to content

Commit

Permalink
chore: merge pull request #21 from ninpnin/dev
Browse files Browse the repository at this point in the history
JSON loading and bugfix
  • Loading branch information
ninpnin authored Sep 19, 2024
2 parents 8000c6a + f7f4f43 commit 1fb1596
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 18 deletions.
37 changes: 27 additions & 10 deletions probabilistic_word_embeddings/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import tensorflow as tf
import networkx as nx
import random, pickle
import random, pickle, json
import progressbar
from .utils import dict_to_tf
import warnings
Expand Down Expand Up @@ -39,14 +39,18 @@ def __init__(self, vocabulary=None, dimensionality=100, lambda0=1.0, shared_cont
else:
if type(saved_model_path) != str:
raise TypeError("saved_model_path must be a str")
with open(saved_model_path, "rb") as f:
d = pickle.load(f)
d = None
if saved_model_path.split(".")[-1] == "json":
with open(saved_model_path, "r") as f:
d = json.load(f)
else:
with open(saved_model_path, "rb") as f:
d = pickle.load(f)
self.vocabulary = d["vocabulary"]
self.tf_vocabulary = dict_to_tf(self.vocabulary)
self.theta = tf.Variable(d["theta"])
self.theta = tf.Variable(d["theta"], dtype=tf.float64)
self.lambda0 = d["lambda0"]

@tf.function
def _get_embeddings(self, item):
if type(item) == str:
return self.theta[self.vocabulary[item]]
Expand Down Expand Up @@ -126,8 +130,16 @@ def save(self, path):
if hasattr(self, 'graph'):
d["graph"] = self.graph

with open(path, "wb") as f:
pickle.dump(d, f, protocol=4)
if path.split(".")[-1] == "json":
d["theta"] = theta.tolist()
if "graph" in d:
d["graph"] = nx.readwrite.json_graph.adjacency_data(self.graph)

with open(path, 'w') as f:
json.dump(d, f, indent=2, ensure_ascii=False)
else:
with open(path, "wb") as f:
pickle.dump(d, f, protocol=4)

class LaplacianEmbedding(Embedding):
"""
Expand All @@ -147,11 +159,16 @@ def __init__(self, vocabulary=None, dimensionality=100, graph=None, lambda0=1.0,
self.graph = graph
self.edges_i = None
else:
with open(saved_model_path, "rb") as f:
d = pickle.load(f)
d = None
if saved_model_path.split(".")[-1] == "json":
with open(saved_model_path, "r") as f:
d = json.load(f)
else:
with open(saved_model_path, "rb") as f:
d = pickle.load(f)
self.vocabulary = d["vocabulary"]
self.tf_vocabulary = dict_to_tf(self.vocabulary)
self.theta = tf.Variable(d["theta"])
self.theta = tf.Variable(d["theta"], dtype=tf.float64)
self.lambda0 = d["lambda0"]
self.lambda1 = d["lambda1"]
self.graph = d["graph"]
Expand Down
6 changes: 3 additions & 3 deletions probabilistic_word_embeddings/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def words_in_e(row):
r = len(df)
target_words = list(df[columns[-1]])

X1 = embedding[df[columns[0]]]
X2 = embedding[df[columns[1]]]
X3 = embedding[df[columns[2]]]
X1 = embedding[list(df[columns[0]])]
X2 = embedding[list(df[columns[1]])]
X3 = embedding[list(df[columns[2]])]
X = X1 - X2 + X3

inv_vocab = {v: k for k, v in e.vocabulary.items()}
Expand Down
9 changes: 5 additions & 4 deletions probabilistic_word_embeddings/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,11 @@ def add_subscript(t, subscript):
if labels is not None:
print("Add partition labels to words...")
texts = [add_subscript(text, label) for text, label in zip(texts, progressbar.progressbar(labels))]
vocabs = [set(text) for text in progressbar.progressbar(texts)]
empty = set()
vocabulary = empty.union(*vocabs)


vocabulary = set()
for text in progressbar.progressbar(texts):
for wd in text:
vocabulary.add(wd)

def _remove_subscript(wd):
s = wd.split("_")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "probabilistic-word-embeddings"
version = "1.13.7"
version = "1.15.1"
description = "Probabilistic Word Embeddings for Python"
authors = ["Your Name <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit 1fb1596

Please sign in to comment.