Add TRT engine resource ops for serialization/deserialization of TRT engine
cache. PiperOrigin-RevId: 243750961
This commit is contained in:
parent
178fa6b8b9
commit
7109ac9ac6
tensorflow/compiler/tf2tensorrt
@ -23,9 +23,13 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load(
|
||||
"@local_config_tensorrt//:build_defs.bzl",
|
||||
"if_tensorrt",
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_additional_all_protos",
|
||||
"tf_proto_library",
|
||||
)
|
||||
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
|
||||
|
||||
# Google-internal targets go here (must be at the end).
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "tensorrt_test_cc",
|
||||
@ -74,13 +78,67 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "trt_engine_resource_op_kernels",
|
||||
srcs = ["kernels/trt_engine_resource_ops.cc"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":trt_allocator",
|
||||
":trt_engine_instance_proto_cc",
|
||||
":trt_logging",
|
||||
":trt_plugins",
|
||||
":trt_resources",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_headers_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "trt_engine_resource_ops_test",
|
||||
size = "small",
|
||||
srcs = ["kernels/trt_engine_resource_ops_test.cc"],
|
||||
tags = [
|
||||
"no_cuda_on_cpu_tap",
|
||||
"no_windows",
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
":trt_engine_instance_proto_cc",
|
||||
":trt_engine_resource_op_kernels",
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
":trt_logging",
|
||||
":trt_resources",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
name = "python/ops/libtftrt.so",
|
||||
copts = tf_copts(is_external = True),
|
||||
linkopts = ["-lm"],
|
||||
deps = [
|
||||
":trt_op_kernels",
|
||||
":trt_engine_resource_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
@ -145,6 +203,7 @@ tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"trt_engine_op",
|
||||
"get_serialized_resource_op",
|
||||
"trt_engine_resource_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -171,6 +230,7 @@ tf_cuda_library(
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "trt_ops",
|
||||
deps = [
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
":trt_op_libs",
|
||||
],
|
||||
)
|
||||
@ -185,7 +245,9 @@ tf_custom_op_py_library(
|
||||
]),
|
||||
kernels = [
|
||||
":trt_op_kernels",
|
||||
":trt_engine_resource_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
@ -470,6 +532,13 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "trt_engine_instance_proto",
|
||||
srcs = ["utils/trt_engine_instance.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = tf_additional_all_protos(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_utils",
|
||||
srcs = ["utils/py_utils.cc"],
|
||||
|
@ -0,0 +1,223 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/io/record_reader.h"
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
using ::nvinfer1::IRuntime;
|
||||
|
||||
class CreateTRTEngineCache : public OpKernel {
|
||||
public:
|
||||
explicit CreateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
VLOG(1) << "Creating TRT engine cache resource in container " << container_
|
||||
<< " for op " << resource_name_ << " on device "
|
||||
<< ctx->device()->name();
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->resource_manager()->Create(
|
||||
container_, resource_name_,
|
||||
new TRTEngineCacheResource(ctx, max_cached_engines_)));
|
||||
|
||||
Tensor* handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
||||
handle->scalar<ResourceHandle>()() =
|
||||
MakeResourceHandle<TRTEngineCacheResource>(ctx, container_,
|
||||
resource_name_);
|
||||
}
|
||||
|
||||
private:
|
||||
string container_;
|
||||
string resource_name_;
|
||||
|
||||
// Maximum number of cached engines
|
||||
int max_cached_engines_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTEngineCache);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateTRTEngineCache")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("engine_cache_handle"),
|
||||
CreateTRTEngineCache);
|
||||
|
||||
class PopulateTRTEngineCache : public OpKernel {
|
||||
public:
|
||||
explicit PopulateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
ResourceHandle handle = HandleFromInput(ctx, 0);
|
||||
TRTEngineCacheResource* resource = nullptr;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource));
|
||||
core::ScopedUnref unref_me(resource);
|
||||
|
||||
auto allocator = resource->allocator_.get();
|
||||
OP_REQUIRES(ctx, allocator != nullptr,
|
||||
errors::Internal("Not able to initialize TRT engine cache when "
|
||||
"GPU allocator is empty."));
|
||||
OP_REQUIRES(ctx, resource->cache_.size() == 0,
|
||||
errors::Internal("Expect engine cache to be empty, but got ",
|
||||
resource->cache_.size(), " entries."));
|
||||
|
||||
// Get the file name.
|
||||
const string& filename = ctx->input(1).scalar<string>()();
|
||||
OP_REQUIRES(ctx, !filename.empty(),
|
||||
errors::InvalidArgument("filename cannot be empty."));
|
||||
|
||||
// Parse the serialized engines and add them to the cache.
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &file));
|
||||
auto reader = absl::make_unique<io::RecordReader>(file.get());
|
||||
|
||||
uint64 offset = 0;
|
||||
int num_loaded_engine = 0;
|
||||
do {
|
||||
string record;
|
||||
Status status = reader->ReadRecord(&offset, &record);
|
||||
if (errors::IsOutOfRange(status)) break;
|
||||
|
||||
TRTEngineInstance engine_instance;
|
||||
engine_instance.ParseFromString(record);
|
||||
std::vector<TensorShape> engine_input_shapes;
|
||||
for (const TensorShapeProto& shape : engine_instance.input_shapes()) {
|
||||
engine_input_shapes.emplace_back(shape);
|
||||
}
|
||||
|
||||
TrtUniquePtrType<IRuntime> infer(
|
||||
nvinfer1::createInferRuntime(TRTEngineCacheResource::GetLogger()));
|
||||
infer->setGpuAllocator(allocator);
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine(
|
||||
infer->deserializeCudaEngine(
|
||||
engine_instance.serialized_engine().c_str(),
|
||||
engine_instance.serialized_engine().size(),
|
||||
PluginFactoryTensorRT::GetInstance()));
|
||||
auto raw_engine = engine.get();
|
||||
resource->cache_.emplace(
|
||||
engine_input_shapes,
|
||||
absl::make_unique<EngineContext>(
|
||||
std::move(engine), TrtUniquePtrType<nvinfer1::IExecutionContext>(
|
||||
raw_engine->createExecutionContext())));
|
||||
++num_loaded_engine;
|
||||
} while (1);
|
||||
VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines to container "
|
||||
<< handle.container() << " for op " << handle.name()
|
||||
<< " on device " << ctx->device()->name() << " from file "
|
||||
<< filename;
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PopulateTRTEngineCache);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("PopulateTRTEngineCache")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("engine_cache_handle"),
|
||||
PopulateTRTEngineCache);
|
||||
|
||||
class DumpTRTEngineCache : public OpKernel {
|
||||
public:
|
||||
explicit DumpTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_cache_after_dump",
|
||||
&delete_cache_after_dump_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const string& container = ctx->input(0).scalar<string>()();
|
||||
const string& resource_name = ctx->input(1).scalar<string>()();
|
||||
const string& filename = ctx->input(2).scalar<string>()();
|
||||
OP_REQUIRES(ctx, !filename.empty(),
|
||||
errors::InvalidArgument("filename cannot be empty."));
|
||||
|
||||
TRTEngineCacheResource* resource = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(
|
||||
container, resource_name, &resource));
|
||||
core::ScopedUnref unref_me(resource);
|
||||
|
||||
// Serialize the engines and write them to file.
|
||||
std::unique_ptr<WritableFile> file;
|
||||
OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file));
|
||||
auto writer = absl::make_unique<io::RecordWriter>(file.get());
|
||||
|
||||
for (const auto& pair : resource->cache_) {
|
||||
TRTEngineInstance engine_instance;
|
||||
// Add input shapes.
|
||||
const std::vector<TensorShape>& engine_input_shapes = pair.first;
|
||||
for (const TensorShape& shape : engine_input_shapes) {
|
||||
shape.AsProto(engine_instance.add_input_shapes());
|
||||
}
|
||||
// Add the serialized engine.
|
||||
const std::unique_ptr<EngineContext>& engine = pair.second;
|
||||
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(
|
||||
engine->cuda_engine->serialize());
|
||||
engine_instance.set_serialized_engine(engine_data->data(),
|
||||
engine_data->size());
|
||||
|
||||
OP_REQUIRES_OK(ctx,
|
||||
writer->WriteRecord(engine_instance.SerializeAsString()));
|
||||
}
|
||||
VLOG(1) << "Serialized " << resource->cache_.size()
|
||||
<< " TRT engines in container " << container << " for op "
|
||||
<< resource_name << " on device " << ctx->device()->name()
|
||||
<< " to file " << filename;
|
||||
|
||||
if (delete_cache_after_dump_) {
|
||||
VLOG(1) << "Destroying TRT engine cache resource in container "
|
||||
<< container << " for op " << resource_name << " on device "
|
||||
<< ctx->device()->name();
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->resource_manager()->Delete<TRTEngineCacheResource>(
|
||||
container, resource_name));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool delete_cache_after_dump_ = false;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DumpTRTEngineCache);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DumpTRTEngineCache").Device(DEVICE_GPU),
|
||||
DumpTRTEngineCache);
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -0,0 +1,205 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <dirent.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/io/record_reader.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
class TRTEngineResourceOpsTest : public OpsTestBase {
|
||||
protected:
|
||||
void Reset() {
|
||||
inputs_.clear();
|
||||
gtl::STLDeleteElements(&tensors_);
|
||||
gtl::STLDeleteElements(&managed_outputs_);
|
||||
}
|
||||
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> CreateTRTEngine() {
|
||||
Logger logger;
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder(
|
||||
nvinfer1::createInferBuilder(logger));
|
||||
TrtUniquePtrType<nvinfer1::INetworkDefinition> network(
|
||||
builder->createNetwork());
|
||||
|
||||
// Add the input.
|
||||
nvinfer1::Dims dims;
|
||||
dims.nbDims = 1;
|
||||
dims.d[0] = 1;
|
||||
nvinfer1::ITensor* input =
|
||||
network->addInput("input", nvinfer1::DataType::kFLOAT, dims);
|
||||
EXPECT_NE(nullptr, input);
|
||||
|
||||
// Add a unary layer.
|
||||
nvinfer1::IUnaryLayer* layer =
|
||||
network->addUnary(*input, nvinfer1::UnaryOperation::kEXP);
|
||||
EXPECT_NE(nullptr, layer);
|
||||
|
||||
// Mark the output.
|
||||
nvinfer1::ITensor* output = layer->getOutput(0);
|
||||
output->setName("output");
|
||||
network->markOutput(*output);
|
||||
|
||||
// Build the engine
|
||||
builder->setMaxBatchSize(1);
|
||||
builder->setMaxWorkspaceSize(1 << 10);
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine(
|
||||
builder->buildCudaEngine(*network));
|
||||
EXPECT_NE(nullptr, engine);
|
||||
return engine;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TRTEngineResourceOpsTest, Basic) {
|
||||
// Create the GPU device.
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
|
||||
ResourceMgr* rm = device->resource_manager();
|
||||
SetDevice(DEVICE_GPU, std::move(device));
|
||||
|
||||
// Create the resource.
|
||||
const string container = "mycontainer";
|
||||
const string resource_name = "myresource";
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache")
|
||||
.Attr("container", container)
|
||||
.Attr("resource_name", resource_name)
|
||||
.Attr("max_cached_engines_count", 1)
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
ResourceHandle handle =
|
||||
context_->mutable_output(0)->scalar<ResourceHandle>()();
|
||||
|
||||
TRTEngineCacheResource* resource = nullptr;
|
||||
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
|
||||
|
||||
// Create a serialized TRT engine file.
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine = CreateTRTEngine();
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> context(
|
||||
engine->createExecutionContext());
|
||||
resource->cache_.emplace(
|
||||
std::vector<TensorShape>{TensorShape({1, 1})},
|
||||
absl::make_unique<EngineContext>(std::move(engine), std::move(context)));
|
||||
resource->Unref();
|
||||
|
||||
// Serialize the engine using DumpTRTEngineCache op.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "DumpTRTEngineCache")
|
||||
.Attr("delete_cache_after_dump", true)
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<string>(TensorShape({}), {container});
|
||||
AddInputFromArray<string>(TensorShape({}), {resource_name});
|
||||
const string filename = io::JoinPath(testing::TmpDir(), "trt_engine_file");
|
||||
AddInputFromArray<string>(TensorShape({}), {filename});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Make sure the cache is deleted.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp")
|
||||
.Attr("ignore_lookup_error", false)
|
||||
.Input(FakeInput(DT_RESOURCE))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
|
||||
EXPECT_TRUE(errors::IsNotFound(RunOpKernel()));
|
||||
|
||||
// Verify the serialized engine file.
|
||||
Env* env = Env::Default();
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_ASSERT_OK(env->NewRandomAccessFile(filename, &file));
|
||||
auto reader = absl::make_unique<io::RecordReader>(file.get());
|
||||
uint64 offset = 0;
|
||||
string record;
|
||||
TF_ASSERT_OK(reader->ReadRecord(&offset, &record));
|
||||
TRTEngineInstance engine_instance;
|
||||
engine_instance.ParseFromString(record);
|
||||
EXPECT_EQ(1, engine_instance.input_shapes_size());
|
||||
EXPECT_EQ(2, engine_instance.input_shapes(0).dim_size());
|
||||
EXPECT_EQ(1, engine_instance.input_shapes(0).dim(0).size());
|
||||
EXPECT_EQ(1, engine_instance.input_shapes(0).dim(1).size());
|
||||
EXPECT_TRUE(errors::IsOutOfRange(reader->ReadRecord(&offset, &record)));
|
||||
|
||||
// Recreate the cache resource.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache")
|
||||
.Attr("container", container)
|
||||
.Attr("resource_name", resource_name)
|
||||
.Attr("max_cached_engines_count", 1)
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
handle = context_->mutable_output(0)->scalar<ResourceHandle>()();
|
||||
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
|
||||
EXPECT_EQ(0, resource->cache_.size());
|
||||
resource->Unref();
|
||||
|
||||
// Deserialize the engine using PopulateTRTEngineCache op.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "PopulateTRTEngineCache")
|
||||
.Input(FakeInput(DT_RESOURCE))
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
|
||||
AddInputFromArray<string>(TensorShape({}), {filename});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
|
||||
EXPECT_EQ(1, resource->cache_.size());
|
||||
resource->Unref();
|
||||
|
||||
// Destroy the engine cache again.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp")
|
||||
.Attr("ignore_lookup_error", false)
|
||||
.Input(FakeInput(DT_RESOURCE))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
EXPECT_TRUE(errors::IsNotFound(RunOpKernel()));
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -0,0 +1,52 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("CreateTRTEngineCache")
|
||||
.Attr("container: string")
|
||||
.Attr("resource_name: string")
|
||||
.Attr("max_cached_engines_count: int = 1")
|
||||
.Output("engine_cache_handle: resource")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("PopulateTRTEngineCache")
|
||||
.Input("engine_cache_handle: resource")
|
||||
.Input("filename: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("DumpTRTEngineCache")
|
||||
.Attr("delete_cache_after_dump: bool = false")
|
||||
.Input("container: string")
|
||||
.Input("resource_name: string")
|
||||
.Input("filename: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -0,0 +1,19 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.tensorrt;
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
|
||||
// Containing information for a serialized TensorRT engine.
|
||||
message TRTEngineInstance {
|
||||
// The input shapes of the TRT engine.
|
||||
repeated TensorShapeProto input_shapes = 1;
|
||||
|
||||
// The serialized TRT engine.
|
||||
//
|
||||
// TODO(laigd): consider using a more efficient in-memory representation
|
||||
// instead of string which is the default here.
|
||||
bytes serialized_engine = 2;
|
||||
|
||||
// TODO(laigd): consider adding calibration stats, precision_modes, etc.
|
||||
}
|
@ -30,6 +30,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
Logger& TRTEngineCacheResource::GetLogger() {
|
||||
static Logger* logger = new Logger();
|
||||
return *logger;
|
||||
}
|
||||
|
||||
TRTEngineCacheResource::TRTEngineCacheResource(OpKernelContext* ctx,
|
||||
size_t capacity)
|
||||
: cache_(capacity) {
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
@ -141,6 +142,13 @@ struct EngineContext {
|
||||
|
||||
class TRTEngineCacheResource : public ResourceBase {
|
||||
public:
|
||||
// According to the TensorRT API, the logger is considered a singleton by the
|
||||
// TensorRT library, and multiple instances of IRuntime and/or IBuilder must
|
||||
// all use the same logger. So here we make it a singleton.
|
||||
//
|
||||
// TODO(laigd): use this logger in all places where conversion happens.
|
||||
static Logger& GetLogger();
|
||||
|
||||
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity);
|
||||
|
||||
~TRTEngineCacheResource() override;
|
||||
|
Loading…
Reference in New Issue
Block a user