Add calibration to TrtGraphConverterV2.convert
Add TrtGraphConverterV2.build Also do not return function from convert. Convert dict_values to list for python3 Fix tests as well Fix pylint errors
This commit is contained in:
parent
163f9df4c7
commit
ec5c02724a
@ -878,6 +878,7 @@ class TrtGraphConverterV2(object):
|
||||
if (self._need_calibration and not conversion_params.is_dynamic_op):
|
||||
raise ValueError("INT8 precision mode with calibration is not supported "
|
||||
"with static TensorRT ops. Set is_dynamic_op to True.")
|
||||
self._calibration_data_collected = False
|
||||
|
||||
self._converted = False
|
||||
|
||||
@ -900,13 +901,41 @@ class TrtGraphConverterV2(object):
|
||||
|
||||
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
|
||||
# use it here (b/124792963).
|
||||
def convert(self):
|
||||
def convert(self,
|
||||
num_calibration_runs=None,
|
||||
calibration_input_map_fn=None):
|
||||
"""Convert the input SavedModel in 2.0 format.
|
||||
|
||||
Args:
|
||||
num_calibration_runs: number of runs of the graph during calibration.
|
||||
calibration_input_map_fn: a function that returns inputs for the converted
|
||||
tf_function to be calibrated.
|
||||
Example:
|
||||
```
|
||||
def input_map_fn():
|
||||
return input1, input2, input3
|
||||
```
|
||||
Raises:
|
||||
ValueError: if the input combination is invalid.
|
||||
|
||||
Returns:
|
||||
The TF-TRT converted Function.
|
||||
"""
|
||||
assert not self._converted
|
||||
|
||||
if (self._need_calibration and
|
||||
(not num_calibration_runs or
|
||||
not calibration_input_map_fn)):
|
||||
raise ValueError(
|
||||
"Should specify num_calibration_runs and calibration_input_map_fn"
|
||||
"because INT8 calibration is needed")
|
||||
if (not self._need_calibration and
|
||||
(num_calibration_runs or
|
||||
calibration_input_map_fn)):
|
||||
raise ValueError(
|
||||
"Should not specify num_calibration_runs or calibration_input_map_fn"
|
||||
"because INT8 calibration is not needed")
|
||||
|
||||
self._saved_model = load.load(self._input_saved_model_dir,
|
||||
self._input_saved_model_tags)
|
||||
func = self._saved_model.signatures[self._input_saved_model_signature_key]
|
||||
@ -934,13 +963,49 @@ class TrtGraphConverterV2(object):
|
||||
|
||||
self._converted = True
|
||||
|
||||
# Wrap the converted ConcreteFunction in a Function so it can accept numpy
|
||||
# arrays as input.
|
||||
@def_function.function
|
||||
def wrapper_func(*args, **kwargs):
|
||||
return self._converted_func(*args, **kwargs)
|
||||
if self._need_calibration and not self._calibration_data_collected:
|
||||
self._calibrate(num_runs=num_calibration_runs,
|
||||
input_map_fn=calibration_input_map_fn)
|
||||
|
||||
return wrapper_func
|
||||
def _calibrate(self,
|
||||
num_runs=None,
|
||||
input_map_fn=None):
|
||||
"""Run calibration.
|
||||
|
||||
Args:
|
||||
num_runs: number of runs of the graph during calibration.
|
||||
input_map_fn: a function that returns inputs for the converted
|
||||
tf_function to be calibrated.
|
||||
Example:
|
||||
```
|
||||
def input_map_fn():
|
||||
return input1, input2, input3
|
||||
```
|
||||
"""
|
||||
assert self._converted
|
||||
assert self._need_calibration
|
||||
assert num_runs
|
||||
assert input_map_fn
|
||||
|
||||
for _ in range(num_runs):
|
||||
self._converted_func(*map(ops.convert_to_tensor, input_map_fn()))
|
||||
|
||||
self._calibration_data_collected = True
|
||||
|
||||
def build(self, *args, **kwargs):
|
||||
"""Run inference on graph in order to build a TensorRT engine
|
||||
in the cahce of TRTEngineOp.
|
||||
|
||||
Returns:
|
||||
The output of the converted Function for the given inputs.
|
||||
"""
|
||||
args_tensor = [ops.convert_to_tensor(arg) for arg in args]
|
||||
kwargs_tensor = {k: ops.convert_to_tensor(v) for k, v in kwargs.items()}
|
||||
try:
|
||||
return self._converted_func(*args_tensor, **kwargs_tensor)
|
||||
except OpError:
|
||||
print('Failure in execution of function with input args {}'
|
||||
'and kwargs {}'.format(args_tensor, kwargs_tensor))
|
||||
|
||||
def save(self, output_saved_model_dir):
|
||||
"""Save the converted SavedModel.
|
||||
|
@ -117,22 +117,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _GetGraph(cls, inp, var):
|
||||
def _GetGraph(cls, inp1, inp2, var):
|
||||
"""Get the graph for testing."""
|
||||
# The graph computes (input+1)^2, it looks like:
|
||||
#
|
||||
# input (Placeholder) v1 (Variable)
|
||||
# | \ /
|
||||
# \ +
|
||||
# \ / \
|
||||
# * |
|
||||
# \ /
|
||||
# +
|
||||
# |
|
||||
# output (Identity)
|
||||
add = inp + var
|
||||
mul = inp * add
|
||||
# The graph computes: inp1^2 + inp1*var + inp1 + inp2 + var
|
||||
add = inp1 + var
|
||||
mul = inp1 * add
|
||||
add = mul + add
|
||||
add = add + inp2
|
||||
out = array_ops.identity(add, name="output")
|
||||
return out
|
||||
|
||||
@ -144,12 +135,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self.v = None
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
|
||||
])
|
||||
def run(self, inp):
|
||||
def run(self, inp1, inp2):
|
||||
if self.v is None:
|
||||
self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
|
||||
return TrtConvertTest._GetGraph(inp, self.v)
|
||||
return TrtConvertTest._GetGraph(inp1, inp2, self.v)
|
||||
|
||||
return SimpleModel()
|
||||
|
||||
@ -157,15 +149,17 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
with g.device("/GPU:0"):
|
||||
inp = array_ops.placeholder(
|
||||
dtype=dtypes.float32, shape=[None, 1, 1], name="input")
|
||||
inp1 = array_ops.placeholder(
|
||||
dtype=dtypes.float32, shape=[None, 1, 1], name="input1")
|
||||
inp2 = array_ops.placeholder(
|
||||
dtype=dtypes.float32, shape=[None, 1, 1], name="input2")
|
||||
var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
|
||||
out = TrtConvertTest._GetGraph(inp, var)
|
||||
return g, var, inp, out
|
||||
out = TrtConvertTest._GetGraph(inp1, inp2, var)
|
||||
return g, var, inp1, inp2, out
|
||||
|
||||
def _GetGraphDef(self):
|
||||
"""Get the graph def for testing."""
|
||||
g, var, _, _ = self._GetGraphForV1()
|
||||
g, var, _, _, _ = self._GetGraphForV1()
|
||||
with self.session(graph=g, config=self._GetConfigProto()) as sess:
|
||||
sess.run(var.initializer)
|
||||
graph_def = graph_util.convert_variables_to_constants(
|
||||
@ -175,19 +169,22 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
{
|
||||
"v1": "Const",
|
||||
"add/ReadVariableOp": "Identity",
|
||||
"input": "Placeholder",
|
||||
"input1": "Placeholder",
|
||||
"input2": "Placeholder",
|
||||
"add": "AddV2",
|
||||
"mul": "Mul",
|
||||
"add_1": "AddV2",
|
||||
"add_2": "AddV2",
|
||||
"output": "Identity"
|
||||
}, node_name_to_op)
|
||||
return graph_def
|
||||
|
||||
def _WriteInputSavedModel(self, input_saved_model_dir):
|
||||
"""Write the saved model as an input for testing."""
|
||||
g, var, inp, out = self._GetGraphForV1()
|
||||
g, var, inp1, inp2, out = self._GetGraphForV1()
|
||||
signature_def = signature_def_utils.build_signature_def(
|
||||
inputs={"myinput": utils.build_tensor_info(inp)},
|
||||
inputs={"myinput1": utils.build_tensor_info(inp1),
|
||||
"myinput2": utils.build_tensor_info(inp2)},
|
||||
outputs={"myoutput": utils.build_tensor_info(out)},
|
||||
method_name=signature_constants.PREDICT_METHOD_NAME)
|
||||
saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
|
||||
@ -231,7 +228,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def next(self):
|
||||
self._data += 1
|
||||
return {"input:0": [[[self._data]]]}
|
||||
return {"input1:0": [[[self._data]]],
|
||||
"input2:0": [[[self._data]]]}
|
||||
|
||||
output_graph_def = converter.calibrate(
|
||||
fetch_names=["output:0"],
|
||||
@ -265,7 +263,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
node_name_to_op = {node.name: node.op for node in graph_def.node}
|
||||
self.assertEqual(
|
||||
{
|
||||
"input": "Placeholder",
|
||||
"input1": "Placeholder",
|
||||
"input2": "Placeholder",
|
||||
"TRTEngineOp_0": "TRTEngineOp",
|
||||
"output": "Identity"
|
||||
}, node_name_to_op)
|
||||
@ -284,8 +283,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
with self.session(config=self._GetConfigProto()) as sess:
|
||||
for test_data in range(10):
|
||||
self.assertEqual(
|
||||
(test_data + 1.0)**2,
|
||||
sess.run("output:0", feed_dict={"input:0": [[[test_data]]]}))
|
||||
(test_data + 1.0)**2 + test_data,
|
||||
sess.run("output:0", feed_dict={"input1:0": [[[test_data]]],
|
||||
"input2:0": [[[test_data]]]}))
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testTrtGraphConverter_BasicConversion(self):
|
||||
@ -345,24 +345,19 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
|
||||
# Create a model and save it.
|
||||
input_saved_model_dir = self.mkdtemp()
|
||||
root = self._GetModelForV2()
|
||||
expected_output = root.run(np_input)
|
||||
expected_output = root.run(np_input1, np_input2)
|
||||
save.save(root, input_saved_model_dir,
|
||||
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||
|
||||
# Run TRT conversion.
|
||||
converter = self._CreateConverterV2(input_saved_model_dir)
|
||||
converted_func = converter.convert()
|
||||
|
||||
# Verify the converted GraphDef and ConcreteFunction.
|
||||
self.assertIsInstance(converted_func, def_function.Function)
|
||||
converted_concrete_func = converted_func.get_concrete_function(
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
||||
self._CheckTrtOps(converted_concrete_func)
|
||||
converter.convert()
|
||||
|
||||
# Save the converted model without any TRT engine cache.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
@ -372,7 +367,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self.assertFalse(os.path.exists(unexpected_asset_file))
|
||||
|
||||
# Run the converted function to populate the engine cache.
|
||||
output_with_trt = converted_func(np_input)
|
||||
output_with_trt = converter.build(np_input1, np_input2)
|
||||
self.assertEqual(1, len(output_with_trt))
|
||||
self.assertAllClose(
|
||||
expected_output,
|
||||
@ -402,7 +397,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
# self._CheckTrtOps(root_with_trt.run.get_concrete_function())
|
||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||
self._CheckTrtOps(converted_signature)
|
||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
# The output of running the converted signature is a dict due to
|
||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||
output_with_trt = output_with_trt.values()[0]
|
||||
@ -417,28 +413,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
|
||||
# Create a model and save it.
|
||||
input_saved_model_dir = self.mkdtemp()
|
||||
root = self._GetModelForV2()
|
||||
expected_output = root.run(np_input)
|
||||
expected_output = root.run(np_input1, np_input2)
|
||||
save.save(root, input_saved_model_dir,
|
||||
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||
|
||||
# Run TRT conversion.
|
||||
converter = self._CreateConverterV2(input_saved_model_dir, max_batch_size=4)
|
||||
converted_func = converter.convert()
|
||||
converter.convert()
|
||||
|
||||
def _CheckFn(node):
|
||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||
|
||||
# Verify the converted GraphDef and ConcreteFunction.
|
||||
self.assertIsInstance(converted_func, def_function.Function)
|
||||
converted_concrete_func = converted_func.get_concrete_function(
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
||||
self._CheckTrtOps(converted_concrete_func, _CheckFn)
|
||||
|
||||
# Save the converted model with the statically-built engine inlined.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
converter.save(output_saved_model_dir)
|
||||
@ -453,7 +444,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
root_with_trt = load.load(output_saved_model_dir)
|
||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||
self._CheckTrtOps(converted_signature, _CheckFn)
|
||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
# The output of running the converted signature is a dict due to
|
||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||
self.assertAllClose(
|
||||
@ -470,28 +462,25 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
|
||||
# Create a model and save it.
|
||||
input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||
root = self._GetModelForV2()
|
||||
expected_output = root.run(np_input)
|
||||
expected_output = root.run(np_input1, np_input2)
|
||||
save.save(root, input_saved_model_dir,
|
||||
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||
|
||||
# Run TRT conversion.
|
||||
converter = self._CreateConverterV2(
|
||||
input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.INT8)
|
||||
converted_func = converter.convert()
|
||||
|
||||
# Run the converted function for INT8 calibration.
|
||||
calibration_output = converted_func(np_input)
|
||||
self.assertEqual(1, len(calibration_output))
|
||||
self.assertAllClose(
|
||||
expected_output,
|
||||
list(calibration_output.values())[0],
|
||||
atol=1e-6,
|
||||
rtol=1e-6)
|
||||
# Convert and perform INT8 calibration
|
||||
def input_map_fn():
|
||||
return np_input1, np_input2
|
||||
converter.convert(num_calibration_runs=1,
|
||||
calibration_input_map_fn=input_map_fn)
|
||||
|
||||
# Save the converted model again with serialized engine cache.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
@ -511,7 +500,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
root_with_trt = load.load(output_saved_model_dir)
|
||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||
self._CheckTrtOps(converted_signature, _CheckFn)
|
||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
self.assertEqual(1, len(output_with_trt))
|
||||
# The output of running the converted signature is a dict due to
|
||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||
@ -523,8 +513,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Run with an input of different batch size. It should build a new engine
|
||||
# using calibration table.
|
||||
np_input = np.random.random_sample([5, 1, 1]).astype(np.float32)
|
||||
converted_signature(ops.convert_to_tensor(np_input))
|
||||
np_input1 = np.random.random_sample([5, 1, 1]).astype(np.float32)
|
||||
np_input2 = np.random.random_sample([5, 1, 1]).astype(np.float32)
|
||||
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
|
||||
del root_with_trt
|
||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||
@ -535,7 +527,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
|
||||
# Create a model and save it.
|
||||
input_saved_model_dir = self.mkdtemp()
|
||||
@ -545,8 +538,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Run TRT conversion.
|
||||
converter = self._CreateConverterV2(input_saved_model_dir)
|
||||
converted_func = converter.convert()
|
||||
converted_func(np_input) # Populate the TRT engine cache.
|
||||
converter.convert()
|
||||
converter.build(np_input1, np_input2) # Populate the TRT engine cache.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
converter.save(output_saved_model_dir)
|
||||
|
||||
@ -696,22 +689,26 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self._CompareSavedModel(_Model)
|
||||
|
||||
def _TestRun(self, sess, batch_size, expect_engine_is_run=True):
|
||||
result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
|
||||
self.assertAllEqual([[[4.0]]] * batch_size, result)
|
||||
result = sess.run("output:0",
|
||||
feed_dict={"input1:0": [[[1.0]]] * batch_size,
|
||||
"input2:0": [[[1.0]]] * batch_size})
|
||||
self.assertAllEqual([[[5.0]]] * batch_size, result)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testTrtGraphConverter_MinimumSegmentSize(self):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
output_graph_def = self._ConvertGraph(minimum_segment_size=5)
|
||||
output_graph_def = self._ConvertGraph(minimum_segment_size=7)
|
||||
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
|
||||
self.assertEqual(
|
||||
{
|
||||
"add/ReadVariableOp": "Const",
|
||||
"input": "Placeholder",
|
||||
"input1": "Placeholder",
|
||||
"input2": "Placeholder",
|
||||
"add": "AddV2",
|
||||
"mul": "Mul",
|
||||
"add_1": "AddV2",
|
||||
"add_2": "AddV2",
|
||||
"output": "Identity"
|
||||
}, node_name_to_op)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user