Terminate calibration in TrtGraphConverterV2.convert() and improve the test to cover that.
This commit is contained in:
parent
61de424923
commit
fee3d86071
@ -5191,7 +5191,11 @@ Status ConvertGraphDefToEngine(
|
||||
}
|
||||
|
||||
// Build the network
|
||||
VLOG(1) << "Starting engine conversion ";
|
||||
if (VLOG_IS_ON(1)) {
|
||||
string mode_str;
|
||||
TF_RETURN_IF_ERROR(TrtPrecisionModeToName(precision_mode, &mode_str));
|
||||
VLOG(1) << "Starting engine conversion, precision mode: " << mode_str;
|
||||
}
|
||||
Converter converter(trt_network.get(), precision_mode, use_calibration);
|
||||
std::vector<Converter::EngineOutputInfo> output_tensors;
|
||||
// Graph nodes are already topologically sorted during construction
|
||||
|
@ -57,9 +57,8 @@ from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
# Lazily load the op, since it's not available in cpu-only builds. Importing
|
||||
# this at top will cause tests that imports TF-TRT fail when they're built
|
||||
# and run without CUDA/GPU.
|
||||
gen_trt_ops = LazyLoader(
|
||||
"gen_trt_ops", globals(),
|
||||
"tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
|
||||
gen_trt_ops = LazyLoader("gen_trt_ops", globals(),
|
||||
"tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
|
||||
|
||||
# Register TRT ops in python, so that when users import this module they can
|
||||
# execute a TRT-converted graph without calling any of the methods in this
|
||||
@ -898,6 +897,16 @@ class TrtGraphConverterV2(object):
|
||||
return tf_optimizer.OptimizeGraph(
|
||||
grappler_session_config, meta_graph_def, graph_id=b"tf_graph")
|
||||
|
||||
def _for_each_trt_node(self, graph_def, fn):
|
||||
"""Helper method to manipulate all TRTEngineOps in a GraphDef."""
|
||||
for node in graph_def.node:
|
||||
if node.op == _TRT_ENGINE_OP_NAME:
|
||||
fn(node)
|
||||
for func in graph_def.library.function:
|
||||
for node in func.node_def:
|
||||
if node.op == _TRT_ENGINE_OP_NAME:
|
||||
fn(node)
|
||||
|
||||
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
|
||||
# use it here (b/124792963).
|
||||
def convert(self, num_calibration_runs=None, calibration_input_map_fn=None):
|
||||
@ -907,11 +916,8 @@ class TrtGraphConverterV2(object):
|
||||
num_calibration_runs: number of runs of the graph during calibration.
|
||||
calibration_input_map_fn: a function that returns inputs for the converted
|
||||
tf_function to be calibrated.
|
||||
Example:
|
||||
```
|
||||
def input_map_fn():
|
||||
return input1, input2, input3
|
||||
```
|
||||
Example: ```
|
||||
def input_map_fn(): return input1, input2, input3 ```
|
||||
|
||||
Raises:
|
||||
ValueError: if the input combination is invalid.
|
||||
@ -962,6 +968,24 @@ class TrtGraphConverterV2(object):
|
||||
self._converted_func(
|
||||
*map(ops.convert_to_tensor, calibration_input_map_fn()))
|
||||
|
||||
def _save_calibration_table(node):
|
||||
calibration_table = gen_trt_ops.get_calibration_data_op(
|
||||
_get_canonical_engine_name(node.name))
|
||||
node.attr["calibration_data"].s = calibration_table.numpy()
|
||||
|
||||
self._for_each_trt_node(self._converted_graph_def,
|
||||
_save_calibration_table)
|
||||
|
||||
# Rebuild the function since calibration has changed the graph.
|
||||
calibrated_func = wrap_function.function_from_graph_def(
|
||||
self._converted_graph_def,
|
||||
[tensor.name for tensor in self._converted_func.inputs],
|
||||
[tensor.name for tensor in self._converted_func.outputs])
|
||||
calibrated_func.graph.structured_outputs = nest.pack_sequence_as(
|
||||
self._converted_func.graph.structured_outputs,
|
||||
calibrated_func.graph.structured_outputs)
|
||||
self._converted_func = calibrated_func
|
||||
|
||||
self._converted = True
|
||||
|
||||
def build(self, *args, **kwargs):
|
||||
@ -1002,10 +1026,6 @@ class TrtGraphConverterV2(object):
|
||||
|
||||
filename = os.path.join(engine_asset_dir,
|
||||
"trt-serialized-engine." + canonical_engine_name)
|
||||
if self._need_calibration:
|
||||
calibration_table = gen_trt_ops.get_calibration_data_op(
|
||||
canonical_engine_name)
|
||||
node.attr["calibration_data"].s = calibration_table.numpy()
|
||||
|
||||
try:
|
||||
gen_trt_ops.serialize_trt_resource(
|
||||
@ -1022,30 +1042,15 @@ class TrtGraphConverterV2(object):
|
||||
canonical_engine_name, filename,
|
||||
self._conversion_params.maximum_cached_engines)
|
||||
|
||||
for node in self._converted_graph_def.node:
|
||||
if node.op == _TRT_ENGINE_OP_NAME:
|
||||
_serialize_and_track_engine(node)
|
||||
for func in self._converted_graph_def.library.function:
|
||||
for node in func.node_def:
|
||||
if node.op == _TRT_ENGINE_OP_NAME:
|
||||
_serialize_and_track_engine(node)
|
||||
|
||||
self._for_each_trt_node(self._converted_graph_def,
|
||||
_serialize_and_track_engine)
|
||||
self._saved_model.trt_engine_resources = resource_map
|
||||
|
||||
# Rebuild the function since calibration may change the graph.
|
||||
func_to_save = wrap_function.function_from_graph_def(
|
||||
self._converted_graph_def,
|
||||
[tensor.name for tensor in self._converted_func.inputs],
|
||||
[tensor.name for tensor in self._converted_func.outputs])
|
||||
func_to_save.graph.structured_outputs = nest.pack_sequence_as(
|
||||
self._converted_func.graph.structured_outputs,
|
||||
func_to_save.graph.structured_outputs)
|
||||
|
||||
# Rewrite the signature map using the optimized ConcreteFunction.
|
||||
signatures = {
|
||||
key: value for key, value in self._saved_model.signatures.items()
|
||||
}
|
||||
signatures[self._input_saved_model_signature_key] = func_to_save
|
||||
signatures[self._input_saved_model_signature_key] = self._converted_func
|
||||
save.save(self._saved_model, output_saved_model_dir, signatures)
|
||||
|
||||
|
||||
|
@ -57,9 +57,8 @@ from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
|
||||
_SAVED_MODEL_SIGNATURE_KEY = "mypredict"
|
||||
|
||||
gen_trt_ops = LazyLoader(
|
||||
"gen_trt_ops", globals(),
|
||||
"tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
|
||||
gen_trt_ops = LazyLoader("gen_trt_ops", globals(),
|
||||
"tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
|
||||
|
||||
|
||||
class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
@ -183,8 +182,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
"""Write the saved model as an input for testing."""
|
||||
g, var, inp1, inp2, out = self._GetGraphForV1()
|
||||
signature_def = signature_def_utils.build_signature_def(
|
||||
inputs={"myinput1": utils.build_tensor_info(inp1),
|
||||
"myinput2": utils.build_tensor_info(inp2)},
|
||||
inputs={
|
||||
"myinput1": utils.build_tensor_info(inp1),
|
||||
"myinput2": utils.build_tensor_info(inp2)
|
||||
},
|
||||
outputs={"myoutput": utils.build_tensor_info(out)},
|
||||
method_name=signature_constants.PREDICT_METHOD_NAME)
|
||||
saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
|
||||
@ -228,8 +229,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def next(self):
|
||||
self._data += 1
|
||||
return {"input1:0": [[[self._data]]],
|
||||
"input2:0": [[[self._data]]]}
|
||||
return {"input1:0": [[[self._data]]], "input2:0": [[[self._data]]]}
|
||||
|
||||
output_graph_def = converter.calibrate(
|
||||
fetch_names=["output:0"],
|
||||
@ -282,10 +282,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
importer.import_graph_def(graph_def, name="")
|
||||
with self.session(config=self._GetConfigProto()) as sess:
|
||||
for test_data in range(10):
|
||||
self.assertEqual(
|
||||
(test_data + 1.0)**2 + test_data,
|
||||
sess.run("output:0", feed_dict={"input1:0": [[[test_data]]],
|
||||
"input2:0": [[[test_data]]]}))
|
||||
self.assertEqual((test_data + 1.0)**2 + test_data,
|
||||
sess.run(
|
||||
"output:0",
|
||||
feed_dict={
|
||||
"input1:0": [[[test_data]]],
|
||||
"input2:0": [[[test_data]]]
|
||||
}))
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testTrtGraphConverter_BasicConversion(self):
|
||||
@ -311,7 +314,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
input_saved_model_dir,
|
||||
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||
precision_mode=trt_convert.TrtPrecisionMode.FP32,
|
||||
max_batch_size=None):
|
||||
max_batch_size=None,
|
||||
maximum_cached_engines=2):
|
||||
return trt_convert.TrtGraphConverterV2(
|
||||
input_saved_model_dir=input_saved_model_dir,
|
||||
input_saved_model_signature_key=input_saved_model_signature_key,
|
||||
@ -319,7 +323,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
max_workspace_size_bytes=10 << 20, # Use a smaller workspace.
|
||||
precision_mode=precision_mode,
|
||||
is_dynamic_op=False if max_batch_size else True,
|
||||
maximum_cached_engines=2,
|
||||
maximum_cached_engines=maximum_cached_engines,
|
||||
max_batch_size=max_batch_size if max_batch_size else 1))
|
||||
|
||||
def _CheckTrtOps(self, concrete_func, check_fn=None):
|
||||
@ -339,14 +343,18 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(1, len(trt_op_names))
|
||||
self.assertIn("TRTEngineOp_0", trt_op_names[0])
|
||||
|
||||
def _RandomInput(self, shape, dtype=np.float32):
|
||||
inp1 = np.random.random_sample(shape).astype(dtype)
|
||||
inp2 = np.random.random_sample(shape).astype(dtype)
|
||||
return inp1, inp2
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTrtGraphConverter_DynamicConversion_v2(self):
|
||||
"""Test case for trt_convert.TrtGraphConverter()."""
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input1, np_input2 = self._RandomInput([4, 1, 1])
|
||||
|
||||
# Create a model and save it.
|
||||
input_saved_model_dir = self.mkdtemp()
|
||||
@ -360,7 +368,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
converter.convert()
|
||||
|
||||
# Verify the converted GraphDef and ConcreteFunction.
|
||||
self._CheckTrtOps(converter._converted_func)
|
||||
self._CheckTrtOps(converter._converted_func) # pylint: disable=protected-access
|
||||
|
||||
# Save the converted model without any TRT engine cache.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
@ -400,8 +408,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
# self._CheckTrtOps(root_with_trt.run.get_concrete_function())
|
||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||
self._CheckTrtOps(converted_signature)
|
||||
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
output_with_trt = converted_signature(
|
||||
inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
# The output of running the converted signature is a dict due to
|
||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||
output_with_trt = output_with_trt.values()[0]
|
||||
@ -416,8 +425,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input1, np_input2 = self._RandomInput([4, 1, 1])
|
||||
|
||||
# Create a model and save it.
|
||||
input_saved_model_dir = self.mkdtemp()
|
||||
@ -433,8 +441,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
def _CheckFn(node):
|
||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||
|
||||
# Verify the converted GraphDef and ConcreteFunction.
|
||||
self._CheckTrtOps(converter._converted_func, _CheckFn)
|
||||
# Verify the converted GraphDef.
|
||||
self._CheckTrtOps(converter._converted_func, _CheckFn) # pylint: disable=protected-access
|
||||
|
||||
# Save the converted model with the statically-built engine inlined.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
@ -450,8 +458,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
root_with_trt = load.load(output_saved_model_dir)
|
||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||
self._CheckTrtOps(converted_signature, _CheckFn)
|
||||
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
output_with_trt = converted_signature(
|
||||
inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
# The output of running the converted signature is a dict due to
|
||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||
self.assertAllClose(
|
||||
@ -468,8 +477,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input1, np_input2 = self._RandomInput([4, 1, 1])
|
||||
|
||||
# Create a model and save it.
|
||||
input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||
@ -480,15 +488,26 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Run TRT conversion.
|
||||
converter = self._CreateConverterV2(
|
||||
input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.INT8)
|
||||
input_saved_model_dir,
|
||||
precision_mode=trt_convert.TrtPrecisionMode.INT8,
|
||||
maximum_cached_engines=3)
|
||||
|
||||
# Convert and perform INT8 calibration
|
||||
def input_map_fn():
|
||||
return np_input1, np_input2
|
||||
converter.convert(num_calibration_runs=1,
|
||||
calibration_input_map_fn=input_map_fn)
|
||||
input_map_fn = lambda: (np_input1, np_input2)
|
||||
converter.convert(
|
||||
num_calibration_runs=2, calibration_input_map_fn=input_map_fn)
|
||||
|
||||
# Save the converted model again with serialized engine cache.
|
||||
def _CheckFn(node):
|
||||
self.assertTrue(len(node.attr["calibration_data"].s), node.name)
|
||||
|
||||
# Verify the converted GraphDef.
|
||||
self._CheckTrtOps(converter._converted_func, _CheckFn) # pylint: disable=protected-access
|
||||
|
||||
# Build another engine with different batch size.
|
||||
converter.build(*self._RandomInput([5, 1, 1]))
|
||||
|
||||
# Save the converted model.
|
||||
# TODO(laigd): check that it should contain two engines.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
converter.save(output_saved_model_dir)
|
||||
expected_asset_file = os.path.join(
|
||||
@ -499,15 +518,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
del converter
|
||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||
|
||||
def _CheckFn(node):
|
||||
self.assertTrue(len(node.attr["calibration_data"].s), node.name)
|
||||
|
||||
# Load and verify the converted model.
|
||||
root_with_trt = load.load(output_saved_model_dir)
|
||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||
self._CheckTrtOps(converted_signature, _CheckFn)
|
||||
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
output_with_trt = converted_signature(
|
||||
inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
self.assertEqual(1, len(output_with_trt))
|
||||
# The output of running the converted signature is a dict due to
|
||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||
@ -519,10 +536,11 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Run with an input of different batch size. It should build a new engine
|
||||
# using calibration table.
|
||||
np_input1 = np.random.random_sample([5, 1, 1]).astype(np.float32)
|
||||
np_input2 = np.random.random_sample([5, 1, 1]).astype(np.float32)
|
||||
output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
# TODO(laigd): check that it should contain three engines.
|
||||
np_input1, np_input2 = self._RandomInput([6, 1, 1])
|
||||
converted_signature(
|
||||
inp1=ops.convert_to_tensor(np_input1),
|
||||
inp2=ops.convert_to_tensor(np_input2))
|
||||
|
||||
del root_with_trt
|
||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||
@ -533,8 +551,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
np_input1, np_input2 = self._RandomInput([4, 1, 1])
|
||||
|
||||
# Create a model and save it.
|
||||
input_saved_model_dir = self.mkdtemp()
|
||||
@ -695,9 +712,12 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self._CompareSavedModel(_Model)
|
||||
|
||||
def _TestRun(self, sess, batch_size, expect_engine_is_run=True):
|
||||
result = sess.run("output:0",
|
||||
feed_dict={"input1:0": [[[1.0]]] * batch_size,
|
||||
"input2:0": [[[1.0]]] * batch_size})
|
||||
result = sess.run(
|
||||
"output:0",
|
||||
feed_dict={
|
||||
"input1:0": [[[1.0]]] * batch_size,
|
||||
"input2:0": [[[1.0]]] * batch_size
|
||||
})
|
||||
self.assertAllEqual([[[5.0]]] * batch_size, result)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
|
Loading…
Reference in New Issue
Block a user