From 3638d4b89b17c31151d6962d11c2cc80d16fc430 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Tue, 28 May 2019 10:03:02 -0700 Subject: [PATCH] Specialize GetSerializedResourceOp for calibration data This is in preparation to combine calibration data and engine into a single resource. Also, remove GetSerializedResourceOpTest that tests the generic op. PiperOrigin-RevId: 250306470 --- tensorflow/compiler/tf2tensorrt/BUILD | 37 ++------- .../tf2tensorrt/convert/convert_graph.cc | 2 +- .../tf2tensorrt/convert/convert_nodes.cc | 2 +- .../tf2tensorrt/convert/convert_nodes.h | 2 +- ...ource_op.cc => get_calibration_data_op.cc} | 20 ++--- .../get_serialized_resource_op_test.cc | 81 ------------------- .../tf2tensorrt/kernels/trt_engine_op.cc | 14 ++-- .../tf2tensorrt/kernels/trt_engine_op_test.cc | 2 +- ...ource_op.cc => get_calibration_data_op.cc} | 6 +- ...t_resources.cc => calibration_resource.cc} | 4 +- ...trt_resources.h => calibration_resource.h} | 9 +-- .../python/compiler/tensorrt/trt_convert.py | 18 ++--- 12 files changed, 41 insertions(+), 156 deletions(-) rename tensorflow/compiler/tf2tensorrt/kernels/{get_serialized_resource_op.cc => get_calibration_data_op.cc} (75%) delete mode 100644 tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc rename tensorflow/compiler/tf2tensorrt/ops/{get_serialized_resource_op.cc => get_calibration_data_op.cc} (86%) rename tensorflow/compiler/tf2tensorrt/utils/{trt_resources.cc => calibration_resource.cc} (93%) rename tensorflow/compiler/tf2tensorrt/utils/{trt_resources.h => calibration_resource.h} (89%) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index eb10b021349..03e38853e53 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -53,7 +53,7 @@ tf_cuda_cc_test( cc_library( name = "trt_op_kernels", srcs = [ - "kernels/get_serialized_resource_op.cc", + "kernels/get_calibration_data_op.cc", "kernels/trt_engine_op.cc", ], copts = tf_copts(), @@ -144,31 +144,6 @@ tf_cc_shared_object( ]) + tf_custom_op_library_additional_deps(), ) -tf_cuda_cc_test( - name = "get_serialized_resource_op_test", - size = "small", - srcs = ["kernels/get_serialized_resource_op_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - # TODO(laigd): consider splitting get_serialized_resource_op out from - # TF-TRT. - ":trt_op_kernels", - ":trt_op_libs", - ":trt_resources", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/kernels:ops_testutil", - ], -) - tf_cuda_cc_test( name = "trt_engine_op_test", size = "small", @@ -201,7 +176,7 @@ tf_cuda_cc_test( tf_gen_op_libs( op_lib_names = [ "trt_engine_op", - "get_serialized_resource_op", + "get_calibration_data_op", "trt_engine_resource_ops", ], ) @@ -209,7 +184,7 @@ tf_gen_op_libs( cc_library( name = "trt_op_libs", deps = [ - ":get_serialized_resource_op_op_lib", + ":get_calibration_data_op_op_lib", ":trt_engine_op_op_lib", ], ) @@ -262,14 +237,14 @@ tf_custom_op_py_library( tf_cuda_library( name = "trt_resources", srcs = [ + "utils/calibration_resource.cc", "utils/trt_int8_calibrator.cc", "utils/trt_lru_cache.cc", - "utils/trt_resources.cc", ], hdrs = [ + "utils/calibration_resource.h", "utils/trt_int8_calibrator.h", "utils/trt_lru_cache.h", - "utils/trt_resources.h", ], deps = [ ":trt_allocator", @@ -553,7 +528,7 @@ tf_py_wrap_cc( srcs = ["utils/py_utils.i"], copts = tf_copts(), deps = [ - "//tensorflow/compiler/tf2tensorrt:py_utils", + ":py_utils", "//third_party/python_runtime:headers", ], ) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 34d4ee79542..0742f2ae2a9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/segment/segment.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index a1ccb3b3e6e..7017c9a1706 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -30,8 +30,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index d0f6d5ef1d1..a046e4d5e96 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -23,10 +23,10 @@ limitations under the License. #include #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc similarity index 75% rename from tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc rename to tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc index e252f9111d6..4c4ae6f88da 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -28,24 +28,24 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -class GetSerializedResourceOp : public OpKernel { +class GetCalibrationDataOp : public OpKernel { public: - explicit GetSerializedResourceOp(OpKernelConstruction* context) + explicit GetCalibrationDataOp(OpKernelConstruction* context) : OpKernel(context) {} - ~GetSerializedResourceOp() override {} + ~GetCalibrationDataOp() override {} void Compute(OpKernelContext* context) override { // TODO(laigd): it will allocate the tensor on the device and copy the // serialized string to that tensor, and later sess.run() will copy it back // to host. We need to optimize this. - const string& container = context->input(0).scalar()(); - const string& resource_name = context->input(1).scalar()(); + const string& resource_name = context->input(0).scalar()(); // Get the resource. - SerializableResourceBase* resource = nullptr; + TRTCalibrationResource* resource = nullptr; OP_REQUIRES_OK(context, context->resource_manager()->Lookup( - container, resource_name, &resource)); + std::string(kCalibrationContainerName), + resource_name, &resource)); core::ScopedUnref sc(resource); // Serialize the resource as output. @@ -59,8 +59,8 @@ class GetSerializedResourceOp : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("GetSerializedResourceOp").Device(DEVICE_GPU), - GetSerializedResourceOp); +REGISTER_KERNEL_BUILDER(Name("GetCalibrationDataOp").Device(DEVICE_GPU), + GetCalibrationDataOp); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc deleted file mode 100644 index d54cbf7836e..00000000000 --- a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include -#include - -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/platform/test.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT - -namespace tensorflow { -namespace tensorrt { - -class GetSerializedResourceOpTest : public OpsTestBase {}; - -TEST_F(GetSerializedResourceOpTest, Basic) { - // Create the GPU device. - std::unique_ptr device( - DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0")); - - // Create the resource. - class MySerializableResource : public SerializableResourceBase { - public: - string DebugString() const override { return ""; } - Status SerializeToString(string* serialized) override { - *serialized = "my_serialized_str"; - return Status::OK(); - } - }; - const string container = "mycontainer"; - const string resource_name = "myresource"; - SerializableResourceBase* resource = new MySerializableResource(); - ResourceMgr* rm = device->resource_manager(); - EXPECT_TRUE(rm->Create(container, resource_name, resource).ok()); - - // Create the op. - SetDevice(DEVICE_GPU, std::move(device)); - TF_ASSERT_OK(NodeDefBuilder("op", "GetSerializedResourceOp") - .Input(FakeInput(DT_STRING)) - .Input(FakeInput(DT_STRING)) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - - // Execute the op. - AddInputFromArray(TensorShape({}), {container}); - AddInputFromArray(TensorShape({}), {resource_name}); - TF_ASSERT_OK(RunOpKernel()); - - // Verify the result. - // string type output will remain on CPU, so we're not using GetOutput() here. - EXPECT_EQ("my_serialized_str", - context_->mutable_output(0)->scalar()()); -} - -} // namespace tensorrt -} // namespace tensorflow - -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 51ac9528864..616640f4bbd 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/op.h" @@ -96,7 +96,7 @@ class TRTEngineOp : public AsyncOpKernel { // Allocate necessary resources for calibration Status AllocateCalibrationResources(OpKernelContext* ctx, - SerializableResourceBase** cr); + TRTCalibrationResource** cr); // Get engine for the input shape EngineContext* GetEngine(const std::vector& input_shapes, @@ -286,8 +286,8 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, ctx, ctx->resource_manager()->LookupOrCreate( "TF-TRT-Calibration", name(), - reinterpret_cast(&calib_res), - {[ctx, this](SerializableResourceBase** cr) -> Status { + reinterpret_cast(&calib_res), + {[ctx, this](TRTCalibrationResource** cr) -> Status { return this->AllocateCalibrationResources(ctx, cr); }}), *helper); @@ -542,7 +542,7 @@ EngineContext* TRTEngineOp::GetEngine( // Get engine cache. TRTEngineCacheResource* cache_res = nullptr; auto status = ctx->resource_manager()->LookupOrCreate( - "TF-TRT-Engine-Cache", string(resource_name), &cache_res, + std::string(kCalibrationContainerName), string(resource_name), &cache_res, {[this, ctx](TRTEngineCacheResource** cr) -> Status { *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_); return Status::OK(); @@ -663,8 +663,8 @@ EngineContext* TRTEngineOp::GetEngine( return cache.at(engine_input_shapes).get(); } -Status TRTEngineOp::AllocateCalibrationResources( - OpKernelContext* ctx, SerializableResourceBase** cr) { +Status TRTEngineOp::AllocateCalibrationResources(OpKernelContext* ctx, + TRTCalibrationResource** cr) { auto cres = new TRTCalibrationResource(); *cr = cres; // Get the allocator. diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc index d4077692235..73d2742a4e9 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/ops/get_calibration_data_op.cc similarity index 86% rename from tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc rename to tensorflow/compiler/tf2tensorrt/ops/get_calibration_data_op.cc index 59da73f5efc..573172b92e6 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/get_calibration_data_op.cc @@ -23,15 +23,13 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("GetSerializedResourceOp") - .Input("container: string") +REGISTER_OP("GetCalibrationDataOp") .Input("resource_name: string") .Output("serialized_resource: string") .SetShapeFn(shape_inference::ScalarShape) .SetIsStateful() .Doc(R"doc( -Gets a resource from a container managed by the resource manager and returns -its serialized representation. +Returns calibration data for the given resource name )doc"); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc b/tensorflow/compiler/tf2tensorrt/utils/calibration_resource.cc similarity index 93% rename from tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc rename to tensorflow/compiler/tf2tensorrt/utils/calibration_resource.cc index 534e59f06b7..0972217dbd3 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/calibration_resource.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -21,6 +21,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +const absl::string_view kCalibrationContainerName = "TF-TRT-Calibration"; + TRTCalibrationResource::~TRTCalibrationResource() { VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); builder_.reset(); diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h b/tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h similarity index 89% rename from tensorflow/compiler/tf2tensorrt/utils/trt_resources.h rename to tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h index 697cef5d788..e09e7f06318 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h +++ b/tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h @@ -37,18 +37,15 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -class SerializableResourceBase : public ResourceBase { - public: - virtual Status SerializeToString(string* serialized) = 0; -}; +ABSL_CONST_INIT extern const absl::string_view kCalibrationContainerName; -class TRTCalibrationResource : public SerializableResourceBase { +class TRTCalibrationResource : public ResourceBase { public: ~TRTCalibrationResource() override; string DebugString() const override; - Status SerializeToString(string* serialized) override; + Status SerializeToString(string* serialized); // Lookup table for temporary staging areas of input tensors for calibration. std::unordered_map> device_buffers_; diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 7bc922bc8d6..4730628b39b 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -532,7 +532,6 @@ DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams( max_batch_size=1, cached_engine_batches=None) -_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration" _TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache" _TRT_ENGINE_OP_NAME = "TRTEngineOp" @@ -784,37 +783,32 @@ class TrtGraphConverter(GraphConverter): assert not self._calibration_data_collected # TODO(laigd): a better way would be to use self._calibration_sess to list - # all the devices, add one get_serialized_resource_op for each device, and + # all the devices, add one get_calibration_data for each device, and # fetch each such op for every resource until its found. This can work # even when the device of the TRTEngineOp is empty or not fully specified. - # Maps device name to the corresponding get_serialized_resource_op. + # Maps device name to the corresponding get_calibration_data. device_to_get_resource_op_map = {} with self._calibration_graph.as_default(): - container_input = array_ops.placeholder(dtypes.string) resource_name_input = array_ops.placeholder(dtypes.string) for node in self._converted_graph_def.node: if node.op == _TRT_ENGINE_OP_NAME: - # Adds the get_serialized_resource_op for the device if not done - # before. We only add one such op for each device. + # Adds the get_calibration_data op for the device if not done before. + # We only add one such op for each device. # TODO(laigd): What if the device is empty????? if node.device not in device_to_get_resource_op_map: with self._calibration_graph.device(node.device): serialized_resources_output = ( - gen_trt_ops.get_serialized_resource_op( - container_input, resource_name_input)) + gen_trt_ops.get_calibration_data_op(resource_name_input)) device_to_get_resource_op_map[node.device] = ( serialized_resources_output) # Get the calibration resource. calibration_result = self._calibration_sess.run( device_to_get_resource_op_map[node.device], - feed_dict={ - container_input: _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME, - resource_name_input: node.name - }) + feed_dict={resource_name_input: node.name}) node.attr["calibration_data"].s = calibration_result self._calibration_data_collected = True