Skip to content

Commit

Permalink
Merge pull request #71 from cms-ml/feature/update_tf2_inference
Browse files Browse the repository at this point in the history
Update tf2 inference docs.
  • Loading branch information
valsdav authored Jan 8, 2024
2 parents d3e84b6 + 438ca87 commit 33d0789
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
5 changes: 4 additions & 1 deletion content/inference/code/tensorflow/tensorflow2_mt_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class MyPlugin : public edm::stream::EDAnalyzer<edm::GlobalCache<tensorflow::Ses

static void fillDescriptions(edm::ConfigurationDescriptions&);

// an additional static method for initializing the global cache
// additional static methods for initializing and closing the global cache
static std::unique_ptr<tensorflow::SessionCache> initializeGlobalCache(const edm::ParameterSet&);
static void globalEndJob(const tensorflow::SessionCache*);

private:
void beginJob();
Expand All @@ -47,6 +48,8 @@ std::unique_ptr<tensorflow::SessionCache> MyPlugin::initializeGlobalCache(const
return std::make_unique<tensorflow::SessionCache>(graphPath, options);
}

void MyPlugin::globalEndJob(const tensorflow::SessionCache* cache) {}

void MyPlugin::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
// defining this function will lead to a *_cfi file being generated when compiling
edm::ParameterSetDescription desc;
Expand Down
3 changes: 0 additions & 3 deletions content/inference/tensorflow1.md

This file was deleted.

8 changes: 5 additions & 3 deletions content/inference/tensorflow2.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,20 +369,22 @@ public:
explicit GraphLoadingMT(const edm::ParameterSet&, const tensorflow::SessionCache*);
~GraphLoadingMT();

// an additional static method for initializing the global cache
// additional static methods for initializing and closing the global cache
static std::unique_ptr<tensorflow::SessionCache> initializeGlobalCache(const edm::ParameterSet&);
static void globalEndJob(const CacheData*);
static void globalEndJob(const tensorflow::SessionCache*);
...
```
Implement `initializeGlobalCache` to control the behavior of how the cache object is created.
The destructor of `tensorflow::SessionCache` already handles the closing of the session itself and the deletion of all objects.
You also need to implement `globalEndJob`, however, it can remain empty as the destructor of `tensorflow::SessionCache` already handles the closing of the session itself and the deletion of all objects.
```cpp
std::unique_ptr<tensorflow::SessionCache> MyPlugin::initializeGlobalCache(const edm::ParameterSet& config) {
std::string graphPath = edm::FileInPath(params.getParameter<std::string>("graphPath")).fullPath();
return std::make_unique<tensorflow::SessionCache>(graphPath);
}
void MyPlugin::globalEndJob(const tensorflow::SessionCache* cache) {}
```

??? hint "Custom cache struct"
Expand Down
1 change: 0 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ nav:
- After training: general_advice/after/after.md
- Inference:
- Direct inference:
- TensorFlow 1: inference/tensorflow1.md
- TensorFlow 2: inference/tensorflow2.md
- PyTorch: inference/pytorch.md
- PyTorch Geometric: inference/pyg.md
Expand Down

0 comments on commit 33d0789

Please sign in to comment.