Merge pull request #36852 from tfeher:allow_build_at_runtime
PiperOrigin-RevId: 296283355 Change-Id: I1161707e70e26c1f348104958bbc08cd6354e7b2
This commit is contained in:
commit
2a91c19008
@ -189,6 +189,8 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:array",
|
||||
] + if_tensorrt([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
]),
|
||||
|
@ -469,6 +469,7 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
.Attr("precision_mode", prec_string)
|
||||
.Attr("use_calibration", info.use_calibration)
|
||||
.Attr("_use_implicit_batch", params.use_implicit_batch)
|
||||
.Attr("_allow_build_at_runtime", info.allow_build_at_runtime)
|
||||
.Attr("OutT", out_types)
|
||||
.Finalize(&trt_node);
|
||||
if (!status.ok()) {
|
||||
@ -672,6 +673,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
: EngineInfo::EngineType::TRTStatic);
|
||||
curr_engine.use_calibration = params.use_calibration;
|
||||
curr_engine.maximum_cached_engines = params.max_cached_engines;
|
||||
curr_engine.allow_build_at_runtime = params.allow_build_at_runtime;
|
||||
|
||||
status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
|
||||
&graph, curr_engine.engine_name);
|
||||
|
@ -49,6 +49,7 @@ struct ConversionParams {
|
||||
int max_cached_engines = 1;
|
||||
bool use_calibration = true;
|
||||
bool use_implicit_batch = true;
|
||||
bool allow_build_at_runtime = true;
|
||||
};
|
||||
|
||||
// Method to call from optimization pass
|
||||
|
@ -93,7 +93,8 @@ struct EngineInfo {
|
||||
: engine_type(EngineType::TRTStatic),
|
||||
max_workspace_size_bytes(0),
|
||||
precision_mode(TrtPrecisionMode::FP32),
|
||||
use_calibration(true) {}
|
||||
use_calibration(true),
|
||||
allow_build_at_runtime(true) {}
|
||||
|
||||
string engine_name;
|
||||
string device;
|
||||
@ -110,6 +111,7 @@ struct EngineInfo {
|
||||
int maximum_cached_engines;
|
||||
TrtPrecisionMode precision_mode;
|
||||
bool use_calibration;
|
||||
bool allow_build_at_runtime;
|
||||
};
|
||||
|
||||
// Constructs a graphdef from the segment in the given graph. Adds _Arg
|
||||
|
@ -70,6 +70,9 @@ Status TRTOptimizationPass::Init(
|
||||
if (params.count("trt_logger")) {
|
||||
trt_logger_name_ = params.at("trt_logger").s();
|
||||
}
|
||||
if (params.count("allow_build_at_runtime")) {
|
||||
allow_build_at_runtime_ = params.at("allow_build_at_runtime").b();
|
||||
}
|
||||
if (params.count("use_implicit_batch")) {
|
||||
use_implicit_batch_ = params.at("use_implicit_batch").b();
|
||||
}
|
||||
@ -265,6 +268,7 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
|
||||
cp.max_cached_engines = max_cached_batches_;
|
||||
cp.use_calibration = use_calibration_;
|
||||
cp.use_implicit_batch = use_implicit_batch_;
|
||||
cp.allow_build_at_runtime = allow_build_at_runtime_;
|
||||
auto status = ConvertAfterShapes(cp);
|
||||
VLOG(1) << "Returning from " << name_;
|
||||
return status;
|
||||
|
@ -42,7 +42,8 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
||||
max_cached_batches_(1),
|
||||
max_workspace_size_bytes_(256LL << 20),
|
||||
use_calibration_(true),
|
||||
use_implicit_batch_(true) {
|
||||
use_implicit_batch_(true),
|
||||
allow_build_at_runtime_(true) {
|
||||
VLOG(1) << "Constructing " << name_;
|
||||
}
|
||||
|
||||
@ -75,6 +76,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
||||
int64_t max_workspace_size_bytes_;
|
||||
bool use_calibration_;
|
||||
bool use_implicit_batch_;
|
||||
bool allow_build_at_runtime_;
|
||||
};
|
||||
|
||||
} // namespace convert
|
||||
|
@ -159,6 +159,9 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
// Whether to use implicit batch dimension for TensorRT
|
||||
bool use_implicit_batch_;
|
||||
|
||||
// Whether to build TensorRT engines at runtime
|
||||
bool allow_build_at_runtime_;
|
||||
|
||||
// Maximum number of cached engines
|
||||
int max_cached_engines_;
|
||||
|
||||
@ -283,6 +286,14 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
context->GetAttr("use_calibration", &use_calibration_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("input_shapes", &input_partial_shapes_));
|
||||
auto status =
|
||||
context->GetAttr("_allow_build_at_runtime", &allow_build_at_runtime_);
|
||||
if (status.code() == tensorflow::error::NOT_FOUND) {
|
||||
VLOG(2) << "Not found _allow_build_at_runtime in "
|
||||
<< context->device()->name()
|
||||
<< ", thus setting _allow_build_at_runtime=true";
|
||||
allow_build_at_runtime_ = true;
|
||||
}
|
||||
func_handle_ = kInvalidHandle;
|
||||
if (!static_engine_) {
|
||||
FunctionLibraryRuntime* lib = context->function_library();
|
||||
@ -304,7 +315,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
|
||||
&max_cached_engines_));
|
||||
|
||||
auto status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_);
|
||||
status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_);
|
||||
if (status.code() == tensorflow::error::NOT_FOUND) {
|
||||
VLOG(2) << "Not found _use_implicit_batch in " << context->device()->name()
|
||||
<< ", thus setting _use_implicit_batch=true";
|
||||
@ -979,6 +990,16 @@ StatusOr<EngineContext*> TRTEngineOp::GetEngine(
|
||||
// If matched, use that engine. Otherwise, we will look in cache for that
|
||||
// exact shape and possibly create a new engine if it is not in cache.
|
||||
if (!cache.count(engine_input_shapes)) {
|
||||
if (!allow_build_at_runtime_) {
|
||||
LOG(WARNING) << "Found no engine in cache matching input shapes. "
|
||||
<< "Not building a new engine because "
|
||||
<< "allow_build_at_runtime=False. "
|
||||
<< "The native segment will be used instead.";
|
||||
// Store an empty engine in the cache for these input shapes so we don't
|
||||
// try to build the same failing engine again.
|
||||
cache.emplace(engine_input_shapes, absl::make_unique<EngineContext>());
|
||||
return &empty_context;
|
||||
}
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
|
||||
bool convert_successfully = false;
|
||||
LOG(INFO) << "Building a new TensorRT engine for " << name()
|
||||
|
@ -62,7 +62,8 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
||||
public:
|
||||
void AddSimpleTrtOp(DataType dtype, int max_cached_engines_count = 1,
|
||||
PartialTensorShape shape = PartialTensorShape({-1, -1}),
|
||||
bool use_implicit_batch = true) {
|
||||
bool use_implicit_batch = true,
|
||||
bool allow_build_at_runtime = true) {
|
||||
// Create the GPU device.
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
|
||||
@ -104,6 +105,7 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
||||
.Attr("precision_mode", "FP32")
|
||||
.Attr("use_calibration", false)
|
||||
.Attr("_use_implicit_batch", use_implicit_batch)
|
||||
.Attr("_allow_build_at_runtime", allow_build_at_runtime)
|
||||
.Attr("OutT", {dtype})
|
||||
.Finalize(OpsTestBase::node_def()));
|
||||
TF_ASSERT_OK(InitOpWithFunctionLibrary());
|
||||
@ -191,6 +193,32 @@ TEST_F(TRTEngineOpTestBase, DynamicEngines) {
|
||||
EXPECT_EQ(1, cache->count({TensorShape({10, 10})}));
|
||||
}
|
||||
|
||||
TEST_F(TRTEngineOpTestBase, AllowBuildAtRuntime) {
|
||||
TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/1,
|
||||
PartialTensorShape({-1, -1}),
|
||||
/*use_implicit_batch=*/true,
|
||||
/*allow_build_at_runtime=*/false);
|
||||
|
||||
// Execute the op
|
||||
TensorShape input_shape({2, 2});
|
||||
TRTEngineOpTestBase::AddSimpleInput<float>(input_shape);
|
||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||
|
||||
// Get the engine cache.
|
||||
TRTEngineCacheResource* cache_resource = nullptr;
|
||||
TF_ASSERT_OK(
|
||||
device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource));
|
||||
core::ScopedUnref sc(cache_resource);
|
||||
|
||||
// It should contain a placeholder with an empty cuda_engine (to mark that
|
||||
// engine creation was not successful for the given input shape).
|
||||
auto cache = &cache_resource->cache_;
|
||||
EXPECT_EQ(1, cache->size());
|
||||
ASSERT_EQ(1, cache->count({input_shape}));
|
||||
EngineContext* ectx = cache->at({input_shape}).get();
|
||||
EXPECT_EQ(ectx->cuda_engine, nullptr);
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
TEST_F(TRTEngineOpTestBase, ExplicitBatch) {
|
||||
// Test inference in explicit batch mode with static input shapes. Static
|
||||
|
@ -71,6 +71,9 @@ OpsTestBase::OpsTestBase() : device_type_(DEVICE_CPU) {
|
||||
auto device = DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
|
||||
CHECK(device) << "Could not create CPU device";
|
||||
|
||||
thread_pool_ = absl::make_unique<thread::ThreadPool>(
|
||||
Env::Default(), /*name=*/"default", /*num_threads=*/1);
|
||||
|
||||
device_ = device.get();
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
|
||||
|
||||
@ -104,7 +107,8 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions());
|
||||
TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(),
|
||||
thread_pool_.get());
|
||||
|
||||
device_type_ = device_type;
|
||||
#ifdef GOOGLE_CUDA
|
||||
|
@ -49,6 +49,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
@ -183,6 +184,7 @@ class OpsTestBase : public ::testing::Test {
|
||||
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);
|
||||
|
@ -113,10 +113,13 @@ DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
||||
|
||||
|
||||
@tf_export("experimental.tensorrt.ConversionParams", v1=[])
|
||||
class TrtConversionParams(collections.namedtuple("TrtConversionParams", [
|
||||
"rewriter_config_template", "max_workspace_size_bytes", "precision_mode",
|
||||
"minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
|
||||
"use_calibration", "max_batch_size"])):
|
||||
class TrtConversionParams(
|
||||
collections.namedtuple("TrtConversionParams", [
|
||||
"rewriter_config_template", "max_workspace_size_bytes",
|
||||
"precision_mode", "minimum_segment_size", "is_dynamic_op",
|
||||
"maximum_cached_engines", "use_calibration", "max_batch_size",
|
||||
"allow_build_at_runtime"
|
||||
])):
|
||||
"""Parameters that are used for TF-TRT conversion.
|
||||
|
||||
Fields:
|
||||
@ -151,6 +154,11 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [
|
||||
tensors were trained with fake quantization.
|
||||
max_batch_size: max size for the input batch. This parameter is only
|
||||
effective when is_dynamic_op=False which is not supported in TF 2.0.
|
||||
allow_build_at_runtime: whether to build TensorRT engines during runtime.
|
||||
If no TensorRT engine can be found in cache that can handle the given
|
||||
inputs during runtime, then a new TensorRT engine is built at runtime if
|
||||
allow_build_at_runtime=True, and otherwise native TF is used. This
|
||||
argument is only effective if is_dynamic_op=True.
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
@ -161,11 +169,14 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [
|
||||
is_dynamic_op=True,
|
||||
maximum_cached_engines=1,
|
||||
use_calibration=True,
|
||||
max_batch_size=1):
|
||||
return super(TrtConversionParams, cls).__new__(
|
||||
cls, rewriter_config_template, max_workspace_size_bytes, precision_mode,
|
||||
minimum_segment_size, is_dynamic_op, maximum_cached_engines,
|
||||
use_calibration, max_batch_size)
|
||||
max_batch_size=1,
|
||||
allow_build_at_runtime=True):
|
||||
return super(TrtConversionParams,
|
||||
cls).__new__(cls, rewriter_config_template,
|
||||
max_workspace_size_bytes, precision_mode,
|
||||
minimum_segment_size, is_dynamic_op,
|
||||
maximum_cached_engines, use_calibration,
|
||||
max_batch_size, allow_build_at_runtime)
|
||||
|
||||
|
||||
DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams()
|
||||
@ -228,6 +239,13 @@ def _check_conversion_params(conversion_params, is_v2=False):
|
||||
not trt_optimizer.parameter_map["is_dynamic_op"]):
|
||||
raise ValueError("Option is_dynamic_op=False is not supported "
|
||||
"in TF 2.0, please set it to True instead.")
|
||||
if (conversion_params.allow_build_at_runtime and
|
||||
not conversion_params.is_dynamic_op):
|
||||
tf_logging.warn(
|
||||
("Building TensorRT engines at runtime is not supported "
|
||||
"if is_dynamic_op=False, therefore assuming "
|
||||
"allow_build_at_runtime=False. If building TensorRT engines "
|
||||
"at runtime is desired, set is_dynamic_op=True."))
|
||||
|
||||
|
||||
def _check_trt_version_compatibility():
|
||||
@ -320,6 +338,8 @@ def get_tensorrt_rewriter_config(conversion_params,
|
||||
optimizer.parameter_map[
|
||||
"use_calibration"].b = conversion_params.use_calibration
|
||||
optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op
|
||||
optimizer.parameter_map[
|
||||
"allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime
|
||||
if not is_v2:
|
||||
optimizer.parameter_map[
|
||||
"max_batch_size"].i = conversion_params.max_batch_size
|
||||
@ -505,7 +525,8 @@ class TrtGraphConverter(object):
|
||||
is_dynamic_op=is_dynamic_op,
|
||||
maximum_cached_engines=maximum_cached_engines,
|
||||
use_calibration=use_calibration,
|
||||
max_batch_size=max_batch_size)
|
||||
max_batch_size=max_batch_size,
|
||||
allow_build_at_runtime=True)
|
||||
_check_conversion_params(self._conversion_params)
|
||||
|
||||
def _run_conversion(self):
|
||||
@ -1165,6 +1186,28 @@ class TrtGraphConverterV2(object):
|
||||
signatures = {
|
||||
key: value for key, value in self._saved_model.signatures.items()
|
||||
}
|
||||
|
||||
# Set allow_build_at_runtime=False if asked by user.
|
||||
#
|
||||
# This attribute is set here because build() needs it to be True in order to
|
||||
# build engines.
|
||||
if not self._conversion_params.allow_build_at_runtime:
|
||||
|
||||
def _reset_allow_build_at_runtime(node):
|
||||
node.attr["allow_build_at_runtime"].b = False
|
||||
|
||||
self._for_each_trt_node(self._converted_graph_def,
|
||||
_reset_allow_build_at_runtime)
|
||||
# Rebuild the function since a node attribute changed above
|
||||
reset_converted_func = wrap_function.function_from_graph_def(
|
||||
self._converted_graph_def,
|
||||
[tensor.name for tensor in self._converted_func.inputs],
|
||||
[tensor.name for tensor in self._converted_func.outputs])
|
||||
reset_converted_func.graph.structured_outputs = nest.pack_sequence_as(
|
||||
self._converted_func.graph.structured_outputs,
|
||||
reset_converted_func.graph.structured_outputs)
|
||||
self._converted_func = reset_converted_func
|
||||
|
||||
signatures[self._input_saved_model_signature_key] = self._converted_func
|
||||
save.save(self._saved_model, output_saved_model_dir, signatures)
|
||||
|
||||
|
@ -3,6 +3,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.compiler.tensorrt.trt_convert.TrtConversionParams\'>"
|
||||
is_instance: "<class \'tensorflow.python.compiler.tensorrt.trt_convert.TrtConversionParams\'>"
|
||||
is_instance: "<type \'tuple\'>"
|
||||
member {
|
||||
name: "allow_build_at_runtime"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "is_dynamic_op"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
x
Reference in New Issue
Block a user