ACTS
Experiment-independent tracking
Loading...
Searching...
No Matches
ActsPlugins::MLTrackClassifier Class Reference

#include </home/runner/work/acts/acts/Plugins/Onnx/include/ActsPlugins/Onnx/MLTrackClassifier.hpp>

Inheritance diagram for ActsPlugins::MLTrackClassifier:
[legend]
Collaboration diagram for ActsPlugins::MLTrackClassifier:
[legend]

Public Types

enum class  TrackLabels { eGood , eDuplicate , eFake }
 The labels for track quality. More...

Public Member Functions

bool isDuplicate (std::vector< float > &inputFeatures, double decisionThreshProb) const
 Check if the predicted track label is 'duplicate'.
TrackLabels predictTrackLabel (std::vector< float > &inputFeatures, double decisionThreshProb) const
 Predict the track label.
Public Member Functions inherited from ActsPlugins::OnnxRuntimeBase
 OnnxRuntimeBase ()=default
 Default constructor.
 OnnxRuntimeBase (Ort::Env &env, const char *modelPath)
 Parametrized constructor.
 ~OnnxRuntimeBase ()=default
 Default destructor.
std::vector< std::vector< float > > runONNXInference (NetworkBatchInput &inputTensorValues) const
 Run the ONNX inference function for a batch of input.
std::vector< float > runONNXInference (std::vector< float > &inputTensorValues) const
 Run the ONNX inference function.
std::vector< std::vector< std::vector< float > > > runONNXInferenceMultiOutput (NetworkBatchInput &inputTensorValues) const
 Run the multi-output ONNX inference function for a batch of input.

Member Enumeration Documentation

◆ TrackLabels

The labels for track quality.

Enumerator
eGood 
eDuplicate 
eFake 

Member Function Documentation

◆ isDuplicate()

bool ActsPlugins::MLTrackClassifier::isDuplicate ( std::vector< float > & inputFeatures,
double decisionThreshProb ) const

Check if the predicted track label is 'duplicate'.

Parameters
inputFeaturesThe vector of input features for the trajectory to be classified
decisionThreshProbThe probability threshold used to predict the track label
Returns
If the predicted track label is 'duplicate'

◆ predictTrackLabel()

TrackLabels ActsPlugins::MLTrackClassifier::predictTrackLabel ( std::vector< float > & inputFeatures,
double decisionThreshProb ) const

Predict the track label.

Parameters
inputFeaturesThe vector of input features for the trajectory to be classified
decisionThreshProbThe probability threshold used to predict the track label
Returns
The predicted track label of the trajectory