diff --git a/README.md b/README.md index dc9af111..a3373b70 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ model. * [Evaluation and Inference](#evaluation-and-inference) * [Create Your Own Dataset Files](#create-your-own-dataset-files) * [Training without this Starter Code](#training-without-this-starter-code) +* [Export Your Model for MediaPipe Inference](#export-your-model-for-mediapipe-inference) * [More Documents](#more-documents) * [About This Project](#about-this-project) @@ -321,6 +322,16 @@ and the following for the inference code: num examples processed: 8192 elapsed seconds: 14.85 ``` +## Export Your Model for MediaPipe Inference +To run inference with your model in [MediaPipe inference +demo](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/youtube8m#steps-to-run-the-youtube-8m-inference-graph-with-the-yt8m-dataset), you need to export your checkpoint to a SavedModel. + +Example command: +```sh +python export_model_mediapipe.py --checkpoint_file ~/yt8m/models/frame/sample_model/inference_model/segment_inference_model --output_dir /tmp/mediapipe/saved_model/ +``` + + ## Create Your Own Dataset Files You can create your dataset files from your own videos. Our diff --git a/export_model_mediapipe.py b/export_model_mediapipe.py new file mode 100644 index 00000000..d81cd9c6 --- /dev/null +++ b/export_model_mediapipe.py @@ -0,0 +1,60 @@ +# Lint as: python3 +import numpy as np +import tensorflow as tf +from tensorflow import app +from tensorflow import flags + +FLAGS = flags.FLAGS + + +def main(unused_argv): + # Get the input tensor names to be replaced. + tf.reset_default_graph() + meta_graph_location = FLAGS.checkpoint_file + ".meta" + tf.train.import_meta_graph(meta_graph_location, clear_devices=True) + + input_tensor_name = tf.get_collection("input_batch_raw")[0].name + num_frames_tensor_name = tf.get_collection("num_frames")[0].name + + # Create output graph. + saver = tf.train.Saver() + tf.reset_default_graph() + + input_feature_placeholder = tf.placeholder( + tf.float32, shape=(None, None, 1152)) + num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1)) + + saver = tf.train.import_meta_graph( + meta_graph_location, + input_map={ + input_tensor_name: input_feature_placeholder, + num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1) + }, + clear_devices=True) + predictions_tensor = tf.get_collection("predictions")[0] + + with tf.Session() as sess: + print("restoring variables from " + FLAGS.checkpoint_file) + saver.restore(sess, FLAGS.checkpoint_file) + tf.saved_model.simple_save( + sess, + FLAGS.output_dir, + inputs={'rgb_and_audio': input_feature_placeholder, + 'num_frames': num_frames_placeholder}, + outputs={'predictions': predictions_tensor}) + + # Try running inference. + predictions = sess.run( + [predictions_tensor], + feed_dict={ + input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32), + num_frames_placeholder: np.array([[7]], dtype=np.int32)}) + print('Test inference:', predictions) + + print('Model saved to ', FLAGS.output_dir) + + +if __name__ == '__main__': + flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.') + flags.DEFINE_string('output_dir', None, 'SavedModel output directory.') + app.run(main)