From a7ce3ab15fa40369cd9f28cfd615ad77fdb077f8 Mon Sep 17 00:00:00 2001 From: Douglas Blank Date: Fri, 19 May 2023 16:32:19 -0700 Subject: [PATCH] Fixed include for pca; pass in kwargs for projection algorithm; version 2.4.1 --- backend/kangas/_version.py | 2 +- backend/kangas/datatypes/embedding.py | 55 +++++++++++++------ .../Visualizing_embeddings_in_Kangas.ipynb | 40 +++++++++----- 3 files changed, 67 insertions(+), 30 deletions(-) diff --git a/backend/kangas/_version.py b/backend/kangas/_version.py index 3b5950b..0d7d1d5 100644 --- a/backend/kangas/_version.py +++ b/backend/kangas/_version.py @@ -11,5 +11,5 @@ # All rights reserved # ###################################################### -version_info = (2, 4, 0) +version_info = (2, 4, 1) __version__ = ".".join(map(str, version_info)) diff --git a/backend/kangas/datatypes/embedding.py b/backend/kangas/datatypes/embedding.py index 5d92526..838018b 100644 --- a/backend/kangas/datatypes/embedding.py +++ b/backend/kangas/datatypes/embedding.py @@ -62,6 +62,7 @@ def __init__( source=None, unserialize=False, dimensions=PROJECTION_DIMENSIONS, + **kwargs ): """ Create an embedding vector. @@ -79,6 +80,7 @@ def __init__( projection. Useful if you want to see one part of the datagrid in the project of another. dimensions: (int) maximum number of dimensions + kwargs: (dict) optional keyword arguments for projection algorithm Example: @@ -91,6 +93,11 @@ def __init__( >>> dg.save("embeddings.datagrid") ``` """ + if not include and projection not in ["pca"]: + raise Exception( + "projection '%s' does not allow embeddings to be excluded; change projection or set include=True" + ) + super().__init__(source) if unserialize: self._unserialize = unserialize @@ -115,6 +122,7 @@ def __init__( self.metadata["projection"] = projection self.metadata["include"] = include self.metadata["dimensions"] = dimensions + self.metadata["kwargs"] = kwargs if file_name: if is_valid_file_path(file_name): @@ -174,7 +182,9 @@ def get_statistics(cls, datagrid, col_name, field_name): projection = None batch = [] - asset_ids = [] + batch_asset_ids = [] + not_included = [] + not_included_asset_ids = [] for row in datagrid.conn.execute( """SELECT {field_name} as assetId, asset_data, asset_metadata from datagrid JOIN assets ON assetId = assets.asset_id;""".format( @@ -186,20 +196,11 @@ def get_statistics(cls, datagrid, col_name, field_name): continue asset_metdata = json.loads(asset_metadata_json) + projection = asset_metdata["projection"] include = asset_metdata["include"] dimensions = asset_metdata["dimensions"] - - # Skip if explicitly False - if not include: - continue - - asset_data = json.loads(asset_data_json) - vector = prepare_embedding(asset_data["vector"], dimensions, seed) - - # Save asset_id to update assets next - batch.append(vector) - asset_ids.append(asset_id) + kwargs = asset_metdata["kwargs"] if projection == "pca": projection_name = "pca" @@ -208,19 +209,37 @@ def get_statistics(cls, datagrid, col_name, field_name): elif projection == "umap": projection_name = "umap" else: - raise Exception("projection not found") + raise Exception("projection not found for %s" % asset_id) + + asset_data = json.loads(asset_data_json) + vector = prepare_embedding(asset_data["vector"], dimensions, seed) + + if include: + batch.append(vector) + batch_asset_ids.append(asset_id) + else: + not_included.append(vector) + not_included_asset_ids.append(asset_id) if projection_name == "pca": from sklearn.decomposition import PCA - projection = PCA(n_components=2) + if "n_components" not in kwargs: + kwargs["n_components"] = 2 + + projection = PCA(**kwargs) transformed = projection.fit_transform(np.array(batch)) + if not_included: + transformed_not_included = projection.transform(np.array(not_included)) + else: + transformed_not_included = np.array([]) elif projection_name == "t-sne": from sklearn.manifold import TSNE - projection = TSNE() + projection = TSNE(**kwargs) transformed = projection.fit_transform(np.array(batch)) + transformed_not_included = np.array([]) elif projection_name == "umap": pass # TODO @@ -244,7 +263,11 @@ def get_statistics(cls, datagrid, col_name, field_name): # update assets with transformed cursor = datagrid.conn.cursor() - for asset_id, tran in zip(asset_ids, transformed): + if not_included_asset_ids: + batch_asset_ids = batch_asset_ids + not_included_asset_ids + transformed = np.concatenate((transformed, transformed_not_included)) + + for asset_id, tran in zip(batch_asset_ids, transformed): sql = """SELECT asset_data from assets WHERE asset_id = ?;""" asset_data_json = datagrid.conn.execute(sql, (asset_id,)).fetchone()[0] asset_data = json.loads(asset_data_json) diff --git a/notebooks/Visualizing_embeddings_in_Kangas.ipynb b/notebooks/Visualizing_embeddings_in_Kangas.ipynb index 28b1a9c..a2709c1 100644 --- a/notebooks/Visualizing_embeddings_in_Kangas.ipynb +++ b/notebooks/Visualizing_embeddings_in_Kangas.ipynb @@ -92,8 +92,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "1001it [00:00, 3257.63it/s]\n", - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 4498.83it/s]\n" + "1001it [00:00, 2097.76it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2324.51it/s]\n" ] } ], @@ -237,7 +237,10 @@ " row[8] = kg.Embedding(\n", " ast.literal_eval(row[8]), \n", " name=str(row[3]), \n", - " text=\"%s - %.10s\" % (row[3], row[4])\n", + " text=\"%s - %.10s\" % (row[3], row[4]),\n", + " projection=\"t-sne\",\n", + " learning_rate=10.0,\n", + " n_iter=500,\n", " )\n", " dg.append(row)" ] @@ -309,7 +312,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 34767.11it/s]" + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 28791.81it/s]" ] }, { @@ -337,7 +340,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:01<00:00, 636.93it/s]\n" + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 377.84it/s]\n" ] }, { @@ -351,7 +354,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:01<00:00, 7.00it/s]\n" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:08<00:00, 1.23it/s]\n" ] } ], @@ -365,7 +368,11 @@ "source": [ "### 3. Render 2D Projections\n", "\n", - "To render the data directly in the notebook, simply show it. Note that each row contains an embedding projection. Group by \"Score\" to see rows of each group." + "To render the data directly in the notebook, simply show it. Note that each row contains an embedding projection. \n", + "\n", + "Scroll to far right to see embeddings projection per row.\n", + "\n", + "The color of the point in projection space represents the Score." ] }, { @@ -387,7 +394,7 @@ " " + "" ] }, "metadata": {}, @@ -406,6 +413,13 @@ "dg.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Group by \"Score\" to see rows of each group. Again, scroll right to see groups of embeddings." + ] + }, { "cell_type": "code", "execution_count": 10, @@ -418,7 +432,7 @@ " " + "" ] }, "metadata": {}, @@ -434,7 +448,7 @@ } ], "source": [ - "dg.show(group=\"Score\", sort=\"Score\")" + "dg.show(group=\"Score\", sort=\"Score\", rows=5)" ] }, { @@ -470,7 +484,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.11" }, "vscode": { "interpreter": {