Add dataset element compression ops.

These allow us to implement tf.data service compression/decompression as a part of the tf.data pipeline.

PiperOrigin-RevId: 312605093
Change-Id: I4a833bc89e602c8fd78abc4c1a0026c2a397449f
This commit is contained in:
Andrew Audibert 2020-05-20 20:04:33 -07:00 committed by TensorFlower Gardener
parent e56cf87b54
commit f24faa153a
15 changed files with 366 additions and 31 deletions

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "CompressElement"
visibility: HIDDEN
summary: "Compresses a dataset element."
}

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "UncompressElement"
visibility: HIDDEN
summary: "Uncompresses a compressed dataset element."
}

View File

@ -468,6 +468,25 @@ Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
return Status::OK();
}
Status DatasetIteratorShape(shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as `output_types` (",
output_shapes.size(), " vs. ", c->num_outputs());
}
for (size_t i = 0; i < output_shapes.size(); ++i) {
shape_inference::ShapeHandle output_shape_handle;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
output_shapes[i], &output_shape_handle));
c->set_output(static_cast<int>(i), output_shape_handle);
}
return Status::OK();
}
Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
const std::vector<DimensionOrConstant>& spatial,
DimensionOrConstant C, ShapeHandle* out,

View File

@ -92,6 +92,9 @@ inline Status MergeBothInputsShapeFn(InferenceContext* c) {
return Status::OK();
}
// Shape function for dataset iterators.
Status DatasetIteratorShape(shape_inference::InferenceContext* c);
// Returns a new shape with the specified dims arranged in the specified
// format. The returned value is owned by this context.
// Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth.

View File

@ -109,6 +109,20 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "compression_ops",
srcs = ["compression_ops.cc"],
hdrs = ["compression_ops.h"],
deps = [
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/data:compression_utils",
"//tensorflow/core/data:dataset_proto_cc",
],
)
tf_kernel_library(
name = "csv_dataset_op",
srcs = ["csv_dataset_op.cc"],
@ -681,6 +695,7 @@ tf_kernel_library(
":auto_shard_dataset_op",
":choose_fastest_branch_dataset_op",
":choose_fastest_dataset_op",
":compression_ops",
":csv_dataset_op",
":dense_to_sparse_batch_dataset_op",
":directed_interleave_dataset_op",

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/compression_ops.h"
#include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace data {
namespace experimental {
CompressElementOp::CompressElementOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void CompressElementOp::Compute(OpKernelContext* ctx) {
std::vector<Tensor> components;
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
components.push_back(ctx->input(i));
}
CompressedElement compressed;
OP_REQUIRES_OK(ctx, CompressElement(components, &compressed));
Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
output->scalar<Variant>()() = std::move(compressed);
}
UncompressElementOp::UncompressElementOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
}
void UncompressElementOp::Compute(OpKernelContext* ctx) {
Tensor tensor = ctx->input(0);
const Variant& variant = tensor.scalar<Variant>()();
const CompressedElement* compressed = variant.get<CompressedElement>();
std::vector<Tensor> components;
OP_REQUIRES_OK(ctx, UncompressElement(*compressed, &components));
OP_REQUIRES(ctx, components.size() == output_types_.size(),
errors::FailedPrecondition("Expected ", output_types_.size(),
" outputs from uncompress, but got ",
components.size()));
for (int i = 0; i < components.size(); ++i) {
OP_REQUIRES(
ctx, components[i].dtype() == output_types_[i],
errors::FailedPrecondition("Expected a tensor of type ",
DataTypeString(output_types_[i]),
" but got a tensor of type ",
DataTypeString(components[i].dtype())));
ctx->set_output(i, components[i]);
}
}
REGISTER_KERNEL_BUILDER(Name("CompressElement").Device(DEVICE_CPU),
CompressElementOp);
REGISTER_KERNEL_BUILDER(Name("UncompressElement").Device(DEVICE_CPU),
UncompressElementOp);
} // namespace experimental
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
namespace experimental {
class CompressElementOp : public OpKernel {
public:
explicit CompressElementOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
};
class UncompressElementOp : public OpKernel {
public:
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
explicit UncompressElementOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
private:
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
} // namespace experimental
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_

View File

@ -731,42 +731,19 @@ REGISTER_OP("OneShotIterator")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
namespace {
Status IteratorGetNextShapeFn(shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as `output_types` (",
output_shapes.size(), " vs. ", c->num_outputs());
}
for (size_t i = 0; i < output_shapes.size(); ++i) {
shape_inference::ShapeHandle output_shape_handle;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
output_shapes[i], &output_shape_handle));
c->set_output(static_cast<int>(i), output_shape_handle);
}
return Status::OK();
}
} // namespace
REGISTER_OP("IteratorGetNext")
.Input("iterator: resource")
.Output("components: output_types")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(IteratorGetNextShapeFn);
.SetShapeFn(shape_inference::DatasetIteratorShape);
REGISTER_OP("IteratorGetNextSync")
.Input("iterator: resource")
.Output("components: output_types")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(IteratorGetNextShapeFn);
.SetShapeFn(shape_inference::DatasetIteratorShape);
// TODO(b/124308596): Instead of conservatively marking this op as stateful,
// implement a mechanism to determine whether `dataset` has a side-effect
@ -778,7 +755,7 @@ REGISTER_OP("DatasetToSingleElement")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful()
.SetShapeFn(IteratorGetNextShapeFn);
.SetShapeFn(shape_inference::DatasetIteratorShape);
// TODO(b/124308596): Instead of conservatively marking this op as stateful,
// implement a mechanism to determine whether `dataset` has a side-effect
@ -796,7 +773,7 @@ REGISTER_OP("ReduceDataset")
.Attr("output_shapes: list(shape) >= 1")
.Attr("use_inter_op_parallelism: bool = true")
.SetIsStateful()
.SetShapeFn(IteratorGetNextShapeFn);
.SetShapeFn(shape_inference::DatasetIteratorShape);
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")
@ -875,7 +852,7 @@ REGISTER_OP("OptionalGetValue")
.Output("components: output_types")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(IteratorGetNextShapeFn);
.SetShapeFn(shape_inference::DatasetIteratorShape);
REGISTER_OP("IteratorGetNextAsOptional")
.Input("iterator: resource")
@ -992,7 +969,7 @@ REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
.Output("components: output_types")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(IteratorGetNextShapeFn);
.SetShapeFn(shape_inference::DatasetIteratorShape);
REGISTER_OP("MultiDeviceIteratorToStringHandle")
.Input("multi_device_iterator: resource")

