diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index ac4a331041d..98d199ca9ab 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/lib/statusor.h" #if GOOGLE_CUDA @@ -521,6 +522,17 @@ Status TRTEngineOp::VerifyInputShapes( return Status::OK(); } +static bool AllowEngineNativeSegmentExecution() { + bool value; + Status status = + ReadBoolFromEnvVar("TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION", + /*default_value=*/true, &value); + if (!status.ok()) { + LOG(ERROR) << status; + } + return value; +} + void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, AsyncOpKernel::DoneCallback done) { auto helper = new AsyncHelper(done); @@ -605,17 +617,31 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, EngineContext* engine_context = status.ValueOrDie().first; int trt_context_idx = status.ValueOrDie().second; + auto may_execute_native_segment = [&] { + if (!AllowEngineNativeSegmentExecution()) { + ctx->CtxFailure( + errors::Aborted("User disallowed engine native segment execution")); + return false; + } + return true; + }; if (!engine_context->cuda_engine) { - VLOG(1) << "Engine retrieval for input shapes: " - << TensorShapeUtils::ShapeListString(input_concrete_shapes) - << " failed. Running native segment for " << name(); - ExecuteNativeSegment(ctx, helper); + LOG_WARNING_WITH_PREFIX + << "Engine retrieval for input shapes: " + << TensorShapeUtils::ShapeListString(input_concrete_shapes) + << " failed. Running native segment for " << name(); + if (may_execute_native_segment()) { + ExecuteNativeSegment(ctx, helper); + } return; } Status stat = ExecuteTrtEngine(ctx, engine_context, trt_context_idx); if (!stat.ok()) { LOG_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat << " Retrying with native segment for " << name(); + if (!may_execute_native_segment()) { + return; + } // Release any outputs that are allocated, ExecuteNativeSegment will // re-allocate them and fail if they are currently allocated. for (int i = 0; i < ctx->num_outputs(); i++) { diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index df21e93f836..05ff6fcaebe 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -439,6 +439,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): self, input_saved_model_dir, input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, + max_workspace_size_bytes=10 << 20, # Use a smaller workspace. precision_mode=trt_convert.TrtPrecisionMode.FP32, is_dynamic_op=True, maximum_cached_engines=2): @@ -446,7 +447,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): input_saved_model_dir=input_saved_model_dir, input_saved_model_signature_key=input_saved_model_signature_key, conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( - max_workspace_size_bytes=10 << 20, # Use a smaller workspace. + max_workspace_size_bytes=max_workspace_size_bytes, precision_mode=precision_mode, is_dynamic_op=is_dynamic_op, maximum_cached_engines=maximum_cached_engines)) @@ -924,6 +925,36 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): # to fall back to TF function. self._TestRun(sess, 2) + @test_util.run_v2_only + def testTrtGraphConverter_AllowEngineNativeSegmentExecution(self): + if not is_tensorrt_enabled(): + return + + np_input1, np_input2 = self._RandomInput([4, 1, 1]) + + # Create a model and save it. + input_saved_model_dir = self.mkdtemp() + root = self._GetModelForV2() + save.save(root, input_saved_model_dir, + {_SAVED_MODEL_SIGNATURE_KEY: root.run}) + + def _InputFn(): + yield np_input1, np_input2 + + # Run TRT conversion and request an unreasonably large workspace. + converter = self._CreateConverterV2( + input_saved_model_dir, max_workspace_size_bytes=10 << 40) + converter.convert() + + os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False" + with self.assertRaisesRegex( + errors.AbortedError, + r"User disallowed engine native segment execution"): + converter.build(input_fn=_InputFn) + + os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" + converter.build(input_fn=_InputFn) + @test_util.run_v2_only def testBackwardCompatibility(self): """Load and execute a model that was saved in TF2.0."""