Specialize GetSerializedResourceOp for calibration data
This is in preparation to combine calibration data and engine into a single resource. Also, remove GetSerializedResourceOpTest that tests the generic op. PiperOrigin-RevId: 250306470
This commit is contained in:
parent
bb237cf743
commit
3638d4b89b
@ -53,7 +53,7 @@ tf_cuda_cc_test(
|
||||
cc_library(
|
||||
name = "trt_op_kernels",
|
||||
srcs = [
|
||||
"kernels/get_serialized_resource_op.cc",
|
||||
"kernels/get_calibration_data_op.cc",
|
||||
"kernels/trt_engine_op.cc",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
@ -144,31 +144,6 @@ tf_cc_shared_object(
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "get_serialized_resource_op_test",
|
||||
size = "small",
|
||||
srcs = ["kernels/get_serialized_resource_op_test.cc"],
|
||||
tags = [
|
||||
"no_cuda_on_cpu_tap",
|
||||
"no_windows",
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
# TODO(laigd): consider splitting get_serialized_resource_op out from
|
||||
# TF-TRT.
|
||||
":trt_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_resources",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "trt_engine_op_test",
|
||||
size = "small",
|
||||
@ -201,7 +176,7 @@ tf_cuda_cc_test(
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"trt_engine_op",
|
||||
"get_serialized_resource_op",
|
||||
"get_calibration_data_op",
|
||||
"trt_engine_resource_ops",
|
||||
],
|
||||
)
|
||||
@ -209,7 +184,7 @@ tf_gen_op_libs(
|
||||
cc_library(
|
||||
name = "trt_op_libs",
|
||||
deps = [
|
||||
":get_serialized_resource_op_op_lib",
|
||||
":get_calibration_data_op_op_lib",
|
||||
":trt_engine_op_op_lib",
|
||||
],
|
||||
)
|
||||
@ -262,14 +237,14 @@ tf_custom_op_py_library(
|
||||
tf_cuda_library(
|
||||
name = "trt_resources",
|
||||
srcs = [
|
||||
"utils/calibration_resource.cc",
|
||||
"utils/trt_int8_calibrator.cc",
|
||||
"utils/trt_lru_cache.cc",
|
||||
"utils/trt_resources.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/calibration_resource.h",
|
||||
"utils/trt_int8_calibrator.h",
|
||||
"utils/trt_lru_cache.h",
|
||||
"utils/trt_resources.h",
|
||||
],
|
||||
deps = [
|
||||
":trt_allocator",
|
||||
@ -553,7 +528,7 @@ tf_py_wrap_cc(
|
||||
srcs = ["utils/py_utils.i"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2tensorrt:py_utils",
|
||||
":py_utils",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
|
||||
|
||||
@ -30,8 +30,8 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
|
||||
|
||||
@ -23,10 +23,10 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
|
||||
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
@ -28,24 +28,24 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
class GetSerializedResourceOp : public OpKernel {
|
||||
class GetCalibrationDataOp : public OpKernel {
|
||||
public:
|
||||
explicit GetSerializedResourceOp(OpKernelConstruction* context)
|
||||
explicit GetCalibrationDataOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
~GetSerializedResourceOp() override {}
|
||||
~GetCalibrationDataOp() override {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// TODO(laigd): it will allocate the tensor on the device and copy the
|
||||
// serialized string to that tensor, and later sess.run() will copy it back
|
||||
// to host. We need to optimize this.
|
||||
const string& container = context->input(0).scalar<string>()();
|
||||
const string& resource_name = context->input(1).scalar<string>()();
|
||||
const string& resource_name = context->input(0).scalar<string>()();
|
||||
|
||||
// Get the resource.
|
||||
SerializableResourceBase* resource = nullptr;
|
||||
TRTCalibrationResource* resource = nullptr;
|
||||
OP_REQUIRES_OK(context, context->resource_manager()->Lookup(
|
||||
container, resource_name, &resource));
|
||||
std::string(kCalibrationContainerName),
|
||||
resource_name, &resource));
|
||||
core::ScopedUnref sc(resource);
|
||||
|
||||
// Serialize the resource as output.
|
||||
@ -59,8 +59,8 @@ class GetSerializedResourceOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("GetSerializedResourceOp").Device(DEVICE_GPU),
|
||||
GetSerializedResourceOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("GetCalibrationDataOp").Device(DEVICE_GPU),
|
||||
GetCalibrationDataOp);
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
@ -1,81 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <dirent.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
class GetSerializedResourceOpTest : public OpsTestBase {};
|
||||
|
||||
TEST_F(GetSerializedResourceOpTest, Basic) {
|
||||
// Create the GPU device.
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
|
||||
|
||||
// Create the resource.
|
||||
class MySerializableResource : public SerializableResourceBase {
|
||||
public:
|
||||
string DebugString() const override { return ""; }
|
||||
Status SerializeToString(string* serialized) override {
|
||||
*serialized = "my_serialized_str";
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
const string container = "mycontainer";
|
||||
const string resource_name = "myresource";
|
||||
SerializableResourceBase* resource = new MySerializableResource();
|
||||
ResourceMgr* rm = device->resource_manager();
|
||||
EXPECT_TRUE(rm->Create(container, resource_name, resource).ok());
|
||||
|
||||
// Create the op.
|
||||
SetDevice(DEVICE_GPU, std::move(device));
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "GetSerializedResourceOp")
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
|
||||
// Execute the op.
|
||||
AddInputFromArray<string>(TensorShape({}), {container});
|
||||
AddInputFromArray<string>(TensorShape({}), {resource_name});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Verify the result.
|
||||
// string type output will remain on CPU, so we're not using GetOutput() here.
|
||||
EXPECT_EQ("my_serialized_str",
|
||||
context_->mutable_output(0)->scalar<string>()());
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
||||
@ -22,10 +22,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -96,7 +96,7 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
|
||||
// Allocate necessary resources for calibration
|
||||
Status AllocateCalibrationResources(OpKernelContext* ctx,
|
||||
SerializableResourceBase** cr);
|
||||
TRTCalibrationResource** cr);
|
||||
|
||||
// Get engine for the input shape
|
||||
EngineContext* GetEngine(const std::vector<TensorShape>& input_shapes,
|
||||
@ -286,8 +286,8 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate(
|
||||
"TF-TRT-Calibration", name(),
|
||||
reinterpret_cast<SerializableResourceBase**>(&calib_res),
|
||||
{[ctx, this](SerializableResourceBase** cr) -> Status {
|
||||
reinterpret_cast<TRTCalibrationResource**>(&calib_res),
|
||||
{[ctx, this](TRTCalibrationResource** cr) -> Status {
|
||||
return this->AllocateCalibrationResources(ctx, cr);
|
||||
}}),
|
||||
*helper);
|
||||
@ -542,7 +542,7 @@ EngineContext* TRTEngineOp::GetEngine(
|
||||
// Get engine cache.
|
||||
TRTEngineCacheResource* cache_res = nullptr;
|
||||
auto status = ctx->resource_manager()->LookupOrCreate(
|
||||
"TF-TRT-Engine-Cache", string(resource_name), &cache_res,
|
||||
std::string(kCalibrationContainerName), string(resource_name), &cache_res,
|
||||
{[this, ctx](TRTEngineCacheResource** cr) -> Status {
|
||||
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
|
||||
return Status::OK();
|
||||
@ -663,8 +663,8 @@ EngineContext* TRTEngineOp::GetEngine(
|
||||
return cache.at(engine_input_shapes).get();
|
||||
}
|
||||
|
||||
Status TRTEngineOp::AllocateCalibrationResources(
|
||||
OpKernelContext* ctx, SerializableResourceBase** cr) {
|
||||
Status TRTEngineOp::AllocateCalibrationResources(OpKernelContext* ctx,
|
||||
TRTCalibrationResource** cr) {
|
||||
auto cres = new TRTCalibrationResource();
|
||||
*cr = cres;
|
||||
// Get the allocator.
|
||||
|
||||
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
@ -23,15 +23,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("GetSerializedResourceOp")
|
||||
.Input("container: string")
|
||||
REGISTER_OP("GetCalibrationDataOp")
|
||||
.Input("resource_name: string")
|
||||
.Output("serialized_resource: string")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
Gets a resource from a container managed by the resource manager and returns
|
||||
its serialized representation.
|
||||
Returns calibration data for the given resource name
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
@ -21,6 +21,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
const absl::string_view kCalibrationContainerName = "TF-TRT-Calibration";
|
||||
|
||||
TRTCalibrationResource::~TRTCalibrationResource() {
|
||||
VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
|
||||
builder_.reset();
|
||||
@ -37,18 +37,15 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
class SerializableResourceBase : public ResourceBase {
|
||||
public:
|
||||
virtual Status SerializeToString(string* serialized) = 0;
|
||||
};
|
||||
ABSL_CONST_INIT extern const absl::string_view kCalibrationContainerName;
|
||||
|
||||
class TRTCalibrationResource : public SerializableResourceBase {
|
||||
class TRTCalibrationResource : public ResourceBase {
|
||||
public:
|
||||
~TRTCalibrationResource() override;
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
Status SerializeToString(string* serialized) override;
|
||||
Status SerializeToString(string* serialized);
|
||||
|
||||
// Lookup table for temporary staging areas of input tensors for calibration.
|
||||
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
|
||||
@ -532,7 +532,6 @@ DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams(
|
||||
max_batch_size=1,
|
||||
cached_engine_batches=None)
|
||||
|
||||
_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration"
|
||||
_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache"
|
||||
_TRT_ENGINE_OP_NAME = "TRTEngineOp"
|
||||
|
||||
@ -784,37 +783,32 @@ class TrtGraphConverter(GraphConverter):
|
||||
assert not self._calibration_data_collected
|
||||
|
||||
# TODO(laigd): a better way would be to use self._calibration_sess to list
|
||||
# all the devices, add one get_serialized_resource_op for each device, and
|
||||
# all the devices, add one get_calibration_data for each device, and
|
||||
# fetch each such op for every resource until its found. This can work
|
||||
# even when the device of the TRTEngineOp is empty or not fully specified.
|
||||
|
||||
# Maps device name to the corresponding get_serialized_resource_op.
|
||||
# Maps device name to the corresponding get_calibration_data.
|
||||
device_to_get_resource_op_map = {}
|
||||
|
||||
with self._calibration_graph.as_default():
|
||||
container_input = array_ops.placeholder(dtypes.string)
|
||||
resource_name_input = array_ops.placeholder(dtypes.string)
|
||||
|
||||
for node in self._converted_graph_def.node:
|
||||
if node.op == _TRT_ENGINE_OP_NAME:
|
||||
# Adds the get_serialized_resource_op for the device if not done
|
||||
# before. We only add one such op for each device.
|
||||
# Adds the get_calibration_data op for the device if not done before.
|
||||
# We only add one such op for each device.
|
||||
# TODO(laigd): What if the device is empty?????
|
||||
if node.device not in device_to_get_resource_op_map:
|
||||
with self._calibration_graph.device(node.device):
|
||||
serialized_resources_output = (
|
||||
gen_trt_ops.get_serialized_resource_op(
|
||||
container_input, resource_name_input))
|
||||
gen_trt_ops.get_calibration_data_op(resource_name_input))
|
||||
device_to_get_resource_op_map[node.device] = (
|
||||
serialized_resources_output)
|
||||
|
||||
# Get the calibration resource.
|
||||
calibration_result = self._calibration_sess.run(
|
||||
device_to_get_resource_op_map[node.device],
|
||||
feed_dict={
|
||||
container_input: _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME,
|
||||
resource_name_input: node.name
|
||||
})
|
||||
feed_dict={resource_name_input: node.name})
|
||||
node.attr["calibration_data"].s = calibration_result
|
||||
|
||||
self._calibration_data_collected = True
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user