Merge pull request #45122 from DEKHTIARJonathan:jdekhtiar/tftrt_deterministic_tactic_selection
PiperOrigin-RevId: 344190703 Change-Id: Ib00a3e3802c7c1caa48690096410176349b6a676
This commit is contained in:
commit
dc5faccddb
@ -1373,6 +1373,38 @@ Status Converter::RenameAndMarkOutputTensors(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(7, 1, 3, 0)
|
||||
// An algorithm selector that always returns a specific ID for selectAlgorithms.
|
||||
// This is used to support the implementation of using environment variable
|
||||
// `TF_TRT_FIXED_ALGORITHM_ID` for debugging TensorRT.
|
||||
class StaticAlgorithmSelector : public nvinfer1::IAlgorithmSelector {
|
||||
private:
|
||||
int32_t algorithm_id_;
|
||||
|
||||
public:
|
||||
StaticAlgorithmSelector(int32_t algorithm_id) : algorithm_id_(algorithm_id) {}
|
||||
|
||||
// Returns value in [0, nbChoices] for a valid algorithm.
|
||||
int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext,
|
||||
const nvinfer1::IAlgorithm* const* algoChoices,
|
||||
int32_t nbChoices, int32_t* selection) override {
|
||||
// TensorRT always provides more than zero number of algorithms
|
||||
// in selectAlgorithms.
|
||||
assert(nbChoices > 0);
|
||||
|
||||
// making sure that the requested TRT algorithm ID doesn't go above the
|
||||
// max value accepted.
|
||||
selection[0] = std::min(algorithm_id_, nbChoices);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Called by TensorRT to report choices it made.
|
||||
void reportAlgorithms(const nvinfer1::IAlgorithmContext* const* algoContexts,
|
||||
const nvinfer1::IAlgorithm* const* algoChoices,
|
||||
int32_t nbAlgorithms) override {} // do nothing
|
||||
};
|
||||
#endif
|
||||
|
||||
Status Converter::BuildCudaEngine(
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, int max_batch_size,
|
||||
size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator,
|
||||
@ -1385,6 +1417,23 @@ Status Converter::BuildCudaEngine(
|
||||
TrtUniquePtrType<nvinfer1::IBuilderConfig> builder_config(
|
||||
trt_builder_->createBuilderConfig());
|
||||
builder_config->setMaxWorkspaceSize(max_workspace_size_bytes);
|
||||
|
||||
#if IS_TRT_VERSION_GE(7, 1, 3, 0)
|
||||
static int32_t trt_algorithm_id = [] {
|
||||
int64 trt_algorithm_id;
|
||||
TF_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_TRT_FIXED_ALGORITHM_ID",
|
||||
/*default_val=*/-1,
|
||||
&trt_algorithm_id));
|
||||
return static_cast<int32_t>(trt_algorithm_id);
|
||||
}();
|
||||
|
||||
if (trt_algorithm_id >= 0) {
|
||||
VLOG(1) << "Forcing TRT algorithm selection to: ID=" << trt_algorithm_id;
|
||||
StaticAlgorithmSelector trt_algorithm_selector(trt_algorithm_id);
|
||||
builder_config->setAlgorithmSelector(&trt_algorithm_selector);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (precision_mode_ == TrtPrecisionMode::FP16) {
|
||||
builder_config->setFlag(nvinfer1::BuilderFlag::kFP16);
|
||||
} else if (precision_mode_ == TrtPrecisionMode::INT8) {
|
||||
|
Loading…
Reference in New Issue
Block a user