From 669dc0c76a6b271a98047d522cf131eebfca1d08 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 5 Nov 2019 16:32:20 -0800 Subject: [PATCH 1/2] Add allow_build_at_runtime option --- tensorflow/compiler/tf2tensorrt/BUILD | 2 + .../tf2tensorrt/convert/convert_graph.cc | 2 + .../tf2tensorrt/convert/convert_graph.h | 1 + .../tf2tensorrt/convert/convert_nodes.h | 4 +- .../convert/trt_optimization_pass.cc | 4 ++ .../convert/trt_optimization_pass.h | 4 +- .../tf2tensorrt/kernels/trt_engine_op.cc | 23 +++++++++- .../tf2tensorrt/kernels/trt_engine_op_test.cc | 32 +++++++++++++- .../python/compiler/tensorrt/trt_convert.py | 43 +++++++++++++++++-- 9 files changed, 106 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index a55ca56e551..82b682ed7a4 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -189,6 +189,8 @@ tf_cuda_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:array", ] + if_tensorrt([ "@local_config_cuda//cuda:cuda_headers", ]), diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 0131d45f815..f17361fb211 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -468,6 +468,7 @@ Status CreateTRTNode(const ConversionParams& params, .Attr("precision_mode", prec_string) .Attr("use_calibration", info.use_calibration) .Attr("_use_implicit_batch", params.use_implicit_batch) + .Attr("_allow_build_at_runtime", info.allow_build_at_runtime) .Attr("OutT", out_types) .Finalize(&trt_node); if (!status.ok()) { @@ -671,6 +672,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { : EngineInfo::EngineType::TRTStatic); curr_engine.use_calibration = params.use_calibration; curr_engine.maximum_cached_engines = params.max_cached_engines; + curr_engine.allow_build_at_runtime = params.allow_build_at_runtime; status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def, &graph, curr_engine.engine_name); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 00dc4c72f43..2bfaa2a786c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -49,6 +49,7 @@ struct ConversionParams { int max_cached_engines = 1; bool use_calibration = true; bool use_implicit_batch = true; + bool allow_build_at_runtime = true; }; // Method to call from optimization pass diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index d295f074a98..4375af8ad3f 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -92,7 +92,8 @@ struct EngineInfo { : engine_type(EngineType::TRTStatic), max_workspace_size_bytes(0), precision_mode(TrtPrecisionMode::FP32), - use_calibration(true) {} + use_calibration(true), + allow_build_at_runtime(true) {} string engine_name; string device; @@ -109,6 +110,7 @@ struct EngineInfo { int maximum_cached_engines; TrtPrecisionMode precision_mode; bool use_calibration; + bool allow_build_at_runtime; }; // Constructs a graphdef from the segment in the given graph. Adds _Arg diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 757ddd159c9..7995163ed44 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -70,6 +70,9 @@ Status TRTOptimizationPass::Init( if (params.count("trt_logger")) { trt_logger_name_ = params.at("trt_logger").s(); } + if (params.count("allow_build_at_runtime")) { + allow_build_at_runtime_ = params.at("allow_build_at_runtime").b(); + } if (params.count("use_implicit_batch")) { use_implicit_batch_ = params.at("use_implicit_batch").b(); } @@ -265,6 +268,7 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, cp.max_cached_engines = max_cached_batches_; cp.use_calibration = use_calibration_; cp.use_implicit_batch = use_implicit_batch_; + cp.allow_build_at_runtime = allow_build_at_runtime_; auto status = ConvertAfterShapes(cp); VLOG(1) << "Returning from " << name_; return status; diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index 3ce0d09b7c0..f79048bb5f6 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -42,7 +42,8 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { max_cached_batches_(1), max_workspace_size_bytes_(256LL << 20), use_calibration_(true), - use_implicit_batch_(true) { + use_implicit_batch_(true), + allow_build_at_runtime_(true) { VLOG(1) << "Constructing " << name_; } @@ -75,6 +76,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { int64_t max_workspace_size_bytes_; bool use_calibration_; bool use_implicit_batch_; + bool allow_build_at_runtime_; }; } // namespace convert diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 909e3e11006..b98e75527cc 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -157,6 +157,9 @@ class TRTEngineOp : public AsyncOpKernel { // Whether to use implicit batch dimension for TensorRT bool use_implicit_batch_; + // Whether to build TensorRT engines at runtime + bool allow_build_at_runtime_; + // Maximum number of cached engines int max_cached_engines_; @@ -281,6 +284,14 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) context->GetAttr("use_calibration", &use_calibration_)); OP_REQUIRES_OK(context, context->GetAttr("input_shapes", &input_partial_shapes_)); + auto status = + context->GetAttr("_allow_build_at_runtime", &allow_build_at_runtime_); + if (status.code() == tensorflow::error::NOT_FOUND) { + VLOG(2) << "Not found _allow_build_at_runtime in " + << context->device()->name() + << ", thus setting _allow_build_at_runtime=true"; + allow_build_at_runtime_ = true; + } func_handle_ = kInvalidHandle; if (!static_engine_) { FunctionLibraryRuntime* lib = context->function_library(); @@ -302,7 +313,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", &max_cached_engines_)); - auto status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_); + status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_); if (status.code() == tensorflow::error::NOT_FOUND) { VLOG(2) << "Not found _use_implicit_batch in " << context->device()->name() << ", thus setting _use_implicit_batch=true"; @@ -957,6 +968,16 @@ StatusOr TRTEngineOp::GetEngine( // If matched, use that engine. Otherwise, we will look in cache for that // exact shape and possibly create a new engine if it is not in cache. if (!cache.count(engine_input_shapes)) { + if (!allow_build_at_runtime_) { + LOG(WARNING) << "Found no engine in cache matching input shapes. " + << "Not building a new engine because " + << "allow_build_at_runtime=False. " + << "The native segment will be used instead."; + // Store an empty engine in the cache for these input shapes so we don't + // try to build the same failing engine again. + cache.emplace(engine_input_shapes, absl::make_unique()); + return &empty_context; + } TrtUniquePtrType engine; bool convert_successfully = false; LOG(INFO) << "Building a new TensorRT engine for " << name() diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc index a88f2b5e29e..2cf20e443fb 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/math_ops.h" @@ -49,6 +48,7 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/public/version.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -62,7 +62,8 @@ class TRTEngineOpTestBase : public OpsTestBase { public: void AddSimpleTrtOp(DataType dtype, int max_cached_engines_count = 1, PartialTensorShape shape = PartialTensorShape({-1, -1}), - bool use_implicit_batch = true) { + bool use_implicit_batch = true, + bool allow_build_at_runtime = true) { // Create the GPU device. std::unique_ptr device( DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0")); @@ -104,6 +105,7 @@ class TRTEngineOpTestBase : public OpsTestBase { .Attr("precision_mode", "FP32") .Attr("use_calibration", false) .Attr("_use_implicit_batch", use_implicit_batch) + .Attr("_allow_build_at_runtime", allow_build_at_runtime) .Attr("OutT", {dtype}) .Finalize(OpsTestBase::node_def())); TF_ASSERT_OK(InitOpWithFunctionLibrary()); @@ -186,6 +188,32 @@ TEST_F(TRTEngineOpTestBase, DynamicEngines) { EXPECT_EQ(1, cache->count({TensorShape({10, 10})})); } +TEST_F(TRTEngineOpTestBase, AllowBuildAtRuntime) { + TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/1, + PartialTensorShape({-1, -1}), + /*use_implicit_batch=*/true, + /*allow_build_at_runtime=*/false); + + // Execute the op + TensorShape input_shape({2, 2}); + TRTEngineOpTestBase::AddSimpleInput(input_shape); + TF_ASSERT_OK(OpsTestBase::RunOpKernel()); + + // Get the engine cache. + TRTEngineCacheResource* cache_resource = nullptr; + TF_ASSERT_OK( + device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource)); + core::ScopedUnref sc(cache_resource); + + // It should contain a placeholder with an empty cuda_engine (to mark that + // engine creation was not successful for the given input shape). + auto cache = &cache_resource->cache_; + EXPECT_EQ(1, cache->size()); + ASSERT_EQ(1, cache->count({input_shape})); + EngineContext* ectx = cache->at({input_shape}).get(); + EXPECT_EQ(ectx->cuda_engine, nullptr); +} + TEST_F(TRTEngineOpTestBase, ExplicitBatch) { // Test inference in explicit batch mode with static input shapes. Static // shapes in this context means that the TensorRT knows all the input shapes diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 2ea22ebba49..f56f7a9b5d0 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -116,7 +116,7 @@ DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 class TrtConversionParams(collections.namedtuple("TrtConversionParams", [ "rewriter_config_template", "max_workspace_size_bytes", "precision_mode", "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines", - "use_calibration", "max_batch_size"])): + "use_calibration", "max_batch_size", "allow_build_at_runtime"])): """Parameters that are used for TF-TRT conversion. Fields: @@ -151,6 +151,11 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [ tensors were trained with fake quantization. max_batch_size: max size for the input batch. This parameter is only effective when is_dynamic_op=False which is not supported in TF 2.0. + allow_build_at_runtime: whether to build TensorRT engines during runtime. + If no TensorRT engine can be found in cache that can handle the given + inputs during runtime, then a new TensorRT engine is built at runtime if + allow_build_at_runtime=True, and otherwise native TF is used. This + argument is only effective if is_dynamic_op=True. """ def __new__(cls, @@ -161,11 +166,12 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [ is_dynamic_op=True, maximum_cached_engines=1, use_calibration=True, - max_batch_size=1): + max_batch_size=1, + allow_build_at_runtime=True): return super(TrtConversionParams, cls).__new__( cls, rewriter_config_template, max_workspace_size_bytes, precision_mode, minimum_segment_size, is_dynamic_op, maximum_cached_engines, - use_calibration, max_batch_size) + use_calibration, max_batch_size, allow_build_at_runtime) DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams() @@ -228,6 +234,13 @@ def _check_conversion_params(conversion_params, is_v2=False): not trt_optimizer.parameter_map["is_dynamic_op"]): raise ValueError("Option is_dynamic_op=False is not supported " "in TF 2.0, please set it to True instead.") + if (conversion_params.allow_build_at_runtime and + not conversion_params.is_dynamic_op): + tf_logging.warn(( + "Building TensorRT engines at runtime is not supported " + "if is_dynamic_op=False, therefore assuming " + "allow_build_at_runtime=False. If building TensorRT engines " + "at runtime is desired, set is_dynamic_op=True.")) def _check_trt_version_compatibility(): @@ -320,6 +333,8 @@ def get_tensorrt_rewriter_config(conversion_params, optimizer.parameter_map[ "use_calibration"].b = conversion_params.use_calibration optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op + optimizer.parameter_map[ + "allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime if not is_v2: optimizer.parameter_map[ "max_batch_size"].i = conversion_params.max_batch_size @@ -505,7 +520,8 @@ class TrtGraphConverter(object): is_dynamic_op=is_dynamic_op, maximum_cached_engines=maximum_cached_engines, use_calibration=use_calibration, - max_batch_size=max_batch_size) + max_batch_size=max_batch_size, + allow_build_at_runtime=True) _check_conversion_params(self._conversion_params) def _run_conversion(self): @@ -1165,6 +1181,25 @@ class TrtGraphConverterV2(object): signatures = { key: value for key, value in self._saved_model.signatures.items() } + + # Set allow_build_at_runtime=False if asked by user. + # This attribute is set here because build() needs it to be True + # in order to build engines. + if not self._conversion_params.allow_build_at_runtime: + def _reset_allow_build_at_runtime(node): + node.attr["allow_build_at_runtime"].b = False + self._for_each_trt_node(self._converted_graph_def, + _reset_allow_build_at_runtime) + # Rebuild the function since a node attribute changed above + reset_converted_func = wrap_function.function_from_graph_def( + self._converted_graph_def, + [tensor.name for tensor in self._converted_func.inputs], + [tensor.name for tensor in self._converted_func.outputs]) + reset_converted_func.graph.structured_outputs = nest.pack_sequence_as( + self._converted_func.graph.structured_outputs, + reset_converted_func.graph.structured_outputs) + self._converted_func = reset_converted_func + signatures[self._input_saved_model_signature_key] = self._converted_func save.save(self._saved_model, output_saved_model_dir, signatures) From 3050e7ddd10ad1b09dca3b30d6fcf2441ca6cf4f Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Mon, 17 Feb 2020 21:08:06 +0100 Subject: [PATCH 2/2] Fix bad_function_call --- tensorflow/core/kernels/ops_testutil.cc | 6 +++++- tensorflow/core/kernels/ops_testutil.h | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc index 3dab8bf2f50..614e184b0b2 100644 --- a/tensorflow/core/kernels/ops_testutil.cc +++ b/tensorflow/core/kernels/ops_testutil.cc @@ -71,6 +71,9 @@ OpsTestBase::OpsTestBase() : device_type_(DEVICE_CPU) { auto device = DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); CHECK(device) << "Could not create CPU device"; + thread_pool_ = absl::make_unique( + Env::Default(), /*name=*/"default", /*num_threads=*/1); + device_ = device.get(); device_mgr_ = absl::make_unique(std::move(device)); @@ -104,7 +107,8 @@ void OpsTestBase::SetDevice(const DeviceType& device_type, device_mgr_ = absl::make_unique(std::move(device)); pflr_ = absl::make_unique( device_mgr_.get(), Env::Default(), /*config=*/nullptr, - TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions()); + TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), + thread_pool_.get()); device_type_ = device_type; #ifdef GOOGLE_CUDA diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h index ab7b994d9d2..f6821e3c49c 100644 --- a/tensorflow/core/kernels/ops_testutil.h +++ b/tensorflow/core/kernels/ops_testutil.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" @@ -183,6 +184,7 @@ class OpsTestBase : public ::testing::Test { std::unique_ptr flib_def_; std::unique_ptr pflr_; + std::unique_ptr thread_pool_; private: TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);