From 7109ac9ac6d4cfdd249e148793fa7e3e43e640a9 Mon Sep 17 00:00:00 2001 From: Guangda Lai <laigd@google.com> Date: Mon, 15 Apr 2019 23:03:12 -0700 Subject: [PATCH] Add TRT engine resource ops for serialization/deserialization of TRT engine cache. PiperOrigin-RevId: 243750961 --- tensorflow/compiler/tf2tensorrt/BUILD | 73 +++++- .../kernels/trt_engine_resource_ops.cc | 223 ++++++++++++++++++ .../kernels/trt_engine_resource_ops_test.cc | 205 ++++++++++++++++ .../ops/trt_engine_resource_ops.cc | 52 ++++ .../utils/trt_engine_instance.proto | 19 ++ .../tf2tensorrt/utils/trt_lru_cache.cc | 5 + .../tf2tensorrt/utils/trt_lru_cache.h | 8 + 7 files changed, 583 insertions(+), 2 deletions(-) create mode 100644 tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc create mode 100644 tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc create mode 100644 tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc create mode 100644 tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.proto diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index c51dae1b4b7..6d27a5012bf 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -23,9 +23,13 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( - "@local_config_tensorrt//:build_defs.bzl", - "if_tensorrt", + "//tensorflow/core:platform/default/build_config.bzl", + "tf_additional_all_protos", + "tf_proto_library", ) +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") + +# Google-internal targets go here (must be at the end). tf_cuda_cc_test( name = "tensorrt_test_cc", @@ -74,13 +78,67 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "trt_engine_resource_op_kernels", + srcs = ["kernels/trt_engine_resource_ops.cc"], + copts = tf_copts(), + visibility = ["//visibility:private"], + deps = [ + ":trt_allocator", + ":trt_engine_instance_proto_cc", + ":trt_logging", + ":trt_plugins", + ":trt_resources", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), + alwayslink = 1, +) + +tf_cuda_cc_test( + name = "trt_engine_resource_ops_test", + size = "small", + srcs = ["kernels/trt_engine_resource_ops_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_engine_instance_proto_cc", + ":trt_engine_resource_op_kernels", + ":trt_engine_resource_ops_op_lib", + ":trt_logging", + ":trt_resources", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:resource_variable_ops", + "@com_google_absl//absl/memory", + ], +) + tf_cc_shared_object( name = "python/ops/libtftrt.so", copts = tf_copts(is_external = True), linkopts = ["-lm"], deps = [ ":trt_op_kernels", + ":trt_engine_resource_op_kernels", ":trt_op_libs", + ":trt_engine_resource_ops_op_lib", "//tensorflow/core:lib_proto_parsing", ] + if_tensorrt([ "@local_config_tensorrt//:tensorrt", @@ -145,6 +203,7 @@ tf_gen_op_libs( op_lib_names = [ "trt_engine_op", "get_serialized_resource_op", + "trt_engine_resource_ops", ], ) @@ -171,6 +230,7 @@ tf_cuda_library( tf_gen_op_wrapper_py( name = "trt_ops", deps = [ + ":trt_engine_resource_ops_op_lib", ":trt_op_libs", ], ) @@ -185,7 +245,9 @@ tf_custom_op_py_library( ]), kernels = [ ":trt_op_kernels", + ":trt_engine_resource_op_kernels", ":trt_op_libs", + ":trt_engine_resource_ops_op_lib", ], srcs_version = "PY2AND3", deps = [ @@ -470,6 +532,13 @@ cc_library( ], ) +tf_proto_library( + name = "trt_engine_instance_proto", + srcs = ["utils/trt_engine_instance.proto"], + cc_api_version = 2, + protodeps = tf_additional_all_protos(), +) + cc_library( name = "py_utils", srcs = ["utils/py_utils.cc"], diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc new file mode 100644 index 00000000000..06031bd731d --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc @@ -0,0 +1,223 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <memory> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/platform/logging.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +using ::nvinfer1::IRuntime; + +class CreateTRTEngineCache : public OpKernel { + public: + explicit CreateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_)); + } + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "Creating TRT engine cache resource in container " << container_ + << " for op " << resource_name_ << " on device " + << ctx->device()->name(); + OP_REQUIRES_OK(ctx, + ctx->resource_manager()->Create( + container_, resource_name_, + new TRTEngineCacheResource(ctx, max_cached_engines_))); + + Tensor* handle; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); + handle->scalar<ResourceHandle>()() = + MakeResourceHandle<TRTEngineCacheResource>(ctx, container_, + resource_name_); + } + + private: + string container_; + string resource_name_; + + // Maximum number of cached engines + int max_cached_engines_; + + TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTEngineCache); +}; + +REGISTER_KERNEL_BUILDER(Name("CreateTRTEngineCache") + .Device(DEVICE_GPU) + .HostMemory("engine_cache_handle"), + CreateTRTEngineCache); + +class PopulateTRTEngineCache : public OpKernel { + public: + explicit PopulateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + ResourceHandle handle = HandleFromInput(ctx, 0); + TRTEngineCacheResource* resource = nullptr; + OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource)); + core::ScopedUnref unref_me(resource); + + auto allocator = resource->allocator_.get(); + OP_REQUIRES(ctx, allocator != nullptr, + errors::Internal("Not able to initialize TRT engine cache when " + "GPU allocator is empty.")); + OP_REQUIRES(ctx, resource->cache_.size() == 0, + errors::Internal("Expect engine cache to be empty, but got ", + resource->cache_.size(), " entries.")); + + // Get the file name. + const string& filename = ctx->input(1).scalar<string>()(); + OP_REQUIRES(ctx, !filename.empty(), + errors::InvalidArgument("filename cannot be empty.")); + + // Parse the serialized engines and add them to the cache. + std::unique_ptr<RandomAccessFile> file; + OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &file)); + auto reader = absl::make_unique<io::RecordReader>(file.get()); + + uint64 offset = 0; + int num_loaded_engine = 0; + do { + string record; + Status status = reader->ReadRecord(&offset, &record); + if (errors::IsOutOfRange(status)) break; + + TRTEngineInstance engine_instance; + engine_instance.ParseFromString(record); + std::vector<TensorShape> engine_input_shapes; + for (const TensorShapeProto& shape : engine_instance.input_shapes()) { + engine_input_shapes.emplace_back(shape); + } + + TrtUniquePtrType<IRuntime> infer( + nvinfer1::createInferRuntime(TRTEngineCacheResource::GetLogger())); + infer->setGpuAllocator(allocator); + TrtUniquePtrType<nvinfer1::ICudaEngine> engine( + infer->deserializeCudaEngine( + engine_instance.serialized_engine().c_str(), + engine_instance.serialized_engine().size(), + PluginFactoryTensorRT::GetInstance())); + auto raw_engine = engine.get(); + resource->cache_.emplace( + engine_input_shapes, + absl::make_unique<EngineContext>( + std::move(engine), TrtUniquePtrType<nvinfer1::IExecutionContext>( + raw_engine->createExecutionContext()))); + ++num_loaded_engine; + } while (1); + VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines to container " + << handle.container() << " for op " << handle.name() + << " on device " << ctx->device()->name() << " from file " + << filename; + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(PopulateTRTEngineCache); +}; + +REGISTER_KERNEL_BUILDER(Name("PopulateTRTEngineCache") + .Device(DEVICE_GPU) + .HostMemory("engine_cache_handle"), + PopulateTRTEngineCache); + +class DumpTRTEngineCache : public OpKernel { + public: + explicit DumpTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_cache_after_dump", + &delete_cache_after_dump_)); + } + + void Compute(OpKernelContext* ctx) override { + const string& container = ctx->input(0).scalar<string>()(); + const string& resource_name = ctx->input(1).scalar<string>()(); + const string& filename = ctx->input(2).scalar<string>()(); + OP_REQUIRES(ctx, !filename.empty(), + errors::InvalidArgument("filename cannot be empty.")); + + TRTEngineCacheResource* resource = nullptr; + OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( + container, resource_name, &resource)); + core::ScopedUnref unref_me(resource); + + // Serialize the engines and write them to file. + std::unique_ptr<WritableFile> file; + OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file)); + auto writer = absl::make_unique<io::RecordWriter>(file.get()); + + for (const auto& pair : resource->cache_) { + TRTEngineInstance engine_instance; + // Add input shapes. + const std::vector<TensorShape>& engine_input_shapes = pair.first; + for (const TensorShape& shape : engine_input_shapes) { + shape.AsProto(engine_instance.add_input_shapes()); + } + // Add the serialized engine. + const std::unique_ptr<EngineContext>& engine = pair.second; + TrtUniquePtrType<nvinfer1::IHostMemory> engine_data( + engine->cuda_engine->serialize()); + engine_instance.set_serialized_engine(engine_data->data(), + engine_data->size()); + + OP_REQUIRES_OK(ctx, + writer->WriteRecord(engine_instance.SerializeAsString())); + } + VLOG(1) << "Serialized " << resource->cache_.size() + << " TRT engines in container " << container << " for op " + << resource_name << " on device " << ctx->device()->name() + << " to file " << filename; + + if (delete_cache_after_dump_) { + VLOG(1) << "Destroying TRT engine cache resource in container " + << container << " for op " << resource_name << " on device " + << ctx->device()->name(); + OP_REQUIRES_OK(ctx, + ctx->resource_manager()->Delete<TRTEngineCacheResource>( + container, resource_name)); + } + } + + private: + bool delete_cache_after_dump_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(DumpTRTEngineCache); +}; + +REGISTER_KERNEL_BUILDER(Name("DumpTRTEngineCache").Device(DEVICE_GPU), + DumpTRTEngineCache); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc new file mode 100644 index 00000000000..5281433ffc4 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc @@ -0,0 +1,205 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <dirent.h> +#include <string.h> + +#include <fstream> +#include <vector> + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +class TRTEngineResourceOpsTest : public OpsTestBase { + protected: + void Reset() { + inputs_.clear(); + gtl::STLDeleteElements(&tensors_); + gtl::STLDeleteElements(&managed_outputs_); + } + + TrtUniquePtrType<nvinfer1::ICudaEngine> CreateTRTEngine() { + Logger logger; + TrtUniquePtrType<nvinfer1::IBuilder> builder( + nvinfer1::createInferBuilder(logger)); + TrtUniquePtrType<nvinfer1::INetworkDefinition> network( + builder->createNetwork()); + + // Add the input. + nvinfer1::Dims dims; + dims.nbDims = 1; + dims.d[0] = 1; + nvinfer1::ITensor* input = + network->addInput("input", nvinfer1::DataType::kFLOAT, dims); + EXPECT_NE(nullptr, input); + + // Add a unary layer. + nvinfer1::IUnaryLayer* layer = + network->addUnary(*input, nvinfer1::UnaryOperation::kEXP); + EXPECT_NE(nullptr, layer); + + // Mark the output. + nvinfer1::ITensor* output = layer->getOutput(0); + output->setName("output"); + network->markOutput(*output); + + // Build the engine + builder->setMaxBatchSize(1); + builder->setMaxWorkspaceSize(1 << 10); + TrtUniquePtrType<nvinfer1::ICudaEngine> engine( + builder->buildCudaEngine(*network)); + EXPECT_NE(nullptr, engine); + return engine; + } +}; + +TEST_F(TRTEngineResourceOpsTest, Basic) { + // Create the GPU device. + std::unique_ptr<Device> device( + DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0")); + ResourceMgr* rm = device->resource_manager(); + SetDevice(DEVICE_GPU, std::move(device)); + + // Create the resource. + const string container = "mycontainer"; + const string resource_name = "myresource"; + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache") + .Attr("container", container) + .Attr("resource_name", resource_name) + .Attr("max_cached_engines_count", 1) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + TF_ASSERT_OK(RunOpKernel()); + ResourceHandle handle = + context_->mutable_output(0)->scalar<ResourceHandle>()(); + + TRTEngineCacheResource* resource = nullptr; + EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok()); + + // Create a serialized TRT engine file. + TrtUniquePtrType<nvinfer1::ICudaEngine> engine = CreateTRTEngine(); + TrtUniquePtrType<nvinfer1::IExecutionContext> context( + engine->createExecutionContext()); + resource->cache_.emplace( + std::vector<TensorShape>{TensorShape({1, 1})}, + absl::make_unique<EngineContext>(std::move(engine), std::move(context))); + resource->Unref(); + + // Serialize the engine using DumpTRTEngineCache op. + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "DumpTRTEngineCache") + .Attr("delete_cache_after_dump", true) + .Input(FakeInput(DT_STRING)) + .Input(FakeInput(DT_STRING)) + .Input(FakeInput(DT_STRING)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray<string>(TensorShape({}), {container}); + AddInputFromArray<string>(TensorShape({}), {resource_name}); + const string filename = io::JoinPath(testing::TmpDir(), "trt_engine_file"); + AddInputFromArray<string>(TensorShape({}), {filename}); + TF_ASSERT_OK(RunOpKernel()); + + // Make sure the cache is deleted. + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp") + .Attr("ignore_lookup_error", false) + .Input(FakeInput(DT_RESOURCE)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray<ResourceHandle>(TensorShape({}), {handle}); + EXPECT_TRUE(errors::IsNotFound(RunOpKernel())); + + // Verify the serialized engine file. + Env* env = Env::Default(); + std::unique_ptr<RandomAccessFile> file; + TF_ASSERT_OK(env->NewRandomAccessFile(filename, &file)); + auto reader = absl::make_unique<io::RecordReader>(file.get()); + uint64 offset = 0; + string record; + TF_ASSERT_OK(reader->ReadRecord(&offset, &record)); + TRTEngineInstance engine_instance; + engine_instance.ParseFromString(record); + EXPECT_EQ(1, engine_instance.input_shapes_size()); + EXPECT_EQ(2, engine_instance.input_shapes(0).dim_size()); + EXPECT_EQ(1, engine_instance.input_shapes(0).dim(0).size()); + EXPECT_EQ(1, engine_instance.input_shapes(0).dim(1).size()); + EXPECT_TRUE(errors::IsOutOfRange(reader->ReadRecord(&offset, &record))); + + // Recreate the cache resource. + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache") + .Attr("container", container) + .Attr("resource_name", resource_name) + .Attr("max_cached_engines_count", 1) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + TF_ASSERT_OK(RunOpKernel()); + handle = context_->mutable_output(0)->scalar<ResourceHandle>()(); + EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok()); + EXPECT_EQ(0, resource->cache_.size()); + resource->Unref(); + + // Deserialize the engine using PopulateTRTEngineCache op. + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "PopulateTRTEngineCache") + .Input(FakeInput(DT_RESOURCE)) + .Input(FakeInput(DT_STRING)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray<ResourceHandle>(TensorShape({}), {handle}); + AddInputFromArray<string>(TensorShape({}), {filename}); + TF_ASSERT_OK(RunOpKernel()); + EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok()); + EXPECT_EQ(1, resource->cache_.size()); + resource->Unref(); + + // Destroy the engine cache again. + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp") + .Attr("ignore_lookup_error", false) + .Input(FakeInput(DT_RESOURCE)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray<ResourceHandle>(TensorShape({}), {handle}); + TF_ASSERT_OK(RunOpKernel()); + EXPECT_TRUE(errors::IsNotFound(RunOpKernel())); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc new file mode 100644 index 00000000000..cf1909a0b47 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +REGISTER_OP("CreateTRTEngineCache") + .Attr("container: string") + .Attr("resource_name: string") + .Attr("max_cached_engines_count: int = 1") + .Output("engine_cache_handle: resource") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("PopulateTRTEngineCache") + .Input("engine_cache_handle: resource") + .Input("filename: string") + .SetIsStateful() + .SetShapeFn(shape_inference::NoOutputs); + +REGISTER_OP("DumpTRTEngineCache") + .Attr("delete_cache_after_dump: bool = false") + .Input("container: string") + .Input("resource_name: string") + .Input("filename: string") + .SetIsStateful() + .SetShapeFn(shape_inference::NoOutputs); + +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.proto b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.proto new file mode 100644 index 00000000000..e8394974478 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package tensorflow.tensorrt; + +import "tensorflow/core/framework/tensor_shape.proto"; + +// Containing information for a serialized TensorRT engine. +message TRTEngineInstance { + // The input shapes of the TRT engine. + repeated TensorShapeProto input_shapes = 1; + + // The serialized TRT engine. + // + // TODO(laigd): consider using a more efficient in-memory representation + // instead of string which is the default here. + bytes serialized_engine = 2; + + // TODO(laigd): consider adding calibration stats, precision_modes, etc. +} diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc index ee677b76db8..8675f5d69f4 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc @@ -30,6 +30,11 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +Logger& TRTEngineCacheResource::GetLogger() { + static Logger* logger = new Logger(); + return *logger; +} + TRTEngineCacheResource::TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity) : cache_(capacity) { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h index 112be459871..378fb1c5424 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/errors.h" @@ -141,6 +142,13 @@ struct EngineContext { class TRTEngineCacheResource : public ResourceBase { public: + // According to the TensorRT API, the logger is considered a singleton by the + // TensorRT library, and multiple instances of IRuntime and/or IBuilder must + // all use the same logger. So here we make it a singleton. + // + // TODO(laigd): use this logger in all places where conversion happens. + static Logger& GetLogger(); + TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity); ~TRTEngineCacheResource() override;