-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathArmnnPreparedModel_1_3.hpp
148 lines (119 loc) · 6.88 KB
/
ArmnnPreparedModel_1_3.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "ArmnnDriver.hpp"
#include "ArmnnDriverImpl.hpp"
#include "RequestThread_1_3.hpp"
#include "ModelToINetworkConverter.hpp"
#include <NeuralNetworks.h>
#include <armnn/ArmNN.hpp>
#include <string>
#include <vector>
namespace armnn_driver
{
using CallbackAsync_1_3 = std::function<
void(V1_3::ErrorStatus errorStatus,
std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
std::string callingFunction)>;
struct ExecutionContext_1_3
{
::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings =
::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
TimePoint driverStart;
TimePoint driverEnd;
TimePoint deviceStart;
TimePoint deviceEnd;
};
using CallbackContext_1_3 = CallbackContext<CallbackAsync_1_3, ExecutionContext_1_3>;
using executeFenced_cb = std::function<void(::android::hardware::neuralnetworks::V1_3::ErrorStatus status,
const ::android::hardware::hidl_handle& syncFence,
const ::android::sp<::android::hardware::neuralnetworks::V1_3::IFencedExecutionCallback>& callback)>;
template <typename HalVersion>
class ArmnnPreparedModel_1_3 : public V1_3::IPreparedModel
{
public:
using HalModel = typename V1_3::Model;
ArmnnPreparedModel_1_3(armnn::NetworkId networkId,
armnn::IRuntime* runtime,
const HalModel& model,
const std::string& requestInputsAndOutputsDumpDir,
const bool gpuProfilingEnabled,
V1_3::Priority priority = V1_3::Priority::MEDIUM);
virtual ~ArmnnPreparedModel_1_3();
Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
const sp<V1_0::IExecutionCallback>& callback) override;
Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, V1_2::MeasureTiming measure,
const sp<V1_2::IExecutionCallback>& callback) override;
Return<V1_3::ErrorStatus> execute_1_3(const V1_3::Request& request,
V1_2::MeasureTiming measure,
const V1_3::OptionalTimePoint&,
const V1_3::OptionalTimeoutDuration&,
const sp<V1_3::IExecutionCallback>& callback) override;
Return<void> executeSynchronously(const V1_0::Request &request,
V1_2::MeasureTiming measure,
V1_3::IPreparedModel::executeSynchronously_cb cb) override;
Return<void> executeSynchronously_1_3(const V1_3::Request &request,
V1_2::MeasureTiming measure,
const V1_3::OptionalTimePoint& deadline,
const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
V1_3::IPreparedModel::executeSynchronously_1_3_cb cb) override;
Return<void> executeFenced(const V1_3::Request& request,
const android::hardware::hidl_vec<android::hardware::hidl_handle>& fenceWaitFor,
V1_2::MeasureTiming measure,
const V1_3::OptionalTimePoint& deadline,
const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
const V1_3::OptionalTimeoutDuration& duration,
executeFenced_cb callback) override;
Return<void> configureExecutionBurst(
const sp<V1_2::IBurstCallback>& callback,
const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
configureExecutionBurst_cb cb) override;
template<typename CallbackContext>
Return<void> ExecuteSynchronously(const V1_3::Request& request, CallbackContext cbCtx);
/// execute the graph prepared from the request
template<typename CallbackContext>
Return <V1_3::ErrorStatus> ExecuteGraph(
std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
armnn::InputTensors& inputTensors,
armnn::OutputTensors& outputTensors,
CallbackContext callback);
/// Executes this model with dummy inputs (e.g. all zeroes).
/// \return false on failure, otherwise true
bool ExecuteWithDummyInputs();
V1_3::Priority GetModelPriority();
private:
Return <V1_3::ErrorStatus> Execute(const V1_3::Request& request,
V1_2::MeasureTiming measureTiming,
CallbackAsync_1_3 callback);
Return<V1_3::ErrorStatus> PrepareMemoryForInputs(
armnn::InputTensors& inputs,
const V1_3::Request& request,
const std::vector<android::nn::RunTimePoolInfo>& memPools);
Return<V1_3::ErrorStatus> PrepareMemoryForOutputs(
armnn::OutputTensors& outputs,
std::vector<V1_2::OutputShape> &outputShapes,
const V1_3::Request& request,
const std::vector<android::nn::RunTimePoolInfo>& memPools);
std::tuple<V1_3::ErrorStatus, android::hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing, std::string> PrepareMemoryForIO(
armnn::InputTensors& inputs,
armnn::OutputTensors& outputs,
std::vector<android::nn::RunTimePoolInfo>& memPools,
const V1_3::Request& request);
template <typename TensorBindingCollection>
void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
armnn::NetworkId m_NetworkId;
armnn::IRuntime* m_Runtime;
V1_3::Model m_Model;
// There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
// It is specific to this class, so it is declared as static here
static RequestThread_1_3<ArmnnPreparedModel_1_3, HalVersion, CallbackContext_1_3> m_RequestThread;
uint32_t m_RequestCount;
const std::string& m_RequestInputsAndOutputsDumpDir;
const bool m_GpuProfilingEnabled;
V1_3::Priority m_ModelPriority;
};
}