SourceXtractorPlusPlus 1.0.3
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
OnnxCompactModel.h
Go to the documentation of this file.
1/*
2 * OnnxCompactModel.h
3 *
4 * Created on: Sep 3, 2021
5 * Author: mschefer
6 */
7
8#ifndef _SEIMPLEMENTATION_PLUGIN_FLEXIBLEMODELFITTING_ONNXCOMPACTMODEL_H_
9#define _SEIMPLEMENTATION_PLUGIN_FLEXIBLEMODELFITTING_ONNXCOMPACTMODEL_H_
10
11#include <numeric>
12
13#include <ElementsKernel/Logging.h>
14
16
19
20namespace ModelFitting {
21
22static auto logger = Elements::Logging::getLogger("FlexibleModelFitting");
23
24template <typename ImageType>
25class OnnxCompactModel : public CompactModelBase<ImageType> {
26public:
38
39 virtual ~OnnxCompactModel() = default;
40
41 double getValue(double, double) const override {
42 return 0.0; // unused
43 }
44
45 ImageType getRasterizedImage(double pixel_scale, std::size_t size_x, std::size_t size_y) const override {
46 using Traits = ImageTraits<ImageType>;
47 ImageType image = Traits::factory(size_x, size_y);
48
49 int largest_size = std::max(size_x, size_y);
50
52 for (auto model : m_models) {
53 auto shape = model->getOutputShape();
54 if (largest_size < shape[2]) {
55 selected_model = model;
56 break;
57 }
58 }
59
60 if (selected_model == nullptr) {
61 logger.warn() << "No large enough ONNX model could be found, skipping...";
62 return image;
63 }
64
65 auto input_shape = selected_model->getInputShape();
66 auto output_shape = selected_model->getOutputShape();
67 int render_size = output_shape[2];
68
69 if (selected_model->getOutputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
70 throw Elements::Exception() << "Only ONNX models with float output are supported";
71 }
72
73 // allocate memory
75 std::vector<float> output_data(render_size * render_size);
76
77 for (auto const& it : m_params) {
78 input_data_arrays[it.first] = std::vector<float>( { static_cast<float>(it.second->getValue()) } );
79 }
80
81 input_data_arrays["x"] = std::vector<float>(render_size * render_size);
82 input_data_arrays["y"] = std::vector<float>(render_size * render_size);
83
85
86 for (int y=0; y<(int)size_y; ++y) {
87 int dy = y - size_y / 2;
88 for (int x=0; x<(int)size_x; ++x) {
89 int dx = x - size_x / 2;
90
91 float x2 = dx * transform[0] + dy * transform[1];
92 float y2 = dx * transform[2] + dy * transform[3];
93
94 input_data_arrays["x"][x + y * render_size] = x2;
95 input_data_arrays["y"][x + y * render_size] = y2;
96 }
97 }
98
99 selected_model->runMultiInput<float, float>(input_data_arrays, output_data);
100
101 for (int y = 0; y < (int) size_y; ++y) {
102 for (int x = 0; x < (int) size_x; ++x) {
103 Traits::at(image, x, y) = output_data[x + y * render_size];
104 }
105 }
106
107 renormalize(image, m_flux->getValue());
108 return image;
109 }
110
111private:
112 using CompactModelBase<ImageType>::getMaxRadiusSqr;
114 using CompactModelBase<ImageType>::m_jacobian;
115 using CompactModelBase<ImageType>::samplePixel;
117 using CompactModelBase<ImageType>::renormalize;
118
120
121 // parameters
124};
125
126}
127
128
129
130#endif /* _SEIMPLEMENTATION_PLUGIN_FLEXIBLEMODELFITTING_ONNXCOMPACTMODEL_H_ */
const double pixel_scale
Definition TestImage.cpp:74
static Logging getLogger(const std::string &name="")
float samplePixel(const ModelEvaluator &model_eval, int x, int y, unsigned int subsampling) const
CompactModelBase(std::shared_ptr< BasicParameter > x_scale, std::shared_ptr< BasicParameter > y_scale, std::shared_ptr< BasicParameter > rotation, double width, double height, std::shared_ptr< BasicParameter > x, std::shared_ptr< BasicParameter > y, std::tuple< double, double, double, double > transform)
void renormalize(ImageType &image, double flux) const
double getMaxRadiusSqr(std::size_t size_x, std::size_t size_y, const Mat22 &transform) const
Mat22 getCombinedTransform(double pixel_scale) const
float adaptiveSamplePixel(const ModelEvaluator &model_eval, int x, int y, unsigned int max_subsampling, float threshold=1.1) const
ImageType getRasterizedImage(double pixel_scale, std::size_t size_x, std::size_t size_y) const override
std::map< std::string, std::shared_ptr< BasicParameter > > m_params
OnnxCompactModel(std::vector< std::shared_ptr< SourceXtractor::OnnxModel > > models, std::shared_ptr< BasicParameter > x_scale, std::shared_ptr< BasicParameter > y_scale, std::shared_ptr< BasicParameter > rotation, double width, double height, std::shared_ptr< BasicParameter > x, std::shared_ptr< BasicParameter > y, std::shared_ptr< BasicParameter > flux, std::map< std::string, std::shared_ptr< BasicParameter > > params, std::tuple< double, double, double, double > transform)
std::vector< std::shared_ptr< SourceXtractor::OnnxModel > > m_models
virtual ~OnnxCompactModel()=default
double getValue(double, double) const override
std::shared_ptr< BasicParameter > m_flux
T max(T... args)
static Elements::Logging logger
T transform(T... args)