Support INT32 for calibration plus tests

This commit is contained in:
Pooya Davoodi 2019-06-17 15:47:56 -07:00
parent 92144e5bd6
commit 7ac340afaa
3 changed files with 22 additions and 2 deletions

View File

@ -160,6 +160,7 @@ void* GetTensorAddress(const Tensor* tensor_ptr) {
TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
TYPECASE(DT_INT32, tensor_ptr, dest_ptr);
default: {
LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
return nullptr;

View File

@ -61,6 +61,26 @@ class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase):
"""Return the expected engines to build."""
return []
class CalibrationInt32Support(trt_test.TfTrtIntegrationTestBase):
"""Test execution of calibration with int32 input"""
def GraphFn(self, inp):
# Can use any op that is converted to TRT with int32 inputs
inp_transposed = array_ops.transpose(inp, [0, 3, 2, 1], name='transpose_0')
return array_ops.identity(inp_transposed, name='output_0')
def GetParams(self):
return self.BuildParams(self.GraphFn, dtypes.int32,
[[3, 4, 5, 6]], [[3, 6, 5, 4]])
def ShouldRunTest(self, run_params):
# Although test passes with all configurations but only
# execute INT8 with use_calibration=True because
# that is the purpose of the test.
return trt_test.IsQuantizationWithCalibration(run_params)
def ExpectedEnginesToBuild(self, run_params):
return ['TRTEngineOp_0']
if __name__ == '__main__':
test.main()

View File

@ -232,8 +232,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def ShouldRunTest(self, run_params):
"""Whether to run the test."""
# This setting combination requires quantization nodes to be present in
# order to build the engine.
# Ensure use_calibration=True in case of INT8 precision
return (run_params.use_calibration or
not IsQuantizationMode(run_params.precision_mode))