[TF-TRT] Enable allow_soft_placement for executing native segments in dynamic

engine mode.

Previously, the function handle of an TRTEngineOp is used for two purposes:
importing the function GraphDef and executing the function using native
TensorFlow. In dynamic engine mode, we construct the function handle for
importing the function GraphDef with allow_soft_placement disabled during
OpKernel construction. Then when the kernel is executed we reuse such a
function handle to execute the function using native TensorFlow with
allow_soft_placement disable. This is a problem because a TRTEngineOp is
assigned to GPU devices and may include operations that don't have a TensorFlow
GPU implementation.

To fix the problem, we use the function handle only for native TensorFlow
execution. Since we only import the function GraphDef once, it is not necessary
to keep such a function handle around.

Enabling allow_soft_placement requires a CPU device to be available. On the
other hand, the TRTEngineOpTestBase test only allows a GPU device to be created
by the device manager. To not break this test, we use attribute
_allow_soft_placement to selectively disable allow_soft_placement when
executing the graph for TRTEngineOp.

Enable CombinedNonMaxSuppresion test for INT8, which is fixed by this CL.

PiperOrigin-RevId: 343906225
Change-Id: I5383df96f8aa758a4c24128277c4de93190e234a
This commit is contained in:
Bixia Zheng 2020-11-23 12:32:40 -08:00 committed by TensorFlower Gardener
parent 40df25db82
commit 1dcb38c020
3 changed files with 64 additions and 39 deletions

View File

