diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 8fbe0f4ceb9..b3120312e3b 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -618,11 +618,6 @@ Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def, auto segment_func = library.add_function(); TF_RETURN_IF_ERROR(GraphToFunctionDef( segment_graph, StrCat(engine_name, "_native_segment"), segment_func)); - // Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on - // a GPU device as expected. Otherwise, some of the tensors of type DT_INT32 - // would be on host if the op generating the tensor has host memory tag set. - (*segment_func->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr] - .set_b(true); if (VLOG_IS_ON(7)) { VLOG(7) << engine_name << " Function_Def "; VLOG(7) << segment_func->DebugString(); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 23fd0095da1..e0a731e502e 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -107,7 +107,9 @@ class TRTEngineOp : public AsyncOpKernel { // These are the exact same function. Status ConstructFunctionHandle(FunctionLibraryRuntime* lib, - const string& device_name); + const string& device_name, + bool allow_soft_placement = false, + size_t num_inputs = 0, size_t num_outputs = 0); // Executes replaced native segment as function Op. void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); @@ -259,7 +261,10 @@ static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle, } Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib, - const string& device_name) { + const string& device_name, + bool allow_soft_placement, + size_t num_inputs, + size_t num_outputs) { VLOG(1) << "Constructing function handle"; if (lib == nullptr) { return errors::Internal("Context function library is null"); @@ -267,6 +272,32 @@ Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib, FunctionLibraryRuntime::InstantiateOptions inst_ops; inst_ops.state_handle = ""; inst_ops.target = device_name; + if (allow_soft_placement) { + const FunctionDef* fdef = + lib->GetFunctionLibraryDefinition()->Find(func_.name()); + if (!fdef) { + return errors::Internal( + StrCat("Cann't find FunctionDef for", func_.name())); + } + bool ints_on_device = + fdef->attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 && + fdef->attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b(); + // kIntsOnDeviceAttr is not compatible with is_multi_device_function which + // is needed to support allow_soft_placement. + if (ints_on_device) { + LOG_FIRST_FEW_WARNING_WITH_PREFIX + << "Function " << name() + << " has attribute kIntsOnDeviceAttr=true " + "and will be executed natively with allow_soft_placement=false. " + "If this is a problem, please re-generate your SavedModel with " + "the TF-TRT runtime you are using."; + } else { + inst_ops.is_multi_device_function = true; + inst_ops.input_devices.resize(num_inputs, device_name); + inst_ops.output_devices.resize(num_outputs, device_name); + inst_ops.config_proto.set_allow_soft_placement(true); + } + } return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops, &func_handle_); } @@ -383,7 +414,9 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, if (func_handle_ == kInvalidHandle) { OP_REQUIRES_OK_ASYNC( ctx, - ConstructFunctionHandle(ctx->function_library(), ctx->device()->name()), + ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(), + /*allow_soft_placement=*/true, + ctx->num_inputs(), ctx->num_outputs()), *helper); } auto lib = ctx->function_library(); diff --git a/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py b/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py index 3f2a5469ae6..23397d76cc3 100644 --- a/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py +++ b/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -61,7 +63,7 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): # Parameters q = 1 - batch_size = 1 + batch_size = 2 num_boxes = 200 num_classes = 2 max_total_size = 3 @@ -98,5 +100,33 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase): run_params.precision_mode), 'test >=TRT5.1 and non-INT8' +class CombinedNmsExecuteNativeSegmentTest(CombinedNmsTest): + + def setUp(self): + super().setUp() + os.environ['TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION'] = 'True' + + def tearDown(self): + super().tearDown() + os.environ['TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION'] = 'False' + + def GetConversionParams(self, run_params): + conversion_param = super().GetConversionParams(run_params) + # Build the engine with the allowed max_batch_size less than the actual + # max_batch_size, to fore the runtime to execute the native segment. This + # is to test that combined_non_max_suppression, which doesn't have a TF GPU + # implementation, can be executed natively even though the it is in the + # the graph for the TRTEngineOp with a GPU as a default device. + return conversion_param._replace( + max_batch_size=conversion_param.max_batch_size - 1) + + def ShouldRunTest(self, run_params): + should_run, reason = super().ShouldRunTest(run_params) + # max_batch_size is only useful for selecting static engines. As such, + # we shouldn't run the test for dynamic engines. + return should_run and \ + not run_params.dynamic_engine, reason + ' and static engines' + + if __name__ == '__main__': test.main()