[TF:TRT] Add flag TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION.
The default value of the flag is True. When the flag value is false, the bridge will report an error when the native segment of a TRTEngineOp is executed. Add test cases. PiperOrigin-RevId: 317340632 Change-Id: Iacded09b38e63442bbd93076a079d385fb8a77e6
This commit is contained in:
parent
8088eddf20
commit
9152edc1f0
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/env_var.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
@ -521,6 +522,17 @@ Status TRTEngineOp::VerifyInputShapes(
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool AllowEngineNativeSegmentExecution() {
|
||||||
|
bool value;
|
||||||
|
Status status =
|
||||||
|
ReadBoolFromEnvVar("TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION",
|
||||||
|
/*default_value=*/true, &value);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << status;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
|
void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
|
||||||
AsyncOpKernel::DoneCallback done) {
|
AsyncOpKernel::DoneCallback done) {
|
||||||
auto helper = new AsyncHelper(done);
|
auto helper = new AsyncHelper(done);
|
||||||
|
@ -605,17 +617,31 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
|
||||||
|
|
||||||
EngineContext* engine_context = status.ValueOrDie().first;
|
EngineContext* engine_context = status.ValueOrDie().first;
|
||||||
int trt_context_idx = status.ValueOrDie().second;
|
int trt_context_idx = status.ValueOrDie().second;
|
||||||
|
auto may_execute_native_segment = [&] {
|
||||||
|
if (!AllowEngineNativeSegmentExecution()) {
|
||||||
|
ctx->CtxFailure(
|
||||||
|
errors::Aborted("User disallowed engine native segment execution"));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
if (!engine_context->cuda_engine) {
|
if (!engine_context->cuda_engine) {
|
||||||
VLOG(1) << "Engine retrieval for input shapes: "
|
LOG_WARNING_WITH_PREFIX
|
||||||
<< TensorShapeUtils::ShapeListString(input_concrete_shapes)
|
<< "Engine retrieval for input shapes: "
|
||||||
<< " failed. Running native segment for " << name();
|
<< TensorShapeUtils::ShapeListString(input_concrete_shapes)
|
||||||
ExecuteNativeSegment(ctx, helper);
|
<< " failed. Running native segment for " << name();
|
||||||
|
if (may_execute_native_segment()) {
|
||||||
|
ExecuteNativeSegment(ctx, helper);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Status stat = ExecuteTrtEngine(ctx, engine_context, trt_context_idx);
|
Status stat = ExecuteTrtEngine(ctx, engine_context, trt_context_idx);
|
||||||
if (!stat.ok()) {
|
if (!stat.ok()) {
|
||||||
LOG_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat
|
LOG_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat
|
||||||
<< " Retrying with native segment for " << name();
|
<< " Retrying with native segment for " << name();
|
||||||
|
if (!may_execute_native_segment()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
// Release any outputs that are allocated, ExecuteNativeSegment will
|
// Release any outputs that are allocated, ExecuteNativeSegment will
|
||||||
// re-allocate them and fail if they are currently allocated.
|
// re-allocate them and fail if they are currently allocated.
|
||||||
for (int i = 0; i < ctx->num_outputs(); i++) {
|
for (int i = 0; i < ctx->num_outputs(); i++) {
|
||||||
|
|
|
@ -439,6 +439,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
self,
|
self,
|
||||||
input_saved_model_dir,
|
input_saved_model_dir,
|
||||||
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||||
|
max_workspace_size_bytes=10 << 20, # Use a smaller workspace.
|
||||||
precision_mode=trt_convert.TrtPrecisionMode.FP32,
|
precision_mode=trt_convert.TrtPrecisionMode.FP32,
|
||||||
is_dynamic_op=True,
|
is_dynamic_op=True,
|
||||||
maximum_cached_engines=2):
|
maximum_cached_engines=2):
|
||||||
|
@ -446,7 +447,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
input_saved_model_dir=input_saved_model_dir,
|
input_saved_model_dir=input_saved_model_dir,
|
||||||
input_saved_model_signature_key=input_saved_model_signature_key,
|
input_saved_model_signature_key=input_saved_model_signature_key,
|
||||||
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||||
max_workspace_size_bytes=10 << 20, # Use a smaller workspace.
|
max_workspace_size_bytes=max_workspace_size_bytes,
|
||||||
precision_mode=precision_mode,
|
precision_mode=precision_mode,
|
||||||
is_dynamic_op=is_dynamic_op,
|
is_dynamic_op=is_dynamic_op,
|
||||||
maximum_cached_engines=maximum_cached_engines))
|
maximum_cached_engines=maximum_cached_engines))
|
||||||
|
@ -924,6 +925,36 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
# to fall back to TF function.
|
# to fall back to TF function.
|
||||||
self._TestRun(sess, 2)
|
self._TestRun(sess, 2)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testTrtGraphConverter_AllowEngineNativeSegmentExecution(self):
|
||||||
|
if not is_tensorrt_enabled():
|
||||||
|
return
|
||||||
|
|
||||||
|
np_input1, np_input2 = self._RandomInput([4, 1, 1])
|
||||||
|
|
||||||
|
# Create a model and save it.
|
||||||
|
input_saved_model_dir = self.mkdtemp()
|
||||||
|
root = self._GetModelForV2()
|
||||||
|
save.save(root, input_saved_model_dir,
|
||||||
|
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||||
|
|
||||||
|
def _InputFn():
|
||||||
|
yield np_input1, np_input2
|
||||||
|
|
||||||
|
# Run TRT conversion and request an unreasonably large workspace.
|
||||||
|
converter = self._CreateConverterV2(
|
||||||
|
input_saved_model_dir, max_workspace_size_bytes=10 << 40)
|
||||||
|
converter.convert()
|
||||||
|
|
||||||
|
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
errors.AbortedError,
|
||||||
|
r"User disallowed engine native segment execution"):
|
||||||
|
converter.build(input_fn=_InputFn)
|
||||||
|
|
||||||
|
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
|
||||||
|
converter.build(input_fn=_InputFn)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testBackwardCompatibility(self):
|
def testBackwardCompatibility(self):
|
||||||
"""Load and execute a model that was saved in TF2.0."""
|
"""Load and execute a model that was saved in TF2.0."""
|
||||||
|
|
Loading…
Reference in New Issue