From fee3d860717075742c2d30f160beb598c57f2a47 Mon Sep 17 00:00:00 2001 From: Guangda Lai <31743510+aaroey@users.noreply.github.com> Date: Sat, 17 Aug 2019 22:37:05 -0700 Subject: [PATCH] Terminate calibration in TrtGraphConverterV2.convert() and improve the test to cover that. --- .../tf2tensorrt/convert/convert_nodes.cc | 6 +- .../python/compiler/tensorrt/trt_convert.py | 65 +++++----- .../compiler/tensorrt/trt_convert_test.py | 112 +++++++++++------- 3 files changed, 106 insertions(+), 77 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 746fdf17d22..13ac187b566 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -5191,7 +5191,11 @@ Status ConvertGraphDefToEngine( } // Build the network - VLOG(1) << "Starting engine conversion "; + if (VLOG_IS_ON(1)) { + string mode_str; + TF_RETURN_IF_ERROR(TrtPrecisionModeToName(precision_mode, &mode_str)); + VLOG(1) << "Starting engine conversion, precision mode: " << mode_str; + } Converter converter(trt_network.get(), precision_mode, use_calibration); std::vector output_tensors; // Graph nodes are already topologically sorted during construction diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index ef629a9bf2b..a4e5d329b61 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -57,9 +57,8 @@ from tensorflow.python.util.lazy_loader import LazyLoader # Lazily load the op, since it's not available in cpu-only builds. Importing # this at top will cause tests that imports TF-TRT fail when they're built # and run without CUDA/GPU. -gen_trt_ops = LazyLoader( - "gen_trt_ops", globals(), - "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops") +gen_trt_ops = LazyLoader("gen_trt_ops", globals(), + "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops") # Register TRT ops in python, so that when users import this module they can # execute a TRT-converted graph without calling any of the methods in this @@ -898,6 +897,16 @@ class TrtGraphConverterV2(object): return tf_optimizer.OptimizeGraph( grappler_session_config, meta_graph_def, graph_id=b"tf_graph") + def _for_each_trt_node(self, graph_def, fn): + """Helper method to manipulate all TRTEngineOps in a GraphDef.""" + for node in graph_def.node: + if node.op == _TRT_ENGINE_OP_NAME: + fn(node) + for func in graph_def.library.function: + for node in func.node_def: + if node.op == _TRT_ENGINE_OP_NAME: + fn(node) + # TODO(laigd): provide a utility function to optimize a ConcreteFunction and # use it here (b/124792963). def convert(self, num_calibration_runs=None, calibration_input_map_fn=None): @@ -907,11 +916,8 @@ class TrtGraphConverterV2(object): num_calibration_runs: number of runs of the graph during calibration. calibration_input_map_fn: a function that returns inputs for the converted tf_function to be calibrated. - Example: - ``` - def input_map_fn(): - return input1, input2, input3 - ``` + Example: ``` + def input_map_fn(): return input1, input2, input3 ``` Raises: ValueError: if the input combination is invalid. @@ -962,6 +968,24 @@ class TrtGraphConverterV2(object): self._converted_func( *map(ops.convert_to_tensor, calibration_input_map_fn())) + def _save_calibration_table(node): + calibration_table = gen_trt_ops.get_calibration_data_op( + _get_canonical_engine_name(node.name)) + node.attr["calibration_data"].s = calibration_table.numpy() + + self._for_each_trt_node(self._converted_graph_def, + _save_calibration_table) + + # Rebuild the function since calibration has changed the graph. + calibrated_func = wrap_function.function_from_graph_def( + self._converted_graph_def, + [tensor.name for tensor in self._converted_func.inputs], + [tensor.name for tensor in self._converted_func.outputs]) + calibrated_func.graph.structured_outputs = nest.pack_sequence_as( + self._converted_func.graph.structured_outputs, + calibrated_func.graph.structured_outputs) + self._converted_func = calibrated_func + self._converted = True def build(self, *args, **kwargs): @@ -1002,10 +1026,6 @@ class TrtGraphConverterV2(object): filename = os.path.join(engine_asset_dir, "trt-serialized-engine." + canonical_engine_name) - if self._need_calibration: - calibration_table = gen_trt_ops.get_calibration_data_op( - canonical_engine_name) - node.attr["calibration_data"].s = calibration_table.numpy() try: gen_trt_ops.serialize_trt_resource( @@ -1022,30 +1042,15 @@ class TrtGraphConverterV2(object): canonical_engine_name, filename, self._conversion_params.maximum_cached_engines) - for node in self._converted_graph_def.node: - if node.op == _TRT_ENGINE_OP_NAME: - _serialize_and_track_engine(node) - for func in self._converted_graph_def.library.function: - for node in func.node_def: - if node.op == _TRT_ENGINE_OP_NAME: - _serialize_and_track_engine(node) - + self._for_each_trt_node(self._converted_graph_def, + _serialize_and_track_engine) self._saved_model.trt_engine_resources = resource_map - # Rebuild the function since calibration may change the graph. - func_to_save = wrap_function.function_from_graph_def( - self._converted_graph_def, - [tensor.name for tensor in self._converted_func.inputs], - [tensor.name for tensor in self._converted_func.outputs]) - func_to_save.graph.structured_outputs = nest.pack_sequence_as( - self._converted_func.graph.structured_outputs, - func_to_save.graph.structured_outputs) - # Rewrite the signature map using the optimized ConcreteFunction. signatures = { key: value for key, value in self._saved_model.signatures.items() } - signatures[self._input_saved_model_signature_key] = func_to_save + signatures[self._input_saved_model_signature_key] = self._converted_func save.save(self._saved_model, output_saved_model_dir, signatures) diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index a1a55acb19d..4319d3ee4b9 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -57,9 +57,8 @@ from tensorflow.python.util.lazy_loader import LazyLoader _SAVED_MODEL_SIGNATURE_KEY = "mypredict" -gen_trt_ops = LazyLoader( - "gen_trt_ops", globals(), - "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops") +gen_trt_ops = LazyLoader("gen_trt_ops", globals(), + "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops") class TrtConvertTest(test_util.TensorFlowTestCase): @@ -183,8 +182,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase): """Write the saved model as an input for testing.""" g, var, inp1, inp2, out = self._GetGraphForV1() signature_def = signature_def_utils.build_signature_def( - inputs={"myinput1": utils.build_tensor_info(inp1), - "myinput2": utils.build_tensor_info(inp2)}, + inputs={ + "myinput1": utils.build_tensor_info(inp1), + "myinput2": utils.build_tensor_info(inp2) + }, outputs={"myoutput": utils.build_tensor_info(out)}, method_name=signature_constants.PREDICT_METHOD_NAME) saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir) @@ -228,8 +229,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): def next(self): self._data += 1 - return {"input1:0": [[[self._data]]], - "input2:0": [[[self._data]]]} + return {"input1:0": [[[self._data]]], "input2:0": [[[self._data]]]} output_graph_def = converter.calibrate( fetch_names=["output:0"], @@ -282,10 +282,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase): importer.import_graph_def(graph_def, name="") with self.session(config=self._GetConfigProto()) as sess: for test_data in range(10): - self.assertEqual( - (test_data + 1.0)**2 + test_data, - sess.run("output:0", feed_dict={"input1:0": [[[test_data]]], - "input2:0": [[[test_data]]]})) + self.assertEqual((test_data + 1.0)**2 + test_data, + sess.run( + "output:0", + feed_dict={ + "input1:0": [[[test_data]]], + "input2:0": [[[test_data]]] + })) @test_util.deprecated_graph_mode_only def testTrtGraphConverter_BasicConversion(self): @@ -311,7 +314,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): input_saved_model_dir, input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, precision_mode=trt_convert.TrtPrecisionMode.FP32, - max_batch_size=None): + max_batch_size=None, + maximum_cached_engines=2): return trt_convert.TrtGraphConverterV2( input_saved_model_dir=input_saved_model_dir, input_saved_model_signature_key=input_saved_model_signature_key, @@ -319,7 +323,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): max_workspace_size_bytes=10 << 20, # Use a smaller workspace. precision_mode=precision_mode, is_dynamic_op=False if max_batch_size else True, - maximum_cached_engines=2, + maximum_cached_engines=maximum_cached_engines, max_batch_size=max_batch_size if max_batch_size else 1)) def _CheckTrtOps(self, concrete_func, check_fn=None): @@ -339,14 +343,18 @@ class TrtConvertTest(test_util.TensorFlowTestCase): self.assertEqual(1, len(trt_op_names)) self.assertIn("TRTEngineOp_0", trt_op_names[0]) + def _RandomInput(self, shape, dtype=np.float32): + inp1 = np.random.random_sample(shape).astype(dtype) + inp2 = np.random.random_sample(shape).astype(dtype) + return inp1, inp2 + @test_util.run_v2_only def testTrtGraphConverter_DynamicConversion_v2(self): """Test case for trt_convert.TrtGraphConverter().""" if not is_tensorrt_enabled(): return - np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32) - np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input1, np_input2 = self._RandomInput([4, 1, 1]) # Create a model and save it. input_saved_model_dir = self.mkdtemp() @@ -360,7 +368,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): converter.convert() # Verify the converted GraphDef and ConcreteFunction. - self._CheckTrtOps(converter._converted_func) + self._CheckTrtOps(converter._converted_func) # pylint: disable=protected-access # Save the converted model without any TRT engine cache. output_saved_model_dir = self.mkdtemp() @@ -400,8 +408,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase): # self._CheckTrtOps(root_with_trt.run.get_concrete_function()) converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] self._CheckTrtOps(converted_signature) - output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1), - inp2=ops.convert_to_tensor(np_input2)) + output_with_trt = converted_signature( + inp1=ops.convert_to_tensor(np_input1), + inp2=ops.convert_to_tensor(np_input2)) # The output of running the converted signature is a dict due to # compatibility reasons with V1 SavedModel signature mechanism. output_with_trt = output_with_trt.values()[0] @@ -416,8 +425,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): if not is_tensorrt_enabled(): return - np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32) - np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input1, np_input2 = self._RandomInput([4, 1, 1]) # Create a model and save it. input_saved_model_dir = self.mkdtemp() @@ -433,8 +441,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): def _CheckFn(node): self.assertTrue(len(node.attr["serialized_segment"].s), node.name) - # Verify the converted GraphDef and ConcreteFunction. - self._CheckTrtOps(converter._converted_func, _CheckFn) + # Verify the converted GraphDef. + self._CheckTrtOps(converter._converted_func, _CheckFn) # pylint: disable=protected-access # Save the converted model with the statically-built engine inlined. output_saved_model_dir = self.mkdtemp() @@ -450,8 +458,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase): root_with_trt = load.load(output_saved_model_dir) converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] self._CheckTrtOps(converted_signature, _CheckFn) - output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1), - inp2=ops.convert_to_tensor(np_input2)) + output_with_trt = converted_signature( + inp1=ops.convert_to_tensor(np_input1), + inp2=ops.convert_to_tensor(np_input2)) # The output of running the converted signature is a dict due to # compatibility reasons with V1 SavedModel signature mechanism. self.assertAllClose( @@ -468,8 +477,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): if not is_tensorrt_enabled(): return - np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32) - np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input1, np_input2 = self._RandomInput([4, 1, 1]) # Create a model and save it. input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) @@ -480,15 +488,26 @@ class TrtConvertTest(test_util.TensorFlowTestCase): # Run TRT conversion. converter = self._CreateConverterV2( - input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.INT8) + input_saved_model_dir, + precision_mode=trt_convert.TrtPrecisionMode.INT8, + maximum_cached_engines=3) # Convert and perform INT8 calibration - def input_map_fn(): - return np_input1, np_input2 - converter.convert(num_calibration_runs=1, - calibration_input_map_fn=input_map_fn) + input_map_fn = lambda: (np_input1, np_input2) + converter.convert( + num_calibration_runs=2, calibration_input_map_fn=input_map_fn) - # Save the converted model again with serialized engine cache. + def _CheckFn(node): + self.assertTrue(len(node.attr["calibration_data"].s), node.name) + + # Verify the converted GraphDef. + self._CheckTrtOps(converter._converted_func, _CheckFn) # pylint: disable=protected-access + + # Build another engine with different batch size. + converter.build(*self._RandomInput([5, 1, 1])) + + # Save the converted model. + # TODO(laigd): check that it should contain two engines. output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) expected_asset_file = os.path.join( @@ -499,15 +518,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase): del converter gc.collect() # Force GC to destroy the TRT engine cache. - def _CheckFn(node): - self.assertTrue(len(node.attr["calibration_data"].s), node.name) - # Load and verify the converted model. root_with_trt = load.load(output_saved_model_dir) converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] self._CheckTrtOps(converted_signature, _CheckFn) - output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1), - inp2=ops.convert_to_tensor(np_input2)) + output_with_trt = converted_signature( + inp1=ops.convert_to_tensor(np_input1), + inp2=ops.convert_to_tensor(np_input2)) self.assertEqual(1, len(output_with_trt)) # The output of running the converted signature is a dict due to # compatibility reasons with V1 SavedModel signature mechanism. @@ -519,10 +536,11 @@ class TrtConvertTest(test_util.TensorFlowTestCase): # Run with an input of different batch size. It should build a new engine # using calibration table. - np_input1 = np.random.random_sample([5, 1, 1]).astype(np.float32) - np_input2 = np.random.random_sample([5, 1, 1]).astype(np.float32) - output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1), - inp2=ops.convert_to_tensor(np_input2)) + # TODO(laigd): check that it should contain three engines. + np_input1, np_input2 = self._RandomInput([6, 1, 1]) + converted_signature( + inp1=ops.convert_to_tensor(np_input1), + inp2=ops.convert_to_tensor(np_input2)) del root_with_trt gc.collect() # Force GC to destroy the TRT engine cache. @@ -533,8 +551,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): if not is_tensorrt_enabled(): return - np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32) - np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input1, np_input2 = self._RandomInput([4, 1, 1]) # Create a model and save it. input_saved_model_dir = self.mkdtemp() @@ -695,9 +712,12 @@ class TrtConvertTest(test_util.TensorFlowTestCase): self._CompareSavedModel(_Model) def _TestRun(self, sess, batch_size, expect_engine_is_run=True): - result = sess.run("output:0", - feed_dict={"input1:0": [[[1.0]]] * batch_size, - "input2:0": [[[1.0]]] * batch_size}) + result = sess.run( + "output:0", + feed_dict={ + "input1:0": [[[1.0]]] * batch_size, + "input2:0": [[[1.0]]] * batch_size + }) self.assertAllEqual([[[5.0]]] * batch_size, result) @test_util.deprecated_graph_mode_only