#include <ActsPlugins/Gnn/TensorRTEdgeClassifier.hpp>
◆ TensorRTEdgeClassifier()
| ActsPlugins::TensorRTEdgeClassifier::TensorRTEdgeClassifier |
( |
const Config & | cfg, |
|
|
std::unique_ptr< const Acts::Logger > | logger ) |
◆ ~TensorRTEdgeClassifier()
| ActsPlugins::TensorRTEdgeClassifier::~TensorRTEdgeClassifier |
( |
| ) |
|
◆ config()
| Config ActsPlugins::TensorRTEdgeClassifier::config |
( |
| ) |
const |
◆ operator()()
Perform edge classification.
- Parameters
-
| tensors | Input pipeline tensors |
| execContext | Device & stream information |
- Returns
- (node_features, edge_features, edge_index, edge_scores)
Implements ActsPlugins::EdgeClassificationBase.