From 21de77485f9230eff364a4f0c8e0a55857965223 Mon Sep 17 00:00:00 2001
From: Bixia Zheng <bixia@google.com>
Date: Fri, 3 Apr 2020 17:25:57 -0700
Subject: [PATCH] Rollback changelist 304624399 with a fix to the failing test.

Remove the graph sequence number from TRTEngineOp node names before comparing
the node names with the expected names.

PiperOrigin-RevId: 304723970
Change-Id: If9c5508ad0655f31283945e7f3cf689a58f85431
---
 .../tf2tensorrt/convert/convert_graph.cc      |  9 +++-
 .../tf2tensorrt/convert/convert_graph_test.cc | 15 ++++--
 .../test/tf_trt_integration_test_base.py      | 35 ++++++++++++--
 .../compiler/tensorrt/trt_convert_test.py     | 48 ++++++++++++++++---
 4 files changed, 92 insertions(+), 15 deletions(-)

diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc
index c9d46251069..3e9a7954b03 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc
@@ -617,6 +617,11 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
   return std::make_pair(cuda_device_id, dev_allocator);
 }
 
+int64 GetNextGraphSequenceNumber() {
+  static std::atomic<int64> graph_sequence_num;
+  return graph_sequence_num++;
+}
+
 // Entry function from optimization pass.
 Status ConvertAfterShapes(const ConversionParams& params) {
   // Sanity checks.
@@ -666,10 +671,12 @@ Status ConvertAfterShapes(const ConversionParams& params) {
   std::vector<size_t> engine_bytes_size;
   segment::SegmentNodesVector converted_segments;
   converted_segments.reserve(initial_segments.size());
+  string engine_name_prefix =
+      StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_");
   for (size_t t = 0; t < initial_segments.size(); t++) {
     auto& curr_segment = initial_segments.at(t);
     EngineInfo curr_engine;
-    curr_engine.engine_name = StrCat("TRTEngineOp_", t);
+    curr_engine.engine_name = StrCat(engine_name_prefix, t);
     Status status =
         GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map,
                       reverse_topo_order, &curr_engine);
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc
index 1646749ad9c..2cfefd27a67 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc
@@ -15,6 +15,8 @@ limitations under the License.
 
 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
 
+#include <regex>  // NOLINT
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "tensorflow/cc/framework/ops.h"
@@ -203,15 +205,22 @@ TEST_F(ConvertAfterShapesTest, DirectlyConnectedEngines) {
   GraphDef output_graph_def;
   TF_EXPECT_OK(RunConvertAfterShape(s, &output_graph_def));
 
+  auto remove_graph_sequence_number = [](std::string node_name) {
+    const std::regex pattern("TRTEngineOp_[0-9]+_");
+    return std::regex_replace(node_name, pattern, "TRTEngineOp_");
+  };
   int num_trt_ops = 0;
   for (const NodeDef& node : output_graph_def.node()) {
-    if (node.name() == "TRTEngineOp_1") {
+    std::string node_name = node.name();
+    if (node.op() != "TRTEngineOp") continue;
+    node_name = remove_graph_sequence_number(node_name);
+    if (node_name == "TRTEngineOp_1") {
       EXPECT_EQ(1, node.input_size());
       EXPECT_EQ("input", node.input(0));
       ++num_trt_ops;
-    } else if (node.name() == "TRTEngineOp_0") {
+    } else if (node_name == "TRTEngineOp_0") {
       EXPECT_EQ(2, node.input_size());
-      EXPECT_EQ("TRTEngineOp_1", node.input(0));
+      EXPECT_EQ("TRTEngineOp_1", remove_graph_sequence_number(node.input(0)));
       EXPECT_EQ("reshape2", node.input(1));
       ++num_trt_ops;
     }
diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
index 3245a100265..773061d57a7 100644
--- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
@@ -522,6 +522,25 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
     logging.info("Writing graph to %s/%s", temp_dir, graph_name)
     graph_io.write_graph(gdef, temp_dir, graph_name)
 
+  # Remove the graph sequence number prefix from the name only if the name has
+  # a prefix TRTEngineOp_n_. When expecting_prefix is true, assert such a
+  # prefix exists.
+  def _RemoveGraphSequenceNumberImpl(self, name, expecting_prefix):
+    match = re.search(r"TRTEngineOp_\d+_", name)
+    has_prefix = match and name.startswith(match.group(0))
+    assert (not expecting_prefix) or has_prefix
+    if has_prefix:
+      parts = name.split("_", maxsplit=2)
+      assert len(parts) == 3
+      return parts[0] + "_" + parts[2]
+    return name
+
+  def _RemoveGraphSequenceNumber(self, name):
+    return self._RemoveGraphSequenceNumberImpl(name, True)
+
+  def _MayRemoveGraphSequenceNumber(self, name):
+    return self._RemoveGraphSequenceNumberImpl(name, False)
+
   def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef):
     old_to_new_node_map = {
         self._ToString(node.name): self._ToString(node.name)
@@ -579,11 +598,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
     # Compute the actual mapping from each node to its input nodes.
     actual_input_map = {}
     for node in converted_gdef.node:
-      name_str = self._ToString(node.name)
+      name_str = node.name
+      if node.op == "TRTEngineOp":
+        name_str = self._RemoveGraphSequenceNumber(name_str)
       actual_input_map[name_str] = set()
       input_set = actual_input_map[name_str]
       for inp in node.input:
         (prefix, node_name) = _InputName(inp)
+        node_name = self._MayRemoveGraphSequenceNumber(node_name)
         input_set.add(prefix + node_name)
 
     self.assertEqual(
@@ -628,7 +650,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
         self.assertIn(function_name, functions)
         if not IsQuantizationWithCalibration and not is_dynamic_engine:
           self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
-        self.assertIn(node.name, expected_engines)
+        self.assertIn(
+            self._RemoveGraphSequenceNumber(node.name), expected_engines)
         self.assertEqual(
             self._ToBytes(run_params.precision_mode),
             node.attr["precision_mode"].s, node.name)
@@ -662,7 +685,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
         node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp"
     ]
     for func in gdef_to_verify.library.function:
-      if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name):
+      if not re.search(r"TRTEngineOp_\d+_\d+_native_segment",
+                       func.signature.name):
         for node in func.node_def:
           all_op_names.append(node.name)
           if node.op == "TRTEngineOp":
@@ -670,9 +694,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
     # Remove the function name prefix.
     def _Canonicalize(names):
       return set(self._ToString(name.split("/")[-1]) for name in names)
+    # Remove the graph sequence number prefix from all the names.
+    def _RemoveGraphSequenceNumber(names):
+      return set(self._RemoveGraphSequenceNumber(name) for name in names)
 
     all_op_names = _Canonicalize(all_op_names)
-    trt_op_names = _Canonicalize(trt_op_names)
+    trt_op_names = _RemoveGraphSequenceNumber(_Canonicalize(trt_op_names))
 
     if isinstance(expected_engines, dict):
       # For simplicity we don't verify the connections inside the engine in
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
index fbe312fc4d6..df21e93f836 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 import gc
 import os
+import re
 import tempfile
 
 from absl.testing import parameterized
@@ -310,6 +311,24 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       converter.save(output_saved_model_dir=output_saved_model_dir)
     return output_graph_def
 
+  # Remove the graph sequence number prefix from the name only if the name has
+  # a prefix TRTEngineOp_n_.
+  def _MayRemoveGraphSequenceNumber(self, name):
+    prefix = re.search(r"TRTEngineOp_\d+_", name)
+    if prefix and name.startswith(prefix.group(0)):
+      parts = name.split("_", maxsplit=2)
+      assert len(parts) == 3
+      return parts[0] + "_" + parts[2]
+    return name
+
+  # Return the unique TRTEngineOp in the given graph def.
+  def _GetUniqueTRTEngineOp(self, graph_def):
+    trt_engine_nodes = [
+        node for node in graph_def.node if node.op == "TRTEngineOp"
+    ]
+    assert len(trt_engine_nodes) == 1
+    return trt_engine_nodes[0]
+
   def _TestTrtGraphConverter(self,
                              device,
                              output_saved_model_dir=None,
@@ -330,7 +349,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       graph_defs_to_verify.append(saved_model_graph_def)
 
     for graph_def in graph_defs_to_verify:
-      node_name_to_op = {node.name: node.op for node in graph_def.node}
+      node_name_to_op = {
+          self._MayRemoveGraphSequenceNumber(node.name): node.op
+          for node in graph_def.node
+      }
       self.assertEqual(
           {
               "input1": "Placeholder",
@@ -434,13 +456,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     trt_op_names = []
     for node in graph_def.node:
       if node.op == "TRTEngineOp":
-        trt_op_names.append(node.name)
+        trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
         if check_fn:
           check_fn(node)
     for func in graph_def.library.function:
       for node in func.node_def:
         if node.op == "TRTEngineOp":
-          trt_op_names.append(node.name)
+          trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
           if check_fn:
             check_fn(node)
     self.assertEqual(1, len(trt_op_names))
@@ -473,11 +495,15 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     # Verify the converted GraphDef and ConcreteFunction.
     self._CheckTrtOps(converter._converted_func)  # pylint: disable=protected-access
 
+    trt_engine_name = self._GetUniqueTRTEngineOp(
+        converter._converted_graph_def).name
+
     # Save the converted model without any TRT engine cache.
     output_saved_model_dir = self.mkdtemp()
     converter.save(output_saved_model_dir)
     unexpected_asset_file = os.path.join(
-        output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
+        output_saved_model_dir,
+        "assets/trt-serialized-engine." + trt_engine_name)
     self.assertFalse(os.path.exists(unexpected_asset_file))
 
     # Run the converted function to populate the engine cache.
@@ -490,7 +516,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     output_saved_model_dir = self.mkdtemp()
     converter.save(output_saved_model_dir)
     expected_asset_file = os.path.join(
-        output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
+        output_saved_model_dir,
+        "assets/trt-serialized-engine." + trt_engine_name)
     self.assertTrue(os.path.exists(expected_asset_file))
     self.assertTrue(os.path.getsize(expected_asset_file))
 
@@ -566,6 +593,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 
     converter.convert(calibration_input_fn=_CalibrationInputFn)
 
+    trt_engine_name = self._GetUniqueTRTEngineOp(
+        converter._converted_graph_def).name
+
     def _CheckFn(node):
       self.assertTrue(len(node.attr["calibration_data"].s), node.name)
 
@@ -583,7 +613,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     output_saved_model_dir = self.mkdtemp()
     converter.save(output_saved_model_dir)
     expected_asset_file = os.path.join(
-        output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
+        output_saved_model_dir,
+        "assets/trt-serialized-engine." + trt_engine_name)
     self.assertTrue(os.path.exists(expected_asset_file))
     self.assertTrue(os.path.getsize(expected_asset_file))
 
@@ -635,6 +666,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     converter = self._CreateConverterV2(input_saved_model_dir)
     converter.convert()
 
+    trt_engine_name = self._GetUniqueTRTEngineOp(
+        converter._converted_graph_def).name
+
     def _InputFn():
       yield np_input1, np_input2
 
@@ -645,7 +679,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     def _DestroyCache():
       with ops.device("GPU:0"):
         handle = gen_trt_ops.create_trt_resource_handle(
-            resource_name="TRTEngineOp_0")
+            resource_name=trt_engine_name)
         gen_resource_variable_ops.destroy_resource_op(
             handle, ignore_lookup_error=False)