Merge pull request #45122 from DEKHTIARJonathan:jdekhtiar/tftrt_deterministic_tactic_selection

PiperOrigin-RevId: 344190703
Change-Id: Ib00a3e3802c7c1caa48690096410176349b6a676
This commit is contained in:
TensorFlower Gardener 2020-11-24 22:30:38 -08:00
commit dc5faccddb

View File

@ -1373,6 +1373,38 @@ Status Converter::RenameAndMarkOutputTensors(
return Status::OK();
}
#if IS_TRT_VERSION_GE(7, 1, 3, 0)
// An algorithm selector that always returns a specific ID for selectAlgorithms.
// This is used to support the implementation of using environment variable
// `TF_TRT_FIXED_ALGORITHM_ID` for debugging TensorRT.
class StaticAlgorithmSelector : public nvinfer1::IAlgorithmSelector {
private:
int32_t algorithm_id_;
public:
StaticAlgorithmSelector(int32_t algorithm_id) : algorithm_id_(algorithm_id) {}
// Returns value in [0, nbChoices] for a valid algorithm.
int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext,
const nvinfer1::IAlgorithm* const* algoChoices,
int32_t nbChoices, int32_t* selection) override {
// TensorRT always provides more than zero number of algorithms
// in selectAlgorithms.
assert(nbChoices > 0);
// making sure that the requested TRT algorithm ID doesn't go above the
// max value accepted.
selection[0] = std::min(algorithm_id_, nbChoices);
return 1;
}
// Called by TensorRT to report choices it made.
void reportAlgorithms(const nvinfer1::IAlgorithmContext* const* algoContexts,
const nvinfer1::IAlgorithm* const* algoChoices,
int32_t nbAlgorithms) override {} // do nothing
};
#endif
Status Converter::BuildCudaEngine(
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, int max_batch_size,
size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator,
@ -1385,6 +1417,23 @@ Status Converter::BuildCudaEngine(
TrtUniquePtrType<nvinfer1::IBuilderConfig> builder_config(
trt_builder_->createBuilderConfig());
builder_config->setMaxWorkspaceSize(max_workspace_size_bytes);
#if IS_TRT_VERSION_GE(7, 1, 3, 0)
static int32_t trt_algorithm_id = [] {
int64 trt_algorithm_id;
TF_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_TRT_FIXED_ALGORITHM_ID",
/*default_val=*/-1,
&trt_algorithm_id));
return static_cast<int32_t>(trt_algorithm_id);
}();
if (trt_algorithm_id >= 0) {
VLOG(1) << "Forcing TRT algorithm selection to: ID=" << trt_algorithm_id;
StaticAlgorithmSelector trt_algorithm_selector(trt_algorithm_id);
builder_config->setAlgorithmSelector(&trt_algorithm_selector);
}
#endif
if (precision_mode_ == TrtPrecisionMode::FP16) {
builder_config->setFlag(nvinfer1::BuilderFlag::kFP16);
} else if (precision_mode_ == TrtPrecisionMode::INT8) {