diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index c9d46251069..3e9a7954b03 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -617,6 +617,11 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params, return std::make_pair(cuda_device_id, dev_allocator); } +int64 GetNextGraphSequenceNumber() { + static std::atomic<int64> graph_sequence_num; + return graph_sequence_num++; +} + // Entry function from optimization pass. Status ConvertAfterShapes(const ConversionParams& params) { // Sanity checks. @@ -666,10 +671,12 @@ Status ConvertAfterShapes(const ConversionParams& params) { std::vector<size_t> engine_bytes_size; segment::SegmentNodesVector converted_segments; converted_segments.reserve(initial_segments.size()); + string engine_name_prefix = + StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_"); for (size_t t = 0; t < initial_segments.size(); t++) { auto& curr_segment = initial_segments.at(t); EngineInfo curr_engine; - curr_engine.engine_name = StrCat("TRTEngineOp_", t); + curr_engine.engine_name = StrCat(engine_name_prefix, t); Status status = GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map, reverse_topo_order, &curr_engine); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 1646749ad9c..2cfefd27a67 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" +#include <regex> // NOLINT + #include <gmock/gmock.h> #include <gtest/gtest.h> #include "tensorflow/cc/framework/ops.h" @@ -203,15 +205,22 @@ TEST_F(ConvertAfterShapesTest, DirectlyConnectedEngines) { GraphDef output_graph_def; TF_EXPECT_OK(RunConvertAfterShape(s, &output_graph_def)); + auto remove_graph_sequence_number = [](std::string node_name) { + const std::regex pattern("TRTEngineOp_[0-9]+_"); + return std::regex_replace(node_name, pattern, "TRTEngineOp_"); + }; int num_trt_ops = 0; for (const NodeDef& node : output_graph_def.node()) { - if (node.name() == "TRTEngineOp_1") { + std::string node_name = node.name(); + if (node.op() != "TRTEngineOp") continue; + node_name = remove_graph_sequence_number(node_name); + if (node_name == "TRTEngineOp_1") { EXPECT_EQ(1, node.input_size()); EXPECT_EQ("input", node.input(0)); ++num_trt_ops; - } else if (node.name() == "TRTEngineOp_0") { + } else if (node_name == "TRTEngineOp_0") { EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("TRTEngineOp_1", node.input(0)); + EXPECT_EQ("TRTEngineOp_1", remove_graph_sequence_number(node.input(0))); EXPECT_EQ("reshape2", node.input(1)); ++num_trt_ops; } diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 3245a100265..773061d57a7 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -522,6 +522,25 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): logging.info("Writing graph to %s/%s", temp_dir, graph_name) graph_io.write_graph(gdef, temp_dir, graph_name) + # Remove the graph sequence number prefix from the name only if the name has + # a prefix TRTEngineOp_n_. When expecting_prefix is true, assert such a + # prefix exists. + def _RemoveGraphSequenceNumberImpl(self, name, expecting_prefix): + match = re.search(r"TRTEngineOp_\d+_", name) + has_prefix = match and name.startswith(match.group(0)) + assert (not expecting_prefix) or has_prefix + if has_prefix: + parts = name.split("_", maxsplit=2) + assert len(parts) == 3 + return parts[0] + "_" + parts[2] + return name + + def _RemoveGraphSequenceNumber(self, name): + return self._RemoveGraphSequenceNumberImpl(name, True) + + def _MayRemoveGraphSequenceNumber(self, name): + return self._RemoveGraphSequenceNumberImpl(name, False) + def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef): old_to_new_node_map = { self._ToString(node.name): self._ToString(node.name) @@ -579,11 +598,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): # Compute the actual mapping from each node to its input nodes. actual_input_map = {} for node in converted_gdef.node: - name_str = self._ToString(node.name) + name_str = node.name + if node.op == "TRTEngineOp": + name_str = self._RemoveGraphSequenceNumber(name_str) actual_input_map[name_str] = set() input_set = actual_input_map[name_str] for inp in node.input: (prefix, node_name) = _InputName(inp) + node_name = self._MayRemoveGraphSequenceNumber(node_name) input_set.add(prefix + node_name) self.assertEqual( @@ -628,7 +650,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): self.assertIn(function_name, functions) if not IsQuantizationWithCalibration and not is_dynamic_engine: self.assertTrue(len(node.attr["serialized_segment"].s), node.name) - self.assertIn(node.name, expected_engines) + self.assertIn( + self._RemoveGraphSequenceNumber(node.name), expected_engines) self.assertEqual( self._ToBytes(run_params.precision_mode), node.attr["precision_mode"].s, node.name) @@ -662,7 +685,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp" ] for func in gdef_to_verify.library.function: - if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name): + if not re.search(r"TRTEngineOp_\d+_\d+_native_segment", + func.signature.name): for node in func.node_def: all_op_names.append(node.name) if node.op == "TRTEngineOp": @@ -670,9 +694,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): # Remove the function name prefix. def _Canonicalize(names): return set(self._ToString(name.split("/")[-1]) for name in names) + # Remove the graph sequence number prefix from all the names. + def _RemoveGraphSequenceNumber(names): + return set(self._RemoveGraphSequenceNumber(name) for name in names) all_op_names = _Canonicalize(all_op_names) - trt_op_names = _Canonicalize(trt_op_names) + trt_op_names = _RemoveGraphSequenceNumber(_Canonicalize(trt_op_names)) if isinstance(expected_engines, dict): # For simplicity we don't verify the connections inside the engine in diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index fbe312fc4d6..df21e93f836 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import gc import os +import re import tempfile from absl.testing import parameterized @@ -310,6 +311,24 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): converter.save(output_saved_model_dir=output_saved_model_dir) return output_graph_def + # Remove the graph sequence number prefix from the name only if the name has + # a prefix TRTEngineOp_n_. + def _MayRemoveGraphSequenceNumber(self, name): + prefix = re.search(r"TRTEngineOp_\d+_", name) + if prefix and name.startswith(prefix.group(0)): + parts = name.split("_", maxsplit=2) + assert len(parts) == 3 + return parts[0] + "_" + parts[2] + return name + + # Return the unique TRTEngineOp in the given graph def. + def _GetUniqueTRTEngineOp(self, graph_def): + trt_engine_nodes = [ + node for node in graph_def.node if node.op == "TRTEngineOp" + ] + assert len(trt_engine_nodes) == 1 + return trt_engine_nodes[0] + def _TestTrtGraphConverter(self, device, output_saved_model_dir=None, @@ -330,7 +349,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): graph_defs_to_verify.append(saved_model_graph_def) for graph_def in graph_defs_to_verify: - node_name_to_op = {node.name: node.op for node in graph_def.node} + node_name_to_op = { + self._MayRemoveGraphSequenceNumber(node.name): node.op + for node in graph_def.node + } self.assertEqual( { "input1": "Placeholder", @@ -434,13 +456,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): trt_op_names = [] for node in graph_def.node: if node.op == "TRTEngineOp": - trt_op_names.append(node.name) + trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name)) if check_fn: check_fn(node) for func in graph_def.library.function: for node in func.node_def: if node.op == "TRTEngineOp": - trt_op_names.append(node.name) + trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name)) if check_fn: check_fn(node) self.assertEqual(1, len(trt_op_names)) @@ -473,11 +495,15 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): # Verify the converted GraphDef and ConcreteFunction. self._CheckTrtOps(converter._converted_func) # pylint: disable=protected-access + trt_engine_name = self._GetUniqueTRTEngineOp( + converter._converted_graph_def).name + # Save the converted model without any TRT engine cache. output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) unexpected_asset_file = os.path.join( - output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0") + output_saved_model_dir, + "assets/trt-serialized-engine." + trt_engine_name) self.assertFalse(os.path.exists(unexpected_asset_file)) # Run the converted function to populate the engine cache. @@ -490,7 +516,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) expected_asset_file = os.path.join( - output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0") + output_saved_model_dir, + "assets/trt-serialized-engine." + trt_engine_name) self.assertTrue(os.path.exists(expected_asset_file)) self.assertTrue(os.path.getsize(expected_asset_file)) @@ -566,6 +593,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): converter.convert(calibration_input_fn=_CalibrationInputFn) + trt_engine_name = self._GetUniqueTRTEngineOp( + converter._converted_graph_def).name + def _CheckFn(node): self.assertTrue(len(node.attr["calibration_data"].s), node.name) @@ -583,7 +613,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) expected_asset_file = os.path.join( - output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0") + output_saved_model_dir, + "assets/trt-serialized-engine." + trt_engine_name) self.assertTrue(os.path.exists(expected_asset_file)) self.assertTrue(os.path.getsize(expected_asset_file)) @@ -635,6 +666,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): converter = self._CreateConverterV2(input_saved_model_dir) converter.convert() + trt_engine_name = self._GetUniqueTRTEngineOp( + converter._converted_graph_def).name + def _InputFn(): yield np_input1, np_input2 @@ -645,7 +679,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): def _DestroyCache(): with ops.device("GPU:0"): handle = gen_trt_ops.create_trt_resource_handle( - resource_name="TRTEngineOp_0") + resource_name=trt_engine_name) gen_resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=False)