diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc index 83a16892816..4154ffe4e2a 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc @@ -48,12 +48,10 @@ class GetCalibrationDataOp : public OpKernel { &resource)); core::ScopedUnref sc(resource); - auto* calib_ctx = resource->calib_ctx_.get(); - // Serialize the resource as output. - string serialized_resource; - OP_REQUIRES_OK(context, calib_ctx->SerializeToString(&serialized_resource)); - resource->calib_ctx_.reset(); + string serialized_resource = resource->calib_ctx_->TerminateCalibration(); + OP_REQUIRES(context, !serialized_resource.empty(), + errors::Unknown("Calibration table is empty.")); Tensor* output = nullptr; OP_REQUIRES_OK(context, diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index fb9e257b8af..646a44f1405 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -837,19 +837,18 @@ Status TRTEngineOp::AllocateCalibrationResources( if (!s.ok()) { LOG(ERROR) << "Calibration failed: " << s; cres->calibrator_->setDone(); // Ignore further pushes + } else { + // Transfer the ownership of the engine to the engine cache, so we can + // dump it out during conversion for TF 2.0. + mutex_lock lock(this->engine_mutex_); + this->calibrator_ = std::move(cres->calibrator_); + TrtUniquePtrType exec_context( + cres->engine_->createExecutionContext()); + cache_res->cache_.emplace( + shapes, absl::make_unique(std::move(cres->engine_), + std::move(exec_context))); } - // Transfer the ownership of the engine to the engine cache, so we can - // dump it out during conversion for TF 2.0. - mutex_lock lock(this->engine_mutex_); - cres->SetCalibrationTable(); - this->calibrator_ = std::move(cres->calibrator_); - TrtUniquePtrType exec_context( - cres->engine_->createExecutionContext()); - cache_res->cache_.emplace( - shapes, absl::make_unique(std::move(cres->engine_), - std::move(exec_context))); - VLOG(1) << "Calibration loop terminated " << this->name(); })); VLOG(1) << "initialized calibrator resource"; diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc index e28dcc1cbba..c25281b8645 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc @@ -184,13 +184,7 @@ class SerializeTRTResource : public OpKernel { core::ScopedUnref unref_me(resource); // Terminate the calibration if any. - if (resource->calib_ctx_) { - // We don't save the calibration_table for TF 2.0 at the moment, it's used - // in 1.x environment. - string calibration_table; - OP_REQUIRES_OK( - ctx, resource->calib_ctx_->SerializeToString(&calibration_table)); - } + if (resource->calib_ctx_) resource->calib_ctx_->TerminateCalibration(); // Serialize the engines and write them to file. std::unique_ptr file; diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc index f9306d563d7..5ab6bf1a317 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc @@ -30,6 +30,26 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +string CalibrationContext::TerminateCalibration() { + mutex_lock l(mu_); + if (terminated_) return calibration_table_; + + TRTInt8Calibrator* raw_calibrator = calibrator_.get(); + raw_calibrator->waitAndSetDone(); + terminated_ = true; + + // At this point the calibration thread `thr_` is woken up and can + // transfer the ownership of `calibrator_` and `engine_` at any time, so + // it's not safe to use `calibrator_` below, but we can still access it + // using raw pointer. + // TODO(laigd): make TRTEngineOp::AllocateCalibrationResources() a member + // function of this class instead. + + thr_->join(); + calibration_table_ = raw_calibrator->getCalibrationTableAsString(); + return calibration_table_; +} + const absl::string_view kTfTrtContainerName = "TF-TRT"; Logger& TRTEngineCacheResource::GetLogger() { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h index 9c29d56d6da..8d603ac4d55 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -142,19 +142,7 @@ struct EngineContext { // Contains the context required to build the calibration data. class CalibrationContext { public: - void SetCalibrationTable() { - calibration_table_ = calibrator_->getCalibrationTableAsString(); - } - - Status SerializeToString(string* serialized) { - calibrator_->waitAndSetDone(); - thr_->join(); - *serialized = calibration_table_; - if (serialized->empty()) { - return errors::Unknown("Calibration table is empty."); - } - return Status::OK(); - } + string TerminateCalibration(); // Lookup table for temporary staging areas of input tensors for calibration. std::unordered_map> device_buffers_; @@ -162,12 +150,16 @@ class CalibrationContext { // Temporary staging areas for calibration inputs. std::vector device_tensors_; - string calibration_table_; std::unique_ptr calibrator_; TrtUniquePtrType builder_; TrtUniquePtrType engine_; // TODO(sami): Use threadpool threads! std::unique_ptr thr_; + + private: + mutex mu_; + bool terminated_ GUARDED_BY(mu_) = false; + std::string calibration_table_ GUARDED_BY(mu_); }; ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName; diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 2db71858409..5a8100a1a80 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -100,6 +100,7 @@ class TrtPrecisionMode(object): ] return precisions + [p.lower() for p in precisions] + # Use a large enough number as the default max_workspace_size for TRT engines, # so it can produce reasonable performance results with the default. DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 @@ -263,8 +264,7 @@ def get_tensorrt_rewriter_config( "maximum_cached_engines"].i = conversion_params.maximum_cached_engines optimizer.parameter_map[ "use_calibration"].b = conversion_params.use_calibration - optimizer.parameter_map[ - "max_batch_size"].i = conversion_params.max_batch_size + optimizer.parameter_map["max_batch_size"].i = conversion_params.max_batch_size optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op return rewriter_config_with_trt @@ -955,14 +955,20 @@ class TrtGraphConverterV2(object): engine_asset_dir = tempfile.mkdtemp() resource_map = {} - def _serialize_and_track_engine(canonical_engine_name): + def _serialize_and_track_engine(node): """Serialize TRT engines in the cache and track them.""" # Don't dump the same cache twice. + canonical_engine_name = _get_canonical_engine_name(node.name) if canonical_engine_name in resource_map: return filename = os.path.join(engine_asset_dir, "trt-serialized-engine." + canonical_engine_name) + if self._need_calibration: + calibration_table = gen_trt_ops.get_calibration_data_op( + canonical_engine_name) + node.attr["calibration_data"].s = calibration_table.numpy() + try: gen_trt_ops.serialize_trt_resource( resource_name=canonical_engine_name, @@ -980,19 +986,28 @@ class TrtGraphConverterV2(object): for node in self._converted_graph_def.node: if node.op == _TRT_ENGINE_OP_NAME: - _serialize_and_track_engine(_get_canonical_engine_name(node.name)) + _serialize_and_track_engine(node) for func in self._converted_graph_def.library.function: for node in func.node_def: if node.op == _TRT_ENGINE_OP_NAME: - _serialize_and_track_engine(canonical_engine_name(node)) + _serialize_and_track_engine(node) self._saved_model.trt_engine_resources = resource_map + # Rebuild the function since calibration may change the graph. + func_to_save = wrap_function.function_from_graph_def( + self._converted_graph_def, + [tensor.name for tensor in self._converted_func.inputs], + [tensor.name for tensor in self._converted_func.outputs]) + func_to_save.graph.structured_outputs = nest.pack_sequence_as( + self._converted_func.graph.structured_outputs, + func_to_save.graph.structured_outputs) + # Rewrite the signature map using the optimized ConcreteFunction. signatures = { key: value for key, value in self._saved_model.signatures.items() } - signatures[self._input_saved_model_signature_key] = self._converted_func + signatures[self._input_saved_model_signature_key] = func_to_save save.save(self._saved_model, output_saved_model_dir, signatures) diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index 72bee53bf97..eb06fccffff 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -322,6 +322,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase): maximum_cached_engines=2, max_batch_size=max_batch_size if max_batch_size else 1)) + def _CheckTrtOps(self, concrete_func, check_fn=None): + graph_def = concrete_func.graph.as_graph_def() + trt_op_names = [] + for node in graph_def.node: + if node.op == "TRTEngineOp": + trt_op_names.append(node.name) + if check_fn: + check_fn(node) + for func in graph_def.library.function: + for node in func.node_def: + if node.op == "TRTEngineOp": + trt_op_names.append(node.name) + if check_fn: + check_fn(node) + self.assertEqual(1, len(trt_op_names)) + self.assertIn("TRTEngineOp_0", trt_op_names[0]) + @test_util.run_v2_only def testTrtGraphConverter_DynamicConversion_v2(self): """Test case for trt_convert.TrtGraphConverter().""" @@ -341,22 +358,11 @@ class TrtConvertTest(test_util.TensorFlowTestCase): converter = self._CreateConverterV2(input_saved_model_dir) converted_func = converter.convert() - def _CheckTrtOps(graph_def): - trt_op_names = [ - node.name for node in graph_def.node if node.op == "TRTEngineOp" - ] - for func in graph_def.library.function: - for node in func.node_def: - if node.op == "TRTEngineOp": - trt_op_names.append(node.name) - self.assertEqual(1, len(trt_op_names)) - self.assertIn("TRTEngineOp_0", trt_op_names[0]) - # Verify the converted GraphDef and ConcreteFunction. self.assertIsInstance(converted_func, def_function.Function) converted_concrete_func = converted_func.get_concrete_function( tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)) - _CheckTrtOps(converted_concrete_func.graph.as_graph_def()) + self._CheckTrtOps(converted_concrete_func) # Save the converted model without any TRT engine cache. output_saved_model_dir = self.mkdtemp() @@ -382,6 +388,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase): self.assertTrue(os.path.exists(expected_asset_file)) self.assertTrue(os.path.getsize(expected_asset_file)) + del converter + gc.collect() # Force GC to destroy the TRT engine cache. + # Load and verify the converted model. # # TODO(laigd): the name of the new input_signature of the @@ -390,16 +399,18 @@ class TrtConvertTest(test_util.TensorFlowTestCase): root_with_trt = load.load(output_saved_model_dir) # TODO(laigd): `root_with_trt.run` is still using the original graph without # trt. Consider changing that. - # _CheckTrtOps( - # root_with_trt.run.get_concrete_function().graph.as_graph_def()) + # self._CheckTrtOps(root_with_trt.run.get_concrete_function()) converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] - _CheckTrtOps(converted_signature.graph.as_graph_def()) + self._CheckTrtOps(converted_signature) output_with_trt = converted_signature(ops.convert_to_tensor(np_input)) # The output of running the converted signature is a dict due to # compatibility reasons with V1 SavedModel signature mechanism. output_with_trt = output_with_trt.values()[0] self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6) + del root_with_trt + gc.collect() # Force GC to destroy the TRT engine cache. + @test_util.run_v2_only def testTrtGraphConverter_StaticConversion_v2(self): """Test case for trt_convert.TrtGraphConverter() using static mode.""" @@ -419,23 +430,14 @@ class TrtConvertTest(test_util.TensorFlowTestCase): converter = self._CreateConverterV2(input_saved_model_dir, max_batch_size=4) converted_func = converter.convert() - def _CheckTrtOps(graph_def): - trt_op_names = [ - node.name for node in graph_def.node if node.op == "TRTEngineOp" - ] - for func in graph_def.library.function: - for node in func.node_def: - if node.op == "TRTEngineOp": - trt_op_names.append(node.name) - self.assertTrue(len(node.attr["serialized_segment"].s), node.name) - self.assertEqual(1, len(trt_op_names)) - self.assertIn("TRTEngineOp_0", trt_op_names[0]) + def _CheckFn(node): + self.assertTrue(len(node.attr["serialized_segment"].s), node.name) # Verify the converted GraphDef and ConcreteFunction. self.assertIsInstance(converted_func, def_function.Function) converted_concrete_func = converted_func.get_concrete_function( tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)) - _CheckTrtOps(converted_concrete_func.graph.as_graph_def()) + self._CheckTrtOps(converted_concrete_func, _CheckFn) # Save the converted model with the statically-built engine inlined. output_saved_model_dir = self.mkdtemp() @@ -444,10 +446,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase): output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0") self.assertFalse(os.path.exists(unexpected_asset_file)) + del converter + gc.collect() # Force GC to destroy the TRT engine cache. + # Load and verify the converted model. root_with_trt = load.load(output_saved_model_dir) converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] - _CheckTrtOps(converted_signature.graph.as_graph_def()) + self._CheckTrtOps(converted_signature, _CheckFn) output_with_trt = converted_signature(ops.convert_to_tensor(np_input)) # The output of running the converted signature is a dict due to # compatibility reasons with V1 SavedModel signature mechanism. @@ -457,6 +462,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase): atol=1e-6, rtol=1e-6) + del root_with_trt + gc.collect() # Force GC to destroy the TRT engine cache. + @test_util.run_v2_only def testTrtGraphConverter_Int8Conversion_v2(self): if not is_tensorrt_enabled(): @@ -493,12 +501,18 @@ class TrtConvertTest(test_util.TensorFlowTestCase): self.assertTrue(os.path.exists(expected_asset_file)) self.assertTrue(os.path.getsize(expected_asset_file)) + del converter + gc.collect() # Force GC to destroy the TRT engine cache. + + def _CheckFn(node): + self.assertTrue(len(node.attr["calibration_data"].s), node.name) + # Load and verify the converted model. root_with_trt = load.load(output_saved_model_dir) converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] + self._CheckTrtOps(converted_signature, _CheckFn) output_with_trt = converted_signature(ops.convert_to_tensor(np_input)) self.assertEqual(1, len(output_with_trt)) - # The output of running the converted signature is a dict due to # compatibility reasons with V1 SavedModel signature mechanism. self.assertAllClose( @@ -507,6 +521,14 @@ class TrtConvertTest(test_util.TensorFlowTestCase): atol=1e-6, rtol=1e-6) + # Run with an input of different batch size. It should build a new engine + # using calibration table. + np_input = np.random.random_sample([5, 1, 1]).astype(np.float32) + converted_signature(ops.convert_to_tensor(np_input)) + + del root_with_trt + gc.collect() # Force GC to destroy the TRT engine cache. + @test_util.run_v2_only def testTrtGraphConverter_DestroyEngineCache(self): """Test case for trt_convert.TrtGraphConverter()."""