diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 006793b8fcd..25c8147a61f 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -160,6 +160,7 @@ void* GetTensorAddress(const Tensor* tensor_ptr) { TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr); TYPECASE(DT_HALF, tensor_ptr, dest_ptr); TYPECASE(DT_INT8, tensor_ptr, dest_ptr); + TYPECASE(DT_INT32, tensor_ptr, dest_ptr); default: { LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type); return nullptr; diff --git a/tensorflow/python/compiler/tensorrt/test/int32_test.py b/tensorflow/python/compiler/tensorrt/test/int32_test.py index 5653ff1f9be..63a72288d36 100644 --- a/tensorflow/python/compiler/tensorrt/test/int32_test.py +++ b/tensorflow/python/compiler/tensorrt/test/int32_test.py @@ -62,5 +62,27 @@ class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase): return [] +class CalibrationInt32Support(trt_test.TfTrtIntegrationTestBase): + """Test execution of calibration with int32 input""" + + def GraphFn(self, inp): + # Can use any op that is converted to TRT with int32 inputs + inp_transposed = array_ops.transpose(inp, [0, 3, 2, 1], name='transpose_0') + return array_ops.identity(inp_transposed, name='output_0') + + def GetParams(self): + return self.BuildParams(self.GraphFn, dtypes.int32, [[3, 4, 5, 6]], + [[3, 6, 5, 4]]) + + def ShouldRunTest(self, run_params): + # Although test passes with all configurations but only + # execute INT8 with use_calibration=True because + # that is the purpose of the test. + return trt_test.IsQuantizationWithCalibration(run_params) + + def ExpectedEnginesToBuild(self, run_params): + return ['TRTEngineOp_0'] + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index ddcf1d5faf2..8e8844a2066 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -232,8 +232,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def ShouldRunTest(self, run_params): """Whether to run the test.""" - # This setting combination requires quantization nodes to be present in - # order to build the engine. + # Ensure use_calibration=True in case of INT8 precision return (run_params.use_calibration or not IsQuantizationMode(run_params.precision_mode))