From 7ac340afaa3f18f817af6957b72953fcda3791fa Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Mon, 17 Jun 2019 15:47:56 -0700 Subject: [PATCH] Support INT32 for calibration plus tests --- .../tf2tensorrt/kernels/trt_engine_op.cc | 1 + .../compiler/tensorrt/test/int32_test.py | 20 +++++++++++++++++++ .../test/tf_trt_integration_test_base.py | 3 +-- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 006793b8fcd..25c8147a61f 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -160,6 +160,7 @@ void* GetTensorAddress(const Tensor* tensor_ptr) { TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr); TYPECASE(DT_HALF, tensor_ptr, dest_ptr); TYPECASE(DT_INT8, tensor_ptr, dest_ptr); + TYPECASE(DT_INT32, tensor_ptr, dest_ptr); default: { LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type); return nullptr; diff --git a/tensorflow/python/compiler/tensorrt/test/int32_test.py b/tensorflow/python/compiler/tensorrt/test/int32_test.py index 5653ff1f9be..992b74243f7 100644 --- a/tensorflow/python/compiler/tensorrt/test/int32_test.py +++ b/tensorflow/python/compiler/tensorrt/test/int32_test.py @@ -61,6 +61,26 @@ class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" return [] +class CalibrationInt32Support(trt_test.TfTrtIntegrationTestBase): + """Test execution of calibration with int32 input""" + + def GraphFn(self, inp): + # Can use any op that is converted to TRT with int32 inputs + inp_transposed = array_ops.transpose(inp, [0, 3, 2, 1], name='transpose_0') + return array_ops.identity(inp_transposed, name='output_0') + + def GetParams(self): + return self.BuildParams(self.GraphFn, dtypes.int32, + [[3, 4, 5, 6]], [[3, 6, 5, 4]]) + + def ShouldRunTest(self, run_params): + # Although test passes with all configurations but only + # execute INT8 with use_calibration=True because + # that is the purpose of the test. + return trt_test.IsQuantizationWithCalibration(run_params) + + def ExpectedEnginesToBuild(self, run_params): + return ['TRTEngineOp_0'] if __name__ == '__main__': test.main() diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 1cc381a3449..438ff41a973 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -232,8 +232,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def ShouldRunTest(self, run_params): """Whether to run the test.""" - # This setting combination requires quantization nodes to be present in - # order to build the engine. + # Ensure use_calibration=True in case of INT8 precision return (run_params.use_calibration or not IsQuantizationMode(run_params.precision_mode))