[TF:TRT] Enable allow_soft_placement to execute TRTEngineOp graph natively.
When a TRTEngineOp falls back to run with native Tensorflow, we should run the graph with allow_soft_placement=true. This is to allow operations that are supported by TensorRT but otherwise do not have a Tensorflow GPU implementation to run on CPUs. Previously, we set kIntsOnDeviceAttr=true for the TRTEngineOp function. This attribute is not compatible with allow_soft_placement. In order to support SavedModels generated by an older TF-TRT runtime, we issue a warning and run the TRTEngineOp functions with allow_soft_placement=false. PiperOrigin-RevId: 342206089 Change-Id: I0ce7fb8a7778e1982e357db9197644bdd2b92359
This commit is contained in:
parent
0886101a9d
commit
62ab208a35
@ -618,11 +618,6 @@ Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
|
||||
auto segment_func = library.add_function();
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||
segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
|
||||
// Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
|
||||
// a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
|
||||
// would be on host if the op generating the tensor has host memory tag set.
|
||||
(*segment_func->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
|
||||
.set_b(true);
|
||||
if (VLOG_IS_ON(7)) {
|
||||
VLOG(7) << engine_name << " Function_Def ";
|
||||
VLOG(7) << segment_func->DebugString();
|
||||
|
@ -107,7 +107,9 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
// These are the exact same function.
|
||||
|
||||
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||
const string& device_name);
|
||||
const string& device_name,
|
||||
bool allow_soft_placement = false,
|
||||
size_t num_inputs = 0, size_t num_outputs = 0);
|
||||
|
||||
// Executes replaced native segment as function Op.
|
||||
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
||||
@ -259,7 +261,10 @@ static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
|
||||
}
|
||||
|
||||
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||
const string& device_name) {
|
||||
const string& device_name,
|
||||
bool allow_soft_placement,
|
||||
size_t num_inputs,
|
||||
size_t num_outputs) {
|
||||
VLOG(1) << "Constructing function handle";
|
||||
if (lib == nullptr) {
|
||||
return errors::Internal("Context function library is null");
|
||||
@ -267,6 +272,32 @@ Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||
FunctionLibraryRuntime::InstantiateOptions inst_ops;
|
||||
inst_ops.state_handle = "";
|
||||
inst_ops.target = device_name;
|
||||
if (allow_soft_placement) {
|
||||
const FunctionDef* fdef =
|
||||
lib->GetFunctionLibraryDefinition()->Find(func_.name());
|
||||
if (!fdef) {
|
||||
return errors::Internal(
|
||||
StrCat("Cann't find FunctionDef for", func_.name()));
|
||||
}
|
||||
bool ints_on_device =
|
||||
fdef->attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
|
||||
fdef->attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
|
||||
// kIntsOnDeviceAttr is not compatible with is_multi_device_function which
|
||||
// is needed to support allow_soft_placement.
|
||||
if (ints_on_device) {
|
||||
LOG_FIRST_FEW_WARNING_WITH_PREFIX
|
||||
<< "Function " << name()
|
||||
<< " has attribute kIntsOnDeviceAttr=true "
|
||||
"and will be executed natively with allow_soft_placement=false. "
|
||||
"If this is a problem, please re-generate your SavedModel with "
|
||||
"the TF-TRT runtime you are using.";
|
||||
} else {
|
||||
inst_ops.is_multi_device_function = true;
|
||||
inst_ops.input_devices.resize(num_inputs, device_name);
|
||||
inst_ops.output_devices.resize(num_outputs, device_name);
|
||||
inst_ops.config_proto.set_allow_soft_placement(true);
|
||||
}
|
||||
}
|
||||
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
|
||||
&func_handle_);
|
||||
}
|
||||
@ -383,7 +414,9 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
||||
if (func_handle_ == kInvalidHandle) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name()),
|
||||
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
|
||||
/*allow_soft_placement=*/true,
|
||||
ctx->num_inputs(), ctx->num_outputs()),
|
||||
*helper);
|
||||
}
|
||||
auto lib = ctx->function_library();
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -61,7 +63,7 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase):
|
||||
def GetParams(self):
|
||||
# Parameters
|
||||
q = 1
|
||||
batch_size = 1
|
||||
batch_size = 2
|
||||
num_boxes = 200
|
||||
num_classes = 2
|
||||
max_total_size = 3
|
||||
@ -98,5 +100,33 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase):
|
||||
run_params.precision_mode), 'test >=TRT5.1 and non-INT8'
|
||||
|
||||
|
||||
class CombinedNmsExecuteNativeSegmentTest(CombinedNmsTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ['TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION'] = 'True'
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
os.environ['TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION'] = 'False'
|
||||
|
||||
def GetConversionParams(self, run_params):
|
||||
conversion_param = super().GetConversionParams(run_params)
|
||||
# Build the engine with the allowed max_batch_size less than the actual
|
||||
# max_batch_size, to fore the runtime to execute the native segment. This
|
||||
# is to test that combined_non_max_suppression, which doesn't have a TF GPU
|
||||
# implementation, can be executed natively even though the it is in the
|
||||
# the graph for the TRTEngineOp with a GPU as a default device.
|
||||
return conversion_param._replace(
|
||||
max_batch_size=conversion_param.max_batch_size - 1)
|
||||
|
||||
def ShouldRunTest(self, run_params):
|
||||
should_run, reason = super().ShouldRunTest(run_params)
|
||||
# max_batch_size is only useful for selecting static engines. As such,
|
||||
# we shouldn't run the test for dynamic engines.
|
||||
return should_run and \
|
||||
not run_params.dynamic_engine, reason + ' and static engines'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user