From 4455956ce866acfa3ae997cb1c2a396e240dcbda Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 2 Apr 2020 16:23:58 -0700 Subject: [PATCH] Make TRTEngineOp node names unique. Add a unique graph sequence number to TRTEngineOp node names to avoid name collision. Since the TRTEngineOp node names are used as the cache keys for the resource cache objects for the operation, this can avoid mapping two different TRTEngineOp nodes to the same cache objects. Fix affected tests. PiperOrigin-RevId: 304500590 Change-Id: Ibea1e71d57a8a4f16d3710cf176b4ae443aa3815 --- .../tf2tensorrt/convert/convert_graph.cc | 9 +++- .../test/tf_trt_integration_test_base.py | 35 ++++++++++++-- .../compiler/tensorrt/trt_convert_test.py | 48 ++++++++++++++++--- 3 files changed, 80 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index c9d46251069..3e9a7954b03 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -617,6 +617,11 @@ std::pair GetDeviceAndAllocator(const ConversionParams& params, return std::make_pair(cuda_device_id, dev_allocator); } +int64 GetNextGraphSequenceNumber() { + static std::atomic graph_sequence_num; + return graph_sequence_num++; +} + // Entry function from optimization pass. Status ConvertAfterShapes(const ConversionParams& params) { // Sanity checks. @@ -666,10 +671,12 @@ Status ConvertAfterShapes(const ConversionParams& params) { std::vector engine_bytes_size; segment::SegmentNodesVector converted_segments; converted_segments.reserve(initial_segments.size()); + string engine_name_prefix = + StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_"); for (size_t t = 0; t < initial_segments.size(); t++) { auto& curr_segment = initial_segments.at(t); EngineInfo curr_engine; - curr_engine.engine_name = StrCat("TRTEngineOp_", t); + curr_engine.engine_name = StrCat(engine_name_prefix, t); Status status = GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map, reverse_topo_order, &curr_engine); diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 3245a100265..773061d57a7 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -522,6 +522,25 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): logging.info("Writing graph to %s/%s", temp_dir, graph_name) graph_io.write_graph(gdef, temp_dir, graph_name) + # Remove the graph sequence number prefix from the name only if the name has + # a prefix TRTEngineOp_n_. When expecting_prefix is true, assert such a + # prefix exists. + def _RemoveGraphSequenceNumberImpl(self, name, expecting_prefix): + match = re.search(r"TRTEngineOp_\d+_", name) + has_prefix = match and name.startswith(match.group(0)) + assert (not expecting_prefix) or has_prefix + if has_prefix: + parts = name.split("_", maxsplit=2) + assert len(parts) == 3 + return parts[0] + "_" + parts[2] + return name + + def _RemoveGraphSequenceNumber(self, name): + return self._RemoveGraphSequenceNumberImpl(name, True) + + def _MayRemoveGraphSequenceNumber(self, name): + return self._RemoveGraphSequenceNumberImpl(name, False) + def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef): old_to_new_node_map = { self._ToString(node.name): self._ToString(node.name) @@ -579,11 +598,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): # Compute the actual mapping from each node to its input nodes. actual_input_map = {} for node in converted_gdef.node: - name_str = self._ToString(node.name) + name_str = node.name + if node.op == "TRTEngineOp": + name_str = self._RemoveGraphSequenceNumber(name_str) actual_input_map[name_str] = set() input_set = actual_input_map[name_str] for inp in node.input: (prefix, node_name) = _InputName(inp) + node_name = self._MayRemoveGraphSequenceNumber(node_name) input_set.add(prefix + node_name) self.assertEqual( @@ -628,7 +650,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): self.assertIn(function_name, functions) if not IsQuantizationWithCalibration and not is_dynamic_engine: self.assertTrue(len(node.attr["serialized_segment"].s), node.name) - self.assertIn(node.name, expected_engines) + self.assertIn( + self._RemoveGraphSequenceNumber(node.name), expected_engines) self.assertEqual( self._ToBytes(run_params.precision_mode), node.attr["precision_mode"].s, node.name) @@ -662,7 +685,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp" ] for func in gdef_to_verify.library.function: - if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name): + if not re.search(r"TRTEngineOp_\d+_\d+_native_segment", + func.signature.name): for node in func.node_def: all_op_names.append(node.name) if node.op == "TRTEngineOp": @@ -670,9 +694,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): # Remove the function name prefix. def _Canonicalize(names): return set(self._ToString(name.split("/")[-1]) for name in names) + # Remove the graph sequence number prefix from all the names. + def _RemoveGraphSequenceNumber(names): + return set(self._RemoveGraphSequenceNumber(name) for name in names) all_op_names = _Canonicalize(all_op_names) - trt_op_names = _Canonicalize(trt_op_names) + trt_op_names = _RemoveGraphSequenceNumber(_Canonicalize(trt_op_names)) if isinstance(expected_engines, dict): # For simplicity we don't verify the connections inside the engine in diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index fbe312fc4d6..df21e93f836 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import gc import os +import re import tempfile from absl.testing import parameterized @@ -310,6 +311,24 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): converter.save(output_saved_model_dir=output_saved_model_dir) return output_graph_def + # Remove the graph sequence number prefix from the name only if the name has + # a prefix TRTEngineOp_n_. + def _MayRemoveGraphSequenceNumber(self, name): + prefix = re.search(r"TRTEngineOp_\d+_", name) + if prefix and name.startswith(prefix.group(0)): + parts = name.split("_", maxsplit=2) + assert len(parts) == 3 + return parts[0] + "_" + parts[2] + return name + + # Return the unique TRTEngineOp in the given graph def. + def _GetUniqueTRTEngineOp(self, graph_def): + trt_engine_nodes = [ + node for node in graph_def.node if node.op == "TRTEngineOp" + ] + assert len(trt_engine_nodes) == 1 + return trt_engine_nodes[0] + def _TestTrtGraphConverter(self, device, output_saved_model_dir=None, @@ -330,7 +349,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): graph_defs_to_verify.append(saved_model_graph_def) for graph_def in graph_defs_to_verify: - node_name_to_op = {node.name: node.op for node in graph_def.node} + node_name_to_op = { + self._MayRemoveGraphSequenceNumber(node.name): node.op + for node in graph_def.node + } self.assertEqual( { "input1": "Placeholder", @@ -434,13 +456,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): trt_op_names = [] for node in graph_def.node: if node.op == "TRTEngineOp": - trt_op_names.append(node.name) + trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name)) if check_fn: check_fn(node) for func in graph_def.library.function: for node in func.node_def: if node.op == "TRTEngineOp": - trt_op_names.append(node.name) + trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name)) if check_fn: check_fn(node) self.assertEqual(1, len(trt_op_names)) @@ -473,11 +495,15 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): # Verify the converted GraphDef and ConcreteFunction. self._CheckTrtOps(converter._converted_func) # pylint: disable=protected-access + trt_engine_name = self._GetUniqueTRTEngineOp( + converter._converted_graph_def).name + # Save the converted model without any TRT engine cache. output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) unexpected_asset_file = os.path.join( - output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0") + output_saved_model_dir, + "assets/trt-serialized-engine." + trt_engine_name) self.assertFalse(os.path.exists(unexpected_asset_file)) # Run the converted function to populate the engine cache. @@ -490,7 +516,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) expected_asset_file = os.path.join( - output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0") + output_saved_model_dir, + "assets/trt-serialized-engine." + trt_engine_name) self.assertTrue(os.path.exists(expected_asset_file)) self.assertTrue(os.path.getsize(expected_asset_file)) @@ -566,6 +593,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): converter.convert(calibration_input_fn=_CalibrationInputFn) + trt_engine_name = self._GetUniqueTRTEngineOp( + converter._converted_graph_def).name + def _CheckFn(node): self.assertTrue(len(node.attr["calibration_data"].s), node.name) @@ -583,7 +613,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) expected_asset_file = os.path.join( - output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0") + output_saved_model_dir, + "assets/trt-serialized-engine." + trt_engine_name) self.assertTrue(os.path.exists(expected_asset_file)) self.assertTrue(os.path.getsize(expected_asset_file)) @@ -635,6 +666,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): converter = self._CreateConverterV2(input_saved_model_dir) converter.convert() + trt_engine_name = self._GetUniqueTRTEngineOp( + converter._converted_graph_def).name + def _InputFn(): yield np_input1, np_input2 @@ -645,7 +679,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): def _DestroyCache(): with ops.device("GPU:0"): handle = gen_trt_ops.create_trt_resource_handle( - resource_name="TRTEngineOp_0") + resource_name=trt_engine_name) gen_resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=False)