Rollback changelist 304624399 with a fix to the failing test.

Remove the graph sequence number from TRTEngineOp node names before comparing
the node names with the expected names.

PiperOrigin-RevId: 304723970
Change-Id: If9c5508ad0655f31283945e7f3cf689a58f85431
This commit is contained in:
Bixia Zheng 2020-04-03 17:25:57 -07:00 committed by TensorFlower Gardener
parent fc2d7fdacb
commit 21de77485f
4 changed files with 92 additions and 15 deletions

View File

@ -617,6 +617,11 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
return std::make_pair(cuda_device_id, dev_allocator);
}
int64 GetNextGraphSequenceNumber() {
static std::atomic<int64> graph_sequence_num;
return graph_sequence_num++;
}
// Entry function from optimization pass.
Status ConvertAfterShapes(const ConversionParams& params) {
// Sanity checks.
@ -666,10 +671,12 @@ Status ConvertAfterShapes(const ConversionParams& params) {
std::vector<size_t> engine_bytes_size;
segment::SegmentNodesVector converted_segments;
converted_segments.reserve(initial_segments.size());
string engine_name_prefix =
StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_");
for (size_t t = 0; t < initial_segments.size(); t++) {
auto& curr_segment = initial_segments.at(t);
EngineInfo curr_engine;
curr_engine.engine_name = StrCat("TRTEngineOp_", t);
curr_engine.engine_name = StrCat(engine_name_prefix, t);
Status status =
GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map,
reverse_topo_order, &curr_engine);

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
#include <regex> // NOLINT
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/cc/framework/ops.h"
@ -203,15 +205,22 @@ TEST_F(ConvertAfterShapesTest, DirectlyConnectedEngines) {
GraphDef output_graph_def;
TF_EXPECT_OK(RunConvertAfterShape(s, &output_graph_def));
auto remove_graph_sequence_number = [](std::string node_name) {
const std::regex pattern("TRTEngineOp_[0-9]+_");
return std::regex_replace(node_name, pattern, "TRTEngineOp_");
};
int num_trt_ops = 0;
for (const NodeDef& node : output_graph_def.node()) {
if (node.name() == "TRTEngineOp_1") {
std::string node_name = node.name();
if (node.op() != "TRTEngineOp") continue;
node_name = remove_graph_sequence_number(node_name);
if (node_name == "TRTEngineOp_1") {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("input", node.input(0));
++num_trt_ops;
} else if (node.name() == "TRTEngineOp_0") {
} else if (node_name == "TRTEngineOp_0") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("TRTEngineOp_1", node.input(0));
EXPECT_EQ("TRTEngineOp_1", remove_graph_sequence_number(node.input(0)));
EXPECT_EQ("reshape2", node.input(1));
++num_trt_ops;
}

View File

@ -522,6 +522,25 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
logging.info("Writing graph to %s/%s", temp_dir, graph_name)
graph_io.write_graph(gdef, temp_dir, graph_name)
# Remove the graph sequence number prefix from the name only if the name has
# a prefix TRTEngineOp_n_. When expecting_prefix is true, assert such a
# prefix exists.
def _RemoveGraphSequenceNumberImpl(self, name, expecting_prefix):
match = re.search(r"TRTEngineOp_\d+_", name)
has_prefix = match and name.startswith(match.group(0))
assert (not expecting_prefix) or has_prefix
if has_prefix:
parts = name.split("_", maxsplit=2)
assert len(parts) == 3
return parts[0] + "_" + parts[2]
return name
def _RemoveGraphSequenceNumber(self, name):
return self._RemoveGraphSequenceNumberImpl(name, True)
def _MayRemoveGraphSequenceNumber(self, name):
return self._RemoveGraphSequenceNumberImpl(name, False)
def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef):
old_to_new_node_map = {
self._ToString(node.name): self._ToString(node.name)
@ -579,11 +598,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
# Compute the actual mapping from each node to its input nodes.
actual_input_map = {}
for node in converted_gdef.node:
name_str = self._ToString(node.name)
name_str = node.name
if node.op == "TRTEngineOp":
name_str = self._RemoveGraphSequenceNumber(name_str)
actual_input_map[name_str] = set()
input_set = actual_input_map[name_str]
for inp in node.input:
(prefix, node_name) = _InputName(inp)
node_name = self._MayRemoveGraphSequenceNumber(node_name)
input_set.add(prefix + node_name)
self.assertEqual(
@ -628,7 +650,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
self.assertIn(function_name, functions)
if not IsQuantizationWithCalibration and not is_dynamic_engine:
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
self.assertIn(node.name, expected_engines)
self.assertIn(
self._RemoveGraphSequenceNumber(node.name), expected_engines)
self.assertEqual(
self._ToBytes(run_params.precision_mode),
node.attr["precision_mode"].s, node.name)
@ -662,7 +685,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp"
]
for func in gdef_to_verify.library.function:
if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name):
if not re.search(r"TRTEngineOp_\d+_\d+_native_segment",
func.signature.name):
for node in func.node_def:
all_op_names.append(node.name)
if node.op == "TRTEngineOp":
@ -670,9 +694,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
# Remove the function name prefix.
def _Canonicalize(names):
return set(self._ToString(name.split("/")[-1]) for name in names)
# Remove the graph sequence number prefix from all the names.
def _RemoveGraphSequenceNumber(names):
return set(self._RemoveGraphSequenceNumber(name) for name in names)
all_op_names = _Canonicalize(all_op_names)
trt_op_names = _Canonicalize(trt_op_names)
trt_op_names = _RemoveGraphSequenceNumber(_Canonicalize(trt_op_names))
if isinstance(expected_engines, dict):
# For simplicity we don't verify the connections inside the engine in

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import gc
import os
import re
import tempfile
from absl.testing import parameterized
@ -310,6 +311,24 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
converter.save(output_saved_model_dir=output_saved_model_dir)
return output_graph_def
# Remove the graph sequence number prefix from the name only if the name has
# a prefix TRTEngineOp_n_.
def _MayRemoveGraphSequenceNumber(self, name):
prefix = re.search(r"TRTEngineOp_\d+_", name)
if prefix and name.startswith(prefix.group(0)):
parts = name.split("_", maxsplit=2)
assert len(parts) == 3
return parts[0] + "_" + parts[2]
return name
# Return the unique TRTEngineOp in the given graph def.
def _GetUniqueTRTEngineOp(self, graph_def):
trt_engine_nodes = [
node for node in graph_def.node if node.op == "TRTEngineOp"
]
assert len(trt_engine_nodes) == 1
return trt_engine_nodes[0]
def _TestTrtGraphConverter(self,
device,
output_saved_model_dir=None,
@ -330,7 +349,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
graph_defs_to_verify.append(saved_model_graph_def)
for graph_def in graph_defs_to_verify:
node_name_to_op = {node.name: node.op for node in graph_def.node}
node_name_to_op = {
self._MayRemoveGraphSequenceNumber(node.name): node.op
for node in graph_def.node
}
self.assertEqual(
{
"input1": "Placeholder",
@ -434,13 +456,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
trt_op_names = []
for node in graph_def.node:
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
if check_fn:
check_fn(node)
for func in graph_def.library.function:
for node in func.node_def:
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
if check_fn:
check_fn(node)
self.assertEqual(1, len(trt_op_names))
@ -473,11 +495,15 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
# Verify the converted GraphDef and ConcreteFunction.
self._CheckTrtOps(converter._converted_func) # pylint: disable=protected-access
trt_engine_name = self._GetUniqueTRTEngineOp(
converter._converted_graph_def).name
# Save the converted model without any TRT engine cache.
output_saved_model_dir = self.mkdtemp()
converter.save(output_saved_model_dir)
unexpected_asset_file = os.path.join(
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
output_saved_model_dir,
"assets/trt-serialized-engine." + trt_engine_name)
self.assertFalse(os.path.exists(unexpected_asset_file))
# Run the converted function to populate the engine cache.
@ -490,7 +516,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
output_saved_model_dir = self.mkdtemp()
converter.save(output_saved_model_dir)
expected_asset_file = os.path.join(
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
output_saved_model_dir,
"assets/trt-serialized-engine." + trt_engine_name)
self.assertTrue(os.path.exists(expected_asset_file))
self.assertTrue(os.path.getsize(expected_asset_file))
@ -566,6 +593,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
converter.convert(calibration_input_fn=_CalibrationInputFn)
trt_engine_name = self._GetUniqueTRTEngineOp(
converter._converted_graph_def).name
def _CheckFn(node):
self.assertTrue(len(node.attr["calibration_data"].s), node.name)
@ -583,7 +613,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
output_saved_model_dir = self.mkdtemp()
converter.save(output_saved_model_dir)
expected_asset_file = os.path.join(
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
output_saved_model_dir,
"assets/trt-serialized-engine." + trt_engine_name)
self.assertTrue(os.path.exists(expected_asset_file))
self.assertTrue(os.path.getsize(expected_asset_file))
@ -635,6 +666,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
converter = self._CreateConverterV2(input_saved_model_dir)
converter.convert()
trt_engine_name = self._GetUniqueTRTEngineOp(
converter._converted_graph_def).name
def _InputFn():
yield np_input1, np_input2
@ -645,7 +679,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _DestroyCache():
with ops.device("GPU:0"):
handle = gen_trt_ops.create_trt_resource_handle(
resource_name="TRTEngineOp_0")
resource_name=trt_engine_name)
gen_resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=False)