From 438ca87618c2d19c7d2503565e839df8eaf72f3d Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Tue, 2 Jan 2024 11:32:41 +0100 Subject: [PATCH] Update tf2 inference docs. --- .../inference/code/tensorflow/tensorflow2_mt_plugin.cpp | 5 ++++- content/inference/tensorflow1.md | 3 --- content/inference/tensorflow2.md | 8 +++++--- mkdocs.yml | 1 - 4 files changed, 9 insertions(+), 8 deletions(-) delete mode 100644 content/inference/tensorflow1.md diff --git a/content/inference/code/tensorflow/tensorflow2_mt_plugin.cpp b/content/inference/code/tensorflow/tensorflow2_mt_plugin.cpp index 4da7766..a972e28 100644 --- a/content/inference/code/tensorflow/tensorflow2_mt_plugin.cpp +++ b/content/inference/code/tensorflow/tensorflow2_mt_plugin.cpp @@ -20,8 +20,9 @@ class MyPlugin : public edm::stream::EDAnalyzer initializeGlobalCache(const edm::ParameterSet&); + static void globalEndJob(const tensorflow::SessionCache*); private: void beginJob(); @@ -47,6 +48,8 @@ std::unique_ptr MyPlugin::initializeGlobalCache(const return std::make_unique(graphPath, options); } +void MyPlugin::globalEndJob(const tensorflow::SessionCache* cache) {} + void MyPlugin::fillDescriptions(edm::ConfigurationDescriptions& descriptions) { // defining this function will lead to a *_cfi file being generated when compiling edm::ParameterSetDescription desc; diff --git a/content/inference/tensorflow1.md b/content/inference/tensorflow1.md deleted file mode 100644 index 2831e15..0000000 --- a/content/inference/tensorflow1.md +++ /dev/null @@ -1,3 +0,0 @@ -# Direct inference with TensorFlow 1 - -While it is technically still possible to use TensorFlow 1, this version of TensorFlow is quite old and is no longer supported by CMSSW. We highly recommend that you update your model to TensorFlow 2 and follow the integration guide in the [Inference/Direct inference/TensorFlow 2](tensorflow2.md) documentation. \ No newline at end of file diff --git a/content/inference/tensorflow2.md b/content/inference/tensorflow2.md index 57bb4c4..b36fca7 100644 --- a/content/inference/tensorflow2.md +++ b/content/inference/tensorflow2.md @@ -369,20 +369,22 @@ public: explicit GraphLoadingMT(const edm::ParameterSet&, const tensorflow::SessionCache*); ~GraphLoadingMT(); - // an additional static method for initializing the global cache + // additional static methods for initializing and closing the global cache static std::unique_ptr initializeGlobalCache(const edm::ParameterSet&); - static void globalEndJob(const CacheData*); + static void globalEndJob(const tensorflow::SessionCache*); ... ``` Implement `initializeGlobalCache` to control the behavior of how the cache object is created. -The destructor of `tensorflow::SessionCache` already handles the closing of the session itself and the deletion of all objects. +You also need to implement `globalEndJob`, however, it can remain empty as the destructor of `tensorflow::SessionCache` already handles the closing of the session itself and the deletion of all objects. ```cpp std::unique_ptr MyPlugin::initializeGlobalCache(const edm::ParameterSet& config) { std::string graphPath = edm::FileInPath(params.getParameter("graphPath")).fullPath(); return std::make_unique(graphPath); } + +void MyPlugin::globalEndJob(const tensorflow::SessionCache* cache) {} ``` ??? hint "Custom cache struct" diff --git a/mkdocs.yml b/mkdocs.yml index 1189b07..0dc44e6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -143,7 +143,6 @@ nav: - After training: general_advice/after/after.md - Inference: - Direct inference: - - TensorFlow 1: inference/tensorflow1.md - TensorFlow 2: inference/tensorflow2.md - PyTorch: inference/pytorch.md - PyTorch Geometric: inference/pyg.md