FCCAnalyses
Loading...
Searching...
No Matches
WeaverInterface.h
Go to the documentation of this file.
1#ifndef ONNXRuntime_WeaverInterface_h
2#define ONNXRuntime_WeaverInterface_h
3
5#include "ROOT/RVec.hxx"
6
7namespace rv = ROOT::VecOps;
8
10public:
11 using ConstituentVars = rv::RVec<float>;
12
15 explicit WeaverInterface(const std::string& onnx_filename = "",
16 const std::string& json_filename = "",
17 const rv::RVec<std::string>& vars = {});
18
20 rv::RVec<float> run(const rv::RVec<ConstituentVars>&);
21
22private:
24 struct VarInfo {
26 VarInfo(float imedian,
27 float inorm_factor,
28 float ireplace_inf_value,
29 float ilower_bound,
30 float iupper_bound,
31 float ipad)
32 : center(imedian),
33 norm_factor(inorm_factor),
34 replace_inf_value(ireplace_inf_value),
35 lower_bound(ilower_bound),
36 upper_bound(iupper_bound),
37 pad(ipad) {}
38
39 float center{0.};
40 float norm_factor{1.};
42 float lower_bound{-5.};
43 float upper_bound{5.};
44 float pad{0.};
45 };
46 std::string name;
47 size_t min_length{0}, max_length{0};
48 std::vector<std::string> var_names;
49 std::unordered_map<std::string, VarInfo> var_info_map;
50 VarInfo info(const std::string& name) const { return var_info_map.at(name); }
51 void dumpVars() const;
52 };
53 std::vector<float> center_norm_pad(const rv::RVec<float>& input,
54 float center,
55 float scale,
56 size_t min_length,
57 size_t max_length,
58 float pad_value = 0,
59 float replace_inf_value = 0,
60 float min = 0,
61 float max = -1);
62 size_t variablePos(const std::string&) const;
63
64 std::unique_ptr<ONNXRuntime> onnx_;
65 std::vector<std::string> variables_names_;
67 std::vector<unsigned int> input_sizes_;
68 std::unordered_map<std::string, PreprocessParams> prep_info_map_;
70};
71
72#endif
std::vector< std::vector< T > > Tensor
Definition ONNXRuntime.h:17
Definition WeaverInterface.h:9
ONNXRuntime::Tensor< long > input_shapes_
Definition WeaverInterface.h:66
size_t variablePos(const std::string &) const
Definition WeaverInterface.cc:86
std::vector< unsigned int > input_sizes_
Definition WeaverInterface.h:67
std::vector< float > center_norm_pad(const rv::RVec< float > &input, float center, float scale, size_t min_length, size_t max_length, float pad_value=0, float replace_inf_value=0, float min=0, float max=-1)
Definition WeaverInterface.cc:59
std::unordered_map< std::string, PreprocessParams > prep_info_map_
Definition WeaverInterface.h:68
std::vector< std::string > variables_names_
Definition WeaverInterface.h:65
std::unique_ptr< ONNXRuntime > onnx_
Definition WeaverInterface.h:64
ONNXRuntime::Tensor< float > data_
Definition WeaverInterface.h:69
rv::RVec< float > ConstituentVars
Definition WeaverInterface.h:11
rv::RVec< float > run(const rv::RVec< ConstituentVars > &)
Run inference given a list of jet constituents variables.
Definition WeaverInterface.cc:94
WeaverInterface(const std::string &onnx_filename="", const std::string &json_filename="", const rv::RVec< std::string > &vars={})
Initialise an inference model from Weaver output ONNX/JSON files and a list of variables to be provid...
Definition WeaverInterface.cc:7
Definition WeaverInterface.h:24
VarInfo(float imedian, float inorm_factor, float ireplace_inf_value, float ilower_bound, float iupper_bound, float ipad)
Definition WeaverInterface.h:26
VarInfo()
Definition WeaverInterface.h:25
float center
Definition WeaverInterface.h:39
float replace_inf_value
Definition WeaverInterface.h:41
float norm_factor
Definition WeaverInterface.h:40
float pad
Definition WeaverInterface.h:44
float upper_bound
Definition WeaverInterface.h:43
float lower_bound
Definition WeaverInterface.h:42
Definition WeaverInterface.h:23
std::vector< std::string > var_names
Definition WeaverInterface.h:48
std::unordered_map< std::string, VarInfo > var_info_map
Definition WeaverInterface.h:49
VarInfo info(const std::string &name) const
Definition WeaverInterface.h:50
size_t min_length
Definition WeaverInterface.h:47
void dumpVars() const
Definition WeaverInterface.cc:133
size_t max_length
Definition WeaverInterface.h:47
std::string name
Definition WeaverInterface.h:46