Move compression_utils to core/data.
This is in preparation for adding a CompressElementOp, which will use CompressElement to compress a dataset element in a tf.data service agnostic way. PiperOrigin-RevId: 312197651 Change-Id: I3558b2f5036dcf4c91ed9059a7b896351c79da40
This commit is contained in:
parent
714092f360
commit
d3886d23d7
@ -1,5 +1,10 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow/core/platform:build_config.bzl", "tf_protos_all")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_additional_all_protos",
|
||||
"tf_proto_library",
|
||||
"tf_protos_all",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -10,6 +15,46 @@ package(
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "compression_utils",
|
||||
srcs = ["compression_utils.cc"],
|
||||
hdrs = [
|
||||
"compression_utils.h",
|
||||
],
|
||||
deps = [
|
||||
":dataset_proto_cc",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "compression_utils_test",
|
||||
srcs = ["compression_utils_test.cc"],
|
||||
deps = [
|
||||
":compression_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "dataset_proto",
|
||||
srcs = ["dataset.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = tf_additional_all_protos(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "standalone",
|
||||
srcs = ["standalone.cc"],
|
||||
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/data/service/compression_utils.h"
|
||||
#include "tensorflow/core/data/compression_utils.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
@ -21,11 +21,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace service_util {
|
||||
|
||||
Status Compress(const std::vector<Tensor>& element, CompressedElement* out) {
|
||||
Status CompressElement(const std::vector<Tensor>& element,
|
||||
CompressedElement* out) {
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"Compress", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
"CompressElement", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
|
||||
// Step 1: Determine the total uncompressed size. This requires serializing
|
||||
// non-memcopyable tensors, which we save to use again later.
|
||||
@ -51,7 +51,8 @@ Status Compress(const std::vector<Tensor>& element, CompressedElement* out) {
|
||||
char* position = uncompressed.mdata();
|
||||
int non_memcpy_component_index = 0;
|
||||
for (auto& component : element) {
|
||||
ComponentMetadata* metadata = out->mutable_component_metadata()->Add();
|
||||
CompressedComponentMetadata* metadata =
|
||||
out->mutable_component_metadata()->Add();
|
||||
metadata->set_dtype(component.dtype());
|
||||
component.shape().AsProto(metadata->mutable_tensor_shape());
|
||||
if (DataTypeCanUseMemcpy(component.dtype())) {
|
||||
@ -74,10 +75,10 @@ Status Compress(const std::vector<Tensor>& element, CompressedElement* out) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Uncompress(const CompressedElement& compressed,
|
||||
std::vector<Tensor>* out) {
|
||||
Status UncompressElement(const CompressedElement& compressed,
|
||||
std::vector<Tensor>* out) {
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"Uncompress", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
"UncompressElement", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
int num_components = compressed.component_metadata_size();
|
||||
out->clear();
|
||||
out->reserve(num_components);
|
||||
@ -92,7 +93,8 @@ Status Uncompress(const CompressedElement& compressed,
|
||||
tensor_proto_strs.reserve(num_components);
|
||||
int64 total_size = 0;
|
||||
for (int i = 0; i < num_components; ++i) {
|
||||
const ComponentMetadata& metadata = compressed.component_metadata(i);
|
||||
const CompressedComponentMetadata& metadata =
|
||||
compressed.component_metadata(i);
|
||||
if (DataTypeCanUseMemcpy(metadata.dtype())) {
|
||||
out->emplace_back(metadata.dtype(), metadata.tensor_shape());
|
||||
TensorBuffer* buffer = DMAHelper::buffer(&out->back());
|
||||
@ -146,6 +148,5 @@ Status Uncompress(const CompressedElement& compressed,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace service_util
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -16,24 +16,23 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_COMPRESSION_UTILS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/dataset.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace service_util {
|
||||
|
||||
// Compresses the components of `element` into the `CompressedElement` proto.
|
||||
//
|
||||
// In addition to writing the actual compressed bytes, `Compress` fills
|
||||
// out the per-component metadata for the `CompressedElement`.
|
||||
Status Compress(const std::vector<Tensor>& element, CompressedElement* out);
|
||||
Status CompressElement(const std::vector<Tensor>& element,
|
||||
CompressedElement* out);
|
||||
|
||||
// Uncompresses a `CompressedElement` into a vector of tensor components.
|
||||
Status Uncompress(const CompressedElement& compressed,
|
||||
std::vector<Tensor>* out);
|
||||
Status UncompressElement(const CompressedElement& compressed,
|
||||
std::vector<Tensor>* out);
|
||||
|
||||
} // namespace service_util
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/data/service/compression_utils.h"
|
||||
#include "tensorflow/core/data/compression_utils.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
@ -20,7 +20,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace service_util {
|
||||
|
||||
class ParameterizedCompressionUtilsTest
|
||||
: public DatasetOpsTestBase,
|
||||
@ -29,9 +28,9 @@ class ParameterizedCompressionUtilsTest
|
||||
TEST_P(ParameterizedCompressionUtilsTest, RoundTrip) {
|
||||
std::vector<Tensor> element = GetParam();
|
||||
CompressedElement compressed;
|
||||
TF_ASSERT_OK(Compress(element, &compressed));
|
||||
TF_ASSERT_OK(CompressElement(element, &compressed));
|
||||
std::vector<Tensor> round_trip_element;
|
||||
TF_ASSERT_OK(Uncompress(compressed, &round_trip_element));
|
||||
TF_ASSERT_OK(UncompressElement(compressed, &round_trip_element));
|
||||
TF_EXPECT_OK(
|
||||
ExpectEqual(element, round_trip_element, /*compare_order=*/true));
|
||||
}
|
||||
@ -50,6 +49,5 @@ std::vector<std::vector<Tensor>> TestCases() {
|
||||
INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedCompressionUtilsTest,
|
||||
::testing::ValuesIn(TestCases()));
|
||||
|
||||
} // namespace service_util
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
27
tensorflow/core/data/dataset.proto
Normal file
27
tensorflow/core/data/dataset.proto
Normal file
@ -0,0 +1,27 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.data;
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
// This file contains protocol buffers for working with tf.data Datasets.
|
||||
|
||||
// Metadata describing a compressed component of a dataset element.
|
||||
message CompressedComponentMetadata {
|
||||
// The dtype of the component tensor.
|
||||
.tensorflow.DataType dtype = 1;
|
||||
// The shape of the component tensor.
|
||||
.tensorflow.TensorShapeProto tensor_shape = 2;
|
||||
// Size of the uncompressed tensor bytes. For tensors serialized as
|
||||
// TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors,
|
||||
// this is the size of the buffer underlying the Tensor.
|
||||
int64 tensor_size_bytes = 3;
|
||||
}
|
||||
|
||||
message CompressedElement {
|
||||
// Compressed tensor bytes for all components of the element.
|
||||
bytes data = 1;
|
||||
// Metadata for the components of the element.
|
||||
repeated CompressedComponentMetadata component_metadata = 2;
|
||||
}
|
@ -44,6 +44,7 @@ tf_proto_library(
|
||||
cc_api_version = 2,
|
||||
protodeps = tf_additional_all_protos() + [
|
||||
":common_proto",
|
||||
"//tensorflow/core/data:dataset_proto",
|
||||
],
|
||||
)
|
||||
|
||||
@ -84,7 +85,6 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":common_proto_cc",
|
||||
":compression_utils",
|
||||
":credentials_factory",
|
||||
":grpc_util",
|
||||
":master_cc_grpc_proto",
|
||||
@ -98,6 +98,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/data:compression_utils",
|
||||
"//tensorflow/core/data:standalone",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -129,39 +130,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "compression_utils",
|
||||
srcs = ["compression_utils.cc"],
|
||||
hdrs = [
|
||||
"compression_utils.h",
|
||||
],
|
||||
deps = [
|
||||
":common_proto_cc",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "compression_utils_test",
|
||||
srcs = ["compression_utils_test.cc"],
|
||||
deps = [
|
||||
":compression_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "credentials_factory",
|
||||
srcs = ["credentials_factory.cc"],
|
||||
@ -317,7 +285,6 @@ tf_cc_test(
|
||||
srcs = ["data_service_test.cc"],
|
||||
tags = ["no_windows"],
|
||||
deps = [
|
||||
":compression_utils",
|
||||
":data_service",
|
||||
":grpc_master_impl",
|
||||
":grpc_util",
|
||||
@ -333,6 +300,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/data:compression_utils",
|
||||
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||
"@com_google_absl//absl/strings",
|
||||
tf_grpc_cc_dependency(),
|
||||
|
@ -3,7 +3,6 @@ syntax = "proto3";
|
||||
package tensorflow.data;
|
||||
|
||||
import "tensorflow/core/framework/graph.proto";
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
message DatasetDef {
|
||||
@ -12,24 +11,6 @@ message DatasetDef {
|
||||
GraphDef graph = 1;
|
||||
}
|
||||
|
||||
message ComponentMetadata {
|
||||
// The dtype of the component tensor.
|
||||
.tensorflow.DataType dtype = 1;
|
||||
// The shape of the component tensor.
|
||||
.tensorflow.TensorShapeProto tensor_shape = 2;
|
||||
// Size of the uncompressed tensor bytes. For tensors serialized as
|
||||
// TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors,
|
||||
// this is the size of the buffer underlying the Tensor.
|
||||
int64 tensor_size_bytes = 3;
|
||||
}
|
||||
|
||||
message CompressedElement {
|
||||
// Compressed tensor bytes for all components of the element.
|
||||
bytes data = 1;
|
||||
// Metadata for the components of the element.
|
||||
repeated ComponentMetadata component_metadata = 2;
|
||||
}
|
||||
|
||||
message TaskDef {
|
||||
// The dataset to iterate over.
|
||||
// TODO(aaudibert): load the dataset from disk instead of passing it here.
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "grpcpp/create_channel.h"
|
||||
#include "grpcpp/security/credentials.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/core/data/service/compression_utils.h"
|
||||
#include "tensorflow/core/data/compression_utils.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/master.pb.h"
|
||||
@ -74,7 +74,7 @@ Status CheckWorkerOutput(const std::string& worker_address, int64 task_id,
|
||||
return errors::Internal("Reached end of sequence too early.");
|
||||
}
|
||||
std::vector<Tensor> element;
|
||||
TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element));
|
||||
TF_RETURN_IF_ERROR(UncompressElement(compressed, &element));
|
||||
TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(element, expected,
|
||||
/*compare_order=*/true));
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ syntax = "proto3";
|
||||
|
||||
package tensorflow.data;
|
||||
|
||||
import "tensorflow/core/data/dataset.proto";
|
||||
import "tensorflow/core/data/service/common.proto";
|
||||
|
||||
message ProcessTaskRequest {
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/data/service/compression_utils.h"
|
||||
#include "tensorflow/core/data/compression_utils.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
@ -135,8 +135,8 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||
|
||||
if (!end_of_sequence) {
|
||||
VLOG(3) << "Producing an element for task " << request->task_id();
|
||||
TF_RETURN_IF_ERROR(service_util::Compress(
|
||||
outputs, response->mutable_compressed_element()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompressElement(outputs, response->mutable_compressed_element()));
|
||||
}
|
||||
response->set_end_of_sequence(end_of_sequence);
|
||||
|
||||
|
@ -131,8 +131,8 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/data/service:common_proto_cc",
|
||||
"//tensorflow/core/data/service:compression_utils",
|
||||
"//tensorflow/core/data:compression_utils",
|
||||
"//tensorflow/core/data:dataset_proto_cc",
|
||||
"//tensorflow/core/data/service:data_service",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/compression_utils.h"
|
||||
#include "tensorflow/core/data/compression_utils.h"
|
||||
#include "tensorflow/core/data/dataset.pb.h"
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
@ -496,7 +496,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
std::vector<Tensor> element;
|
||||
if (!end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element));
|
||||
TF_RETURN_IF_ERROR(UncompressElement(compressed, &element));
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
if (end_of_sequence) {
|
||||
|
Loading…
Reference in New Issue
Block a user