View File

@ -132,6 +132,19 @@ REGISTER_OP("ExperimentalChooseFastestDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("CompressElement")
.Input("components: input_types")
.Output("compressed: variant")
.Attr("input_types: list(type) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("UncompressElement")
.Input("compressed: variant")
.Output("components: output_types")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::DatasetIteratorShape);
REGISTER_OP("CSVDataset")
.Input("filenames: string")
.Input("compression_type: string")

View File

@ -1,5 +1,5 @@
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load
package(
default_visibility = ["//tensorflow:internal"],
@ -87,6 +87,17 @@ tf_py_test(
],
)
tf_py_test(
name = "compression_ops_test",
srcs = ["compression_ops_test.py"],
deps = [
"//tensorflow/python/data/experimental/ops:compression_ops",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "copy_to_device_test",
size = "small",

View File

@ -0,0 +1,81 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for compression ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import compression_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure
from tensorflow.python.framework import combinations
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
def _test_objects():
return [
combinations.NamedObject("int", 1),
combinations.NamedObject("string", "dog"),
combinations.NamedObject("tuple", (1, 1)),
combinations.NamedObject("int_string_tuple", (1, "dog")),
combinations.NamedObject(
"sparse",
sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])),
combinations.NamedObject(
"sparse_structured", {
"a":
sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 2]],
values=[1, 2],
dense_shape=[3, 4]),
"b": (1, 2, "dog")
})
]
class CompressionOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(element=_test_objects())))
def testCompression(self, element):
element = element._obj
compressed = compression_ops.compress(element)
uncompressed = compression_ops.uncompress(
compressed, structure.type_spec_from_value(element))
self.assertValuesEqual(element, self.evaluate(uncompressed))
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(element=_test_objects())))
def testDatasetCompression(self, element):
element = element._obj
dataset = dataset_ops.Dataset.from_tensors(element)
element_spec = dataset.element_spec
dataset = dataset.map(lambda *x: compression_ops.compress(x))
dataset = dataset.map(lambda x: compression_ops.uncompress(x, element_spec))
self.assertDatasetProduces(dataset, [element])
if __name__ == "__main__":
test.main()

View File

@ -33,6 +33,15 @@ py_library(
],
)
py_library(
name = "compression_ops",
srcs = ["compression_ops.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:experimental_dataset_ops_gen",
],
)
py_library(
name = "counter",
srcs = ["counter.py"],
@ -475,6 +484,7 @@ py_library(
deps = [
":batching",
":cardinality",
":compression_ops",
":counter",
":data_service_ops",
":distribute",

View File

@ -0,0 +1,55 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for compressing and uncompressing dataset elements."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.util import structure
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
def compress(element):
"""Compress a dataset element.
Args:
element: A nested structure of types supported by Tensorflow.
Returns:
A variant tensor representing the compressed element. This variant can be
passed to `uncompress` to get back the original element.
"""
element_spec = structure.type_spec_from_value(element)
tensor_list = structure.to_tensor_list(element_spec, element)
return ged_ops.compress_element(tensor_list)
def uncompress(element, output_spec):
"""Uncompress a compressed dataset element.
Args:
element: A scalar variant tensor to uncompress. The element should have been
created by calling `compress`.
output_spec: A nested structure of `tf.TypeSpec` representing the type(s) of
the uncompressed element.
Returns:
The uncompressed element.
"""
flat_types = structure.get_flat_tensor_types(output_spec)
flat_shapes = structure.get_flat_tensor_shapes(output_spec)
tensor_list = ged_ops.uncompress_element(
element, output_types=flat_types, output_shapes=flat_shapes)
return structure.from_tensor_list(output_spec, tensor_list)

View File

@ -736,6 +736,10 @@ tf_module {
name: "ComplexAbs"
argspec: "args=[\'x\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "CompressElement"
argspec: "args=[\'components\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "ComputeAccidentalHits"
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
@ -4956,6 +4960,10 @@ tf_module {
name: "UnbatchGrad"
argspec: "args=[\'original_input\', \'batch_index\', \'grad\', \'id\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
}
member_method {
name: "UncompressElement"
argspec: "args=[\'compressed\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "UnicodeDecode"
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], "

View File

@ -736,6 +736,10 @@ tf_module {
name: "ComplexAbs"
argspec: "args=[\'x\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "CompressElement"
argspec: "args=[\'components\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "ComputeAccidentalHits"
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
@ -4956,6 +4960,10 @@ tf_module {
name: "UnbatchGrad"
argspec: "args=[\'original_input\', \'batch_index\', \'grad\', \'id\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
}
member_method {
name: "UncompressElement"
argspec: "args=[\'compressed\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "UnicodeDecode"
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], "