Specialize GetSerializedResourceOp for calibration data

This is in preparation to combine calibration data and engine into a single resource. Also, remove GetSerializedResourceOpTest that tests the generic op.

PiperOrigin-RevId: 250306470
This commit is contained in:
Smit Hinsu 2019-05-28 10:03:02 -07:00 committed by TensorFlower Gardener
parent bb237cf743
commit 3638d4b89b
12 changed files with 41 additions and 156 deletions

View File

@ -53,7 +53,7 @@ tf_cuda_cc_test(
cc_library(
name = "trt_op_kernels",
srcs = [
"kernels/get_serialized_resource_op.cc",
"kernels/get_calibration_data_op.cc",
"kernels/trt_engine_op.cc",
],
copts = tf_copts(),
@ -144,31 +144,6 @@ tf_cc_shared_object(
]) + tf_custom_op_library_additional_deps(),
)
tf_cuda_cc_test(
name = "get_serialized_resource_op_test",
size = "small",
srcs = ["kernels/get_serialized_resource_op_test.cc"],
tags = [
"no_cuda_on_cpu_tap",
"no_windows",
"nomac",
],
deps = [
# TODO(laigd): consider splitting get_serialized_resource_op out from
# TF-TRT.
":trt_op_kernels",
":trt_op_libs",
":trt_resources",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_testutil",
],
)
tf_cuda_cc_test(
name = "trt_engine_op_test",
size = "small",
@ -201,7 +176,7 @@ tf_cuda_cc_test(
tf_gen_op_libs(
op_lib_names = [
"trt_engine_op",
"get_serialized_resource_op",
"get_calibration_data_op",
"trt_engine_resource_ops",
],
)
@ -209,7 +184,7 @@ tf_gen_op_libs(
cc_library(
name = "trt_op_libs",
deps = [
":get_serialized_resource_op_op_lib",
":get_calibration_data_op_op_lib",
":trt_engine_op_op_lib",
],
)
@ -262,14 +237,14 @@ tf_custom_op_py_library(
tf_cuda_library(
name = "trt_resources",
srcs = [
"utils/calibration_resource.cc",
"utils/trt_int8_calibrator.cc",
"utils/trt_lru_cache.cc",
"utils/trt_resources.cc",
],
hdrs = [
"utils/calibration_resource.h",
"utils/trt_int8_calibrator.h",
"utils/trt_lru_cache.h",
"utils/trt_resources.h",
],
deps = [
":trt_allocator",
@ -553,7 +528,7 @@ tf_py_wrap_cc(
srcs = ["utils/py_utils.i"],
copts = tf_copts(),
deps = [
"//tensorflow/compiler/tf2tensorrt:py_utils",
":py_utils",
"//third_party/python_runtime:headers",
],
)

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"

View File

@ -30,8 +30,8 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT

View File

@ -23,10 +23,10 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"

View File

@ -16,7 +16,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
@ -28,24 +28,24 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
class GetSerializedResourceOp : public OpKernel {
class GetCalibrationDataOp : public OpKernel {
public:
explicit GetSerializedResourceOp(OpKernelConstruction* context)
explicit GetCalibrationDataOp(OpKernelConstruction* context)
: OpKernel(context) {}
~GetSerializedResourceOp() override {}
~GetCalibrationDataOp() override {}
void Compute(OpKernelContext* context) override {
// TODO(laigd): it will allocate the tensor on the device and copy the
// serialized string to that tensor, and later sess.run() will copy it back
// to host. We need to optimize this.
const string& container = context->input(0).scalar<string>()();
const string& resource_name = context->input(1).scalar<string>()();
const string& resource_name = context->input(0).scalar<string>()();
// Get the resource.
SerializableResourceBase* resource = nullptr;
TRTCalibrationResource* resource = nullptr;
OP_REQUIRES_OK(context, context->resource_manager()->Lookup(
container, resource_name, &resource));
std::string(kCalibrationContainerName),
resource_name, &resource));
core::ScopedUnref sc(resource);
// Serialize the resource as output.
@ -59,8 +59,8 @@ class GetSerializedResourceOp : public OpKernel {
}
};
REGISTER_KERNEL_BUILDER(Name("GetSerializedResourceOp").Device(DEVICE_GPU),
GetSerializedResourceOp);
REGISTER_KERNEL_BUILDER(Name("GetCalibrationDataOp").Device(DEVICE_GPU),
GetCalibrationDataOp);
} // namespace tensorrt
} // namespace tensorflow

View File

@ -1,81 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <dirent.h>
#include <string.h>
#include <fstream>
#include <vector>
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
namespace tensorflow {
namespace tensorrt {
class GetSerializedResourceOpTest : public OpsTestBase {};
TEST_F(GetSerializedResourceOpTest, Basic) {
// Create the GPU device.
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
// Create the resource.
class MySerializableResource : public SerializableResourceBase {
public:
string DebugString() const override { return ""; }
Status SerializeToString(string* serialized) override {
*serialized = "my_serialized_str";
return Status::OK();
}
};
const string container = "mycontainer";
const string resource_name = "myresource";
SerializableResourceBase* resource = new MySerializableResource();
ResourceMgr* rm = device->resource_manager();
EXPECT_TRUE(rm->Create(container, resource_name, resource).ok());
// Create the op.
SetDevice(DEVICE_GPU, std::move(device));
TF_ASSERT_OK(NodeDefBuilder("op", "GetSerializedResourceOp")
.Input(FakeInput(DT_STRING))
.Input(FakeInput(DT_STRING))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Execute the op.
AddInputFromArray<string>(TensorShape({}), {container});
AddInputFromArray<string>(TensorShape({}), {resource_name});
TF_ASSERT_OK(RunOpKernel());
// Verify the result.
// string type output will remain on CPU, so we're not using GetOutput() here.
EXPECT_EQ("my_serialized_str",
context_->mutable_output(0)->scalar<string>()());
}
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -22,10 +22,10 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/op.h"
@ -96,7 +96,7 @@ class TRTEngineOp : public AsyncOpKernel {
// Allocate necessary resources for calibration
Status AllocateCalibrationResources(OpKernelContext* ctx,
SerializableResourceBase** cr);
TRTCalibrationResource** cr);
// Get engine for the input shape
EngineContext* GetEngine(const std::vector<TensorShape>& input_shapes,
@ -286,8 +286,8 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
ctx,
ctx->resource_manager()->LookupOrCreate(
"TF-TRT-Calibration", name(),
reinterpret_cast<SerializableResourceBase**>(&calib_res),
{[ctx, this](SerializableResourceBase** cr) -> Status {
reinterpret_cast<TRTCalibrationResource**>(&calib_res),
{[ctx, this](TRTCalibrationResource** cr) -> Status {
return this->AllocateCalibrationResources(ctx, cr);
}}),
*helper);
@ -542,7 +542,7 @@ EngineContext* TRTEngineOp::GetEngine(
// Get engine cache.
TRTEngineCacheResource* cache_res = nullptr;
auto status = ctx->resource_manager()->LookupOrCreate(
"TF-TRT-Engine-Cache", string(resource_name), &cache_res,
std::string(kCalibrationContainerName), string(resource_name), &cache_res,
{[this, ctx](TRTEngineCacheResource** cr) -> Status {
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
return Status::OK();
@ -663,8 +663,8 @@ EngineContext* TRTEngineOp::GetEngine(
return cache.at(engine_input_shapes).get();
}
Status TRTEngineOp::AllocateCalibrationResources(
OpKernelContext* ctx, SerializableResourceBase** cr) {
Status TRTEngineOp::AllocateCalibrationResources(OpKernelContext* ctx,
TRTCalibrationResource** cr) {
auto cres = new TRTCalibrationResource();
*cr = cres;
// Get the allocator.

View File

@ -22,7 +22,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"

View File

@ -23,15 +23,13 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("GetSerializedResourceOp")
.Input("container: string")
REGISTER_OP("GetCalibrationDataOp")
.Input("resource_name: string")
.Output("serialized_resource: string")
.SetShapeFn(shape_inference::ScalarShape)
.SetIsStateful()
.Doc(R"doc(
Gets a resource from a container managed by the resource manager and returns
its serialized representation.
Returns calibration data for the given resource name
)doc");
} // namespace tensorflow

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@ -21,6 +21,8 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
const absl::string_view kCalibrationContainerName = "TF-TRT-Calibration";
TRTCalibrationResource::~TRTCalibrationResource() {
VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
builder_.reset();

View File

@ -37,18 +37,15 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
class SerializableResourceBase : public ResourceBase {
public:
virtual Status SerializeToString(string* serialized) = 0;
};
ABSL_CONST_INIT extern const absl::string_view kCalibrationContainerName;
class TRTCalibrationResource : public SerializableResourceBase {
class TRTCalibrationResource : public ResourceBase {
public:
~TRTCalibrationResource() override;
string DebugString() const override;
Status SerializeToString(string* serialized) override;
Status SerializeToString(string* serialized);
// Lookup table for temporary staging areas of input tensors for calibration.
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;

View File

@ -532,7 +532,6 @@ DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams(
max_batch_size=1,
cached_engine_batches=None)
_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration"
_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache"
_TRT_ENGINE_OP_NAME = "TRTEngineOp"
@ -784,37 +783,32 @@ class TrtGraphConverter(GraphConverter):
assert not self._calibration_data_collected
# TODO(laigd): a better way would be to use self._calibration_sess to list
# all the devices, add one get_serialized_resource_op for each device, and
# all the devices, add one get_calibration_data for each device, and
# fetch each such op for every resource until its found. This can work
# even when the device of the TRTEngineOp is empty or not fully specified.
# Maps device name to the corresponding get_serialized_resource_op.
# Maps device name to the corresponding get_calibration_data.
device_to_get_resource_op_map = {}
with self._calibration_graph.as_default():
container_input = array_ops.placeholder(dtypes.string)
resource_name_input = array_ops.placeholder(dtypes.string)
for node in self._converted_graph_def.node:
if node.op == _TRT_ENGINE_OP_NAME:
# Adds the get_serialized_resource_op for the device if not done
# before. We only add one such op for each device.
# Adds the get_calibration_data op for the device if not done before.
# We only add one such op for each device.
# TODO(laigd): What if the device is empty?????
if node.device not in device_to_get_resource_op_map:
with self._calibration_graph.device(node.device):
serialized_resources_output = (
gen_trt_ops.get_serialized_resource_op(
container_input, resource_name_input))
gen_trt_ops.get_calibration_data_op(resource_name_input))
device_to_get_resource_op_map[node.device] = (
serialized_resources_output)
# Get the calibration resource.
calibration_result = self._calibration_sess.run(
device_to_get_resource_op_map[node.device],
feed_dict={
container_input: _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME,
resource_name_input: node.name
})
feed_dict={resource_name_input: node.name})
node.attr["calibration_data"].s = calibration_result
self._calibration_data_collected = True