Merge pull request #45122 from DEKHTIARJonathan:jdekhtiar/tftrt_deterministic_tactic_selection
PiperOrigin-RevId: 344190703 Change-Id: Ib00a3e3802c7c1caa48690096410176349b6a676
This commit is contained in:
commit
dc5faccddb
@ -1373,6 +1373,38 @@ Status Converter::RenameAndMarkOutputTensors(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if IS_TRT_VERSION_GE(7, 1, 3, 0)
|
||||||
|
// An algorithm selector that always returns a specific ID for selectAlgorithms.
|
||||||
|
// This is used to support the implementation of using environment variable
|
||||||
|
// `TF_TRT_FIXED_ALGORITHM_ID` for debugging TensorRT.
|
||||||
|
class StaticAlgorithmSelector : public nvinfer1::IAlgorithmSelector {
|
||||||
|
private:
|
||||||
|
int32_t algorithm_id_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
StaticAlgorithmSelector(int32_t algorithm_id) : algorithm_id_(algorithm_id) {}
|
||||||
|
|
||||||
|
// Returns value in [0, nbChoices] for a valid algorithm.
|
||||||
|
int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext,
|
||||||
|
const nvinfer1::IAlgorithm* const* algoChoices,
|
||||||
|
int32_t nbChoices, int32_t* selection) override {
|
||||||
|
// TensorRT always provides more than zero number of algorithms
|
||||||
|
// in selectAlgorithms.
|
||||||
|
assert(nbChoices > 0);
|
||||||
|
|
||||||
|
// making sure that the requested TRT algorithm ID doesn't go above the
|
||||||
|
// max value accepted.
|
||||||
|
selection[0] = std::min(algorithm_id_, nbChoices);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Called by TensorRT to report choices it made.
|
||||||
|
void reportAlgorithms(const nvinfer1::IAlgorithmContext* const* algoContexts,
|
||||||
|
const nvinfer1::IAlgorithm* const* algoChoices,
|
||||||
|
int32_t nbAlgorithms) override {} // do nothing
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
Status Converter::BuildCudaEngine(
|
Status Converter::BuildCudaEngine(
|
||||||
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, int max_batch_size,
|
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, int max_batch_size,
|
||||||
size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator,
|
size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator,
|
||||||
@ -1385,6 +1417,23 @@ Status Converter::BuildCudaEngine(
|
|||||||
TrtUniquePtrType<nvinfer1::IBuilderConfig> builder_config(
|
TrtUniquePtrType<nvinfer1::IBuilderConfig> builder_config(
|
||||||
trt_builder_->createBuilderConfig());
|
trt_builder_->createBuilderConfig());
|
||||||
builder_config->setMaxWorkspaceSize(max_workspace_size_bytes);
|
builder_config->setMaxWorkspaceSize(max_workspace_size_bytes);
|
||||||
|
|
||||||
|
#if IS_TRT_VERSION_GE(7, 1, 3, 0)
|
||||||
|
static int32_t trt_algorithm_id = [] {
|
||||||
|
int64 trt_algorithm_id;
|
||||||
|
TF_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_TRT_FIXED_ALGORITHM_ID",
|
||||||
|
/*default_val=*/-1,
|
||||||
|
&trt_algorithm_id));
|
||||||
|
return static_cast<int32_t>(trt_algorithm_id);
|
||||||
|
}();
|
||||||
|
|
||||||
|
if (trt_algorithm_id >= 0) {
|
||||||
|
VLOG(1) << "Forcing TRT algorithm selection to: ID=" << trt_algorithm_id;
|
||||||
|
StaticAlgorithmSelector trt_algorithm_selector(trt_algorithm_id);
|
||||||
|
builder_config->setAlgorithmSelector(&trt_algorithm_selector);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (precision_mode_ == TrtPrecisionMode::FP16) {
|
if (precision_mode_ == TrtPrecisionMode::FP16) {
|
||||||
builder_config->setFlag(nvinfer1::BuilderFlag::kFP16);
|
builder_config->setFlag(nvinfer1::BuilderFlag::kFP16);
|
||||||
} else if (precision_mode_ == TrtPrecisionMode::INT8) {
|
} else if (precision_mode_ == TrtPrecisionMode::INT8) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user