SourceXtractorPlusPlus 1.0.3
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
OnnxSourceTask.cpp
Go to the documentation of this file.
1
17
22#include <NdArray/NdArray.h>
23#include <AlexandriaKernel/memory_tools.h>
24#include <onnxruntime_cxx_api.h>
25
26namespace NdArray = Euclid::NdArray;
27
28namespace SourceXtractor {
29
30
31template<typename T>
32static void fillCutout(const Image<T>& image, int center_x, int center_y, int width, int height, std::vector<T>& out) {
33 int x_start = center_x - width / 2;
34 int y_start = center_y - height / 2;
35 int x_end = x_start + width;
36 int y_end = y_start + height;
37
38 ImageAccessor<T> accessor(image);
39
40 int index = 0;
41 for (int iy = y_start; iy < y_end; iy++) {
42 for (int ix = x_start; ix < x_end; ix++, index++) {
43 if (ix >= 0 && iy >= 0 && ix < image.getWidth() && iy < image.getHeight()) {
44 out[index] = accessor.getValue(ix, iy);
45 }
46 }
47 }
48}
49
51
59template<typename O>
61computePropertiesSpecialized(const OnnxModel& model, const DetectionFrameImages& detection_frame_images,
62 const PixelCentroid& centroid) {
63 Ort::RunOptions run_options;
64 auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
65
66 const int center_x = static_cast<int>(centroid.getCentroidX() + 0.5);
67 const int center_y = static_cast<int>(centroid.getCentroidY() + 0.5);
68
69 // Allocate memory
70 std::vector<int64_t> input_shape(model.getInputShape().begin(), model.getInputShape().end());
71 input_shape[0] = 1;
72 size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
73 std::vector<float> input_data(input_size);
74
75 std::vector<int64_t> output_shape(model.getOutputShape().begin(), model.getOutputShape().end());
76 output_shape[0] = 1;
77 size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
78 std::vector<O> output_data(output_size);
79
80 // Cut the needed area
81 {
82 const auto& image = detection_frame_images.getLockedImage(LayerSubtractedImage);
83 fillCutout(*image, center_x, center_y, input_shape[2], input_shape[3], input_data);
84 }
85
86 model.run<float, O>(input_data, output_data);
87
88 // Set the output
89 std::vector<size_t> catalog_shape{model.getOutputShape().begin() + 1, model.getOutputShape().end()};
90 return Euclid::make_unique<OnnxProperty::NdWrapper<O>>(catalog_shape, output_data);
91}
92
94 const auto& detection_frame_images = source.getProperty<DetectionFrameImages>();
95 const auto& centroid = source.getProperty<PixelCentroid>();
96
98
99 for (const auto& model_info : m_model_infos) {
101
102 switch (model_info.model->getOutputType()) {
103 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
104 result = computePropertiesSpecialized<float>(*model_info.model, detection_frame_images, centroid);
105 break;
106 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
107 result = computePropertiesSpecialized<int32_t>(*model_info.model, detection_frame_images, centroid);
108 break;
109 default:
110 throw Elements::Exception() << "This should have not happened!" << model_info.model->getOutputType();
111 }
112
113 output_dict.emplace(model_info.prop_name, std::move(result));
114 }
115
116 source.setProperty<OnnxProperty>(std::move(output_dict));
117}
118
119} // end of namespace SourceXtractor
T accumulate(T... args)
T begin(T... args)
std::shared_ptr< ImageAccessor< SeFloat > > getLockedImage(FrameImageLayer layer) const
Interface representing an image.
Definition Image.h:44
virtual int getHeight() const =0
Returns the height of the image in pixels.
virtual int getWidth() const =0
Returns the width of the image in pixels.
void run(std::vector< T > &input_data, std::vector< U > &output_data) const
Definition OnnxModel.h:29
const std::vector< std::int64_t > & getOutputShape() const
Definition OnnxModel.h:120
const std::vector< std::int64_t > & getInputShape() const
Definition OnnxModel.h:116
void computeProperties(SourceInterface &source) const override
Computes one or more properties for the Source.
OnnxSourceTask(const std::vector< OnnxModelInfo > &model_infos)
const std::vector< OnnxModelInfo > & m_model_infos
The centroid of all the pixels in the source, weighted by their DetectionImage pixel values.
SeFloat getCentroidX() const
X coordinate of centroid.
SeFloat getCentroidY() const
Y coordinate of centroid.
The SourceInterface is an abstract "source" that has properties attached to it.
const PropertyType & getProperty(unsigned int index=0) const
Convenience template method to call getProperty() with a more user-friendly syntax.
T end(T... args)
T move(T... args)
std::unique_ptr< T > make_unique(Args &&... args)
@ LayerSubtractedImage
Definition Frame.h:39
static void fillCutout(const Image< T > &image, int center_x, int center_y, int width, int height, std::vector< T > &out)
static std::unique_ptr< OnnxProperty::NdWrapperBase > computePropertiesSpecialized(const OnnxModel &model, const DetectionFrameImages &detection_frame_images, const PixelCentroid &centroid)