SourceXtractorPlusPlus 1.0.3
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
OnnxModel.h
Go to the documentation of this file.
1/*
2 * OnnxModel.h
3 *
4 * Created on: Feb 16, 2021
5 * Author: mschefer
6 */
7
8#ifndef _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
9#define _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
10
11#include <cstdint>
12#include <map>
13#include <memory>
14#include <vector>
15#include <list>
16#include <iostream>
17#include <numeric>
18
19#include <onnxruntime_cxx_api.h>
20
21namespace SourceXtractor {
22
23class OnnxModel {
24public:
25
26 explicit OnnxModel(const std::string& model_path);
27
28 template<typename T, typename U>
29 void run(std::vector<T>& input_data, std::vector<U>& output_data) const {
30 Ort::RunOptions run_options;
31 auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
32
33 // Allocate memory
35 input_shape[0] = 1;
36 size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
37
38 std::vector<int64_t> output_shape(m_output_shape.begin(), m_output_shape.end());
39 output_shape[0] = 1;
40 size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
41
42 // Check input and output size are OK
43 if (input_data.size() < input_size || output_data.size() < output_size) {
44 throw Elements::Exception() << "OnnxModel: Insufficient buffer size ";
45 }
46
47 // Setup input/output tensors
48 auto input_tensor = Ort::Value::CreateTensor<T>(
49 mem_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size());
50 auto output_tensor = Ort::Value::CreateTensor<U>(
51 mem_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
52
53 // Run the model
54 const char *input_name = m_input_names[0].c_str();
55 const char *output_name = m_output_name.c_str();
56
57 m_session->Run(run_options, &input_name, &input_tensor, 1, &output_name, &output_tensor, 1);
58 }
59
60 template<typename T, typename U>
61 void runMultiInput(std::map<std::string, std::vector<T>>& input_data, std::vector<U>& output_data) const {
62 Ort::RunOptions run_options;
63 auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
64
65 std::vector<const char *> input_names;
66 std::vector<Ort::Value> input_tensors;
67
68 int inputs_nb = m_input_names.size();
69 for (int i=0; i<inputs_nb; i++) {
70 input_names.emplace_back(m_input_names[i].c_str());
71
72 // Allocate memory
74 input_shape[0] = 1;
75 size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
76
77 // Check input size is OK
78 if (input_data[m_input_names[i]].size() < input_size) {
79 throw Elements::Exception() << "OnnxModel: Insufficient buffer size ";
80 }
81
82 input_tensors.emplace_back(Ort::Value::CreateTensor<T>(
83 mem_info, input_data[m_input_names[i]].data(), input_data[m_input_names[i]].size(),
84 input_shape.data(), input_shape.size()));
85 }
86
87 // Output name and shape
88 const char *output_name = m_output_name.c_str();
89 std::vector<int64_t> output_shape(m_output_shape.begin(), m_output_shape.end());
90 output_shape[0] = 1;
91
92 // Setup output tensor
93 size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
94
95 // Check output and output size are OK
96 if (output_data.size() < output_size) {
97 throw Elements::Exception() << "OnnxModel: Insufficient buffer size ";
98 }
99
100 auto output_tensor = Ort::Value::CreateTensor<U>(
101 mem_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
102
103 // Run the model
104 m_session->Run(run_options, &input_names[0], &input_tensors[0], inputs_nb, &output_name, &output_tensor, 1);
105 }
106
107
108 ONNXTensorElementDataType getInputType() const {
109 return m_input_types[0];
110 }
111
112 ONNXTensorElementDataType getOutputType() const {
113 return m_output_type;
114 }
115
117 return m_input_shapes[0];
118 }
119
121 return m_output_shape;
122 }
123
125 return m_domain_name;
126 }
127
129 return m_graph_name;
130 }
131
133 return m_input_names[0];
134 }
135
137 return m_output_name;
138 }
139
141 return m_model_path;
142 }
143
144 size_t getInputNb() const {
145 return m_input_names.size();
146 }
147
148 size_t getOutputNb() const {
149 return 1U;
150 }
151
152private:
158 ONNXTensorElementDataType m_output_type;
163};
164
165}
166
167
168#endif /* _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_ */
T accumulate(T... args)
T begin(T... args)
void run(std::vector< T > &input_data, std::vector< U > &output_data) const
Definition OnnxModel.h:29
ONNXTensorElementDataType getInputType() const
Definition OnnxModel.h:108
ONNXTensorElementDataType getOutputType() const
Definition OnnxModel.h:112
std::vector< ONNXTensorElementDataType > m_input_types
Input type.
Definition OnnxModel.h:157
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition OnnxModel.h:162
std::string getGraphName() const
Definition OnnxModel.h:128
std::string getDomain() const
Definition OnnxModel.h:124
std::string m_output_name
Output tensor name.
Definition OnnxModel.h:156
size_t getOutputNb() const
Definition OnnxModel.h:148
const std::vector< std::int64_t > & getOutputShape() const
Definition OnnxModel.h:120
std::string getOutputName() const
Definition OnnxModel.h:136
ONNXTensorElementDataType m_output_type
Output type.
Definition OnnxModel.h:158
std::vector< std::string > m_input_names
Input tensor name.
Definition OnnxModel.h:155
OnnxModel(const std::string &model_path)
Definition OnnxModel.cpp:17
std::string getInputName() const
Definition OnnxModel.h:132
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition OnnxModel.h:160
std::string m_graph_name
graph name
Definition OnnxModel.h:154
void runMultiInput(std::map< std::string, std::vector< T > > &input_data, std::vector< U > &output_data) const
Definition OnnxModel.h:61
std::string m_domain_name
domain name
Definition OnnxModel.h:153
const std::vector< std::int64_t > & getInputShape() const
Definition OnnxModel.h:116
std::string m_model_path
Path to the ONNX model.
Definition OnnxModel.h:161
size_t getInputNb() const
Definition OnnxModel.h:144
std::vector< std::vector< std::int64_t > > m_input_shapes
Input tensor shape.
Definition OnnxModel.h:159
std::string getModelPath() const
Definition OnnxModel.h:140
T data(T... args)
T emplace_back(T... args)
T end(T... args)
T size(T... args)