[TF:TRT] Enable allow_soft_placement to execute TRTEngineOp graph natively.
When a TRTEngineOp falls back to run with native Tensorflow, we should run the graph with allow_soft_placement=true. This is to allow operations that are supported by TensorRT but otherwise do not have a Tensorflow GPU implementation to run on CPUs. Previously, we set kIntsOnDeviceAttr=true for the TRTEngineOp function. This attribute is not compatible with allow_soft_placement. In order to support SavedModels generated by an older TF-TRT runtime, we issue a warning and run the TRTEngineOp functions with allow_soft_placement=false. PiperOrigin-RevId: 342206089 Change-Id: I0ce7fb8a7778e1982e357db9197644bdd2b92359
This commit is contained in:
parent
0886101a9d
commit
62ab208a35
@ -618,11 +618,6 @@ Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
|
|||||||
auto segment_func = library.add_function();
|
auto segment_func = library.add_function();
|
||||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||||
segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
|
segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
|
||||||
// Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
|
|
||||||
// a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
|
|
||||||
// would be on host if the op generating the tensor has host memory tag set.
|
|
||||||
(*segment_func->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
|
|
||||||
.set_b(true);
|
|
||||||
if (VLOG_IS_ON(7)) {
|
if (VLOG_IS_ON(7)) {
|
||||||
VLOG(7) << engine_name << " Function_Def ";
|
VLOG(7) << engine_name << " Function_Def ";
|
||||||
VLOG(7) << segment_func->DebugString();
|
VLOG(7) << segment_func->DebugString();
|
||||||
|
@ -107,7 +107,9 @@ class TRTEngineOp : public AsyncOpKernel {
|
|||||||
// These are the exact same function.
|
// These are the exact same function.
|
||||||
|
|
||||||
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||||
const string& device_name);
|
const string& device_name,
|
||||||
|
bool allow_soft_placement = false,
|
||||||
|
size_t num_inputs = 0, size_t num_outputs = 0);
|
||||||
|
|
||||||
// Executes replaced native segment as function Op.
|
// Executes replaced native segment as function Op.
|
||||||
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
||||||
@ -259,7 +261,10 @@ static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||||
const string& device_name) {
|
const string& device_name,
|
||||||
|
bool allow_soft_placement,
|
||||||
|
size_t num_inputs,
|
||||||
|
size_t num_outputs) {
|
||||||
VLOG(1) << "Constructing function handle";
|
VLOG(1) << "Constructing function handle";
|
||||||
if (lib == nullptr) {
|
if (lib == nullptr) {
|
||||||
return errors::Internal("Context function library is null");
|
return errors::Internal("Context function library is null");
|
||||||
@ -267,6 +272,32 @@ Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
|||||||
FunctionLibraryRuntime::InstantiateOptions inst_ops;
|
FunctionLibraryRuntime::InstantiateOptions inst_ops;
|
||||||
inst_ops.state_handle = "";
|
inst_ops.state_handle = "";
|
||||||
inst_ops.target = device_name;
|
inst_ops.target = device_name;
|
||||||
|
if (allow_soft_placement) {
|
||||||
|
const FunctionDef* fdef =
|
||||||
|
lib->GetFunctionLibraryDefinition()->Find(func_.name());
|
||||||
|
if (!fdef) {
|
||||||
|
return errors::Internal(
|
||||||
|
StrCat("Cann't find FunctionDef for", func_.name()));
|
||||||
|
}
|
||||||
|
bool ints_on_device =
|
||||||
|
fdef->attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
|
||||||
|
fdef->attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
|
||||||
|
// kIntsOnDeviceAttr is not compatible with is_multi_device_function which
|
||||||
|
// is needed to support allow_soft_placement.
|
||||||
|
if (ints_on_device) {
|
||||||
|
LOG_FIRST_FEW_WARNING_WITH_PREFIX
|
||||||
|
<< "Function " << name()
|
||||||
|
<< " has attribute kIntsOnDeviceAttr=true "
|
||||||
|
"and will be executed natively with allow_soft_placement=false. "
|
||||||
|
"If this is a problem, please re-generate your SavedModel with "
|
||||||
|
"the TF-TRT runtime you are using.";
|
||||||
|
} else {
|
||||||
|
inst_ops.is_multi_device_function = true;
|
||||||
|
inst_ops.input_devices.resize(num_inputs, device_name);
|
||||||
|
inst_ops.output_devices.resize(num_outputs, device_name);
|
||||||
|
inst_ops.config_proto.set_allow_soft_placement(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
|
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
|
||||||
&func_handle_);
|
&func_handle_);
|
||||||
}
|
}
|
||||||
@ -383,7 +414,9 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
|||||||
if (func_handle_ == kInvalidHandle) {
|
if (func_handle_ == kInvalidHandle) {
|
||||||
OP_REQUIRES_OK_ASYNC(
|
OP_REQUIRES_OK_ASYNC(
|
||||||
ctx,
|
ctx,
|
||||||
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name()),
|
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
|
||||||
|
/*allow_soft_placement=*/true,
|
||||||
|
ctx->num_inputs(), ctx->num_outputs()),
|
||||||
*helper);
|
*helper);
|
||||||
}
|
}
|
||||||
auto lib = ctx->function_library();
|
auto lib = ctx->function_library();
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -61,7 +63,7 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase):
|
|||||||
def GetParams(self):
|
def GetParams(self):
|
||||||
# Parameters
|
# Parameters
|
||||||
q = 1
|
q = 1
|
||||||
batch_size = 1
|
batch_size = 2
|
||||||
num_boxes = 200
|
num_boxes = 200
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
max_total_size = 3
|
max_total_size = 3
|
||||||
@ -98,5 +100,33 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase):
|
|||||||
run_params.precision_mode), 'test >=TRT5.1 and non-INT8'
|
run_params.precision_mode), 'test >=TRT5.1 and non-INT8'
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedNmsExecuteNativeSegmentTest(CombinedNmsTest):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
os.environ['TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION'] = 'True'
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
os.environ['TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION'] = 'False'
|
||||||
|
|
||||||
|
def GetConversionParams(self, run_params):
|
||||||
|
conversion_param = super().GetConversionParams(run_params)
|
||||||
|
# Build the engine with the allowed max_batch_size less than the actual
|
||||||
|
# max_batch_size, to fore the runtime to execute the native segment. This
|
||||||
|
# is to test that combined_non_max_suppression, which doesn't have a TF GPU
|
||||||
|
# implementation, can be executed natively even though the it is in the
|
||||||
|
# the graph for the TRTEngineOp with a GPU as a default device.
|
||||||
|
return conversion_param._replace(
|
||||||
|
max_batch_size=conversion_param.max_batch_size - 1)
|
||||||
|
|
||||||
|
def ShouldRunTest(self, run_params):
|
||||||
|
should_run, reason = super().ShouldRunTest(run_params)
|
||||||
|
# max_batch_size is only useful for selecting static engines. As such,
|
||||||
|
# we shouldn't run the test for dynamic engines.
|
||||||
|
return should_run and \
|
||||||
|
not run_params.dynamic_engine, reason + ' and static engines'
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user