[TF:TRT] Add flag TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION.

The default value of the flag is True. When the flag value is false, the
bridge will report an error when the native segment of a TRTEngineOp is
executed.

Add test cases.

PiperOrigin-RevId: 317340632
Change-Id: Iacded09b38e63442bbd93076a079d385fb8a77e6
This commit is contained in:
Bixia Zheng 2020-06-19 11:04:49 -07:00 committed by TensorFlower Gardener
parent 8088eddf20
commit 9152edc1f0
2 changed files with 62 additions and 5 deletions

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA
@ -521,6 +522,17 @@ Status TRTEngineOp::VerifyInputShapes(
return Status::OK(); return Status::OK();
} }
static bool AllowEngineNativeSegmentExecution() {
bool value;
Status status =
ReadBoolFromEnvVar("TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION",
/*default_value=*/true, &value);
if (!status.ok()) {
LOG(ERROR) << status;
}
return value;
}
void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
AsyncOpKernel::DoneCallback done) { AsyncOpKernel::DoneCallback done) {
auto helper = new AsyncHelper(done); auto helper = new AsyncHelper(done);
@ -605,17 +617,31 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
EngineContext* engine_context = status.ValueOrDie().first; EngineContext* engine_context = status.ValueOrDie().first;
int trt_context_idx = status.ValueOrDie().second; int trt_context_idx = status.ValueOrDie().second;
auto may_execute_native_segment = [&] {
if (!AllowEngineNativeSegmentExecution()) {
ctx->CtxFailure(
errors::Aborted("User disallowed engine native segment execution"));
return false;
}
return true;
};
if (!engine_context->cuda_engine) { if (!engine_context->cuda_engine) {
VLOG(1) << "Engine retrieval for input shapes: " LOG_WARNING_WITH_PREFIX
<< "Engine retrieval for input shapes: "
<< TensorShapeUtils::ShapeListString(input_concrete_shapes) << TensorShapeUtils::ShapeListString(input_concrete_shapes)
<< " failed. Running native segment for " << name(); << " failed. Running native segment for " << name();
if (may_execute_native_segment()) {
ExecuteNativeSegment(ctx, helper); ExecuteNativeSegment(ctx, helper);
}
return; return;
} }
Status stat = ExecuteTrtEngine(ctx, engine_context, trt_context_idx); Status stat = ExecuteTrtEngine(ctx, engine_context, trt_context_idx);
if (!stat.ok()) { if (!stat.ok()) {
LOG_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat LOG_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat
<< " Retrying with native segment for " << name(); << " Retrying with native segment for " << name();
if (!may_execute_native_segment()) {
return;
}
// Release any outputs that are allocated, ExecuteNativeSegment will // Release any outputs that are allocated, ExecuteNativeSegment will
// re-allocate them and fail if they are currently allocated. // re-allocate them and fail if they are currently allocated.
for (int i = 0; i < ctx->num_outputs(); i++) { for (int i = 0; i < ctx->num_outputs(); i++) {

View File

@ -439,6 +439,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self, self,
input_saved_model_dir, input_saved_model_dir,
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
max_workspace_size_bytes=10 << 20, # Use a smaller workspace.
precision_mode=trt_convert.TrtPrecisionMode.FP32, precision_mode=trt_convert.TrtPrecisionMode.FP32,
is_dynamic_op=True, is_dynamic_op=True,
maximum_cached_engines=2): maximum_cached_engines=2):
@ -446,7 +447,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
input_saved_model_dir=input_saved_model_dir, input_saved_model_dir=input_saved_model_dir,
input_saved_model_signature_key=input_saved_model_signature_key, input_saved_model_signature_key=input_saved_model_signature_key,
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
max_workspace_size_bytes=10 << 20, # Use a smaller workspace. max_workspace_size_bytes=max_workspace_size_bytes,
precision_mode=precision_mode, precision_mode=precision_mode,
is_dynamic_op=is_dynamic_op, is_dynamic_op=is_dynamic_op,
maximum_cached_engines=maximum_cached_engines)) maximum_cached_engines=maximum_cached_engines))
@ -924,6 +925,36 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
# to fall back to TF function. # to fall back to TF function.
self._TestRun(sess, 2) self._TestRun(sess, 2)
@test_util.run_v2_only
def testTrtGraphConverter_AllowEngineNativeSegmentExecution(self):
if not is_tensorrt_enabled():
return
np_input1, np_input2 = self._RandomInput([4, 1, 1])
# Create a model and save it.
input_saved_model_dir = self.mkdtemp()
root = self._GetModelForV2()
save.save(root, input_saved_model_dir,
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
def _InputFn():
yield np_input1, np_input2
# Run TRT conversion and request an unreasonably large workspace.
converter = self._CreateConverterV2(
input_saved_model_dir, max_workspace_size_bytes=10 << 40)
converter.convert()
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
with self.assertRaisesRegex(
errors.AbortedError,
r"User disallowed engine native segment execution"):
converter.build(input_fn=_InputFn)
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
converter.build(input_fn=_InputFn)
@test_util.run_v2_only @test_util.run_v2_only
def testBackwardCompatibility(self): def testBackwardCompatibility(self):
"""Load and execute a model that was saved in TF2.0.""" """Load and execute a model that was saved in TF2.0."""