Add calibration to TrtGraphConverterV2.convert
Add TrtGraphConverterV2.build Also do not return function from convert. Convert dict_values to list for python3 Fix tests as well Fix pylint errors
This commit is contained in:
parent
163f9df4c7
commit
ec5c02724a
@ -878,6 +878,7 @@ class TrtGraphConverterV2(object):
|
|||||||
if (self._need_calibration and not conversion_params.is_dynamic_op):
|
if (self._need_calibration and not conversion_params.is_dynamic_op):
|
||||||
raise ValueError("INT8 precision mode with calibration is not supported "
|
raise ValueError("INT8 precision mode with calibration is not supported "
|
||||||
"with static TensorRT ops. Set is_dynamic_op to True.")
|
"with static TensorRT ops. Set is_dynamic_op to True.")
|
||||||
|
self._calibration_data_collected = False
|
||||||
|
|
||||||
self._converted = False
|
self._converted = False
|
||||||
|
|
||||||
@ -900,13 +901,41 @@ class TrtGraphConverterV2(object):
|
|||||||
|
|
||||||
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
|
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
|
||||||
# use it here (b/124792963).
|
# use it here (b/124792963).
|
||||||
def convert(self):
|
def convert(self,
|
||||||
|
num_calibration_runs=None,
|
||||||
|
calibration_input_map_fn=None):
|
||||||
"""Convert the input SavedModel in 2.0 format.
|
"""Convert the input SavedModel in 2.0 format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_calibration_runs: number of runs of the graph during calibration.
|
||||||
|
calibration_input_map_fn: a function that returns inputs for the converted
|
||||||
|
tf_function to be calibrated.
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
def input_map_fn():
|
||||||
|
return input1, input2, input3
|
||||||
|
```
|
||||||
|
Raises:
|
||||||
|
ValueError: if the input combination is invalid.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The TF-TRT converted Function.
|
The TF-TRT converted Function.
|
||||||
"""
|
"""
|
||||||
assert not self._converted
|
assert not self._converted
|
||||||
|
|
||||||
|
if (self._need_calibration and
|
||||||
|
(not num_calibration_runs or
|
||||||
|
not calibration_input_map_fn)):
|
||||||
|
raise ValueError(
|
||||||
|
"Should specify num_calibration_runs and calibration_input_map_fn"
|
||||||
|
"because INT8 calibration is needed")
|
||||||
|
if (not self._need_calibration and
|
||||||
|
(num_calibration_runs or
|
||||||
|
calibration_input_map_fn)):
|
||||||
|
raise ValueError(
|
||||||
|
"Should not specify num_calibration_runs or calibration_input_map_fn"
|
||||||
|
"because INT8 calibration is not needed")
|
||||||
|
|
||||||
self._saved_model = load.load(self._input_saved_model_dir,
|
self._saved_model = load.load(self._input_saved_model_dir,
|
||||||
self._input_saved_model_tags)
|
self._input_saved_model_tags)
|
||||||
func = self._saved_model.signatures[self._input_saved_model_signature_key]
|
func = self._saved_model.signatures[self._input_saved_model_signature_key]
|
||||||
@ -934,13 +963,49 @@ class TrtGraphConverterV2(object):
|
|||||||
|
|
||||||
self._converted = True
|
self._converted = True
|
||||||
|
|
||||||
# Wrap the converted ConcreteFunction in a Function so it can accept numpy
|
if self._need_calibration and not self._calibration_data_collected:
|
||||||
# arrays as input.
|
self._calibrate(num_runs=num_calibration_runs,
|
||||||
@def_function.function
|
input_map_fn=calibration_input_map_fn)
|
||||||
def wrapper_func(*args, **kwargs):
|
|
||||||
return self._converted_func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper_func
|
def _calibrate(self,
|
||||||
|
num_runs=None,
|
||||||
|
input_map_fn=None):
|
||||||
|
"""Run calibration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_runs: number of runs of the graph during calibration.
|
||||||
|
input_map_fn: a function that returns inputs for the converted
|
||||||
|
tf_function to be calibrated.
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
def input_map_fn():
|
||||||
|
return input1, input2, input3
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
assert self._converted
|
||||||
|
assert self._need_calibration
|
||||||
|
assert num_runs
|
||||||
|
assert input_map_fn
|
||||||
|
|
||||||
|
for _ in range(num_runs):
|
||||||
|
self._converted_func(*map(ops.convert_to_tensor, input_map_fn()))
|
||||||
|
|
||||||
|
self._calibration_data_collected = True
|
||||||
|
|
||||||
|
def build(self, *args, **kwargs):
|
||||||
|
"""Run inference on graph in order to build a TensorRT engine
|
||||||
|
in the cahce of TRTEngineOp.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of the converted Function for the given inputs.
|
||||||
|
"""
|
||||||
|
args_tensor = [ops.convert_to_tensor(arg) for arg in args]
|
||||||
|
kwargs_tensor = {k: ops.convert_to_tensor(v) for k, v in kwargs.items()}
|
||||||
|
try:
|
||||||
|
return self._converted_func(*args_tensor, **kwargs_tensor)
|
||||||
|
except OpError:
|
||||||
|
print('Failure in execution of function with input args {}'
|
||||||
|
'and kwargs {}'.format(args_tensor, kwargs_tensor))
|
||||||
|
|
||||||
def save(self, output_saved_model_dir):
|
def save(self, output_saved_model_dir):
|
||||||
"""Save the converted SavedModel.
|
"""Save the converted SavedModel.
|
||||||
|
@ -117,22 +117,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _GetGraph(cls, inp, var):
|
def _GetGraph(cls, inp1, inp2, var):
|
||||||
"""Get the graph for testing."""
|
"""Get the graph for testing."""
|
||||||
# The graph computes (input+1)^2, it looks like:
|
# The graph computes: inp1^2 + inp1*var + inp1 + inp2 + var
|
||||||
#
|
add = inp1 + var
|
||||||
# input (Placeholder) v1 (Variable)
|
mul = inp1 * add
|
||||||
# | \ /
|
|
||||||
# \ +
|
|
||||||
# \ / \
|
|
||||||
# * |
|
|
||||||
# \ /
|
|
||||||
# +
|
|
||||||
# |
|
|
||||||
# output (Identity)
|
|
||||||
add = inp + var
|
|
||||||
mul = inp * add
|
|
||||||
add = mul + add
|
add = mul + add
|
||||||
|
add = add + inp2
|
||||||
out = array_ops.identity(add, name="output")
|
out = array_ops.identity(add, name="output")
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -144,12 +135,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self.v = None
|
self.v = None
|
||||||
|
|
||||||
@def_function.function(input_signature=[
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
|
||||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
|
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
|
||||||
])
|
])
|
||||||
def run(self, inp):
|
def run(self, inp1, inp2):
|
||||||
if self.v is None:
|
if self.v is None:
|
||||||
self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
|
self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
|
||||||
return TrtConvertTest._GetGraph(inp, self.v)
|
return TrtConvertTest._GetGraph(inp1, inp2, self.v)
|
||||||
|
|
||||||
return SimpleModel()
|
return SimpleModel()
|
||||||
|
|
||||||
@ -157,15 +149,17 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
with g.device("/GPU:0"):
|
with g.device("/GPU:0"):
|
||||||
inp = array_ops.placeholder(
|
inp1 = array_ops.placeholder(
|
||||||
dtype=dtypes.float32, shape=[None, 1, 1], name="input")
|
dtype=dtypes.float32, shape=[None, 1, 1], name="input1")
|
||||||
|
inp2 = array_ops.placeholder(
|
||||||
|
dtype=dtypes.float32, shape=[None, 1, 1], name="input2")
|
||||||
var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
|
var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
|
||||||
out = TrtConvertTest._GetGraph(inp, var)
|
out = TrtConvertTest._GetGraph(inp1, inp2, var)
|
||||||
return g, var, inp, out
|
return g, var, inp1, inp2, out
|
||||||
|
|
||||||
def _GetGraphDef(self):
|
def _GetGraphDef(self):
|
||||||
"""Get the graph def for testing."""
|
"""Get the graph def for testing."""
|
||||||
g, var, _, _ = self._GetGraphForV1()
|
g, var, _, _, _ = self._GetGraphForV1()
|
||||||
with self.session(graph=g, config=self._GetConfigProto()) as sess:
|
with self.session(graph=g, config=self._GetConfigProto()) as sess:
|
||||||
sess.run(var.initializer)
|
sess.run(var.initializer)
|
||||||
graph_def = graph_util.convert_variables_to_constants(
|
graph_def = graph_util.convert_variables_to_constants(
|
||||||
@ -175,19 +169,22 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
{
|
{
|
||||||
"v1": "Const",
|
"v1": "Const",
|
||||||
"add/ReadVariableOp": "Identity",
|
"add/ReadVariableOp": "Identity",
|
||||||
"input": "Placeholder",
|
"input1": "Placeholder",
|
||||||
|
"input2": "Placeholder",
|
||||||
"add": "AddV2",
|
"add": "AddV2",
|
||||||
"mul": "Mul",
|
"mul": "Mul",
|
||||||
"add_1": "AddV2",
|
"add_1": "AddV2",
|
||||||
|
"add_2": "AddV2",
|
||||||
"output": "Identity"
|
"output": "Identity"
|
||||||
}, node_name_to_op)
|
}, node_name_to_op)
|
||||||
return graph_def
|
return graph_def
|
||||||
|
|
||||||
def _WriteInputSavedModel(self, input_saved_model_dir):
|
def _WriteInputSavedModel(self, input_saved_model_dir):
|
||||||
"""Write the saved model as an input for testing."""
|
"""Write the saved model as an input for testing."""
|
||||||
g, var, inp, out = self._GetGraphForV1()
|
g, var, inp1, inp2, out = self._GetGraphForV1()
|
||||||
signature_def = signature_def_utils.build_signature_def(
|
signature_def = signature_def_utils.build_signature_def(
|
||||||
inputs={"myinput": utils.build_tensor_info(inp)},
|
inputs={"myinput1": utils.build_tensor_info(inp1),
|
||||||
|
"myinput2": utils.build_tensor_info(inp2)},
|
||||||
outputs={"myoutput": utils.build_tensor_info(out)},
|
outputs={"myoutput": utils.build_tensor_info(out)},
|
||||||
method_name=signature_constants.PREDICT_METHOD_NAME)
|
method_name=signature_constants.PREDICT_METHOD_NAME)
|
||||||
saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
|
saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
|
||||||
@ -231,7 +228,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
self._data += 1
|
self._data += 1
|
||||||
return {"input:0": [[[self._data]]]}
|
return {"input1:0": [[[self._data]]],
|
||||||
|
"input2:0": [[[self._data]]]}
|
||||||
|
|
||||||
output_graph_def = converter.calibrate(
|
output_graph_def = converter.calibrate(
|
||||||
fetch_names=["output:0"],
|
fetch_names=["output:0"],
|
||||||
@ -265,7 +263,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
node_name_to_op = {node.name: node.op for node in graph_def.node}
|
node_name_to_op = {node.name: node.op for node in graph_def.node}
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{
|
{
|
||||||
"input": "Placeholder",
|
"input1": "Placeholder",
|
||||||
|
"input2": "Placeholder",
|
||||||
"TRTEngineOp_0": "TRTEngineOp",
|
"TRTEngineOp_0": "TRTEngineOp",
|
||||||
"output": "Identity"
|
"output": "Identity"
|
||||||
}, node_name_to_op)
|
}, node_name_to_op)
|
||||||
@ -284,8 +283,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
with self.session(config=self._GetConfigProto()) as sess:
|
with self.session(config=self._GetConfigProto()) as sess:
|
||||||
for test_data in range(10):
|
for test_data in range(10):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
(test_data + 1.0)**2,
|
(test_data + 1.0)**2 + test_data,
|
||||||
sess.run("output:0", feed_dict={"input:0": [[[test_data]]]}))
|
sess.run("output:0", feed_dict={"input1:0": [[[test_data]]],
|
||||||
|
"input2:0": [[[test_data]]]}))
|
||||||
|
|
||||||
@test_util.deprecated_graph_mode_only
|
@test_util.deprecated_graph_mode_only
|
||||||
def testTrtGraphConverter_BasicConversion(self):
|
def testTrtGraphConverter_BasicConversion(self):
|
||||||
@ -345,24 +345,19 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
if not is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
|
|
||||||
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||||
|
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||||
|
|
||||||
# Create a model and save it.
|
# Create a model and save it.
|
||||||
input_saved_model_dir = self.mkdtemp()
|
input_saved_model_dir = self.mkdtemp()
|
||||||
root = self._GetModelForV2()
|
root = self._GetModelForV2()
|
||||||
expected_output = root.run(np_input)
|
expected_output = root.run(np_input1, np_input2)
|
||||||
save.save(root, input_saved_model_dir,
|
save.save(root, input_saved_model_dir,
|
||||||
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||||
|
|
||||||
# Run TRT conversion.
|
# Run TRT conversion.
|
||||||
converter = self._CreateConverterV2(input_saved_model_dir)
|
converter = self._CreateConverterV2(input_saved_model_dir)
|
||||||
converted_func = converter.convert()
|
converter.convert()
|
||||||
|
|
||||||
# Verify the converted GraphDef and ConcreteFunction.
|
|
||||||
self.assertIsInstance(converted_func, def_function.Function)
|
|
||||||
converted_concrete_func = converted_func.get_concrete_function(
|
|
||||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
|
||||||
self._CheckTrtOps(converted_concrete_func)
|
|
||||||
|
|
||||||
# Save the converted model without any TRT engine cache.
|
# Save the converted model without any TRT engine cache.
|
||||||
output_saved_model_dir = self.mkdtemp()
|
output_saved_model_dir = self.mkdtemp()
|
||||||
@ -372,7 +367,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertFalse(os.path.exists(unexpected_asset_file))
|
self.assertFalse(os.path.exists(unexpected_asset_file))
|
||||||
|
|
||||||
# Run the converted function to populate the engine cache.
|
# Run the converted function to populate the engine cache.
|
||||||
output_with_trt = converted_func(np_input)
|
output_with_trt = converter.build(np_input1, np_input2)
|
||||||
self.assertEqual(1, len(output_with_trt))
|
self.assertEqual(1, len(output_with_trt))
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
expected_output,
|
expected_output,
|
||||||
@ -402,7 +397,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
# self._CheckTrtOps(root_with_trt.run.get_concrete_function())
|
# self._CheckTrtOps(root_with_trt.run.get_concrete_function())
|
||||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||||
self._CheckTrtOps(converted_signature)
|
self._CheckTrtOps(converted_signature)
|
||||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||||
|
inp2=ops.convert_to_tensor(np_input2))
|
||||||
# The output of running the converted signature is a dict due to
|
# The output of running the converted signature is a dict due to
|
||||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||||
output_with_trt = output_with_trt.values()[0]
|
output_with_trt = output_with_trt.values()[0]
|
||||||
@ -417,28 +413,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
if not is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
|
|
||||||
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||||
|
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||||
|
|
||||||
# Create a model and save it.
|
# Create a model and save it.
|
||||||
input_saved_model_dir = self.mkdtemp()
|
input_saved_model_dir = self.mkdtemp()
|
||||||
root = self._GetModelForV2()
|
root = self._GetModelForV2()
|
||||||
expected_output = root.run(np_input)
|
expected_output = root.run(np_input1, np_input2)
|
||||||
save.save(root, input_saved_model_dir,
|
save.save(root, input_saved_model_dir,
|
||||||
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||||
|
|
||||||
# Run TRT conversion.
|
# Run TRT conversion.
|
||||||
converter = self._CreateConverterV2(input_saved_model_dir, max_batch_size=4)
|
converter = self._CreateConverterV2(input_saved_model_dir, max_batch_size=4)
|
||||||
converted_func = converter.convert()
|
converter.convert()
|
||||||
|
|
||||||
def _CheckFn(node):
|
def _CheckFn(node):
|
||||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||||
|
|
||||||
# Verify the converted GraphDef and ConcreteFunction.
|
|
||||||
self.assertIsInstance(converted_func, def_function.Function)
|
|
||||||
converted_concrete_func = converted_func.get_concrete_function(
|
|
||||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
|
||||||
self._CheckTrtOps(converted_concrete_func, _CheckFn)
|
|
||||||
|
|
||||||
# Save the converted model with the statically-built engine inlined.
|
# Save the converted model with the statically-built engine inlined.
|
||||||
output_saved_model_dir = self.mkdtemp()
|
output_saved_model_dir = self.mkdtemp()
|
||||||
converter.save(output_saved_model_dir)
|
converter.save(output_saved_model_dir)
|
||||||
@ -453,7 +444,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
root_with_trt = load.load(output_saved_model_dir)
|
root_with_trt = load.load(output_saved_model_dir)
|
||||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||||
self._CheckTrtOps(converted_signature, _CheckFn)
|
self._CheckTrtOps(converted_signature, _CheckFn)
|
||||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||||
|
inp2=ops.convert_to_tensor(np_input2))
|
||||||
# The output of running the converted signature is a dict due to
|
# The output of running the converted signature is a dict due to
|
||||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
@ -470,28 +462,25 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
if not is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
|
|
||||||
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||||
|
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||||
|
|
||||||
# Create a model and save it.
|
# Create a model and save it.
|
||||||
input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||||
root = self._GetModelForV2()
|
root = self._GetModelForV2()
|
||||||
expected_output = root.run(np_input)
|
expected_output = root.run(np_input1, np_input2)
|
||||||
save.save(root, input_saved_model_dir,
|
save.save(root, input_saved_model_dir,
|
||||||
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||||
|
|
||||||
# Run TRT conversion.
|
# Run TRT conversion.
|
||||||
converter = self._CreateConverterV2(
|
converter = self._CreateConverterV2(
|
||||||
input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.INT8)
|
input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.INT8)
|
||||||
converted_func = converter.convert()
|
|
||||||
|
|
||||||
# Run the converted function for INT8 calibration.
|
# Convert and perform INT8 calibration
|
||||||
calibration_output = converted_func(np_input)
|
def input_map_fn():
|
||||||
self.assertEqual(1, len(calibration_output))
|
return np_input1, np_input2
|
||||||
self.assertAllClose(
|
converter.convert(num_calibration_runs=1,
|
||||||
expected_output,
|
calibration_input_map_fn=input_map_fn)
|
||||||
list(calibration_output.values())[0],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-6)
|
|
||||||
|
|
||||||
# Save the converted model again with serialized engine cache.
|
# Save the converted model again with serialized engine cache.
|
||||||
output_saved_model_dir = self.mkdtemp()
|
output_saved_model_dir = self.mkdtemp()
|
||||||
@ -511,7 +500,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
root_with_trt = load.load(output_saved_model_dir)
|
root_with_trt = load.load(output_saved_model_dir)
|
||||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||||
self._CheckTrtOps(converted_signature, _CheckFn)
|
self._CheckTrtOps(converted_signature, _CheckFn)
|
||||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||||
|
inp2=ops.convert_to_tensor(np_input2))
|
||||||
self.assertEqual(1, len(output_with_trt))
|
self.assertEqual(1, len(output_with_trt))
|
||||||
# The output of running the converted signature is a dict due to
|
# The output of running the converted signature is a dict due to
|
||||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||||
@ -523,8 +513,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
# Run with an input of different batch size. It should build a new engine
|
# Run with an input of different batch size. It should build a new engine
|
||||||
# using calibration table.
|
# using calibration table.
|
||||||
np_input = np.random.random_sample([5, 1, 1]).astype(np.float32)
|
np_input1 = np.random.random_sample([5, 1, 1]).astype(np.float32)
|
||||||
converted_signature(ops.convert_to_tensor(np_input))
|
np_input2 = np.random.random_sample([5, 1, 1]).astype(np.float32)
|
||||||
|
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||||
|
inp2=ops.convert_to_tensor(np_input2))
|
||||||
|
|
||||||
del root_with_trt
|
del root_with_trt
|
||||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||||
@ -535,7 +527,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
if not is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
|
|
||||||
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||||
|
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||||
|
|
||||||
# Create a model and save it.
|
# Create a model and save it.
|
||||||
input_saved_model_dir = self.mkdtemp()
|
input_saved_model_dir = self.mkdtemp()
|
||||||
@ -545,8 +538,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
# Run TRT conversion.
|
# Run TRT conversion.
|
||||||
converter = self._CreateConverterV2(input_saved_model_dir)
|
converter = self._CreateConverterV2(input_saved_model_dir)
|
||||||
converted_func = converter.convert()
|
converter.convert()
|
||||||
converted_func(np_input) # Populate the TRT engine cache.
|
converter.build(np_input1, np_input2) # Populate the TRT engine cache.
|
||||||
output_saved_model_dir = self.mkdtemp()
|
output_saved_model_dir = self.mkdtemp()
|
||||||
converter.save(output_saved_model_dir)
|
converter.save(output_saved_model_dir)
|
||||||
|
|
||||||
@ -696,22 +689,26 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self._CompareSavedModel(_Model)
|
self._CompareSavedModel(_Model)
|
||||||
|
|
||||||
def _TestRun(self, sess, batch_size, expect_engine_is_run=True):
|
def _TestRun(self, sess, batch_size, expect_engine_is_run=True):
|
||||||
result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
|
result = sess.run("output:0",
|
||||||
self.assertAllEqual([[[4.0]]] * batch_size, result)
|
feed_dict={"input1:0": [[[1.0]]] * batch_size,
|
||||||
|
"input2:0": [[[1.0]]] * batch_size})
|
||||||
|
self.assertAllEqual([[[5.0]]] * batch_size, result)
|
||||||
|
|
||||||
@test_util.deprecated_graph_mode_only
|
@test_util.deprecated_graph_mode_only
|
||||||
def testTrtGraphConverter_MinimumSegmentSize(self):
|
def testTrtGraphConverter_MinimumSegmentSize(self):
|
||||||
if not is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
output_graph_def = self._ConvertGraph(minimum_segment_size=5)
|
output_graph_def = self._ConvertGraph(minimum_segment_size=7)
|
||||||
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
|
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{
|
{
|
||||||
"add/ReadVariableOp": "Const",
|
"add/ReadVariableOp": "Const",
|
||||||
"input": "Placeholder",
|
"input1": "Placeholder",
|
||||||
|
"input2": "Placeholder",
|
||||||
"add": "AddV2",
|
"add": "AddV2",
|
||||||
"mul": "Mul",
|
"mul": "Mul",
|
||||||
"add_1": "AddV2",
|
"add_1": "AddV2",
|
||||||
|
"add_2": "AddV2",
|
||||||
"output": "Identity"
|
"output": "Identity"
|
||||||
}, node_name_to_op)
|
}, node_name_to_op)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user