-
Notifications
You must be signed in to change notification settings - Fork 157
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(examples): Add a streamlit mnist example (#795)
* feat(examples): Add a streamlit mnist example Signed-off-by: Ce Gao <[email protected]> * fix: Update Signed-off-by: Ce Gao <[email protected]> * fix: Remove requirements.txt Signed-off-by: Ce Gao <[email protected]> * fix: Clean output Signed-off-by: Ce Gao <[email protected]> * fix: Update Signed-off-by: Ce Gao <[email protected]> * fix: Remove run in app.py Signed-off-by: Ce Gao <[email protected]> Signed-off-by: Ce Gao <[email protected]>
- Loading branch information
Showing
8 changed files
with
262 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
model/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Streamlit MNIST demo (drawable) | ||
|
||
> The code is available [here](https://github.com/rahulsrma26/streamlit-mnist-drawable). | ||
A simple digit recognition demo using [keras](https://www.tensorflow.org/overview) and [streamlit](https://www.streamlit.io/). It uses [streamlit-drawable-canvas](https://github.com/andfanilo/streamlit-drawable-canvas) for drawing on canvas. | ||
|
||
![demo](img/demo.gif) | ||
|
||
[streamlit](https://www.streamlit.io/) is an open-source app framework, which is the easiest way for data scientists and machine learning engineers to create beautiful, performant apps. All in pure Python, no longer fiddling with javascript. | ||
|
||
This demo contains two parts: training a simple digit recognition model using mnist dataset and a webapp to live demo that model. | ||
|
||
## Running demo | ||
|
||
1. First install all the dependencies | ||
|
||
``` | ||
envd up | ||
``` | ||
|
||
2. Train model | ||
|
||
Run all the cells of [train.ipynb](train.ipynb) manually. | ||
|
||
3. Run demo web-app | ||
|
||
``` | ||
envd up -f build.envd:serve | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
import numpy as np | ||
import cv2 | ||
from tensorflow.keras.models import load_model | ||
import streamlit as st | ||
from streamlit_drawable_canvas import st_canvas | ||
|
||
model = load_model('model') | ||
|
||
st.title('My Digit Recognizer') | ||
st.markdown(''' | ||
Try to write a digit! | ||
''') | ||
|
||
SIZE = 192 | ||
mode = st.checkbox("Draw (or Delete)?", True) | ||
canvas_result = st_canvas( | ||
fill_color='#000000', | ||
stroke_width=20, | ||
stroke_color='#FFFFFF', | ||
background_color='#000000', | ||
width=SIZE, | ||
height=SIZE, | ||
drawing_mode="freedraw" if mode else "transform", | ||
key='canvas') | ||
|
||
if canvas_result.image_data is not None: | ||
img = cv2.resize(canvas_result.image_data.astype('uint8'), (28, 28)) | ||
rescaled = cv2.resize(img, (SIZE, SIZE), interpolation=cv2.INTER_NEAREST) | ||
st.write('Model Input') | ||
st.image(rescaled) | ||
|
||
if st.button('Predict'): | ||
test_x = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | ||
val = model.predict(test_x.reshape(1, 28, 28)) | ||
st.write(f'result: {np.argmax(val[0])}') | ||
st.bar_chart(val[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
def build(): | ||
base(os="ubuntu20.04", language="python3") | ||
install.vscode_extensions([ | ||
"ms-python.python", | ||
]) | ||
|
||
configure_mnist() | ||
# Configure jupyter notebooks. | ||
config.jupyter() | ||
# Configure zsh. | ||
shell("zsh") | ||
|
||
def serve(): | ||
base(os="ubuntu20.04", language="python3") | ||
configure_streamlit(8501) | ||
configure_mnist() | ||
|
||
def configure_streamlit(port): | ||
install.python_packages([ | ||
"streamlit", | ||
"streamlit_drawable_canvas", | ||
]) | ||
runtime.expose(envd_port=port, host_port=port, service="streamlit") | ||
runtime.daemon(commands=[ | ||
["streamlit", "run", "~/streamlit-mnist/app.py"] | ||
]) | ||
|
||
def configure_mnist(): | ||
# config.pip_index(url = "https://pypi.tuna.tsinghua.edu.cn/simple") | ||
install.system_packages([ | ||
"libgl1", | ||
]) | ||
install.python_packages([ | ||
"tensorflow", | ||
"numpy", | ||
"opencv-python", | ||
"matplotlib", | ||
]) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import tensorflow as tf\n", | ||
"print(tf.__version__)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"print(f'Training samples {len(x_train):,}')\n", | ||
"print(f'Test samples {len(x_test):,}')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def show(idx):\n", | ||
" print(y_train[idx])\n", | ||
" plt.imshow(x_train[idx])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"show(2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"np.random.seed(23)\n", | ||
"tf.random.set_seed(23)\n", | ||
"model = tf.keras.Sequential()\n", | ||
"model.add(tf.keras.layers.Flatten(input_shape=(28,28,1)))\n", | ||
"model.add(tf.keras.layers.Dense(300, activation='relu'))\n", | ||
"model.add(tf.keras.layers.Dropout(0.2))\n", | ||
"model.add(tf.keras.layers.Dense(50, activation='relu'))\n", | ||
"model.add(tf.keras.layers.Dropout(0.3))\n", | ||
"model.add(tf.keras.layers.Dense(10, activation='softmax'))\n", | ||
"model.compile(loss='sparse_categorical_crossentropy',\n", | ||
" optimizer=tf.keras.optimizers.Adam(0.0003),\n", | ||
" metrics=['accuracy'])\n", | ||
"model.summary()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"model.fit(x_train, y_train, batch_size=32, epochs=20)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"model.evaluate(x_test, y_test)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"model.save('model')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters