diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 5a8100a1a80..bb78163b93d 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -878,6 +878,7 @@ class TrtGraphConverterV2(object): if (self._need_calibration and not conversion_params.is_dynamic_op): raise ValueError("INT8 precision mode with calibration is not supported " "with static TensorRT ops. Set is_dynamic_op to True.") + self._calibration_data_collected = False self._converted = False @@ -900,13 +901,41 @@ class TrtGraphConverterV2(object): # TODO(laigd): provide a utility function to optimize a ConcreteFunction and # use it here (b/124792963). - def convert(self): + def convert(self, + num_calibration_runs=None, + calibration_input_map_fn=None): """Convert the input SavedModel in 2.0 format. + Args: + num_calibration_runs: number of runs of the graph during calibration. + calibration_input_map_fn: a function that returns inputs for the converted + tf_function to be calibrated. + Example: + ``` + def input_map_fn(): + return input1, input2, input3 + ``` + Raises: + ValueError: if the input combination is invalid. + Returns: The TF-TRT converted Function. """ assert not self._converted + + if (self._need_calibration and + (not num_calibration_runs or + not calibration_input_map_fn)): + raise ValueError( + "Should specify num_calibration_runs and calibration_input_map_fn" + "because INT8 calibration is needed") + if (not self._need_calibration and + (num_calibration_runs or + calibration_input_map_fn)): + raise ValueError( + "Should not specify num_calibration_runs or calibration_input_map_fn" + "because INT8 calibration is not needed") + self._saved_model = load.load(self._input_saved_model_dir, self._input_saved_model_tags) func = self._saved_model.signatures[self._input_saved_model_signature_key] @@ -934,13 +963,49 @@ class TrtGraphConverterV2(object): self._converted = True - # Wrap the converted ConcreteFunction in a Function so it can accept numpy - # arrays as input. - @def_function.function - def wrapper_func(*args, **kwargs): - return self._converted_func(*args, **kwargs) + if self._need_calibration and not self._calibration_data_collected: + self._calibrate(num_runs=num_calibration_runs, + input_map_fn=calibration_input_map_fn) - return wrapper_func + def _calibrate(self, + num_runs=None, + input_map_fn=None): + """Run calibration. + + Args: + num_runs: number of runs of the graph during calibration. + input_map_fn: a function that returns inputs for the converted + tf_function to be calibrated. + Example: + ``` + def input_map_fn(): + return input1, input2, input3 + ``` + """ + assert self._converted + assert self._need_calibration + assert num_runs + assert input_map_fn + + for _ in range(num_runs): + self._converted_func(*map(ops.convert_to_tensor, input_map_fn())) + + self._calibration_data_collected = True + + def build(self, *args, **kwargs): + """Run inference on graph in order to build a TensorRT engine + in the cahce of TRTEngineOp. + + Returns: + The output of the converted Function for the given inputs. + """ + args_tensor = [ops.convert_to_tensor(arg) for arg in args] + kwargs_tensor = {k: ops.convert_to_tensor(v) for k, v in kwargs.items()} + try: + return self._converted_func(*args_tensor, **kwargs_tensor) + except OpError: + print('Failure in execution of function with input args {}' + 'and kwargs {}'.format(args_tensor, kwargs_tensor)) def save(self, output_saved_model_dir): """Save the converted SavedModel. diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index eb06fccffff..5239f2a2a47 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -117,22 +117,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase): return config @classmethod - def _GetGraph(cls, inp, var): + def _GetGraph(cls, inp1, inp2, var): """Get the graph for testing.""" - # The graph computes (input+1)^2, it looks like: - # - # input (Placeholder) v1 (Variable) - # | \ / - # \ + - # \ / \ - # * | - # \ / - # + - # | - # output (Identity) - add = inp + var - mul = inp * add + # The graph computes: inp1^2 + inp1*var + inp1 + inp2 + var + add = inp1 + var + mul = inp1 * add add = mul + add + add = add + inp2 out = array_ops.identity(add, name="output") return out @@ -144,12 +135,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase): self.v = None @def_function.function(input_signature=[ + tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32), tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32) ]) - def run(self, inp): + def run(self, inp1, inp2): if self.v is None: self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32) - return TrtConvertTest._GetGraph(inp, self.v) + return TrtConvertTest._GetGraph(inp1, inp2, self.v) return SimpleModel() @@ -157,15 +149,17 @@ class TrtConvertTest(test_util.TensorFlowTestCase): g = ops.Graph() with g.as_default(): with g.device("/GPU:0"): - inp = array_ops.placeholder( - dtype=dtypes.float32, shape=[None, 1, 1], name="input") + inp1 = array_ops.placeholder( + dtype=dtypes.float32, shape=[None, 1, 1], name="input1") + inp2 = array_ops.placeholder( + dtype=dtypes.float32, shape=[None, 1, 1], name="input2") var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1") - out = TrtConvertTest._GetGraph(inp, var) - return g, var, inp, out + out = TrtConvertTest._GetGraph(inp1, inp2, var) + return g, var, inp1, inp2, out def _GetGraphDef(self): """Get the graph def for testing.""" - g, var, _, _ = self._GetGraphForV1() + g, var, _, _, _ = self._GetGraphForV1() with self.session(graph=g, config=self._GetConfigProto()) as sess: sess.run(var.initializer) graph_def = graph_util.convert_variables_to_constants( @@ -175,19 +169,22 @@ class TrtConvertTest(test_util.TensorFlowTestCase): { "v1": "Const", "add/ReadVariableOp": "Identity", - "input": "Placeholder", + "input1": "Placeholder", + "input2": "Placeholder", "add": "AddV2", "mul": "Mul", "add_1": "AddV2", + "add_2": "AddV2", "output": "Identity" }, node_name_to_op) return graph_def def _WriteInputSavedModel(self, input_saved_model_dir): """Write the saved model as an input for testing.""" - g, var, inp, out = self._GetGraphForV1() + g, var, inp1, inp2, out = self._GetGraphForV1() signature_def = signature_def_utils.build_signature_def( - inputs={"myinput": utils.build_tensor_info(inp)}, + inputs={"myinput1": utils.build_tensor_info(inp1), + "myinput2": utils.build_tensor_info(inp2)}, outputs={"myoutput": utils.build_tensor_info(out)}, method_name=signature_constants.PREDICT_METHOD_NAME) saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir) @@ -231,7 +228,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): def next(self): self._data += 1 - return {"input:0": [[[self._data]]]} + return {"input1:0": [[[self._data]]], + "input2:0": [[[self._data]]]} output_graph_def = converter.calibrate( fetch_names=["output:0"], @@ -265,7 +263,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): node_name_to_op = {node.name: node.op for node in graph_def.node} self.assertEqual( { - "input": "Placeholder", + "input1": "Placeholder", + "input2": "Placeholder", "TRTEngineOp_0": "TRTEngineOp", "output": "Identity" }, node_name_to_op) @@ -284,8 +283,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase): with self.session(config=self._GetConfigProto()) as sess: for test_data in range(10): self.assertEqual( - (test_data + 1.0)**2, - sess.run("output:0", feed_dict={"input:0": [[[test_data]]]})) + (test_data + 1.0)**2 + test_data, + sess.run("output:0", feed_dict={"input1:0": [[[test_data]]], + "input2:0": [[[test_data]]]})) @test_util.deprecated_graph_mode_only def testTrtGraphConverter_BasicConversion(self): @@ -345,24 +345,19 @@ class TrtConvertTest(test_util.TensorFlowTestCase): if not is_tensorrt_enabled(): return - np_input = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32) # Create a model and save it. input_saved_model_dir = self.mkdtemp() root = self._GetModelForV2() - expected_output = root.run(np_input) + expected_output = root.run(np_input1, np_input2) save.save(root, input_saved_model_dir, {_SAVED_MODEL_SIGNATURE_KEY: root.run}) # Run TRT conversion. converter = self._CreateConverterV2(input_saved_model_dir) - converted_func = converter.convert() - - # Verify the converted GraphDef and ConcreteFunction. - self.assertIsInstance(converted_func, def_function.Function) - converted_concrete_func = converted_func.get_concrete_function( - tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)) - self._CheckTrtOps(converted_concrete_func) + converter.convert() # Save the converted model without any TRT engine cache. output_saved_model_dir = self.mkdtemp() @@ -372,7 +367,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): self.assertFalse(os.path.exists(unexpected_asset_file)) # Run the converted function to populate the engine cache. - output_with_trt = converted_func(np_input) + output_with_trt = converter.build(np_input1, np_input2) self.assertEqual(1, len(output_with_trt)) self.assertAllClose( expected_output, @@ -402,7 +397,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): # self._CheckTrtOps(root_with_trt.run.get_concrete_function()) converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] self._CheckTrtOps(converted_signature) - output_with_trt = converted_signature(ops.convert_to_tensor(np_input)) + output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1), + inp2=ops.convert_to_tensor(np_input2)) # The output of running the converted signature is a dict due to # compatibility reasons with V1 SavedModel signature mechanism. output_with_trt = output_with_trt.values()[0] @@ -417,28 +413,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase): if not is_tensorrt_enabled(): return - np_input = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32) # Create a model and save it. input_saved_model_dir = self.mkdtemp() root = self._GetModelForV2() - expected_output = root.run(np_input) + expected_output = root.run(np_input1, np_input2) save.save(root, input_saved_model_dir, {_SAVED_MODEL_SIGNATURE_KEY: root.run}) # Run TRT conversion. converter = self._CreateConverterV2(input_saved_model_dir, max_batch_size=4) - converted_func = converter.convert() + converter.convert() def _CheckFn(node): self.assertTrue(len(node.attr["serialized_segment"].s), node.name) - # Verify the converted GraphDef and ConcreteFunction. - self.assertIsInstance(converted_func, def_function.Function) - converted_concrete_func = converted_func.get_concrete_function( - tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)) - self._CheckTrtOps(converted_concrete_func, _CheckFn) - # Save the converted model with the statically-built engine inlined. output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) @@ -453,7 +444,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): root_with_trt = load.load(output_saved_model_dir) converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] self._CheckTrtOps(converted_signature, _CheckFn) - output_with_trt = converted_signature(ops.convert_to_tensor(np_input)) + output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1), + inp2=ops.convert_to_tensor(np_input2)) # The output of running the converted signature is a dict due to # compatibility reasons with V1 SavedModel signature mechanism. self.assertAllClose( @@ -470,28 +462,25 @@ class TrtConvertTest(test_util.TensorFlowTestCase): if not is_tensorrt_enabled(): return - np_input = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32) # Create a model and save it. input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) root = self._GetModelForV2() - expected_output = root.run(np_input) + expected_output = root.run(np_input1, np_input2) save.save(root, input_saved_model_dir, {_SAVED_MODEL_SIGNATURE_KEY: root.run}) # Run TRT conversion. converter = self._CreateConverterV2( input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.INT8) - converted_func = converter.convert() - # Run the converted function for INT8 calibration. - calibration_output = converted_func(np_input) - self.assertEqual(1, len(calibration_output)) - self.assertAllClose( - expected_output, - list(calibration_output.values())[0], - atol=1e-6, - rtol=1e-6) + # Convert and perform INT8 calibration + def input_map_fn(): + return np_input1, np_input2 + converter.convert(num_calibration_runs=1, + calibration_input_map_fn=input_map_fn) # Save the converted model again with serialized engine cache. output_saved_model_dir = self.mkdtemp() @@ -511,7 +500,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): root_with_trt = load.load(output_saved_model_dir) converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] self._CheckTrtOps(converted_signature, _CheckFn) - output_with_trt = converted_signature(ops.convert_to_tensor(np_input)) + output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1), + inp2=ops.convert_to_tensor(np_input2)) self.assertEqual(1, len(output_with_trt)) # The output of running the converted signature is a dict due to # compatibility reasons with V1 SavedModel signature mechanism. @@ -523,8 +513,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase): # Run with an input of different batch size. It should build a new engine # using calibration table. - np_input = np.random.random_sample([5, 1, 1]).astype(np.float32) - converted_signature(ops.convert_to_tensor(np_input)) + np_input1 = np.random.random_sample([5, 1, 1]).astype(np.float32) + np_input2 = np.random.random_sample([5, 1, 1]).astype(np.float32) + output_with_trt = converted_signature(inp1=ops.convert_to_tensor(np_input1), + inp2=ops.convert_to_tensor(np_input2)) del root_with_trt gc.collect() # Force GC to destroy the TRT engine cache. @@ -535,7 +527,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): if not is_tensorrt_enabled(): return - np_input = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input1 = np.random.random_sample([4, 1, 1]).astype(np.float32) + np_input2 = np.random.random_sample([4, 1, 1]).astype(np.float32) # Create a model and save it. input_saved_model_dir = self.mkdtemp() @@ -545,8 +538,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): # Run TRT conversion. converter = self._CreateConverterV2(input_saved_model_dir) - converted_func = converter.convert() - converted_func(np_input) # Populate the TRT engine cache. + converter.convert() + converter.build(np_input1, np_input2) # Populate the TRT engine cache. output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) @@ -696,22 +689,26 @@ class TrtConvertTest(test_util.TensorFlowTestCase): self._CompareSavedModel(_Model) def _TestRun(self, sess, batch_size, expect_engine_is_run=True): - result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size}) - self.assertAllEqual([[[4.0]]] * batch_size, result) + result = sess.run("output:0", + feed_dict={"input1:0": [[[1.0]]] * batch_size, + "input2:0": [[[1.0]]] * batch_size}) + self.assertAllEqual([[[5.0]]] * batch_size, result) @test_util.deprecated_graph_mode_only def testTrtGraphConverter_MinimumSegmentSize(self): if not is_tensorrt_enabled(): return - output_graph_def = self._ConvertGraph(minimum_segment_size=5) + output_graph_def = self._ConvertGraph(minimum_segment_size=7) node_name_to_op = {node.name: node.op for node in output_graph_def.node} self.assertEqual( { "add/ReadVariableOp": "Const", - "input": "Placeholder", + "input1": "Placeholder", + "input2": "Placeholder", "add": "AddV2", "mul": "Mul", "add_1": "AddV2", + "add_2": "AddV2", "output": "Identity" }, node_name_to_op)