Make TRTEngineOp node names unique.

Add a unique graph sequence number to TRTEngineOp node names to avoid name
collision. Since the TRTEngineOp node names are used as the cache keys for the
resource cache objects for the operation, this can avoid mapping two different
TRTEngineOp nodes to the same cache objects.

Fix affected tests.

PiperOrigin-RevId: 304500590
Change-Id: Ibea1e71d57a8a4f16d3710cf176b4ae443aa3815
This commit is contained in:
Bixia Zheng 2020-04-02 16:23:58 -07:00 committed by TensorFlower Gardener
parent ff966217f2
commit 4455956ce8
3 changed files with 80 additions and 12 deletions

View File

@ -617,6 +617,11 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
return std::make_pair(cuda_device_id, dev_allocator);
}
int64 GetNextGraphSequenceNumber() {
static std::atomic<int64> graph_sequence_num;
return graph_sequence_num++;
}
// Entry function from optimization pass.
Status ConvertAfterShapes(const ConversionParams& params) {
// Sanity checks.
@ -666,10 +671,12 @@ Status ConvertAfterShapes(const ConversionParams& params) {
std::vector<size_t> engine_bytes_size;
segment::SegmentNodesVector converted_segments;
converted_segments.reserve(initial_segments.size());
string engine_name_prefix =
StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_");
for (size_t t = 0; t < initial_segments.size(); t++) {
auto& curr_segment = initial_segments.at(t);
EngineInfo curr_engine;
curr_engine.engine_name = StrCat("TRTEngineOp_", t);
curr_engine.engine_name = StrCat(engine_name_prefix, t);
Status status =
GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map,
reverse_topo_order, &curr_engine);

View File

@ -522,6 +522,25 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
logging.info("Writing graph to %s/%s", temp_dir, graph_name)
graph_io.write_graph(gdef, temp_dir, graph_name)
# Remove the graph sequence number prefix from the name only if the name has
# a prefix TRTEngineOp_n_. When expecting_prefix is true, assert such a
# prefix exists.
def _RemoveGraphSequenceNumberImpl(self, name, expecting_prefix):
match = re.search(r"TRTEngineOp_\d+_", name)
has_prefix = match and name.startswith(match.group(0))
assert (not expecting_prefix) or has_prefix
if has_prefix:
parts = name.split("_", maxsplit=2)
assert len(parts) == 3
return parts[0] + "_" + parts[2]
return name
def _RemoveGraphSequenceNumber(self, name):
return self._RemoveGraphSequenceNumberImpl(name, True)
def _MayRemoveGraphSequenceNumber(self, name):
return self._RemoveGraphSequenceNumberImpl(name, False)
def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef):
old_to_new_node_map = {
self._ToString(node.name): self._ToString(node.name)
@ -579,11 +598,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
# Compute the actual mapping from each node to its input nodes.
actual_input_map = {}
for node in converted_gdef.node:
name_str = self._ToString(node.name)
name_str = node.name
if node.op == "TRTEngineOp":
name_str = self._RemoveGraphSequenceNumber(name_str)
actual_input_map[name_str] = set()
input_set = actual_input_map[name_str]
for inp in node.input:
(prefix, node_name) = _InputName(inp)
node_name = self._MayRemoveGraphSequenceNumber(node_name)
input_set.add(prefix + node_name)
self.assertEqual(
@ -628,7 +650,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
self.assertIn(function_name, functions)
if not IsQuantizationWithCalibration and not is_dynamic_engine:
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
self.assertIn(node.name, expected_engines)
self.assertIn(
self._RemoveGraphSequenceNumber(node.name), expected_engines)
self.assertEqual(
self._ToBytes(run_params.precision_mode),
node.attr["precision_mode"].s, node.name)
@ -662,7 +685,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp"
]
for func in gdef_to_verify.library.function:
if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name):
if not re.search(r"TRTEngineOp_\d+_\d+_native_segment",
func.signature.name):
for node in func.node_def:
all_op_names.append(node.name)
if node.op == "TRTEngineOp":
@ -670,9 +694,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
# Remove the function name prefix.
def _Canonicalize(names):
return set(self._ToString(name.split("/")[-1]) for name in names)
# Remove the graph sequence number prefix from all the names.
def _RemoveGraphSequenceNumber(names):
return set(self._RemoveGraphSequenceNumber(name) for name in names)
all_op_names = _Canonicalize(all_op_names)
trt_op_names = _Canonicalize(trt_op_names)
trt_op_names = _RemoveGraphSequenceNumber(_Canonicalize(trt_op_names))
if isinstance(expected_engines, dict):
# For simplicity we don't verify the connections inside the engine in

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import gc
import os
import re
import tempfile
from absl.testing import parameterized
@ -310,6 +311,24 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
converter.save(output_saved_model_dir=output_saved_model_dir)
return output_graph_def
# Remove the graph sequence number prefix from the name only if the name has
# a prefix TRTEngineOp_n_.
def _MayRemoveGraphSequenceNumber(self, name):
prefix = re.search(r"TRTEngineOp_\d+_", name)
if prefix and name.startswith(prefix.group(0)):
parts = name.split("_", maxsplit=2)
assert len(parts) == 3
return parts[0] + "_" + parts[2]
return name
# Return the unique TRTEngineOp in the given graph def.
def _GetUniqueTRTEngineOp(self, graph_def):
trt_engine_nodes = [
node for node in graph_def.node if node.op == "TRTEngineOp"
]
assert len(trt_engine_nodes) == 1
return trt_engine_nodes[0]
def _TestTrtGraphConverter(self,
device,
output_saved_model_dir=None,
@ -330,7 +349,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
graph_defs_to_verify.append(saved_model_graph_def)
for graph_def in graph_defs_to_verify:
node_name_to_op = {node.name: node.op for node in graph_def.node}
node_name_to_op = {
self._MayRemoveGraphSequenceNumber(node.name): node.op
for node in graph_def.node
}
self.assertEqual(
{
"input1": "Placeholder",
@ -434,13 +456,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
trt_op_names = []
for node in graph_def.node:
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
if check_fn:
check_fn(node)
for func in graph_def.library.function:
for node in func.node_def:
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
if check_fn:
check_fn(node)
self.assertEqual(1, len(trt_op_names))
@ -473,11 +495,15 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
# Verify the converted GraphDef and ConcreteFunction.
self._CheckTrtOps(converter._converted_func) # pylint: disable=protected-access
trt_engine_name = self._GetUniqueTRTEngineOp(
converter._converted_graph_def).name
# Save the converted model without any TRT engine cache.
output_saved_model_dir = self.mkdtemp()
converter.save(output_saved_model_dir)
unexpected_asset_file = os.path.join(
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
output_saved_model_dir,
"assets/trt-serialized-engine." + trt_engine_name)
self.assertFalse(os.path.exists(unexpected_asset_file))
# Run the converted function to populate the engine cache.
@ -490,7 +516,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
output_saved_model_dir = self.mkdtemp()
converter.save(output_saved_model_dir)
expected_asset_file = os.path.join(
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
output_saved_model_dir,
"assets/trt-serialized-engine." + trt_engine_name)
self.assertTrue(os.path.exists(expected_asset_file))
self.assertTrue(os.path.getsize(expected_asset_file))
@ -566,6 +593,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
converter.convert(calibration_input_fn=_CalibrationInputFn)
trt_engine_name = self._GetUniqueTRTEngineOp(
converter._converted_graph_def).name
def _CheckFn(node):
self.assertTrue(len(node.attr["calibration_data"].s), node.name)
@ -583,7 +613,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
output_saved_model_dir = self.mkdtemp()
converter.save(output_saved_model_dir)
expected_asset_file = os.path.join(
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
output_saved_model_dir,
"assets/trt-serialized-engine." + trt_engine_name)
self.assertTrue(os.path.exists(expected_asset_file))
self.assertTrue(os.path.getsize(expected_asset_file))
@ -635,6 +666,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
converter = self._CreateConverterV2(input_saved_model_dir)
converter.convert()
trt_engine_name = self._GetUniqueTRTEngineOp(
converter._converted_graph_def).name
def _InputFn():
yield np_input1, np_input2
@ -645,7 +679,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _DestroyCache():
with ops.device("GPU:0"):
handle = gen_trt_ops.create_trt_resource_handle(
resource_name="TRTEngineOp_0")
resource_name=trt_engine_name)
gen_resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=False)