Move compression_utils to core/data.

This is in preparation for adding a CompressElementOp, which will use CompressElement to compress a dataset element in a tf.data service agnostic way.

PiperOrigin-RevId: 312197651
Change-Id: I3558b2f5036dcf4c91ed9059a7b896351c79da40
This commit is contained in:
Andrew Audibert 2020-05-18 18:54:25 -07:00 committed by TensorFlower Gardener
parent 714092f360
commit d3886d23d7
12 changed files with 106 additions and 86 deletions

View File

@ -1,5 +1,10 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core/platform:build_config.bzl", "tf_protos_all") load(
"//tensorflow/core/platform:build_config.bzl",
"tf_additional_all_protos",
"tf_proto_library",
"tf_protos_all",
)
package( package(
default_visibility = [ default_visibility = [
@ -10,6 +15,46 @@ package(
exports_files(["LICENSE"]) exports_files(["LICENSE"])
cc_library(
name = "compression_utils",
srcs = ["compression_utils.cc"],
hdrs = [
"compression_utils.h",
],
deps = [
":dataset_proto_cc",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/memory",
],
)
tf_cc_test(
name = "compression_utils_test",
srcs = ["compression_utils_test.cc"],
deps = [
":compression_utils",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels/data:dataset_test_base",
],
)
tf_proto_library(
name = "dataset_proto",
srcs = ["dataset.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos(),
)
cc_library( cc_library(
name = "standalone", name = "standalone",
srcs = ["standalone.cc"], srcs = ["standalone.cc"],

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/data/service/compression_utils.h" #include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
@ -21,11 +21,11 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
namespace service_util {
Status Compress(const std::vector<Tensor>& element, CompressedElement* out) { Status CompressElement(const std::vector<Tensor>& element,
CompressedElement* out) {
tensorflow::profiler::TraceMe activity( tensorflow::profiler::TraceMe activity(
"Compress", tensorflow::profiler::TraceMeLevel::kInfo); "CompressElement", tensorflow::profiler::TraceMeLevel::kInfo);
// Step 1: Determine the total uncompressed size. This requires serializing // Step 1: Determine the total uncompressed size. This requires serializing
// non-memcopyable tensors, which we save to use again later. // non-memcopyable tensors, which we save to use again later.
@ -51,7 +51,8 @@ Status Compress(const std::vector<Tensor>& element, CompressedElement* out) {
char* position = uncompressed.mdata(); char* position = uncompressed.mdata();
int non_memcpy_component_index = 0; int non_memcpy_component_index = 0;
for (auto& component : element) { for (auto& component : element) {
ComponentMetadata* metadata = out->mutable_component_metadata()->Add(); CompressedComponentMetadata* metadata =
out->mutable_component_metadata()->Add();
metadata->set_dtype(component.dtype()); metadata->set_dtype(component.dtype());
component.shape().AsProto(metadata->mutable_tensor_shape()); component.shape().AsProto(metadata->mutable_tensor_shape());
if (DataTypeCanUseMemcpy(component.dtype())) { if (DataTypeCanUseMemcpy(component.dtype())) {
@ -74,10 +75,10 @@ Status Compress(const std::vector<Tensor>& element, CompressedElement* out) {
return Status::OK(); return Status::OK();
} }
Status Uncompress(const CompressedElement& compressed, Status UncompressElement(const CompressedElement& compressed,
std::vector<Tensor>* out) { std::vector<Tensor>* out) {
tensorflow::profiler::TraceMe activity( tensorflow::profiler::TraceMe activity(
"Uncompress", tensorflow::profiler::TraceMeLevel::kInfo); "UncompressElement", tensorflow::profiler::TraceMeLevel::kInfo);
int num_components = compressed.component_metadata_size(); int num_components = compressed.component_metadata_size();
out->clear(); out->clear();
out->reserve(num_components); out->reserve(num_components);
@ -92,7 +93,8 @@ Status Uncompress(const CompressedElement& compressed,
tensor_proto_strs.reserve(num_components); tensor_proto_strs.reserve(num_components);
int64 total_size = 0; int64 total_size = 0;
for (int i = 0; i < num_components; ++i) { for (int i = 0; i < num_components; ++i) {
const ComponentMetadata& metadata = compressed.component_metadata(i); const CompressedComponentMetadata& metadata =
compressed.component_metadata(i);
if (DataTypeCanUseMemcpy(metadata.dtype())) { if (DataTypeCanUseMemcpy(metadata.dtype())) {
out->emplace_back(metadata.dtype(), metadata.tensor_shape()); out->emplace_back(metadata.dtype(), metadata.tensor_shape());
TensorBuffer* buffer = DMAHelper::buffer(&out->back()); TensorBuffer* buffer = DMAHelper::buffer(&out->back());
@ -146,6 +148,5 @@ Status Uncompress(const CompressedElement& compressed,
return Status::OK(); return Status::OK();
} }
} // namespace service_util
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,24 +16,23 @@ limitations under the License.
#define TENSORFLOW_CORE_DATA_SERVICE_COMPRESSION_UTILS_H_ #define TENSORFLOW_CORE_DATA_SERVICE_COMPRESSION_UTILS_H_
#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/dataset.pb.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
namespace service_util {
// Compresses the components of `element` into the `CompressedElement` proto. // Compresses the components of `element` into the `CompressedElement` proto.
// //
// In addition to writing the actual compressed bytes, `Compress` fills // In addition to writing the actual compressed bytes, `Compress` fills
// out the per-component metadata for the `CompressedElement`. // out the per-component metadata for the `CompressedElement`.
Status Compress(const std::vector<Tensor>& element, CompressedElement* out); Status CompressElement(const std::vector<Tensor>& element,
CompressedElement* out);
// Uncompresses a `CompressedElement` into a vector of tensor components. // Uncompresses a `CompressedElement` into a vector of tensor components.
Status Uncompress(const CompressedElement& compressed, Status UncompressElement(const CompressedElement& compressed,
std::vector<Tensor>* out); std::vector<Tensor>* out);
} // namespace service_util
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/data/service/compression_utils.h" #include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h" #include "tensorflow/core/kernels/data/dataset_test_base.h"
@ -20,7 +20,6 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
namespace service_util {
class ParameterizedCompressionUtilsTest class ParameterizedCompressionUtilsTest
: public DatasetOpsTestBase, : public DatasetOpsTestBase,
@ -29,9 +28,9 @@ class ParameterizedCompressionUtilsTest
TEST_P(ParameterizedCompressionUtilsTest, RoundTrip) { TEST_P(ParameterizedCompressionUtilsTest, RoundTrip) {
std::vector<Tensor> element = GetParam(); std::vector<Tensor> element = GetParam();
CompressedElement compressed; CompressedElement compressed;
TF_ASSERT_OK(Compress(element, &compressed)); TF_ASSERT_OK(CompressElement(element, &compressed));
std::vector<Tensor> round_trip_element; std::vector<Tensor> round_trip_element;
TF_ASSERT_OK(Uncompress(compressed, &round_trip_element)); TF_ASSERT_OK(UncompressElement(compressed, &round_trip_element));
TF_EXPECT_OK( TF_EXPECT_OK(
ExpectEqual(element, round_trip_element, /*compare_order=*/true)); ExpectEqual(element, round_trip_element, /*compare_order=*/true));
} }
@ -50,6 +49,5 @@ std::vector<std::vector<Tensor>> TestCases() {
INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedCompressionUtilsTest, INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedCompressionUtilsTest,
::testing::ValuesIn(TestCases())); ::testing::ValuesIn(TestCases()));
} // namespace service_util
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,27 @@
syntax = "proto3";
package tensorflow.data;
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
// This file contains protocol buffers for working with tf.data Datasets.
// Metadata describing a compressed component of a dataset element.
message CompressedComponentMetadata {
// The dtype of the component tensor.
.tensorflow.DataType dtype = 1;
// The shape of the component tensor.
.tensorflow.TensorShapeProto tensor_shape = 2;
// Size of the uncompressed tensor bytes. For tensors serialized as
// TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors,
// this is the size of the buffer underlying the Tensor.
int64 tensor_size_bytes = 3;
}
message CompressedElement {
// Compressed tensor bytes for all components of the element.
bytes data = 1;
// Metadata for the components of the element.
repeated CompressedComponentMetadata component_metadata = 2;
}

View File

@ -44,6 +44,7 @@ tf_proto_library(
cc_api_version = 2, cc_api_version = 2,
protodeps = tf_additional_all_protos() + [ protodeps = tf_additional_all_protos() + [
":common_proto", ":common_proto",
"//tensorflow/core/data:dataset_proto",
], ],
) )
@ -84,7 +85,6 @@ cc_library(
], ],
deps = [ deps = [
":common_proto_cc", ":common_proto_cc",
":compression_utils",
":credentials_factory", ":credentials_factory",
":grpc_util", ":grpc_util",
":master_cc_grpc_proto", ":master_cc_grpc_proto",
@ -98,6 +98,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/data:compression_utils",
"//tensorflow/core/data:standalone", "//tensorflow/core/data:standalone",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
@ -129,39 +130,6 @@ tf_cc_test(
], ],
) )
cc_library(
name = "compression_utils",
srcs = ["compression_utils.cc"],
hdrs = [
"compression_utils.h",
],
deps = [
":common_proto_cc",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/memory",
],
)
tf_cc_test(
name = "compression_utils_test",
srcs = ["compression_utils_test.cc"],
deps = [
":compression_utils",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels/data:dataset_test_base",
],
)
cc_library( cc_library(
name = "credentials_factory", name = "credentials_factory",
srcs = ["credentials_factory.cc"], srcs = ["credentials_factory.cc"],
@ -317,7 +285,6 @@ tf_cc_test(
srcs = ["data_service_test.cc"], srcs = ["data_service_test.cc"],
tags = ["no_windows"], tags = ["no_windows"],
deps = [ deps = [
":compression_utils",
":data_service", ":data_service",
":grpc_master_impl", ":grpc_master_impl",
":grpc_util", ":grpc_util",
@ -333,6 +300,7 @@ tf_cc_test(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/data:compression_utils",
"//tensorflow/core/kernels/data:dataset_test_base", "//tensorflow/core/kernels/data:dataset_test_base",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
tf_grpc_cc_dependency(), tf_grpc_cc_dependency(),

View File

@ -3,7 +3,6 @@ syntax = "proto3";
package tensorflow.data; package tensorflow.data;
import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto"; import "tensorflow/core/framework/types.proto";
message DatasetDef { message DatasetDef {
@ -12,24 +11,6 @@ message DatasetDef {
GraphDef graph = 1; GraphDef graph = 1;
} }
message ComponentMetadata {
// The dtype of the component tensor.
.tensorflow.DataType dtype = 1;
// The shape of the component tensor.
.tensorflow.TensorShapeProto tensor_shape = 2;
// Size of the uncompressed tensor bytes. For tensors serialized as
// TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors,
// this is the size of the buffer underlying the Tensor.
int64 tensor_size_bytes = 3;
}
message CompressedElement {
// Compressed tensor bytes for all components of the element.
bytes data = 1;
// Metadata for the components of the element.
repeated ComponentMetadata component_metadata = 2;
}
message TaskDef { message TaskDef {
// The dataset to iterate over. // The dataset to iterate over.
// TODO(aaudibert): load the dataset from disk instead of passing it here. // TODO(aaudibert): load the dataset from disk instead of passing it here.

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "grpcpp/create_channel.h" #include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h" #include "grpcpp/security/credentials.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "tensorflow/core/data/service/compression_utils.h" #include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h" #include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/master.pb.h" #include "tensorflow/core/data/service/master.pb.h"
@ -74,7 +74,7 @@ Status CheckWorkerOutput(const std::string& worker_address, int64 task_id,
return errors::Internal("Reached end of sequence too early."); return errors::Internal("Reached end of sequence too early.");
} }
std::vector<Tensor> element; std::vector<Tensor> element;
TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element)); TF_RETURN_IF_ERROR(UncompressElement(compressed, &element));
TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(element, expected, TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(element, expected,
/*compare_order=*/true)); /*compare_order=*/true));
} }

View File

@ -2,6 +2,7 @@ syntax = "proto3";
package tensorflow.data; package tensorflow.data;
import "tensorflow/core/data/dataset.proto";
import "tensorflow/core/data/service/common.proto"; import "tensorflow/core/data/service/common.proto";
message ProcessTaskRequest { message ProcessTaskRequest {

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/data/service/compression_utils.h" #include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/data/service/credentials_factory.h" #include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h" #include "tensorflow/core/data/service/master.grpc.pb.h"
@ -135,8 +135,8 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
if (!end_of_sequence) { if (!end_of_sequence) {
VLOG(3) << "Producing an element for task " << request->task_id(); VLOG(3) << "Producing an element for task " << request->task_id();
TF_RETURN_IF_ERROR(service_util::Compress( TF_RETURN_IF_ERROR(
outputs, response->mutable_compressed_element())); CompressElement(outputs, response->mutable_compressed_element()));
} }
response->set_end_of_sequence(end_of_sequence); response->set_end_of_sequence(end_of_sequence);

View File

@ -131,8 +131,8 @@ tf_kernel_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/data/service:common_proto_cc", "//tensorflow/core/data:compression_utils",
"//tensorflow/core/data/service:compression_utils", "//tensorflow/core/data:dataset_proto_cc",
"//tensorflow/core/data/service:data_service", "//tensorflow/core/data/service:data_service",
"//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/kernels/data:dataset_utils", "//tensorflow/core/kernels/data:dataset_utils",

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/data/service/compression_utils.h" #include "tensorflow/core/data/dataset.pb.h"
#include "tensorflow/core/data/service/data_service.h" #include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
@ -496,7 +496,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor> element; std::vector<Tensor> element;
if (!end_of_sequence) { if (!end_of_sequence) {
TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element)); TF_RETURN_IF_ERROR(UncompressElement(compressed, &element));
} }
mutex_lock l(mu_); mutex_lock l(mu_);
if (end_of_sequence) { if (end_of_sequence) {