From dc72ab94e0b9f34d08e9c60c15e0114b3a774b07 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 12 Mar 2019 17:20:08 -0700 Subject: [PATCH] Add helper macro to check TRT version. Use to disable call to setType for 5.1.3 since the bug is fixed --- .../compiler/tf2tensorrt/convert/convert_nodes.cc | 14 +++++++------- .../compiler/tf2tensorrt/convert/convert_nodes.h | 6 ++++++ .../tf2tensorrt/convert/convert_nodes_test.cc | 12 ++++++------ 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 0be54a65b7c..45c58d2259e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -367,11 +367,13 @@ nvinfer1::ITensor* Converter::CreateConstantLayer( if (!layer) return nullptr; const nvinfer1::DataType trt_dtype = trt_weights.type; nvinfer1::ITensor* trt_tensor = layer->getOutput(0); +#if !IS_TRT_VERSION_GE(5, 1, 3) // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set // the data type below, it will always be kFLOAT regardless what the data type // of the weights is. Once NVIDIA fixes this bug, we should remove the data // type setting logic below and test should still pass. trt_tensor->setType(trt_dtype); +#endif return trt_tensor; } @@ -574,13 +576,13 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { void setLocation(nvinfer1::TensorLocation location) override {} -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0) bool setDynamicRange(float min, float max) override { return true; } float getDynamicRange() const override { return 0; } #endif -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +#if IS_TRT_VERSION_GE(5, 1, 0) bool dynamicRangeIsSet() const override { return true; } void resetDynamicRange() override {} @@ -1281,7 +1283,7 @@ void Converter::MaybeApplyQuantizationRanges() { // Infer ranges across marked ops. PropagateQuantizationRanges(); // Apply ranges. -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0) for (auto pair : quantization_ranges_) { nvinfer1::ITensor* tensor = pair.first; const float range = pair.second; @@ -2297,9 +2299,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, } // TRT 5.1 adds a slice layer. For older versions, we attempt to use the // padding layer with negative padding. -#if (NV_TENSORRT_MAJOR > 5 || \ - (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1)) && \ - 0 +#if IS_TRT_VERSION_GE(5, 1, 0) && 0 // TODO(laigd): TRT 5.1 RC has a bug when ISliceLayer is used along with // IConcatenationLayer, so disable ISliceLayer for now until it's fixed. // Use ISliceLayer. @@ -3220,7 +3220,7 @@ UnaryOperationMap() { {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, {"Abs", nvinfer1::UnaryOperation::kABS}, {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +#if IS_TRT_VERSION_GE(5, 1, 0) {"Sin", nvinfer1::UnaryOperation::kSIN}, {"Cos", nvinfer1::UnaryOperation::kCOS}, {"Tan", nvinfer1::UnaryOperation::kTAN}, diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 05630bfb55c..068482a3f64 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -43,6 +43,12 @@ extern const char* const kOutputPHName; namespace convert { +#define IS_TRT_VERSION_GE(major, minor, patch) \ + ((NV_TENSORRT_MAJOR > major) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ + NV_TENSORRT_PATCH >= patch)) + struct EngineConnection { // Constructs a non-control edge. EngineConnection(const string& outside, int out_id, int out_port, diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 1f5d2a69874..bd656b0e836 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -233,7 +233,7 @@ class FakeITensor : public nvinfer1::ITensor { location_ = location; } -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0) bool setDynamicRange(float min, float max) override { dynamic_range_ = std::max(std::abs(min), std::abs(max)); return true; @@ -242,7 +242,7 @@ class FakeITensor : public nvinfer1::ITensor { float getDynamicRange() const override { return dynamic_range_; } #endif -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +#if IS_TRT_VERSION_GE(5, 1, 0) bool dynamicRangeIsSet() const override { return true; } void resetDynamicRange() override {} @@ -845,7 +845,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { // Input range should be inferred along the chain and applied to tensors. int8_converter.MaybeApplyQuantizationRanges(); -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0) EXPECT_EQ(input.getDynamicRange(), 5.0f); EXPECT_EQ(infer_1.getDynamicRange(), 5.0f); EXPECT_EQ(infer_2.getDynamicRange(), 5.0f); @@ -2672,7 +2672,7 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { RunValidationAndConversion(node_def); } // TRT 5.1+ supports strides -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +#if IS_TRT_VERSION_GE(5, 1, 0) { // Negative strides, should fail. Reset(); @@ -2735,7 +2735,7 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { // Same input is used for all tests. const std::vector ok_input = {1, 2, 3, 4, 5, 6}; -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +#if IS_TRT_VERSION_GE(5, 1, 0) const int kStridedSliceOKCases = 23; #else const int kStridedSliceOKCases = 19; @@ -2862,7 +2862,7 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 3}, /*expected_output=*/{1, 2, 3, 4, 5, 6}}, -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +#if IS_TRT_VERSION_GE(5, 1, 0) // Strides TestParams{/*input_dims=*/{6}, /*begin=*/{0, 0}, /*end=*/{0, 5}, /*strides=*/{1, 2},