Move compression_utils to core/data.
This is in preparation for adding a CompressElementOp, which will use CompressElement to compress a dataset element in a tf.data service agnostic way. PiperOrigin-RevId: 312197651 Change-Id: I3558b2f5036dcf4c91ed9059a7b896351c79da40
This commit is contained in:
parent
714092f360
commit
d3886d23d7
@ -1,5 +1,10 @@
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||||
load("//tensorflow/core/platform:build_config.bzl", "tf_protos_all")
|
load(
|
||||||
|
"//tensorflow/core/platform:build_config.bzl",
|
||||||
|
"tf_additional_all_protos",
|
||||||
|
"tf_proto_library",
|
||||||
|
"tf_protos_all",
|
||||||
|
)
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
@ -10,6 +15,46 @@ package(
|
|||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "compression_utils",
|
||||||
|
srcs = ["compression_utils.cc"],
|
||||||
|
hdrs = [
|
||||||
|
"compression_utils.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":dataset_proto_cc",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "compression_utils_test",
|
||||||
|
srcs = ["compression_utils_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":compression_utils",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_proto_library(
|
||||||
|
name = "dataset_proto",
|
||||||
|
srcs = ["dataset.proto"],
|
||||||
|
cc_api_version = 2,
|
||||||
|
protodeps = tf_additional_all_protos(),
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "standalone",
|
name = "standalone",
|
||||||
srcs = ["standalone.cc"],
|
srcs = ["standalone.cc"],
|
||||||
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/data/service/compression_utils.h"
|
#include "tensorflow/core/data/compression_utils.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
@ -21,11 +21,11 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace service_util {
|
|
||||||
|
|
||||||
Status Compress(const std::vector<Tensor>& element, CompressedElement* out) {
|
Status CompressElement(const std::vector<Tensor>& element,
|
||||||
|
CompressedElement* out) {
|
||||||
tensorflow::profiler::TraceMe activity(
|
tensorflow::profiler::TraceMe activity(
|
||||||
"Compress", tensorflow::profiler::TraceMeLevel::kInfo);
|
"CompressElement", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||||
|
|
||||||
// Step 1: Determine the total uncompressed size. This requires serializing
|
// Step 1: Determine the total uncompressed size. This requires serializing
|
||||||
// non-memcopyable tensors, which we save to use again later.
|
// non-memcopyable tensors, which we save to use again later.
|
||||||
@ -51,7 +51,8 @@ Status Compress(const std::vector<Tensor>& element, CompressedElement* out) {
|
|||||||
char* position = uncompressed.mdata();
|
char* position = uncompressed.mdata();
|
||||||
int non_memcpy_component_index = 0;
|
int non_memcpy_component_index = 0;
|
||||||
for (auto& component : element) {
|
for (auto& component : element) {
|
||||||
ComponentMetadata* metadata = out->mutable_component_metadata()->Add();
|
CompressedComponentMetadata* metadata =
|
||||||
|
out->mutable_component_metadata()->Add();
|
||||||
metadata->set_dtype(component.dtype());
|
metadata->set_dtype(component.dtype());
|
||||||
component.shape().AsProto(metadata->mutable_tensor_shape());
|
component.shape().AsProto(metadata->mutable_tensor_shape());
|
||||||
if (DataTypeCanUseMemcpy(component.dtype())) {
|
if (DataTypeCanUseMemcpy(component.dtype())) {
|
||||||
@ -74,10 +75,10 @@ Status Compress(const std::vector<Tensor>& element, CompressedElement* out) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Uncompress(const CompressedElement& compressed,
|
Status UncompressElement(const CompressedElement& compressed,
|
||||||
std::vector<Tensor>* out) {
|
std::vector<Tensor>* out) {
|
||||||
tensorflow::profiler::TraceMe activity(
|
tensorflow::profiler::TraceMe activity(
|
||||||
"Uncompress", tensorflow::profiler::TraceMeLevel::kInfo);
|
"UncompressElement", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||||
int num_components = compressed.component_metadata_size();
|
int num_components = compressed.component_metadata_size();
|
||||||
out->clear();
|
out->clear();
|
||||||
out->reserve(num_components);
|
out->reserve(num_components);
|
||||||
@ -92,7 +93,8 @@ Status Uncompress(const CompressedElement& compressed,
|
|||||||
tensor_proto_strs.reserve(num_components);
|
tensor_proto_strs.reserve(num_components);
|
||||||
int64 total_size = 0;
|
int64 total_size = 0;
|
||||||
for (int i = 0; i < num_components; ++i) {
|
for (int i = 0; i < num_components; ++i) {
|
||||||
const ComponentMetadata& metadata = compressed.component_metadata(i);
|
const CompressedComponentMetadata& metadata =
|
||||||
|
compressed.component_metadata(i);
|
||||||
if (DataTypeCanUseMemcpy(metadata.dtype())) {
|
if (DataTypeCanUseMemcpy(metadata.dtype())) {
|
||||||
out->emplace_back(metadata.dtype(), metadata.tensor_shape());
|
out->emplace_back(metadata.dtype(), metadata.tensor_shape());
|
||||||
TensorBuffer* buffer = DMAHelper::buffer(&out->back());
|
TensorBuffer* buffer = DMAHelper::buffer(&out->back());
|
||||||
@ -146,6 +148,5 @@ Status Uncompress(const CompressedElement& compressed,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace service_util
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
@ -16,24 +16,23 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_DATA_SERVICE_COMPRESSION_UTILS_H_
|
#define TENSORFLOW_CORE_DATA_SERVICE_COMPRESSION_UTILS_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/dataset.pb.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace service_util {
|
|
||||||
|
|
||||||
// Compresses the components of `element` into the `CompressedElement` proto.
|
// Compresses the components of `element` into the `CompressedElement` proto.
|
||||||
//
|
//
|
||||||
// In addition to writing the actual compressed bytes, `Compress` fills
|
// In addition to writing the actual compressed bytes, `Compress` fills
|
||||||
// out the per-component metadata for the `CompressedElement`.
|
// out the per-component metadata for the `CompressedElement`.
|
||||||
Status Compress(const std::vector<Tensor>& element, CompressedElement* out);
|
Status CompressElement(const std::vector<Tensor>& element,
|
||||||
|
CompressedElement* out);
|
||||||
|
|
||||||
// Uncompresses a `CompressedElement` into a vector of tensor components.
|
// Uncompresses a `CompressedElement` into a vector of tensor components.
|
||||||
Status Uncompress(const CompressedElement& compressed,
|
Status UncompressElement(const CompressedElement& compressed,
|
||||||
std::vector<Tensor>* out);
|
std::vector<Tensor>* out);
|
||||||
|
|
||||||
} // namespace service_util
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/data/service/compression_utils.h"
|
#include "tensorflow/core/data/compression_utils.h"
|
||||||
|
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||||
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace service_util {
|
|
||||||
|
|
||||||
class ParameterizedCompressionUtilsTest
|
class ParameterizedCompressionUtilsTest
|
||||||
: public DatasetOpsTestBase,
|
: public DatasetOpsTestBase,
|
||||||
@ -29,9 +28,9 @@ class ParameterizedCompressionUtilsTest
|
|||||||
TEST_P(ParameterizedCompressionUtilsTest, RoundTrip) {
|
TEST_P(ParameterizedCompressionUtilsTest, RoundTrip) {
|
||||||
std::vector<Tensor> element = GetParam();
|
std::vector<Tensor> element = GetParam();
|
||||||
CompressedElement compressed;
|
CompressedElement compressed;
|
||||||
TF_ASSERT_OK(Compress(element, &compressed));
|
TF_ASSERT_OK(CompressElement(element, &compressed));
|
||||||
std::vector<Tensor> round_trip_element;
|
std::vector<Tensor> round_trip_element;
|
||||||
TF_ASSERT_OK(Uncompress(compressed, &round_trip_element));
|
TF_ASSERT_OK(UncompressElement(compressed, &round_trip_element));
|
||||||
TF_EXPECT_OK(
|
TF_EXPECT_OK(
|
||||||
ExpectEqual(element, round_trip_element, /*compare_order=*/true));
|
ExpectEqual(element, round_trip_element, /*compare_order=*/true));
|
||||||
}
|
}
|
||||||
@ -50,6 +49,5 @@ std::vector<std::vector<Tensor>> TestCases() {
|
|||||||
INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedCompressionUtilsTest,
|
INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedCompressionUtilsTest,
|
||||||
::testing::ValuesIn(TestCases()));
|
::testing::ValuesIn(TestCases()));
|
||||||
|
|
||||||
} // namespace service_util
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
27
tensorflow/core/data/dataset.proto
Normal file
27
tensorflow/core/data/dataset.proto
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package tensorflow.data;
|
||||||
|
|
||||||
|
import "tensorflow/core/framework/tensor_shape.proto";
|
||||||
|
import "tensorflow/core/framework/types.proto";
|
||||||
|
|
||||||
|
// This file contains protocol buffers for working with tf.data Datasets.
|
||||||
|
|
||||||
|
// Metadata describing a compressed component of a dataset element.
|
||||||
|
message CompressedComponentMetadata {
|
||||||
|
// The dtype of the component tensor.
|
||||||
|
.tensorflow.DataType dtype = 1;
|
||||||
|
// The shape of the component tensor.
|
||||||
|
.tensorflow.TensorShapeProto tensor_shape = 2;
|
||||||
|
// Size of the uncompressed tensor bytes. For tensors serialized as
|
||||||
|
// TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors,
|
||||||
|
// this is the size of the buffer underlying the Tensor.
|
||||||
|
int64 tensor_size_bytes = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CompressedElement {
|
||||||
|
// Compressed tensor bytes for all components of the element.
|
||||||
|
bytes data = 1;
|
||||||
|
// Metadata for the components of the element.
|
||||||
|
repeated CompressedComponentMetadata component_metadata = 2;
|
||||||
|
}
|
@ -44,6 +44,7 @@ tf_proto_library(
|
|||||||
cc_api_version = 2,
|
cc_api_version = 2,
|
||||||
protodeps = tf_additional_all_protos() + [
|
protodeps = tf_additional_all_protos() + [
|
||||||
":common_proto",
|
":common_proto",
|
||||||
|
"//tensorflow/core/data:dataset_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -84,7 +85,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":common_proto_cc",
|
":common_proto_cc",
|
||||||
":compression_utils",
|
|
||||||
":credentials_factory",
|
":credentials_factory",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":master_cc_grpc_proto",
|
":master_cc_grpc_proto",
|
||||||
@ -98,6 +98,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/data:compression_utils",
|
||||||
"//tensorflow/core/data:standalone",
|
"//tensorflow/core/data:standalone",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
@ -129,39 +130,6 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "compression_utils",
|
|
||||||
srcs = ["compression_utils.cc"],
|
|
||||||
hdrs = [
|
|
||||||
"compression_utils.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":common_proto_cc",
|
|
||||||
"//tensorflow/core:core_cpu",
|
|
||||||
"//tensorflow/core:core_cpu_internal",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cc_test(
|
|
||||||
name = "compression_utils_test",
|
|
||||||
srcs = ["compression_utils_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":compression_utils",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"//tensorflow/core:test_main",
|
|
||||||
"//tensorflow/core:testlib",
|
|
||||||
"//tensorflow/core/kernels/data:dataset_test_base",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "credentials_factory",
|
name = "credentials_factory",
|
||||||
srcs = ["credentials_factory.cc"],
|
srcs = ["credentials_factory.cc"],
|
||||||
@ -317,7 +285,6 @@ tf_cc_test(
|
|||||||
srcs = ["data_service_test.cc"],
|
srcs = ["data_service_test.cc"],
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
deps = [
|
deps = [
|
||||||
":compression_utils",
|
|
||||||
":data_service",
|
":data_service",
|
||||||
":grpc_master_impl",
|
":grpc_master_impl",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
@ -333,6 +300,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/data:compression_utils",
|
||||||
"//tensorflow/core/kernels/data:dataset_test_base",
|
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
tf_grpc_cc_dependency(),
|
tf_grpc_cc_dependency(),
|
||||||
|
@ -3,7 +3,6 @@ syntax = "proto3";
|
|||||||
package tensorflow.data;
|
package tensorflow.data;
|
||||||
|
|
||||||
import "tensorflow/core/framework/graph.proto";
|
import "tensorflow/core/framework/graph.proto";
|
||||||
import "tensorflow/core/framework/tensor_shape.proto";
|
|
||||||
import "tensorflow/core/framework/types.proto";
|
import "tensorflow/core/framework/types.proto";
|
||||||
|
|
||||||
message DatasetDef {
|
message DatasetDef {
|
||||||
@ -12,24 +11,6 @@ message DatasetDef {
|
|||||||
GraphDef graph = 1;
|
GraphDef graph = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ComponentMetadata {
|
|
||||||
// The dtype of the component tensor.
|
|
||||||
.tensorflow.DataType dtype = 1;
|
|
||||||
// The shape of the component tensor.
|
|
||||||
.tensorflow.TensorShapeProto tensor_shape = 2;
|
|
||||||
// Size of the uncompressed tensor bytes. For tensors serialized as
|
|
||||||
// TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors,
|
|
||||||
// this is the size of the buffer underlying the Tensor.
|
|
||||||
int64 tensor_size_bytes = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
message CompressedElement {
|
|
||||||
// Compressed tensor bytes for all components of the element.
|
|
||||||
bytes data = 1;
|
|
||||||
// Metadata for the components of the element.
|
|
||||||
repeated ComponentMetadata component_metadata = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message TaskDef {
|
message TaskDef {
|
||||||
// The dataset to iterate over.
|
// The dataset to iterate over.
|
||||||
// TODO(aaudibert): load the dataset from disk instead of passing it here.
|
// TODO(aaudibert): load the dataset from disk instead of passing it here.
|
||||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
#include "grpcpp/create_channel.h"
|
#include "grpcpp/create_channel.h"
|
||||||
#include "grpcpp/security/credentials.h"
|
#include "grpcpp/security/credentials.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "tensorflow/core/data/service/compression_utils.h"
|
#include "tensorflow/core/data/compression_utils.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
#include "tensorflow/core/data/service/master.pb.h"
|
||||||
@ -74,7 +74,7 @@ Status CheckWorkerOutput(const std::string& worker_address, int64 task_id,
|
|||||||
return errors::Internal("Reached end of sequence too early.");
|
return errors::Internal("Reached end of sequence too early.");
|
||||||
}
|
}
|
||||||
std::vector<Tensor> element;
|
std::vector<Tensor> element;
|
||||||
TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element));
|
TF_RETURN_IF_ERROR(UncompressElement(compressed, &element));
|
||||||
TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(element, expected,
|
TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(element, expected,
|
||||||
/*compare_order=*/true));
|
/*compare_order=*/true));
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ syntax = "proto3";
|
|||||||
|
|
||||||
package tensorflow.data;
|
package tensorflow.data;
|
||||||
|
|
||||||
|
import "tensorflow/core/data/dataset.proto";
|
||||||
import "tensorflow/core/data/service/common.proto";
|
import "tensorflow/core/data/service/common.proto";
|
||||||
|
|
||||||
message ProcessTaskRequest {
|
message ProcessTaskRequest {
|
||||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/data/service/compression_utils.h"
|
#include "tensorflow/core/data/compression_utils.h"
|
||||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||||
@ -135,8 +135,8 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
|||||||
|
|
||||||
if (!end_of_sequence) {
|
if (!end_of_sequence) {
|
||||||
VLOG(3) << "Producing an element for task " << request->task_id();
|
VLOG(3) << "Producing an element for task " << request->task_id();
|
||||||
TF_RETURN_IF_ERROR(service_util::Compress(
|
TF_RETURN_IF_ERROR(
|
||||||
outputs, response->mutable_compressed_element()));
|
CompressElement(outputs, response->mutable_compressed_element()));
|
||||||
}
|
}
|
||||||
response->set_end_of_sequence(end_of_sequence);
|
response->set_end_of_sequence(end_of_sequence);
|
||||||
|
|
||||||
|
@ -131,8 +131,8 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/data/service:common_proto_cc",
|
"//tensorflow/core/data:compression_utils",
|
||||||
"//tensorflow/core/data/service:compression_utils",
|
"//tensorflow/core/data:dataset_proto_cc",
|
||||||
"//tensorflow/core/data/service:data_service",
|
"//tensorflow/core/data/service:data_service",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
"//tensorflow/core/kernels/data:dataset_utils",
|
"//tensorflow/core/kernels/data:dataset_utils",
|
||||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
|||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/compression_utils.h"
|
||||||
#include "tensorflow/core/data/service/compression_utils.h"
|
#include "tensorflow/core/data/dataset.pb.h"
|
||||||
#include "tensorflow/core/data/service/data_service.h"
|
#include "tensorflow/core/data/service/data_service.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
@ -496,7 +496,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
|
|
||||||
std::vector<Tensor> element;
|
std::vector<Tensor> element;
|
||||||
if (!end_of_sequence) {
|
if (!end_of_sequence) {
|
||||||
TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element));
|
TF_RETURN_IF_ERROR(UncompressElement(compressed, &element));
|
||||||
}
|
}
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (end_of_sequence) {
|
if (end_of_sequence) {
|
||||||
|
Loading…
Reference in New Issue
Block a user