[TF:TRT] Add flag TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION.
The default value of the flag is True. When the flag value is false, the bridge will report an error when the native segment of a TRTEngineOp is executed. Add test cases. PiperOrigin-RevId: 317340632 Change-Id: Iacded09b38e63442bbd93076a079d385fb8a77e6
This commit is contained in:
parent
8088eddf20
commit
9152edc1f0
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
@ -521,6 +522,17 @@ Status TRTEngineOp::VerifyInputShapes(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
static bool AllowEngineNativeSegmentExecution() {
|
||||
bool value;
|
||||
Status status =
|
||||
ReadBoolFromEnvVar("TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION",
|
||||
/*default_value=*/true, &value);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
auto helper = new AsyncHelper(done);
|
||||
|
@ -605,17 +617,31 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
|
|||
|
||||
EngineContext* engine_context = status.ValueOrDie().first;
|
||||
int trt_context_idx = status.ValueOrDie().second;
|
||||
auto may_execute_native_segment = [&] {
|
||||
if (!AllowEngineNativeSegmentExecution()) {
|
||||
ctx->CtxFailure(
|
||||
errors::Aborted("User disallowed engine native segment execution"));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (!engine_context->cuda_engine) {
|
||||
VLOG(1) << "Engine retrieval for input shapes: "
|
||||
<< TensorShapeUtils::ShapeListString(input_concrete_shapes)
|
||||
<< " failed. Running native segment for " << name();
|
||||
ExecuteNativeSegment(ctx, helper);
|
||||
LOG_WARNING_WITH_PREFIX
|
||||
<< "Engine retrieval for input shapes: "
|
||||
<< TensorShapeUtils::ShapeListString(input_concrete_shapes)
|
||||
<< " failed. Running native segment for " << name();
|
||||
if (may_execute_native_segment()) {
|
||||
ExecuteNativeSegment(ctx, helper);
|
||||
}
|
||||
return;
|
||||
}
|
||||
Status stat = ExecuteTrtEngine(ctx, engine_context, trt_context_idx);
|
||||
if (!stat.ok()) {
|
||||
LOG_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat
|
||||
<< " Retrying with native segment for " << name();
|
||||
if (!may_execute_native_segment()) {
|
||||
return;
|
||||
}
|
||||
// Release any outputs that are allocated, ExecuteNativeSegment will
|
||||
// re-allocate them and fail if they are currently allocated.
|
||||
for (int i = 0; i < ctx->num_outputs(); i++) {
|
||||
|
|
|
@ -439,6 +439,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
self,
|
||||
input_saved_model_dir,
|
||||
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||
max_workspace_size_bytes=10 << 20, # Use a smaller workspace.
|
||||
precision_mode=trt_convert.TrtPrecisionMode.FP32,
|
||||
is_dynamic_op=True,
|
||||
maximum_cached_engines=2):
|
||||
|
@ -446,7 +447,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
input_saved_model_dir=input_saved_model_dir,
|
||||
input_saved_model_signature_key=input_saved_model_signature_key,
|
||||
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||
max_workspace_size_bytes=10 << 20, # Use a smaller workspace.
|
||||
max_workspace_size_bytes=max_workspace_size_bytes,
|
||||
precision_mode=precision_mode,
|
||||
is_dynamic_op=is_dynamic_op,
|
||||
maximum_cached_engines=maximum_cached_engines))
|
||||
|
@ -924,6 +925,36 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
# to fall back to TF function.
|
||||
self._TestRun(sess, 2)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTrtGraphConverter_AllowEngineNativeSegmentExecution(self):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
np_input1, np_input2 = self._RandomInput([4, 1, 1])
|
||||
|
||||
# Create a model and save it.
|
||||
input_saved_model_dir = self.mkdtemp()
|
||||
root = self._GetModelForV2()
|
||||
save.save(root, input_saved_model_dir,
|
||||
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||
|
||||
def _InputFn():
|
||||
yield np_input1, np_input2
|
||||
|
||||
# Run TRT conversion and request an unreasonably large workspace.
|
||||
converter = self._CreateConverterV2(
|
||||
input_saved_model_dir, max_workspace_size_bytes=10 << 40)
|
||||
converter.convert()
|
||||
|
||||
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
|
||||
with self.assertRaisesRegex(
|
||||
errors.AbortedError,
|
||||
r"User disallowed engine native segment execution"):
|
||||
converter.build(input_fn=_InputFn)
|
||||
|
||||
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
|
||||
converter.build(input_fn=_InputFn)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testBackwardCompatibility(self):
|
||||
"""Load and execute a model that was saved in TF2.0."""
|
||||
|
|
Loading…
Reference in New Issue