[TF-TRT] Enable allow_soft_placement for executing native segments in dynamic
engine mode. Previously, the function handle of an TRTEngineOp is used for two purposes: importing the function GraphDef and executing the function using native TensorFlow. In dynamic engine mode, we construct the function handle for importing the function GraphDef with allow_soft_placement disabled during OpKernel construction. Then when the kernel is executed we reuse such a function handle to execute the function using native TensorFlow with allow_soft_placement disable. This is a problem because a TRTEngineOp is assigned to GPU devices and may include operations that don't have a TensorFlow GPU implementation. To fix the problem, we use the function handle only for native TensorFlow execution. Since we only import the function GraphDef once, it is not necessary to keep such a function handle around. Enabling allow_soft_placement requires a CPU device to be available. On the other hand, the TRTEngineOpTestBase test only allows a GPU device to be created by the device manager. To not break this test, we use attribute _allow_soft_placement to selectively disable allow_soft_placement when executing the graph for TRTEngineOp. Enable CombinedNonMaxSuppresion test for INT8, which is fixed by this CL. PiperOrigin-RevId: 343906225 Change-Id: I5383df96f8aa758a4c24128277c4de93190e234a
This commit is contained in:
parent
40df25db82
commit
1dcb38c020
@ -103,13 +103,16 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
TRTEngineCacheResource* cache_res,
|
||||
AsyncHelper* helper);
|
||||
|
||||
// Construct a function handle for executing native funcdef graph
|
||||
// These are the exact same function.
|
||||
// Constructs a function handle for the segment of the TRTEngineOp.
|
||||
StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle(
|
||||
FunctionLibraryRuntime* lib, const string& device_name,
|
||||
bool allow_soft_placement = false, size_t num_inputs = 0,
|
||||
size_t num_outputs = 0);
|
||||
|
||||
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||
const string& device_name,
|
||||
bool allow_soft_placement = false,
|
||||
size_t num_inputs = 0, size_t num_outputs = 0);
|
||||
// Imports the GraphDef for the segment of the TRTEngineOp to
|
||||
// segment_graph_def_.
|
||||
Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
|
||||
const string& device_name);
|
||||
|
||||
// Executes replaced native segment as function Op.
|
||||
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
||||
@ -175,12 +178,16 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
// Whether to build TensorRT engines at runtime.
|
||||
bool allow_build_at_runtime_;
|
||||
|
||||
// Whether to allow soft placement when the graph is executed with native
|
||||
// TensorFlow.
|
||||
bool allow_soft_placement_;
|
||||
|
||||
// Maximum number of cached engines.
|
||||
int max_cached_engines_;
|
||||
|
||||
int64 workspace_size_;
|
||||
mutex engine_mutex_;
|
||||
FunctionLibraryRuntime::Handle func_handle_;
|
||||
FunctionLibraryRuntime::Handle native_execution_func_handle_;
|
||||
|
||||
// The finalized calibrator for inference.
|
||||
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
||||
@ -260,11 +267,9 @@ static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||
const string& device_name,
|
||||
bool allow_soft_placement,
|
||||
size_t num_inputs,
|
||||
size_t num_outputs) {
|
||||
StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle(
|
||||
FunctionLibraryRuntime* lib, const string& device_name,
|
||||
bool allow_soft_placement, size_t num_inputs, size_t num_outputs) {
|
||||
VLOG(1) << "Constructing function handle";
|
||||
if (lib == nullptr) {
|
||||
return errors::Internal("Context function library is null");
|
||||
@ -298,8 +303,20 @@ Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||
inst_ops.config_proto.set_allow_soft_placement(true);
|
||||
}
|
||||
}
|
||||
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
|
||||
&func_handle_);
|
||||
FunctionLibraryRuntime::Handle func_handle;
|
||||
Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
|
||||
inst_ops, &func_handle);
|
||||
if (status.ok()) {
|
||||
return func_handle;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
|
||||
const string& device_name) {
|
||||
TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle,
|
||||
ConstructFunctionHandle(lib, device_name));
|
||||
return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_);
|
||||
}
|
||||
|
||||
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
@ -335,14 +352,21 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
<< context->device()->name()
|
||||
<< ", thus setting _allow_build_at_runtime=true";
|
||||
allow_build_at_runtime_ = true;
|
||||
} else {
|
||||
OP_REQUIRES_OK(context, status);
|
||||
}
|
||||
func_handle_ = kInvalidHandle;
|
||||
|
||||
status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_);
|
||||
if (status.code() == tensorflow::error::NOT_FOUND) {
|
||||
allow_soft_placement_ = true;
|
||||
} else {
|
||||
OP_REQUIRES_OK(context, status);
|
||||
}
|
||||
|
||||
native_execution_func_handle_ = kInvalidHandle;
|
||||
if (!static_engine_) {
|
||||
FunctionLibraryRuntime* lib = context->function_library();
|
||||
OP_REQUIRES_OK(context,
|
||||
ConstructFunctionHandle(lib, context->device()->name()));
|
||||
OP_REQUIRES_OK(
|
||||
context, FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_));
|
||||
OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(),
|
||||
context->device()->name()));
|
||||
}
|
||||
// TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
|
||||
// backward compatibility reasons. Remove it once all known users switch to
|
||||
@ -411,13 +435,13 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
||||
AsyncHelper* helper) {
|
||||
std::vector<Tensor> inputs;
|
||||
std::vector<Tensor>* outputs = new std::vector<Tensor>();
|
||||
if (func_handle_ == kInvalidHandle) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
if (native_execution_func_handle_ == kInvalidHandle) {
|
||||
StatusOr<FunctionLibraryRuntime::Handle> status_or_handle =
|
||||
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
|
||||
/*allow_soft_placement=*/true,
|
||||
ctx->num_inputs(), ctx->num_outputs()),
|
||||
*helper);
|
||||
allow_soft_placement_, ctx->num_inputs(),
|
||||
ctx->num_outputs());
|
||||
OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), *helper);
|
||||
native_execution_func_handle_ = status_or_handle.ValueOrDie();
|
||||
}
|
||||
auto lib = ctx->function_library();
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
@ -430,7 +454,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
||||
}
|
||||
helper->Ref(); // Increment count for calculating native graph
|
||||
VLOG(1) << "Executing native segment: " << name();
|
||||
lib->Run(opts, func_handle_, inputs, outputs,
|
||||
lib->Run(opts, native_execution_func_handle_, inputs, outputs,
|
||||
[this, ctx, outputs, helper](const Status& s) {
|
||||
core::ScopedUnref sc(helper);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
|
||||
@ -854,12 +878,8 @@ StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
|
||||
return std::pair<EngineContext*, int>(&empty_context, 0);
|
||||
}
|
||||
if (segment_graph_def_.node().empty()) {
|
||||
FunctionLibraryRuntime* lib = ctx->function_library();
|
||||
auto status = ConstructFunctionHandle(lib, ctx->device()->name());
|
||||
if (status.ok()) {
|
||||
status =
|
||||
FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_);
|
||||
}
|
||||
Status status = ImportSegmentGraphDef(ctx->function_library(),
|
||||
ctx->device()->name());
|
||||
if (!status.ok()) {
|
||||
LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for "
|
||||
<< name() << " failed. "
|
||||
|
||||
@ -91,6 +91,13 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
||||
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
||||
NameAttrList function;
|
||||
function.set_name(StrCat(op_name, "_native_segment"));
|
||||
// We disable allow_soft_placement when executing the native segment of the
|
||||
// TRTEngineOp for the following reasons:
|
||||
// OpsTestBase only allow one device in the device manager.
|
||||
// We need to define the GPU device to test TRTEngineOp.
|
||||
// When allow_soft_placement is true, the TensorFlow runtime produces an
|
||||
// error if a CPU device is not defined
|
||||
// (see ProcessFunctionLibraryRuntime::InstantiateMultiDevice).
|
||||
TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp")
|
||||
.Input(FakeInput(1, dtype))
|
||||
.Attr("input_shapes", {shape})
|
||||
@ -105,6 +112,7 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
||||
.Attr("use_calibration", false)
|
||||
.Attr("_use_implicit_batch", use_implicit_batch)
|
||||
.Attr("_allow_build_at_runtime", allow_build_at_runtime)
|
||||
.Attr("_allow_soft_placement", false)
|
||||
.Attr("OutT", {dtype})
|
||||
.Finalize(OpsTestBase::node_def()));
|
||||
TF_ASSERT_OK(InitOpWithFunctionLibrary());
|
||||
|
||||
@ -91,13 +91,10 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase):
|
||||
}
|
||||
|
||||
def ShouldRunTest(self, run_params):
|
||||
# There is no CombinedNonMaxSuppression op for GPU at the moment, so
|
||||
# calibration will fail.
|
||||
# TODO(laigd): fix this.
|
||||
should_run, reason = super().ShouldRunTest(run_params)
|
||||
# Only run for TRT 5.1 and above.
|
||||
return trt_test.IsTensorRTVersionGreaterEqual(
|
||||
5, 1) and not trt_test.IsQuantizationMode(
|
||||
run_params.precision_mode), 'test >=TRT5.1 and non-INT8'
|
||||
return should_run and trt_test.IsTensorRTVersionGreaterEqual(
|
||||
5, 1), reason + ' and >=TRT5.1'
|
||||
|
||||
|
||||
class CombinedNmsExecuteNativeSegmentTest(CombinedNmsTest):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user