Rollback changelist 304624399 with a fix to the failing test.
Remove the graph sequence number from TRTEngineOp node names before comparing the node names with the expected names. PiperOrigin-RevId: 304723970 Change-Id: If9c5508ad0655f31283945e7f3cf689a58f85431
This commit is contained in:
parent
fc2d7fdacb
commit
21de77485f
@ -617,6 +617,11 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
|
||||
return std::make_pair(cuda_device_id, dev_allocator);
|
||||
}
|
||||
|
||||
int64 GetNextGraphSequenceNumber() {
|
||||
static std::atomic<int64> graph_sequence_num;
|
||||
return graph_sequence_num++;
|
||||
}
|
||||
|
||||
// Entry function from optimization pass.
|
||||
Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
// Sanity checks.
|
||||
@ -666,10 +671,12 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
std::vector<size_t> engine_bytes_size;
|
||||
segment::SegmentNodesVector converted_segments;
|
||||
converted_segments.reserve(initial_segments.size());
|
||||
string engine_name_prefix =
|
||||
StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_");
|
||||
for (size_t t = 0; t < initial_segments.size(); t++) {
|
||||
auto& curr_segment = initial_segments.at(t);
|
||||
EngineInfo curr_engine;
|
||||
curr_engine.engine_name = StrCat("TRTEngineOp_", t);
|
||||
curr_engine.engine_name = StrCat(engine_name_prefix, t);
|
||||
Status status =
|
||||
GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map,
|
||||
reverse_topo_order, &curr_engine);
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
|
||||
|
||||
#include <regex> // NOLINT
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
@ -203,15 +205,22 @@ TEST_F(ConvertAfterShapesTest, DirectlyConnectedEngines) {
|
||||
GraphDef output_graph_def;
|
||||
TF_EXPECT_OK(RunConvertAfterShape(s, &output_graph_def));
|
||||
|
||||
auto remove_graph_sequence_number = [](std::string node_name) {
|
||||
const std::regex pattern("TRTEngineOp_[0-9]+_");
|
||||
return std::regex_replace(node_name, pattern, "TRTEngineOp_");
|
||||
};
|
||||
int num_trt_ops = 0;
|
||||
for (const NodeDef& node : output_graph_def.node()) {
|
||||
if (node.name() == "TRTEngineOp_1") {
|
||||
std::string node_name = node.name();
|
||||
if (node.op() != "TRTEngineOp") continue;
|
||||
node_name = remove_graph_sequence_number(node_name);
|
||||
if (node_name == "TRTEngineOp_1") {
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("input", node.input(0));
|
||||
++num_trt_ops;
|
||||
} else if (node.name() == "TRTEngineOp_0") {
|
||||
} else if (node_name == "TRTEngineOp_0") {
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("TRTEngineOp_1", node.input(0));
|
||||
EXPECT_EQ("TRTEngineOp_1", remove_graph_sequence_number(node.input(0)));
|
||||
EXPECT_EQ("reshape2", node.input(1));
|
||||
++num_trt_ops;
|
||||
}
|
||||
|
@ -522,6 +522,25 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
logging.info("Writing graph to %s/%s", temp_dir, graph_name)
|
||||
graph_io.write_graph(gdef, temp_dir, graph_name)
|
||||
|
||||
# Remove the graph sequence number prefix from the name only if the name has
|
||||
# a prefix TRTEngineOp_n_. When expecting_prefix is true, assert such a
|
||||
# prefix exists.
|
||||
def _RemoveGraphSequenceNumberImpl(self, name, expecting_prefix):
|
||||
match = re.search(r"TRTEngineOp_\d+_", name)
|
||||
has_prefix = match and name.startswith(match.group(0))
|
||||
assert (not expecting_prefix) or has_prefix
|
||||
if has_prefix:
|
||||
parts = name.split("_", maxsplit=2)
|
||||
assert len(parts) == 3
|
||||
return parts[0] + "_" + parts[2]
|
||||
return name
|
||||
|
||||
def _RemoveGraphSequenceNumber(self, name):
|
||||
return self._RemoveGraphSequenceNumberImpl(name, True)
|
||||
|
||||
def _MayRemoveGraphSequenceNumber(self, name):
|
||||
return self._RemoveGraphSequenceNumberImpl(name, False)
|
||||
|
||||
def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef):
|
||||
old_to_new_node_map = {
|
||||
self._ToString(node.name): self._ToString(node.name)
|
||||
@ -579,11 +598,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
# Compute the actual mapping from each node to its input nodes.
|
||||
actual_input_map = {}
|
||||
for node in converted_gdef.node:
|
||||
name_str = self._ToString(node.name)
|
||||
name_str = node.name
|
||||
if node.op == "TRTEngineOp":
|
||||
name_str = self._RemoveGraphSequenceNumber(name_str)
|
||||
actual_input_map[name_str] = set()
|
||||
input_set = actual_input_map[name_str]
|
||||
for inp in node.input:
|
||||
(prefix, node_name) = _InputName(inp)
|
||||
node_name = self._MayRemoveGraphSequenceNumber(node_name)
|
||||
input_set.add(prefix + node_name)
|
||||
|
||||
self.assertEqual(
|
||||
@ -628,7 +650,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
self.assertIn(function_name, functions)
|
||||
if not IsQuantizationWithCalibration and not is_dynamic_engine:
|
||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||
self.assertIn(node.name, expected_engines)
|
||||
self.assertIn(
|
||||
self._RemoveGraphSequenceNumber(node.name), expected_engines)
|
||||
self.assertEqual(
|
||||
self._ToBytes(run_params.precision_mode),
|
||||
node.attr["precision_mode"].s, node.name)
|
||||
@ -662,7 +685,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp"
|
||||
]
|
||||
for func in gdef_to_verify.library.function:
|
||||
if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name):
|
||||
if not re.search(r"TRTEngineOp_\d+_\d+_native_segment",
|
||||
func.signature.name):
|
||||
for node in func.node_def:
|
||||
all_op_names.append(node.name)
|
||||
if node.op == "TRTEngineOp":
|
||||
@ -670,9 +694,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
# Remove the function name prefix.
|
||||
def _Canonicalize(names):
|
||||
return set(self._ToString(name.split("/")[-1]) for name in names)
|
||||
# Remove the graph sequence number prefix from all the names.
|
||||
def _RemoveGraphSequenceNumber(names):
|
||||
return set(self._RemoveGraphSequenceNumber(name) for name in names)
|
||||
|
||||
all_op_names = _Canonicalize(all_op_names)
|
||||
trt_op_names = _Canonicalize(trt_op_names)
|
||||
trt_op_names = _RemoveGraphSequenceNumber(_Canonicalize(trt_op_names))
|
||||
|
||||
if isinstance(expected_engines, dict):
|
||||
# For simplicity we don't verify the connections inside the engine in
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -310,6 +311,24 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
converter.save(output_saved_model_dir=output_saved_model_dir)
|
||||
return output_graph_def
|
||||
|
||||
# Remove the graph sequence number prefix from the name only if the name has
|
||||
# a prefix TRTEngineOp_n_.
|
||||
def _MayRemoveGraphSequenceNumber(self, name):
|
||||
prefix = re.search(r"TRTEngineOp_\d+_", name)
|
||||
if prefix and name.startswith(prefix.group(0)):
|
||||
parts = name.split("_", maxsplit=2)
|
||||
assert len(parts) == 3
|
||||
return parts[0] + "_" + parts[2]
|
||||
return name
|
||||
|
||||
# Return the unique TRTEngineOp in the given graph def.
|
||||
def _GetUniqueTRTEngineOp(self, graph_def):
|
||||
trt_engine_nodes = [
|
||||
node for node in graph_def.node if node.op == "TRTEngineOp"
|
||||
]
|
||||
assert len(trt_engine_nodes) == 1
|
||||
return trt_engine_nodes[0]
|
||||
|
||||
def _TestTrtGraphConverter(self,
|
||||
device,
|
||||
output_saved_model_dir=None,
|
||||
@ -330,7 +349,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
graph_defs_to_verify.append(saved_model_graph_def)
|
||||
|
||||
for graph_def in graph_defs_to_verify:
|
||||
node_name_to_op = {node.name: node.op for node in graph_def.node}
|
||||
node_name_to_op = {
|
||||
self._MayRemoveGraphSequenceNumber(node.name): node.op
|
||||
for node in graph_def.node
|
||||
}
|
||||
self.assertEqual(
|
||||
{
|
||||
"input1": "Placeholder",
|
||||
@ -434,13 +456,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
trt_op_names = []
|
||||
for node in graph_def.node:
|
||||
if node.op == "TRTEngineOp":
|
||||
trt_op_names.append(node.name)
|
||||
trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
|
||||
if check_fn:
|
||||
check_fn(node)
|
||||
for func in graph_def.library.function:
|
||||
for node in func.node_def:
|
||||
if node.op == "TRTEngineOp":
|
||||
trt_op_names.append(node.name)
|
||||
trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
|
||||
if check_fn:
|
||||
check_fn(node)
|
||||
self.assertEqual(1, len(trt_op_names))
|
||||
@ -473,11 +495,15 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
# Verify the converted GraphDef and ConcreteFunction.
|
||||
self._CheckTrtOps(converter._converted_func) # pylint: disable=protected-access
|
||||
|
||||
trt_engine_name = self._GetUniqueTRTEngineOp(
|
||||
converter._converted_graph_def).name
|
||||
|
||||
# Save the converted model without any TRT engine cache.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
converter.save(output_saved_model_dir)
|
||||
unexpected_asset_file = os.path.join(
|
||||
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
|
||||
output_saved_model_dir,
|
||||
"assets/trt-serialized-engine." + trt_engine_name)
|
||||
self.assertFalse(os.path.exists(unexpected_asset_file))
|
||||
|
||||
# Run the converted function to populate the engine cache.
|
||||
@ -490,7 +516,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
converter.save(output_saved_model_dir)
|
||||
expected_asset_file = os.path.join(
|
||||
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
|
||||
output_saved_model_dir,
|
||||
"assets/trt-serialized-engine." + trt_engine_name)
|
||||
self.assertTrue(os.path.exists(expected_asset_file))
|
||||
self.assertTrue(os.path.getsize(expected_asset_file))
|
||||
|
||||
@ -566,6 +593,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
converter.convert(calibration_input_fn=_CalibrationInputFn)
|
||||
|
||||
trt_engine_name = self._GetUniqueTRTEngineOp(
|
||||
converter._converted_graph_def).name
|
||||
|
||||
def _CheckFn(node):
|
||||
self.assertTrue(len(node.attr["calibration_data"].s), node.name)
|
||||
|
||||
@ -583,7 +613,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
converter.save(output_saved_model_dir)
|
||||
expected_asset_file = os.path.join(
|
||||
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
|
||||
output_saved_model_dir,
|
||||
"assets/trt-serialized-engine." + trt_engine_name)
|
||||
self.assertTrue(os.path.exists(expected_asset_file))
|
||||
self.assertTrue(os.path.getsize(expected_asset_file))
|
||||
|
||||
@ -635,6 +666,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
converter = self._CreateConverterV2(input_saved_model_dir)
|
||||
converter.convert()
|
||||
|
||||
trt_engine_name = self._GetUniqueTRTEngineOp(
|
||||
converter._converted_graph_def).name
|
||||
|
||||
def _InputFn():
|
||||
yield np_input1, np_input2
|
||||
|
||||
@ -645,7 +679,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
def _DestroyCache():
|
||||
with ops.device("GPU:0"):
|
||||
handle = gen_trt_ops.create_trt_resource_handle(
|
||||
resource_name="TRTEngineOp_0")
|
||||
resource_name=trt_engine_name)
|
||||
gen_resource_variable_ops.destroy_resource_op(
|
||||
handle, ignore_lookup_error=False)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user