[TF:TRT] Enable allow_soft_placement to execute TRTEngineOp graph natively.

When a TRTEngineOp falls back to run with native Tensorflow, we should run the
graph with allow_soft_placement=true. This is to allow operations that are
supported by TensorRT but otherwise do not have a Tensorflow GPU implementation
to run on CPUs.

Previously, we set kIntsOnDeviceAttr=true for the TRTEngineOp function. This
attribute is not compatible with allow_soft_placement. In order to support
SavedModels generated by an older TF-TRT runtime, we issue a warning and
run the TRTEngineOp functions with allow_soft_placement=false.

PiperOrigin-RevId: 342206089
Change-Id: I0ce7fb8a7778e1982e357db9197644bdd2b92359
This commit is contained in:
Bixia Zheng 2020-11-12 23:19:12 -08:00 committed by TensorFlower Gardener
parent 0886101a9d
commit 62ab208a35
3 changed files with 67 additions and 9 deletions

View File

@ -618,11 +618,6 @@ Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
auto segment_func = library.add_function(); auto segment_func = library.add_function();
TF_RETURN_IF_ERROR(GraphToFunctionDef( TF_RETURN_IF_ERROR(GraphToFunctionDef(
segment_graph, StrCat(engine_name, "_native_segment"), segment_func)); segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
// Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
// a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
// would be on host if the op generating the tensor has host memory tag set.
(*segment_func->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
.set_b(true);
if (VLOG_IS_ON(7)) { if (VLOG_IS_ON(7)) {
VLOG(7) << engine_name << " Function_Def "; VLOG(7) << engine_name << " Function_Def ";
VLOG(7) << segment_func->DebugString(); VLOG(7) << segment_func->DebugString();

View File

@ -107,7 +107,9 @@ class TRTEngineOp : public AsyncOpKernel {
// These are the exact same function. // These are the exact same function.
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib, Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
const string& device_name); const string& device_name,
bool allow_soft_placement = false,
size_t num_inputs = 0, size_t num_outputs = 0);
// Executes replaced native segment as function Op. // Executes replaced native segment as function Op.
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
@ -259,7 +261,10 @@ static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
} }
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib, Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
const string& device_name) { const string& device_name,
bool allow_soft_placement,
size_t num_inputs,
size_t num_outputs) {
VLOG(1) << "Constructing function handle"; VLOG(1) << "Constructing function handle";
if (lib == nullptr) { if (lib == nullptr) {
return errors::Internal("Context function library is null"); return errors::Internal("Context function library is null");
@ -267,6 +272,32 @@ Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
FunctionLibraryRuntime::InstantiateOptions inst_ops; FunctionLibraryRuntime::InstantiateOptions inst_ops;
inst_ops.state_handle = ""; inst_ops.state_handle = "";
inst_ops.target = device_name; inst_ops.target = device_name;
if (allow_soft_placement) {
const FunctionDef* fdef =
lib->GetFunctionLibraryDefinition()->Find(func_.name());
if (!fdef) {
return errors::Internal(
StrCat("Cann't find FunctionDef for", func_.name()));
}
bool ints_on_device =
fdef->attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
fdef->attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
// kIntsOnDeviceAttr is not compatible with is_multi_device_function which
// is needed to support allow_soft_placement.
if (ints_on_device) {
LOG_FIRST_FEW_WARNING_WITH_PREFIX
<< "Function " << name()
<< " has attribute kIntsOnDeviceAttr=true "
"and will be executed natively with allow_soft_placement=false. "
"If this is a problem, please re-generate your SavedModel with "
"the TF-TRT runtime you are using.";
} else {
inst_ops.is_multi_device_function = true;
inst_ops.input_devices.resize(num_inputs, device_name);
inst_ops.output_devices.resize(num_outputs, device_name);
inst_ops.config_proto.set_allow_soft_placement(true);
}
}
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops, return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
&func_handle_); &func_handle_);
} }
@ -383,7 +414,9 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
if (func_handle_ == kInvalidHandle) { if (func_handle_ == kInvalidHandle) {
OP_REQUIRES_OK_ASYNC( OP_REQUIRES_OK_ASYNC(
ctx, ctx,
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name()), ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
/*allow_soft_placement=*/true,
ctx->num_inputs(), ctx->num_outputs()),
*helper); *helper);
} }
auto lib = ctx->function_library(); auto lib = ctx->function_library();

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -61,7 +63,7 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase):
def GetParams(self): def GetParams(self):
# Parameters # Parameters
q = 1 q = 1
batch_size = 1 batch_size = 2
num_boxes = 200 num_boxes = 200
num_classes = 2 num_classes = 2
max_total_size = 3 max_total_size = 3
@ -98,5 +100,33 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase):
run_params.precision_mode), 'test >=TRT5.1 and non-INT8' run_params.precision_mode), 'test >=TRT5.1 and non-INT8'
class CombinedNmsExecuteNativeSegmentTest(CombinedNmsTest):
def setUp(self):
super().setUp()
os.environ['TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION'] = 'True'
def tearDown(self):
super().tearDown()
os.environ['TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION'] = 'False'
def GetConversionParams(self, run_params):
conversion_param = super().GetConversionParams(run_params)
# Build the engine with the allowed max_batch_size less than the actual
# max_batch_size, to fore the runtime to execute the native segment. This
# is to test that combined_non_max_suppression, which doesn't have a TF GPU
# implementation, can be executed natively even though the it is in the
# the graph for the TRTEngineOp with a GPU as a default device.
return conversion_param._replace(
max_batch_size=conversion_param.max_batch_size - 1)
def ShouldRunTest(self, run_params):
should_run, reason = super().ShouldRunTest(run_params)
# max_batch_size is only useful for selecting static engines. As such,
# we shouldn't run the test for dynamic engines.
return should_run and \
not run_params.dynamic_engine, reason + ' and static engines'
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()