@ -103,13 +103,16 @@ class TRTEngineOp : public AsyncOpKernel {
TRTEngineCacheResource* cache_res,
AsyncHelper* helper);
// Construct a function handle for executing native funcdef graph
// These are the exact same function.
// Constructs a function handle for the segment of the TRTEngineOp.
StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle(
FunctionLibraryRuntime* lib, const string& device_name,
bool allow_soft_placement = false, size_t num_inputs = 0,
size_t num_outputs = 0);
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
const string& device_name,
bool allow_soft_placement = false,
size_t num_inputs = 0, size_t num_outputs = 0);
// Imports the GraphDef for the segment of the TRTEngineOp to
// segment_graph_def_.
Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
const string& device_name);
// Executes replaced native segment as function Op.
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
@ -175,12 +178,16 @@ class TRTEngineOp : public AsyncOpKernel {
// Whether to build TensorRT engines at runtime.
bool allow_build_at_runtime_;
// Whether to allow soft placement when the graph is executed with native
// TensorFlow.
bool allow_soft_placement_;
// Maximum number of cached engines.
int max_cached_engines_;
int64 workspace_size_;
mutex engine_mutex_;
FunctionLibraryRuntime::Handle func_handle_;
FunctionLibraryRuntime::Handle native_execution_func_handle_;
// The finalized calibrator for inference.
std::unique_ptr<TRTInt8Calibrator> calibrator_;
@ -260,11 +267,9 @@ static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
return Status::OK();
}
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
const string& device_name,
bool allow_soft_placement,
size_t num_inputs,
size_t num_outputs) {
StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle(
FunctionLibraryRuntime* lib, const string& device_name,
bool allow_soft_placement, size_t num_inputs, size_t num_outputs) {
VLOG(1) << "Constructing function handle";
if (lib == nullptr) {
return errors::Internal("Context function library is null");
@ -298,8 +303,20 @@ Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
inst_ops.config_proto.set_allow_soft_placement(true);
}
}
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
&func_handle_);
FunctionLibraryRuntime::Handle func_handle;
Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
inst_ops, &func_handle);
if (status.ok()) {
return func_handle;
}
return status;
}
Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
const string& device_name) {
TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle,
ConstructFunctionHandle(lib, device_name));
return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_);
}
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
@ -335,14 +352,21 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
<< context->device()->name()
<< ", thus setting _allow_build_at_runtime=true";
allow_build_at_runtime_ = true;
} else {
OP_REQUIRES_OK(context, status);
}
func_handle_ = kInvalidHandle;
status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_);
if (status.code() == tensorflow::error::NOT_FOUND) {
allow_soft_placement_ = true;
} else {
OP_REQUIRES_OK(context, status);
}
native_execution_func_handle_ = kInvalidHandle;
if (!static_engine_) {
FunctionLibraryRuntime* lib = context->function_library();
OP_REQUIRES_OK(context,
ConstructFunctionHandle(lib, context->device()->name()));
OP_REQUIRES_OK(
context, FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_));
OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(),
context->device()->name()));
}
// TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
// backward compatibility reasons. Remove it once all known users switch to
@ -411,13 +435,13 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
AsyncHelper* helper) {
std::vector<Tensor> inputs;
std::vector<Tensor>* outputs = new std::vector<Tensor>();
if (func_handle_ == kInvalidHandle) {
OP_REQUIRES_OK_ASYNC(
ctx,
if (native_execution_func_handle_ == kInvalidHandle) {
StatusOr<FunctionLibraryRuntime::Handle> status_or_handle =
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
/*allow_soft_placement=*/true,
ctx->num_inputs(), ctx->num_outputs()),
*helper);
allow_soft_placement_, ctx->num_inputs(),
ctx->num_outputs());
OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), *helper);
native_execution_func_handle_ = status_or_handle.ValueOrDie();
}
auto lib = ctx->function_library();
FunctionLibraryRuntime::Options opts;
@ -430,7 +454,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
}
helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment: " << name();
lib->Run(opts, func_handle_, inputs, outputs,
lib->Run(opts, native_execution_func_handle_, inputs, outputs,
[this, ctx, outputs, helper](const Status& s) {
core::ScopedUnref sc(helper);
OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
@ -854,12 +878,8 @@ StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
return std::pair<EngineContext*, int>(&empty_context, 0);
}
if (segment_graph_def_.node().empty()) {
FunctionLibraryRuntime* lib = ctx->function_library();
auto status = ConstructFunctionHandle(lib, ctx->device()->name());
if (status.ok()) {
status =
FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_);
}
Status status = ImportSegmentGraphDef(ctx->function_library(),
ctx->device()->name());
if (!status.ok()) {
LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for "
<< name() << " failed. "

View File

@ -91,6 +91,13 @@ class TRTEngineOpTestBase : public OpsTestBase {
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
NameAttrList function;
function.set_name(StrCat(op_name, "_native_segment"));
// We disable allow_soft_placement when executing the native segment of the
// TRTEngineOp for the following reasons:
// OpsTestBase only allow one device in the device manager.
// We need to define the GPU device to test TRTEngineOp.
// When allow_soft_placement is true, the TensorFlow runtime produces an
// error if a CPU device is not defined
// (see ProcessFunctionLibraryRuntime::InstantiateMultiDevice).
TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp")
.Input(FakeInput(1, dtype))
.Attr("input_shapes", {shape})
@ -105,6 +112,7 @@ class TRTEngineOpTestBase : public OpsTestBase {
.Attr("use_calibration", false)
.Attr("_use_implicit_batch", use_implicit_batch)
.Attr("_allow_build_at_runtime", allow_build_at_runtime)
.Attr("_allow_soft_placement", false)
.Attr("OutT", {dtype})
.Finalize(OpsTestBase::node_def()));
TF_ASSERT_OK(InitOpWithFunctionLibrary());

View File

@ -91,13 +91,10 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase):
}
def ShouldRunTest(self, run_params):
# There is no CombinedNonMaxSuppression op for GPU at the moment, so
# calibration will fail.
# TODO(laigd): fix this.
should_run, reason = super().ShouldRunTest(run_params)
# Only run for TRT 5.1 and above.
return trt_test.IsTensorRTVersionGreaterEqual(
5, 1) and not trt_test.IsQuantizationMode(
run_params.precision_mode), 'test >=TRT5.1 and non-INT8'
return should_run and trt_test.IsTensorRTVersionGreaterEqual(
5, 1), reason + ' and >=TRT5.1'
class CombinedNmsExecuteNativeSegmentTest(CombinedNmsTest):