From d3886d23d7c5f423390b4c570842fe2c31f24ff5 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Mon, 18 May 2020 18:54:25 -0700 Subject: [PATCH] Move compression_utils to core/data. This is in preparation for adding a CompressElementOp, which will use CompressElement to compress a dataset element in a tf.data service agnostic way. PiperOrigin-RevId: 312197651 Change-Id: I3558b2f5036dcf4c91ed9059a7b896351c79da40 --- tensorflow/core/data/BUILD | 47 ++++++++++++++++++- .../data/{service => }/compression_utils.cc | 21 +++++---- .../data/{service => }/compression_utils.h | 11 ++--- .../{service => }/compression_utils_test.cc | 8 ++-- tensorflow/core/data/dataset.proto | 27 +++++++++++ tensorflow/core/data/service/BUILD | 38 ++------------- tensorflow/core/data/service/common.proto | 19 -------- .../core/data/service/data_service_test.cc | 4 +- tensorflow/core/data/service/worker.proto | 1 + tensorflow/core/data/service/worker_impl.cc | 6 +-- .../core/kernels/data/experimental/BUILD | 4 +- .../experimental/data_service_dataset_op.cc | 6 +-- 12 files changed, 106 insertions(+), 86 deletions(-) rename tensorflow/core/data/{service => }/compression_utils.cc (90%) rename tensorflow/core/data/{service => }/compression_utils.h (82%) rename tensorflow/core/data/{service => }/compression_utils_test.cc (89%) create mode 100644 tensorflow/core/data/dataset.proto diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 9c58be108fc..e42c46d6348 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -1,5 +1,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow/core/platform:build_config.bzl", "tf_protos_all") +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_additional_all_protos", + "tf_proto_library", + "tf_protos_all", +) package( default_visibility = [ @@ -10,6 +15,46 @@ package( exports_files(["LICENSE"]) +cc_library( + name = "compression_utils", + srcs = ["compression_utils.cc"], + hdrs = [ + "compression_utils.h", + ], + deps = [ + ":dataset_proto_cc", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "compression_utils_test", + srcs = ["compression_utils_test.cc"], + deps = [ + ":compression_utils", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels/data:dataset_test_base", + ], +) + +tf_proto_library( + name = "dataset_proto", + srcs = ["dataset.proto"], + cc_api_version = 2, + protodeps = tf_additional_all_protos(), +) + cc_library( name = "standalone", srcs = ["standalone.cc"], diff --git a/tensorflow/core/data/service/compression_utils.cc b/tensorflow/core/data/compression_utils.cc similarity index 90% rename from tensorflow/core/data/service/compression_utils.cc rename to tensorflow/core/data/compression_utils.cc index c4a47e1b00e..ea06a082128 100644 --- a/tensorflow/core/data/service/compression_utils.cc +++ b/tensorflow/core/data/compression_utils.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/compression_utils.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -21,11 +21,11 @@ limitations under the License. namespace tensorflow { namespace data { -namespace service_util { -Status Compress(const std::vector& element, CompressedElement* out) { +Status CompressElement(const std::vector& element, + CompressedElement* out) { tensorflow::profiler::TraceMe activity( - "Compress", tensorflow::profiler::TraceMeLevel::kInfo); + "CompressElement", tensorflow::profiler::TraceMeLevel::kInfo); // Step 1: Determine the total uncompressed size. This requires serializing // non-memcopyable tensors, which we save to use again later. @@ -51,7 +51,8 @@ Status Compress(const std::vector& element, CompressedElement* out) { char* position = uncompressed.mdata(); int non_memcpy_component_index = 0; for (auto& component : element) { - ComponentMetadata* metadata = out->mutable_component_metadata()->Add(); + CompressedComponentMetadata* metadata = + out->mutable_component_metadata()->Add(); metadata->set_dtype(component.dtype()); component.shape().AsProto(metadata->mutable_tensor_shape()); if (DataTypeCanUseMemcpy(component.dtype())) { @@ -74,10 +75,10 @@ Status Compress(const std::vector& element, CompressedElement* out) { return Status::OK(); } -Status Uncompress(const CompressedElement& compressed, - std::vector* out) { +Status UncompressElement(const CompressedElement& compressed, + std::vector* out) { tensorflow::profiler::TraceMe activity( - "Uncompress", tensorflow::profiler::TraceMeLevel::kInfo); + "UncompressElement", tensorflow::profiler::TraceMeLevel::kInfo); int num_components = compressed.component_metadata_size(); out->clear(); out->reserve(num_components); @@ -92,7 +93,8 @@ Status Uncompress(const CompressedElement& compressed, tensor_proto_strs.reserve(num_components); int64 total_size = 0; for (int i = 0; i < num_components; ++i) { - const ComponentMetadata& metadata = compressed.component_metadata(i); + const CompressedComponentMetadata& metadata = + compressed.component_metadata(i); if (DataTypeCanUseMemcpy(metadata.dtype())) { out->emplace_back(metadata.dtype(), metadata.tensor_shape()); TensorBuffer* buffer = DMAHelper::buffer(&out->back()); @@ -146,6 +148,5 @@ Status Uncompress(const CompressedElement& compressed, return Status::OK(); } -} // namespace service_util } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/compression_utils.h b/tensorflow/core/data/compression_utils.h similarity index 82% rename from tensorflow/core/data/service/compression_utils.h rename to tensorflow/core/data/compression_utils.h index 96698aaaf09..5e033771272 100644 --- a/tensorflow/core/data/service/compression_utils.h +++ b/tensorflow/core/data/compression_utils.h @@ -16,24 +16,23 @@ limitations under the License. #define TENSORFLOW_CORE_DATA_SERVICE_COMPRESSION_UTILS_H_ #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/dataset.pb.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { namespace data { -namespace service_util { // Compresses the components of `element` into the `CompressedElement` proto. // // In addition to writing the actual compressed bytes, `Compress` fills // out the per-component metadata for the `CompressedElement`. -Status Compress(const std::vector& element, CompressedElement* out); +Status CompressElement(const std::vector& element, + CompressedElement* out); // Uncompresses a `CompressedElement` into a vector of tensor components. -Status Uncompress(const CompressedElement& compressed, - std::vector* out); +Status UncompressElement(const CompressedElement& compressed, + std::vector* out); -} // namespace service_util } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/compression_utils_test.cc b/tensorflow/core/data/compression_utils_test.cc similarity index 89% rename from tensorflow/core/data/service/compression_utils_test.cc rename to tensorflow/core/data/compression_utils_test.cc index b5da13efeed..eb220092f88 100644 --- a/tensorflow/core/data/service/compression_utils_test.cc +++ b/tensorflow/core/data/compression_utils_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/compression_utils.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/kernels/data/dataset_test_base.h" @@ -20,7 +20,6 @@ limitations under the License. namespace tensorflow { namespace data { -namespace service_util { class ParameterizedCompressionUtilsTest : public DatasetOpsTestBase, @@ -29,9 +28,9 @@ class ParameterizedCompressionUtilsTest TEST_P(ParameterizedCompressionUtilsTest, RoundTrip) { std::vector element = GetParam(); CompressedElement compressed; - TF_ASSERT_OK(Compress(element, &compressed)); + TF_ASSERT_OK(CompressElement(element, &compressed)); std::vector round_trip_element; - TF_ASSERT_OK(Uncompress(compressed, &round_trip_element)); + TF_ASSERT_OK(UncompressElement(compressed, &round_trip_element)); TF_EXPECT_OK( ExpectEqual(element, round_trip_element, /*compare_order=*/true)); } @@ -50,6 +49,5 @@ std::vector> TestCases() { INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedCompressionUtilsTest, ::testing::ValuesIn(TestCases())); -} // namespace service_util } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/dataset.proto b/tensorflow/core/data/dataset.proto new file mode 100644 index 00000000000..27a36364e76 --- /dev/null +++ b/tensorflow/core/data/dataset.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package tensorflow.data; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// This file contains protocol buffers for working with tf.data Datasets. + +// Metadata describing a compressed component of a dataset element. +message CompressedComponentMetadata { + // The dtype of the component tensor. + .tensorflow.DataType dtype = 1; + // The shape of the component tensor. + .tensorflow.TensorShapeProto tensor_shape = 2; + // Size of the uncompressed tensor bytes. For tensors serialized as + // TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors, + // this is the size of the buffer underlying the Tensor. + int64 tensor_size_bytes = 3; +} + +message CompressedElement { + // Compressed tensor bytes for all components of the element. + bytes data = 1; + // Metadata for the components of the element. + repeated CompressedComponentMetadata component_metadata = 2; +} diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 5413493cb78..b87f4f171cd 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -44,6 +44,7 @@ tf_proto_library( cc_api_version = 2, protodeps = tf_additional_all_protos() + [ ":common_proto", + "//tensorflow/core/data:dataset_proto", ], ) @@ -84,7 +85,6 @@ cc_library( ], deps = [ ":common_proto_cc", - ":compression_utils", ":credentials_factory", ":grpc_util", ":master_cc_grpc_proto", @@ -98,6 +98,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/data:compression_utils", "//tensorflow/core/data:standalone", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", @@ -129,39 +130,6 @@ tf_cc_test( ], ) -cc_library( - name = "compression_utils", - srcs = ["compression_utils.cc"], - hdrs = [ - "compression_utils.h", - ], - deps = [ - ":common_proto_cc", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/profiler/lib:traceme", - "@com_google_absl//absl/memory", - ], -) - -tf_cc_test( - name = "compression_utils_test", - srcs = ["compression_utils_test.cc"], - deps = [ - ":compression_utils", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/kernels/data:dataset_test_base", - ], -) - cc_library( name = "credentials_factory", srcs = ["credentials_factory.cc"], @@ -317,7 +285,6 @@ tf_cc_test( srcs = ["data_service_test.cc"], tags = ["no_windows"], deps = [ - ":compression_utils", ":data_service", ":grpc_master_impl", ":grpc_util", @@ -333,6 +300,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/data:compression_utils", "//tensorflow/core/kernels/data:dataset_test_base", "@com_google_absl//absl/strings", tf_grpc_cc_dependency(), diff --git a/tensorflow/core/data/service/common.proto b/tensorflow/core/data/service/common.proto index 6dfa698764b..4bde56fe1ca 100644 --- a/tensorflow/core/data/service/common.proto +++ b/tensorflow/core/data/service/common.proto @@ -3,7 +3,6 @@ syntax = "proto3"; package tensorflow.data; import "tensorflow/core/framework/graph.proto"; -import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/types.proto"; message DatasetDef { @@ -12,24 +11,6 @@ message DatasetDef { GraphDef graph = 1; } -message ComponentMetadata { - // The dtype of the component tensor. - .tensorflow.DataType dtype = 1; - // The shape of the component tensor. - .tensorflow.TensorShapeProto tensor_shape = 2; - // Size of the uncompressed tensor bytes. For tensors serialized as - // TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors, - // this is the size of the buffer underlying the Tensor. - int64 tensor_size_bytes = 3; -} - -message CompressedElement { - // Compressed tensor bytes for all components of the element. - bytes data = 1; - // Metadata for the components of the element. - repeated ComponentMetadata component_metadata = 2; -} - message TaskDef { // The dataset to iterate over. // TODO(aaudibert): load the dataset from disk instead of passing it here. diff --git a/tensorflow/core/data/service/data_service_test.cc b/tensorflow/core/data/service/data_service_test.cc index 73a46bad3d0..bd01cb90a66 100644 --- a/tensorflow/core/data/service/data_service_test.cc +++ b/tensorflow/core/data/service/data_service_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" #include "absl/strings/str_split.h" -#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/compression_utils.h" #include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/master.grpc.pb.h" #include "tensorflow/core/data/service/master.pb.h" @@ -74,7 +74,7 @@ Status CheckWorkerOutput(const std::string& worker_address, int64 task_id, return errors::Internal("Reached end of sequence too early."); } std::vector element; - TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element)); + TF_RETURN_IF_ERROR(UncompressElement(compressed, &element)); TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(element, expected, /*compare_order=*/true)); } diff --git a/tensorflow/core/data/service/worker.proto b/tensorflow/core/data/service/worker.proto index 04b8f03474c..51c6899f540 100644 --- a/tensorflow/core/data/service/worker.proto +++ b/tensorflow/core/data/service/worker.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package tensorflow.data; +import "tensorflow/core/data/dataset.proto"; import "tensorflow/core/data/service/common.proto"; message ProcessTaskRequest { diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 8d00825227b..b4be18ebccd 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/compression_utils.h" #include "tensorflow/core/data/service/credentials_factory.h" #include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/master.grpc.pb.h" @@ -135,8 +135,8 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, if (!end_of_sequence) { VLOG(3) << "Producing an element for task " << request->task_id(); - TF_RETURN_IF_ERROR(service_util::Compress( - outputs, response->mutable_compressed_element())); + TF_RETURN_IF_ERROR( + CompressElement(outputs, response->mutable_compressed_element())); } response->set_end_of_sequence(end_of_sequence); diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 4ddfd99951c..85f8af878ee 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -131,8 +131,8 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/data/service:common_proto_cc", - "//tensorflow/core/data/service:compression_utils", + "//tensorflow/core/data:compression_utils", + "//tensorflow/core/data:dataset_proto_cc", "//tensorflow/core/data/service:data_service", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/kernels/data:dataset_utils", diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index 56077a671fb..3f8e778d1d8 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" -#include "tensorflow/core/data/service/common.pb.h" -#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/compression_utils.h" +#include "tensorflow/core/data/dataset.pb.h" #include "tensorflow/core/data/service/data_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/dataset.h" @@ -496,7 +496,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { std::vector element; if (!end_of_sequence) { - TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element)); + TF_RETURN_IF_ERROR(UncompressElement(compressed, &element)); } mutex_lock l(mu_); if (end_of_sequence) {