From 641a8da8eb6d9e4b606a282aa5da6a5366a65e65 Mon Sep 17 00:00:00 2001 From: Faisal Riaz <46712242+faisalriazz@users.noreply.github.com> Date: Sat, 31 Dec 2022 15:14:30 +0500 Subject: [PATCH] Create googleColabNotebook.ipynb --- googleColabNotebook.ipynb | 1052 +++++++++++++++++++++++++++++++++++++ 1 file changed, 1052 insertions(+) create mode 100644 googleColabNotebook.ipynb diff --git a/googleColabNotebook.ipynb b/googleColabNotebook.ipynb new file mode 100644 index 0000000..aeb8529 --- /dev/null +++ b/googleColabNotebook.ipynb @@ -0,0 +1,1052 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "TPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "#@title checking if directory exist or not deleting perceiver \n", + "import os\n", + "import shutil\n", + "dirname = \"perceiver-ar\"\n", + "isdir = os.path.isdir(dirname)\n", + "if isdir:\n", + " shutil.rmtree(dirname)\n", + " print(\"Deleting {} and subdirectories\".format(dirname))\n", + "else:\n", + " print(\"{} does not exist\".format(dirname))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LvrtiwNmsFM2", + "outputId": "b40b38e4-575b-4d3d-ff61-51aef8595417" + }, + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Deleting perceiver-ar and subdirectories\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ojyzJvW2bv4R", + "outputId": "be87b211-37a8-46fe-82ac-9468d535635a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'perceiver-ar'...\n", + "remote: Enumerating objects: 83, done.\u001b[K\n", + "remote: Counting objects: 100% (50/50), done.\u001b[K\n", + "remote: Compressing objects: 100% (38/38), done.\u001b[K\n", + "remote: Total 83 (delta 30), reused 12 (delta 12), pack-reused 33\u001b[K\n", + "Unpacking objects: 100% (83/83), done.\n" + ] + } + ], + "source": [ + "#@title Cloning Perceiver-AR Reposistory\n", + "! git clone https://github.com/google-research/perceiver-ar.git" + ] + }, + { + "cell_type": "code", + "source": [ + "!pwd" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "S6_Om0ELbVnT", + "outputId": "db7b1d71-c2d1-475c-a2c2-f5b4ddf78643" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Chaning the directory to clone directory\n", + "%cd /content/perceiver-ar" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hsnLiTH0FvoK", + "outputId": "57a77985-d65e-4e4a-e694-006f25163281" + }, + "execution_count": 15, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/perceiver-ar\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "! pip install -r requirements.txt " + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "We0e5GZacgjE", + "outputId": "48eeafe4-0ba8-433c-c720-9059afd24646" + }, + "execution_count": 16, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: absl-py==1.1.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 1)) (1.1.0)\n", + "Requirement already satisfied: argon2-cffi==21.3.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 2)) (21.3.0)\n", + "Requirement already satisfied: argon2-cffi-bindings==21.2.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 3)) (21.2.0)\n", + "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 4)) (1.6.3)\n", + "Requirement already satisfied: attrs==21.4.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 5)) (21.4.0)\n", + "Requirement already satisfied: backcall==0.2.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 6)) (0.2.0)\n", + "Requirement already satisfied: beautifulsoup4==4.11.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 7)) (4.11.1)\n", + "Requirement already satisfied: bleach==5.0.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 8)) (5.0.0)\n", + "Requirement already satisfied: cachetools==5.2.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 9)) (5.2.0)\n", + "Requirement already satisfied: certifi==2022.5.18.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 10)) (2022.5.18.1)\n", + "Requirement already satisfied: cffi==1.15.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 11)) (1.15.0)\n", + "Requirement already satisfied: charset-normalizer==2.0.12 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 12)) (2.0.12)\n", + "Requirement already satisfied: chex==0.1.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 13)) (0.1.3)\n", + "Requirement already satisfied: contextlib2==21.6.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 14)) (21.6.0)\n", + "Requirement already satisfied: cycler==0.11.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 15)) (0.11.0)\n", + "Requirement already satisfied: debugpy==1.6.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 16)) (1.6.0)\n", + "Requirement already satisfied: decorator==5.1.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 17)) (5.1.1)\n", + "Requirement already satisfied: defusedxml==0.7.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 18)) (0.7.1)\n", + "Requirement already satisfied: dill==0.3.5.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 19)) (0.3.5.1)\n", + "Requirement already satisfied: dm-haiku==0.0.6 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 20)) (0.0.6)\n", + "Requirement already satisfied: dm-tree==0.1.7 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 21)) (0.1.7)\n", + "Requirement already satisfied: entrypoints==0.4 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 22)) (0.4)\n", + "Requirement already satisfied: etils==0.6.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 23)) (0.6.0)\n", + "Requirement already satisfied: fastjsonschema==2.15.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 24)) (2.15.3)\n", + "Requirement already satisfied: flatbuffers==1.12 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 25)) (1.12)\n", + "Requirement already satisfied: flax==0.5.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 26)) (0.5.0)\n", + "Requirement already satisfied: fonttools==4.33.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 27)) (4.33.3)\n", + "Requirement already satisfied: gast==0.4.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 28)) (0.4.0)\n", + "Requirement already satisfied: google-auth==2.6.6 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 29)) (2.6.6)\n", + "Requirement already satisfied: google-auth-oauthlib==0.4.6 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 30)) (0.4.6)\n", + "Requirement already satisfied: google-pasta==0.2.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 31)) (0.2.0)\n", + "Requirement already satisfied: googleapis-common-protos==1.56.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 32)) (1.56.2)\n", + "Requirement already satisfied: grpcio==1.46.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 33)) (1.46.3)\n", + "Requirement already satisfied: h5py==3.7.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 34)) (3.7.0)\n", + "Requirement already satisfied: idna==3.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 35)) (3.3)\n", + "Requirement already satisfied: importlib-metadata==4.11.4 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 36)) (4.11.4)\n", + "Requirement already satisfied: importlib-resources==5.7.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 37)) (5.7.1)\n", + "Requirement already satisfied: iniconfig==1.1.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 38)) (1.1.1)\n", + "Requirement already satisfied: ipykernel==6.13.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 39)) (6.13.1)\n", + "Requirement already satisfied: ipython==7.34.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 40)) (7.34.0)\n", + "Requirement already satisfied: ipython-genutils==0.2.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 41)) (0.2.0)\n", + "Requirement already satisfied: ipywidgets==7.7.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 42)) (7.7.0)\n", + "Requirement already satisfied: jax==0.3.13 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 43)) (0.3.13)\n", + "Requirement already satisfied: jaxlib==0.3.10 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 44)) (0.3.10)\n", + "Requirement already satisfied: jaxline==0.0.5 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 45)) (0.0.5)\n", + "Requirement already satisfied: jedi==0.18.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 46)) (0.18.1)\n", + "Requirement already satisfied: Jinja2==3.1.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 47)) (3.1.2)\n", + "Requirement already satisfied: jmp==0.0.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 48)) (0.0.2)\n", + "Requirement already satisfied: jsonschema==4.6.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 49)) (4.6.0)\n", + "Requirement already satisfied: jupyter==1.0.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 50)) (1.0.0)\n", + "Requirement already satisfied: jupyter-client==7.3.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 51)) (7.3.2)\n", + "Requirement already satisfied: jupyter-console==6.4.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 52)) (6.4.3)\n", + "Requirement already satisfied: jupyter-core==4.10.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 53)) (4.10.0)\n", + "Requirement already satisfied: jupyterlab-pygments==0.2.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 54)) (0.2.2)\n", + "Requirement already satisfied: jupyterlab-widgets==1.1.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 55)) (1.1.0)\n", + "Requirement already satisfied: keras==2.9.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 56)) (2.9.0)\n", + "Requirement already satisfied: Keras-Preprocessing==1.1.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 57)) (1.1.2)\n", + "Requirement already satisfied: kiwisolver==1.4.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 58)) (1.4.2)\n", + "Requirement already satisfied: libclang==14.0.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 59)) (14.0.1)\n", + "Requirement already satisfied: Markdown==3.3.7 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 60)) (3.3.7)\n", + "Requirement already satisfied: MarkupSafe==2.1.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 61)) (2.1.1)\n", + "Requirement already satisfied: matplotlib==3.5.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 62)) (3.5.2)\n", + "Requirement already satisfied: matplotlib-inline==0.1.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 63)) (0.1.3)\n", + "Requirement already satisfied: mistune==0.8.4 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 64)) (0.8.4)\n", + "Requirement already satisfied: ml-collections==0.1.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 65)) (0.1.1)\n", + "Requirement already satisfied: msgpack==1.0.4 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 66)) (1.0.4)\n", + "Requirement already satisfied: nbclient==0.6.4 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 67)) (0.6.4)\n", + "Requirement already satisfied: nbconvert==6.5.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 68)) (6.5.0)\n", + "Requirement already satisfied: nbformat==5.4.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 69)) (5.4.0)\n", + "Requirement already satisfied: nest-asyncio==1.5.5 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 70)) (1.5.5)\n", + "Requirement already satisfied: notebook==6.4.11 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 71)) (6.4.11)\n", + "Requirement already satisfied: numpy==1.21.6 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 72)) (1.21.6)\n", + "Requirement already satisfied: oauthlib==3.2.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 73)) (3.2.0)\n", + "Requirement already satisfied: opt-einsum==3.3.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 74)) (3.3.0)\n", + "Requirement already satisfied: optax==0.1.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 75)) (0.1.2)\n", + "Requirement already satisfied: packaging==21.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 76)) (21.3)\n", + "Requirement already satisfied: pandocfilters==1.5.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 77)) (1.5.0)\n", + "Requirement already satisfied: parso==0.8.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 78)) (0.8.3)\n", + "Requirement already satisfied: pexpect==4.8.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 79)) (4.8.0)\n", + "Requirement already satisfied: pickleshare==0.7.5 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 80)) (0.7.5)\n", + "Requirement already satisfied: Pillow==9.1.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 81)) (9.1.1)\n", + "Requirement already satisfied: pluggy==1.0.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 82)) (1.0.0)\n", + "Requirement already satisfied: prometheus-client==0.14.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 83)) (0.14.1)\n", + "Requirement already satisfied: promise==2.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 84)) (2.3)\n", + "Requirement already satisfied: prompt-toolkit==3.0.29 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 85)) (3.0.29)\n", + "Requirement already satisfied: protobuf==3.19.4 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 86)) (3.19.4)\n", + "Requirement already satisfied: psutil==5.9.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 87)) (5.9.1)\n", + "Requirement already satisfied: ptyprocess==0.7.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 88)) (0.7.0)\n", + "Requirement already satisfied: py==1.11.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 89)) (1.11.0)\n", + "Requirement already satisfied: pyasn1==0.4.8 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 90)) (0.4.8)\n", + "Requirement already satisfied: pyasn1-modules==0.2.8 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 91)) (0.2.8)\n", + "Requirement already satisfied: pycparser==2.21 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 92)) (2.21)\n", + "Requirement already satisfied: Pygments==2.12.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 93)) (2.12.0)\n", + "Requirement already satisfied: pyparsing==3.0.9 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 94)) (3.0.9)\n", + "Requirement already satisfied: pyrsistent==0.18.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 95)) (0.18.1)\n", + "Requirement already satisfied: pytest==7.1.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 96)) (7.1.2)\n", + "Requirement already satisfied: python-dateutil==2.8.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 97)) (2.8.2)\n", + "Requirement already satisfied: PyYAML==6.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 98)) (6.0)\n", + "Requirement already satisfied: pyzmq==23.1.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 99)) (23.1.0)\n", + "Requirement already satisfied: qtconsole==5.3.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 100)) (5.3.1)\n", + "Requirement already satisfied: QtPy==2.1.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 101)) (2.1.0)\n", + "Requirement already satisfied: requests==2.27.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 102)) (2.27.1)\n", + "Requirement already satisfied: requests-oauthlib==1.3.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 103)) (1.3.1)\n", + "Requirement already satisfied: rsa==4.8 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 104)) (4.8)\n", + "Requirement already satisfied: scipy==1.7.3 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 105)) (1.7.3)\n", + "Requirement already satisfied: Send2Trash==1.8.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 106)) (1.8.0)\n", + "Requirement already satisfied: six==1.16.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 107)) (1.16.0)\n", + "Requirement already satisfied: soupsieve==2.3.2.post1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 108)) (2.3.2.post1)\n", + "Requirement already satisfied: tabulate==0.8.9 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 109)) (0.8.9)\n", + "Requirement already satisfied: tensorboard==2.9.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 110)) (2.9.0)\n", + "Requirement already satisfied: tensorboard-data-server==0.6.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 111)) (0.6.1)\n", + "Requirement already satisfied: tensorboard-plugin-wit==1.8.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 112)) (1.8.1)\n", + "Requirement already satisfied: tensorflow==2.9.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 113)) (2.9.1)\n", + "Requirement already satisfied: tensorflow-datasets==4.6.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 114)) (4.6.0)\n", + "Requirement already satisfied: tensorflow-estimator==2.9.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 115)) (2.9.0)\n", + "Requirement already satisfied: tensorflow-io-gcs-filesystem==0.26.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 116)) (0.26.0)\n", + "Requirement already satisfied: tensorflow-metadata==1.8.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 117)) (1.8.0)\n", + "Requirement already satisfied: termcolor==1.1.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 118)) (1.1.0)\n", + "Requirement already satisfied: terminado==0.15.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 119)) (0.15.0)\n", + "Requirement already satisfied: tinycss2==1.1.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 120)) (1.1.1)\n", + "Requirement already satisfied: toml==0.10.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 121)) (0.10.2)\n", + "Requirement already satisfied: tomli==2.0.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 122)) (2.0.1)\n", + "Requirement already satisfied: toolz==0.11.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 123)) (0.11.2)\n", + "Requirement already satisfied: tornado==6.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 124)) (6.1)\n", + "Requirement already satisfied: tqdm==4.64.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 125)) (4.64.0)\n", + "Requirement already satisfied: traitlets==5.2.2.post1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 126)) (5.2.2.post1)\n", + "Requirement already satisfied: typing_extensions==4.2.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 127)) (4.2.0)\n", + "Requirement already satisfied: urllib3==1.26.9 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 128)) (1.26.9)\n", + "Requirement already satisfied: wcwidth==0.2.5 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 129)) (0.2.5)\n", + "Requirement already satisfied: webencodings==0.5.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 130)) (0.5.1)\n", + "Requirement already satisfied: Werkzeug==2.1.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 131)) (2.1.2)\n", + "Requirement already satisfied: widgetsnbextension==3.6.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 132)) (3.6.0)\n", + "Requirement already satisfied: wrapt==1.14.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 133)) (1.14.1)\n", + "Requirement already satisfied: zipp==3.8.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 134)) (3.8.0)\n", + "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.8/dist-packages (from astunparse==1.6.3->-r requirements.txt (line 4)) (0.38.4)\n", + "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.8/dist-packages (from ipython==7.34.0->-r requirements.txt (line 40)) (57.4.0)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Chaning the directory to clone directory\n", + "%cd /content/perceiver-ar # required after restart runtime" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6ededd77-1551-4a19-a5be-981d937ed786", + "id": "lTqzcpoffrCQ" + }, + "execution_count": 17, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[Errno 2] No such file or directory: '/content/perceiver-ar # required after restart runtime'\n", + "/content/perceiver-ar\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "! echo $PYTHONPATH" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wLglil-55hak", + "outputId": "65ed58ff-99dd-4398-f6c4-c9d616a24f4b" + }, + "execution_count": 18, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/env/python:/content/perceiver-ar/perceiver_ar/experiment.py --config=perceiver_ar/experiment.py:random_mirrored_32\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title adding the perceiver-ar directory to python path\n", + "import os\n", + "os.environ['PYTHONPATH'] += \":/content/perceiver-ar/perceiver_ar/experiment.py --config=perceiver_ar/experiment.py:random_mirrored_32\"\n", + "! echo $PYTHONPATH" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XV1ChF0o5nwZ", + "outputId": "9da595bb-7d40-4574-97a3-a2fba23632b0" + }, + "execution_count": 19, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/env/python:/content/perceiver-ar/perceiver_ar/experiment.py --config=perceiver_ar/experiment.py:random_mirrored_32:/content/perceiver-ar/perceiver_ar/experiment.py --config=perceiver_ar/experiment.py:random_mirrored_32\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title importing required libraries\n", + "from pathlib import Path\n", + "import functools\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import jaxline\n", + "from typing import Generator, Mapping, Sequence, Text\n", + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds\n", + "import time\n", + "import haiku as hk\n", + "import IPython\n", + "from PIL import Image\n", + "import datetime\n", + "\n", + "from perceiver_ar import experiment\n", + "from perceiver_ar import perceiver_ar_model\n", + "from perceiver_ar import dataset\n", + "from perceiver_ar import sample_utils\n", + "import matplotlib.pyplot as plt" + ], + "metadata": { + "id": "mGXgJwUwDbf7", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "77a1b4f8-91ef-4efb-d973-55b1872edaf3" + }, + "execution_count": 20, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "2022-12-31 10:05:31.167039: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "jax.devices()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2hw1uYiiJyvR", + "outputId": "daca14ba-8581-413c-897f-5c97e42fca9b" + }, + "execution_count": 22, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[CpuDevice(id=0)]" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ] + }, + { + "cell_type": "code", + "source": [ + "TF_CPP_MIN_LOG_LEVEL=0" + ], + "metadata": { + "id": "UjVm6vnEcsXh" + }, + "execution_count": 23, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pwd" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jB1ze63kdl6G", + "outputId": "48b79d74-ea2c-4126-e31f-1c697969f9ca" + }, + "execution_count": 24, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/perceiver-ar\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "## Synthetic Copy Task, 32 positions (suitable for local CPU training)\n", + "# modality = 'raw'\n", + "# input_sequence_init = 'mirror_input'\n", + "# sweep_name = 'random_mirrored_32'\n", + "# checkpoint_base = Path('/tmp/perceiver_ar')\n", + "\n", + "## Synthetic Copy Task, 131k positions\n", + "!mkdir perceiver-ar-checkpoints\n", + "!gsutil cp gs://perceiver-ar/checkpoints/random_mirrored_131072.zip perceiver-ar-checkpoints\n", + "!unzip perceiver-ar-checkpoints/random_mirrored_131072.zip -d perceiver-ar-checkpoints\n", + "modality = 'raw' # modality = raw/image\n", + "input_sequence_init = 'mirror_input'\n", + "sweep_name = 'random_mirrored_131072'\n", + "checkpoint_base = Path('/content/perceiver-ar/perceiver-ar-checkpoints/random_mirrored_131072')" + ], + "metadata": { + "id": "zpiY631HDrLK", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "38a6f699-45b9-4b53-88f4-5ffc472bd7aa" + }, + "execution_count": 25, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Copying gs://perceiver-ar/checkpoints/random_mirrored_131072.zip...\n", + "| [1 files][784.4 MiB/784.4 MiB] 61.4 MiB/s \n", + "Operation completed over 1 objects/784.4 MiB. \n", + "Archive: perceiver-ar-checkpoints/random_mirrored_131072.zip\n", + " creating: perceiver-ar-checkpoints/random_mirrored_131072/\n", + " creating: perceiver-ar-checkpoints/random_mirrored_131072/models/\n", + " creating: perceiver-ar-checkpoints/random_mirrored_131072/models/latest/\n", + " creating: perceiver-ar-checkpoints/random_mirrored_131072/models/latest/step_25000/\n", + " inflating: perceiver-ar-checkpoints/random_mirrored_131072/models/latest/step_25000/checkpoint.dill \n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "print(checkpoint_base)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9SSDFuDjHUG8", + "outputId": "358a7bc3-1a34-46a5-c719-acfcc8546e4a" + }, + "execution_count": 26, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/perceiver-ar/perceiver-ar-checkpoints/random_mirrored_131072\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title loading model from checkpoint\n", + "checkpoint_dir = sorted((checkpoint_base / 'models/latest').iterdir(), key=lambda x: x.stat().st_mtime)[-1]\n", + "\n", + "print(f'Model will be loaded from: {checkpoint_dir}')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yQHfaeA1DzWq", + "outputId": "eb9fd68b-2c33-4386-b8a7-5bed587fe4cb" + }, + "execution_count": 27, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model will be loaded from: /content/perceiver-ar/perceiver-ar-checkpoints/random_mirrored_131072/models/latest/step_25000\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title setting the configuration to sweep name\n", + "config = experiment.get_config(sweep_name)\n", + "experiment.restore_state_to_in_memory_checkpointer(checkpoint_dir, config)" + ], + "metadata": { + "id": "1hRIhuHVe5OR" + }, + "execution_count": 28, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "checkpointer = jaxline.platform.create_checkpointer(config, 'eval')\n", + "state = checkpointer.get_experiment_state('latest')\n", + "\n", + "# Add the fields you want to restore here.\n", + "# Must include experiment_module.\n", + "state.global_step = 0\n", + "state.experiment_module = experiment.Experiment(\n", + " 'eval', jax.random.PRNGKey(config.random_seed),\n", + " **config.experiment_kwargs)\n", + "\n", + "checkpointer.restore('latest')\n", + "exp_params = jaxline.utils.get_first(state.experiment_module._params)\n", + "exp_state = jaxline.utils.get_first(state.experiment_module._state)\n", + "\n", + "max_context_length = config.experiment_kwargs.config.data.max_context_length\n", + "# We want to store max_context_length plus the final prediction.\n", + "max_events_length = max_context_length + 1\n", + "\n", + "events = np.zeros([1, max_events_length], np.int32)\n", + "events[:, 0] = dataset.SOS_ID\n", + "\n", + "print('Restored step', state.global_step)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lVP-Cg-tfAQh", + "outputId": "ef756820-0198-4448-c91e-28223b8e0b5a" + }, + "execution_count": 29, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Restored step 25000\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Set up the input sequence (NB: start_step ignored for some init types)\n", + "\n", + "batch_size = 1#@param {type:\"integer\"}\n", + "start_step = 1#@param {type:\"integer\"}\n", + "\n", + "device_count = jax.local_device_count()\n", + "print('device count', device_count)\n", + "\n", + "max_context_length = config.experiment_kwargs.config.data.max_context_length\n", + "# We want to store max_context_length plus the final prediction.\n", + "max_events_length = max_context_length + 1\n", + "\n", + "if input_sequence_init == 'zeros':\n", + " def gen_initial_events():\n", + " events = np.zeros([device_count, batch_size, max_events_length], np.int32)\n", + " # Add expected SOS prompt.\n", + " events[:, :, 0] = dataset.SOS_ID\n", + " return events\n", + "elif input_sequence_init == 'mirror_input':\n", + " # Account for the SOS\n", + " seq_len = config.experiment_kwargs.config.data.max_context_length - 2\n", + " seq_len = seq_len // 2\n", + "\n", + " start_step = seq_len + 1\n", + " print('Using input_sequence_init `mirror_input`. Setting '\n", + " f'start_step to {start_step}.')\n", + "\n", + " def gen_initial_events():\n", + " # Initialize with a random MirroredDataset sequence.\n", + " events = np.zeros([device_count, batch_size, max_events_length], np.int32)\n", + " rng = jax.random.PRNGKey(0)\n", + " forward_sequence = jax.random.randint(\n", + " rng, [device_count, batch_size, seq_len], \n", + " minval=dataset.NUM_RESERVED_TOKENS, \n", + " maxval=256 + dataset.NUM_RESERVED_TOKENS, \n", + " dtype=jnp.int32)\n", + " \n", + " # Force start_step to half the sequence length:\n", + " events[:, :, 1:seq_len+1] = forward_sequence\n", + " # Add expected SOS prompt.\n", + " events[:, :, 0] = dataset.SOS_ID\n", + " return events\n", + "\n", + "if start_step < 1:\n", + " raise ValueError('start_step must be >= 1 to account for the SOS token.')\n", + "\n", + "# Make sure start_step doesn't exceed the maximum context of the model.\n", + "if start_step > max_context_length:\n", + " print(f'Warning: start_step {start_step} exceeds '\n", + " f'max_context_length used at training {max_context_length}.')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ShK9KC9KfCHj", + "outputId": "85f56b3d-26c3-4c4c-edc3-9863ff936c0e" + }, + "execution_count": 30, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "device count 1\n", + "Using input_sequence_init `mirror_input`. Setting start_step to 65536.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Memory parameters (set to 0 to use model defaults)\n", + "use_memory = True #@param {type:\"boolean\"}\n", + "max_context_length_memory = 0 #@param {type:\"integer\"}\n", + "z_index_dim_memory = 0#@param {type: \"integer\"}\n", + "\n", + "model_kwargs = config.experiment_kwargs.config.model.perceiver_ar_kwargs\n", + "# These values can be adjusted, but set to defaults if not specified\n", + "if use_memory:\n", + " if max_context_length_memory == 0:\n", + " print('Using default max_context_length for memory: '\n", + " f'{config.max_context_length}')\n", + " max_context_length_memory = config.max_context_length\n", + " if z_index_dim_memory == 0:\n", + " print(f'Using default z_index_dim for memory: {model_kwargs.z_index_dim}')\n", + " z_index_dim_memory = model_kwargs.z_index_dim" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0XF4yQnHfLHj", + "outputId": "5e8de3b9-bc78-4495-bacc-db059cfecd30" + }, + "execution_count": 31, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using default max_context_length for memory: 131072\n", + "Using default z_index_dim for memory: 1024\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Set up sampling\n", + "\n", + "@functools.partial(jax.jit, static_argnums=(2, 3, 4, 5))\n", + "def sample_sequences(\n", + " events, rng, \n", + " start_step=1, \n", + " num_steps=-1,\n", + " temperature=1.0, \n", + " top_p=1.0):\n", + " if num_steps < 0:\n", + " num_steps = events.shape[1] - start_step + 1\n", + "\n", + " # Start at position 1 because SOS must be at position 0.\n", + " upper = min(num_steps + start_step, events.shape[1])\n", + " print('start_step', start_step)\n", + " print('upper', upper)\n", + "\n", + " if modality == 'image_w_positions':\n", + " x_event_idxs = jnp.reshape(\n", + " jnp.broadcast_to(jnp.arange(64)[None, :, None] + 1, [64, 64, 3]), [-1])\n", + " y_event_idxs = jnp.reshape(\n", + " jnp.broadcast_to(jnp.arange(64)[:, None, None] + 1, [64, 64, 3]), [-1])\n", + " channel_event_idxs = jnp.reshape(\n", + " jnp.broadcast_to(jnp.array([1, 2, 3]), [64, 64, 3]), [-1])\n", + "\n", + " event_idxs = jnp.stack(\n", + " [x_event_idxs, y_event_idxs, channel_event_idxs], axis=1)\n", + "\n", + " # Account for SOS.\n", + " event_idxs = jnp.concatenate(\n", + " [jnp.ones([1, 3], dtype=jnp.int32), event_idxs + 1], axis=0)\n", + "\n", + " # Pad remaining positions.\n", + " event_idxs = jnp.pad(\n", + " event_idxs, [[0, events.shape[1] - event_idxs.shape[0]], [0, 0]])\n", + "\n", + " event_idxs = jnp.broadcast_to(event_idxs, events.shape + (3,))\n", + " else:\n", + " # Otherwise, assume linear event indices.\n", + " event_idxs = jnp.arange(start=1, stop=events.shape[1] + 1)\n", + " event_idxs = jnp.expand_dims(event_idxs, axis=-1)\n", + " event_idxs = jnp.broadcast_to(event_idxs, events.shape + (1,))\n", + " \n", + " model_kwargs = config.experiment_kwargs.config.model.perceiver_ar_kwargs\n", + "\n", + " if use_memory:\n", + " # Zero-initialize the memory.\n", + " memory = sample_utils.initialize_memory(\n", + " batch_size=batch_size,\n", + " num_transformers_per_block=model_kwargs.num_transformers_per_block,\n", + " num_cross_attend_heads=model_kwargs.num_cross_attend_heads,\n", + " num_transformer_heads=model_kwargs.num_transformer_heads,\n", + " num_z_channels=model_kwargs.num_z_channels,\n", + " max_context_length_memory=max_context_length_memory,\n", + " z_index_dim_memory=z_index_dim_memory,\n", + " position_encoding_type=model_kwargs.position_encoding_type,\n", + " memory_type='fixed_size_kv')\n", + "\n", + " # Build the parameters reused between model calls. \n", + " sample_position_args = dict(\n", + " event_idxs=event_idxs,\n", + " modality=modality,\n", + " temperature=temperature,\n", + " top_p=top_p,\n", + " use_memory=use_memory,\n", + " )\n", + "\n", + " # Package the (constant) params and state with the model.\n", + " forward_fn = functools.partial(\n", + " state.experiment_module.forward.apply,\n", + " exp_params,\n", + " exp_state,\n", + " )\n", + "\n", + " if use_memory and start_step > 1:\n", + " # Run forward the model with long context for one step to initialize \n", + " # the memory and get the first sample.\n", + " events, rng, memory = sample_utils.sample_position(\n", + " i=start_step,\n", + " events_rng_memory=(events, rng, memory),\n", + " forward_fn=forward_fn,\n", + " condition_on='all_previous',\n", + " **sample_position_args)\n", + " start_step += 1\n", + "\n", + " if use_memory:\n", + " condition_on_loop = 'most_recent_only'\n", + " else:\n", + " condition_on_loop = 'all'\n", + " sample_positions_loop = functools.partial(\n", + " sample_utils.sample_position,\n", + " # i, events_rng_memory supplied by caller.\n", + " forward_fn=forward_fn,\n", + " condition_on=condition_on_loop,\n", + " **sample_position_args)\n", + "\n", + " if use_memory:\n", + " inputs = (events, rng, memory)\n", + " else:\n", + " inputs = (events, rng)\n", + " \n", + " outputs = jax.lax.fori_loop(\n", + " start_step, upper, sample_positions_loop, inputs)\n", + " \n", + " return outputs\n", + "\n", + "sample_sequences = jax.pmap(\n", + " sample_sequences,\n", + " static_broadcasted_argnums=(2, 3, 4, 5))\n" + ], + "metadata": { + "id": "2WxyLsBqfnXV" + }, + "execution_count": 32, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Sample the sequence\n", + "\n", + "#@markdown `num_steps`=-1 will fill the entire buffer.\n", + "num_steps = 1#@param {type:\"integer\"}\n", + "temperature = 1.0#@param {type:\"number\"}\n", + "top_p = 1.0#@param {type:\"number\"}\n", + "random_seed = 0#@param{type:\"integer\"}\n", + "\n", + "rng = jax.random.PRNGKey(random_seed)\n", + "rng = jax.random.split(rng, device_count)\n", + "\n", + "events = gen_initial_events()\n", + "\n", + "if 'outputs' in locals():\n", + " del outputs\n", + "if 'memory' in locals():\n", + " del memory\n", + "if 'seq' in locals():\n", + " del seq\n", + "\n", + "tick = time.time()\n", + "print('starting generation', tick)\n", + "\n", + "outputs = sample_sequences(\n", + " events, rng, start_step, num_steps, temperature, top_p)\n", + "\n", + "outputs = jax.tree_map(lambda x: x.block_until_ready(), outputs)\n", + "tock = time.time()\n", + "print('generation complete', tock)\n", + "print(tock - tick, 'seconds')\n", + "print((tock - tick) / 60, 'minutes')\n", + "\n", + "if use_memory:\n", + " events, rng, memory = outputs\n", + "else:\n", + " events, rng = outputs\n", + "\n", + "# reshape to remove device axis\n", + "events = events.reshape((np.prod(events.shape[:2]),) + events.shape[2:])" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wcdm1gl2ft31", + "outputId": "0bf9e1c1-0f4e-48d1-855d-3f95a07534a6" + }, + "execution_count": 33, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "starting generation 1672481283.2725222\n", + "start_step 65536\n", + "upper 65537\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "tcmalloc: large alloc 1073750016 bytes == 0x5f04e000 @ 0x7f1788a581e7 0x7f175f8d4a50 0x7f175f7f9837 0x7f175f7faa69 0x7f175f800c5a 0x7f175f42103a 0x7f175f421086 0x7f175f808197 0x7f175f808a5b 0x7f175f80921b 0x7f175f86a5e6 0x7f175f71afab 0x7f175f71b2a6 0x7f175f808197 0x7f175f808a5b 0x7f175f80921b 0x7f175f83f96d 0x7f175f83f9c6 0x7f175f808197 0x7f175f8095ba 0x7f175f80e2cb 0x7f175f713a78 0x7f175fe9fde5 0x7f175fed46ec 0x7f175fe5299f 0x5d80be 0x5d8d8c 0x4fedd4 0x49abe4 0x5d8868 0x4997c7\n", + "tcmalloc: large alloc 1793105920 bytes == 0xe9888000 @ 0x7f1788a581e7 0x4d30a0 0x57f16a 0x574a98 0x648d6d 0x608583 0x6085dc 0x43eaa7 0x5aac95 0x5d8506 0x7f175fea6bf1 0x7f175f709b27 0x7f175f8e4194 0x7f175f8da083 0x7f175f807c06 0x7f175f87f0b0 0x7f175f808197 0x7f175f808a5b 0x7f175f80921b 0x7f175f86a5e6 0x7f175f71afab 0x7f175f71b2a6 0x7f175f808197 0x7f175f808a5b 0x7f175f80921b 0x7f175f83f96d 0x7f175f83f9c6 0x7f175f808197 0x7f175f8095ba 0x7f175f80e2cb 0x7f175f713a78\n", + "tcmalloc: large alloc 1793400832 bytes == 0x154692000 @ 0x7f1788a581e7 0x4d30a0 0x57f16a 0x574a98 0x648d6d 0x608583 0x608694 0x5f0353 0x4f7699 0x4997a2 0x5d8868 0x4997c7 0x5d8868 0x4997c7 0x55d078 0x5d8941 0x5d8416 0x55f797 0x5d8868 0x5d8506 0x55f797 0x55d078 0x5d8941 0x4997a2 0x55cd91 0x5d8941 0x5d8506 0x55f797 0x55d078 0x5d8941 0x5d8506\n", + "tcmalloc: large alloc 1793400832 bytes == 0xd5778000 @ 0x7f1788a5a887 0x7f17606ba1de 0x7f17606bc979 0x7f17606f2533 0x7f17606d1991 0x5d80be 0x5d8d8c 0x4fedd4 0x49abe4 0x5d8868 0x5d8506 0x55f797 0x55d078 0x5d8941 0x4990ca 0x5d8868 0x4997c7 0x55d078 0x5d8941 0x5d8416 0x55f797 0x5d8868 0x5d8506 0x55f797 0x55d078 0x5d8941 0x4997a2 0x55cd91 0x5d8941 0x5d8506 0x55f797\n", + "2022-12-31 10:08:44.310417: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:61] Constant folding an instruction is taking > 1s:\n", + "\n", + " %gather.6 = f32[1,65536,1024]{2,1,0} gather(f32[131073,1024]{1,0} %constant.126, s32[1,65536,1]{2,1,0} %constant.580), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,1024}, metadata={op_name=\"pmap(sample_sequences)/jit(main)/jit(sample_sequences)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(2,), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1, 1024) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\" source_file=\"/usr/local/lib/python3.8/dist-packages/haiku/_src/embed.py\" source_line=163}\n", + "\n", + "This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.\n", + "\n", + "If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "generation complete 1672481353.0767028\n", + "69.80418062210083 seconds\n", + "1.1634030103683473 minutes\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Do the inputs and outputs of the mirrored_input test match?\n", + "if input_sequence_init == 'mirror_input':\n", + " if num_steps < 0:\n", + " num_steps = events.shape[1] - start_step + 1\n", + "\n", + " start_idx = max(start_step - num_steps, 1)\n", + " end_idx = min(start_step+num_steps, max_context_length-1) \n", + " last_inputs = events[:, start_idx:start_step][:, ::-1]\n", + " first_outputs = events[:, start_step:end_idx]\n", + " print(f'Last inputs (reversed):\\n {last_inputs}')\n", + " print(f'First outputs:\\n {first_outputs}')\n", + " print('Number of matches', (first_outputs == last_inputs).sum(axis=-1))\n", + " print('All match?', np.all(first_outputs == last_inputs))" + ], + "metadata": { + "id": "_qRqJ9VSgPWE", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "69e43ffd-cd0d-4787-afe4-452f90d3ff76" + }, + "execution_count": 34, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Last inputs (reversed):\n", + " [[71]]\n", + "First outputs:\n", + " [[71]]\n", + "Number of matches [1]\n", + "All match? True\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Visualize or play the sampled sequence\n", + "\n", + "inference_dir = checkpoint_base / 'inference' / datetime.datetime.now().isoformat().replace(\":\", \"\")\n", + "im_timestamp = time.time()\n", + "\n", + "for i, seq in enumerate(events):\n", + " eos_idx = jnp.argmax(seq == dataset.EOS_ID)\n", + " if eos_idx:\n", + " print(f'Found EOS at index {eos_idx}, truncating.')\n", + " seq = seq[:eos_idx]\n", + " else:\n", + " print('No EOS token found.')\n", + "\n", + " if seq[0] == dataset.SOS_ID:\n", + " print(f'Found SOS at index 0 as expected, removing.')\n", + " seq = seq[1:]\n", + " else:\n", + " print(f'WARNING: SOS not found at index 0. This should not happen.')\n", + "\n", + " if modality in ['image', 'image_w_positions']:\n", + " seq = seq - dataset.NUM_RESERVED_TOKENS\n", + " rem = len(seq) % 3\n", + " if rem > 0:\n", + " print(f'Truncating {rem} position(s) to ensure multiple of 3')\n", + " seq = seq[:-rem]\n", + " seq = seq.reshape(-1, 3)\n", + " if modality == 'image':\n", + " seq -= jnp.broadcast_to([0, 256, 512], seq.shape)\n", + " rem = len(seq) % 64\n", + " if rem > 0:\n", + " print(f'Truncating {rem} tuple(s) to ensure multiple of 64')\n", + " seq = seq[:-rem]\n", + " seq = seq.reshape(-1, 64, 3)\n", + "\n", + " seq = jnp.where(seq >= 0, seq, 0)\n", + " print('-----')\n", + " print(seq.shape)\n", + " plt.imshow(seq)\n", + " plt.show()\n", + "\n", + " inference_dir.mkdir(parents=True, exist_ok=True)\n", + " im_filename = f'im_{i}.png'\n", + " im_path = inference_dir / im_filename\n", + " img = Image.fromarray(np.array(seq, dtype=np.uint8), mode='RGB')\n", + " print('saving to', im_path)\n", + " IPython.display.display(img)\n", + " img.save(im_path)\n", + " elif modality == 'raw':\n", + " print('Shape:', seq.shape)\n", + " print(seq)\n", + " print('#####')\n", + " else:\n", + " raise ValueError(f'Unknown modality: {modality}')" + ], + "metadata": { + "id": "Vu3-69lngQpc", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "1ec7b696-c5b1-4c65-e3e4-d3185f295bde" + }, + "execution_count": 35, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "No EOS token found.\n", + "Found SOS at index 0 as expected, removing.\n", + "Shape: (131072,)\n", + "[ 65 84 176 ... 0 0 0]\n", + "#####\n" + ] + } + ] + } + ] +}