Merge pull request #29900 from pooyadavoodi:support_calibration_int32
PiperOrigin-RevId: 254047339
This commit is contained in:
commit
ee8de6c410
@ -160,6 +160,7 @@ void* GetTensorAddress(const Tensor* tensor_ptr) {
|
|||||||
TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
|
TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
|
||||||
TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
|
TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
|
||||||
TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
|
TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
|
||||||
|
TYPECASE(DT_INT32, tensor_ptr, dest_ptr);
|
||||||
default: {
|
default: {
|
||||||
LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
|
LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -62,5 +62,27 @@ class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class CalibrationInt32Support(trt_test.TfTrtIntegrationTestBase):
|
||||||
|
"""Test execution of calibration with int32 input"""
|
||||||
|
|
||||||
|
def GraphFn(self, inp):
|
||||||
|
# Can use any op that is converted to TRT with int32 inputs
|
||||||
|
inp_transposed = array_ops.transpose(inp, [0, 3, 2, 1], name='transpose_0')
|
||||||
|
return array_ops.identity(inp_transposed, name='output_0')
|
||||||
|
|
||||||
|
def GetParams(self):
|
||||||
|
return self.BuildParams(self.GraphFn, dtypes.int32, [[3, 4, 5, 6]],
|
||||||
|
[[3, 6, 5, 4]])
|
||||||
|
|
||||||
|
def ShouldRunTest(self, run_params):
|
||||||
|
# Although test passes with all configurations but only
|
||||||
|
# execute INT8 with use_calibration=True because
|
||||||
|
# that is the purpose of the test.
|
||||||
|
return trt_test.IsQuantizationWithCalibration(run_params)
|
||||||
|
|
||||||
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
|
return ['TRTEngineOp_0']
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -232,8 +232,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def ShouldRunTest(self, run_params):
|
def ShouldRunTest(self, run_params):
|
||||||
"""Whether to run the test."""
|
"""Whether to run the test."""
|
||||||
# This setting combination requires quantization nodes to be present in
|
# Ensure use_calibration=True in case of INT8 precision
|
||||||
# order to build the engine.
|
|
||||||
return (run_params.use_calibration or
|
return (run_params.use_calibration or
|
||||||
not IsQuantizationMode(run_params.precision_mode))
|
not IsQuantizationMode(run_params.precision_mode))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user