From e9869ece182be721dc07fe8ecb7c7288f2fce90f Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Thu, 19 Jul 2018 12:19:37 -0700 Subject: [PATCH] [tf.data] Adding `tf.contrib.data.reduce_dataset` which can be used to reduce a dataset to a single element. PiperOrigin-RevId: 205281140 --- tensorflow/contrib/data/__init__.py | 2 + .../contrib/data/python/kernel_tests/BUILD | 2 + .../kernel_tests/get_single_element_test.py | 78 +++++++++++++------ tensorflow/contrib/data/python/ops/BUILD | 3 + .../data/python/ops/get_single_element.py | 30 +++++++ 5 files changed, 92 insertions(+), 23 deletions(-) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 675330716b2..7878e46e88b 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -52,6 +52,7 @@ See @{$guide/datasets$Importing Data} for an overview. @@prefetch_to_device @@read_batch_features @@rejection_resample +@@reduce_dataset @@sample_from_datasets @@scan @@shuffle_and_repeat @@ -77,6 +78,7 @@ from tensorflow.contrib.data.python.ops.counter import Counter from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.get_single_element import get_single_element +from tensorflow.contrib.data.python.ops.get_single_element import reduce_dataset from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length from tensorflow.contrib.data.python.ops.grouping import group_by_reducer from tensorflow.contrib.data.python.ops.grouping import group_by_window diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index f805027727c..036dc795bb4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -121,6 +121,7 @@ py_test( srcs = ["get_single_element_test.py"], deps = [ "//tensorflow/contrib/data/python/ops:get_single_element", + "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -128,6 +129,7 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index 87b7c6ddb7a..e6883d53e02 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -17,9 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized +import numpy as np + from tensorflow.contrib.data.python.ops import get_single_element +from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor @@ -27,40 +30,69 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class GetSingleElementTest(test.TestCase): +class GetSingleElementTest(test.TestCase, parameterized.TestCase): - def testGetSingleElement(self): - skip_value = array_ops.placeholder(dtypes.int64, shape=[]) - take_value = array_ops.placeholder_with_default( - constant_op.constant(1, dtype=dtypes.int64), shape=[]) + @parameterized.named_parameters( + ("Zero", 0, 1), + ("Five", 5, 1), + ("Ten", 10, 1), + ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."), + ("MoreThanOne", 0, 2, errors.InvalidArgumentError, + "Dataset had more than one element."), + ) + def testGetSingleElement(self, skip, take, error=None, error_msg=None): + skip_t = array_ops.placeholder(dtypes.int64, shape=[]) + take_t = array_ops.placeholder(dtypes.int64, shape=[]) def make_sparse(x): x_1d = array_ops.reshape(x, [1]) x_2d = array_ops.reshape(x, [1, 1]) return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d) - dataset = (dataset_ops.Dataset.range(100) - .skip(skip_value) - .map(lambda x: (x * x, make_sparse(x))) - .take(take_value)) - + dataset = dataset_ops.Dataset.range(100).skip(skip_t).map( + lambda x: (x * x, make_sparse(x))).take(take_t) element = get_single_element.get_single_element(dataset) with self.test_session() as sess: - for x in [0, 5, 10]: - dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x}) - self.assertEqual(x * x, dense_val) - self.assertAllEqual([[x]], sparse_val.indices) - self.assertAllEqual([x], sparse_val.values) - self.assertAllEqual([x], sparse_val.dense_shape) + if error is None: + dense_val, sparse_val = sess.run( + element, feed_dict={ + skip_t: skip, + take_t: take + }) + self.assertEqual(skip * skip, dense_val) + self.assertAllEqual([[skip]], sparse_val.indices) + self.assertAllEqual([skip], sparse_val.values) + self.assertAllEqual([skip], sparse_val.dense_shape) + else: + with self.assertRaisesRegexp(error, error_msg): + sess.run(element, feed_dict={skip_t: skip, take_t: take}) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset was empty."): - sess.run(element, feed_dict={skip_value: 100}) + @parameterized.named_parameters( + ("SumZero", 0), + ("SumOne", 1), + ("SumFive", 5), + ("SumTen", 10), + ) + def testReduceDataset(self, stop): + def init_fn(_): + return np.int64(0) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset had more than one element."): - sess.run(element, feed_dict={skip_value: 0, take_value: 2}) + def reduce_fn(state, value): + return state + value + + def finalize_fn(state): + return state + + sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) + + stop_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset_ops.Dataset.range(stop_t) + element = get_single_element.reduce_dataset(dataset, sum_reducer) + + with self.test_session() as sess: + value = sess.run(element, feed_dict={stop_t: stop}) + self.assertEqual(stop * (stop - 1) / 2, value) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 160d7fe22a9..1ad021ea037 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -28,10 +28,12 @@ py_library( srcs = ["get_single_element.py"], srcs_version = "PY2AND3", deps = [ + ":grouping", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", + "//third_party/py/numpy", ], ) @@ -129,6 +131,7 @@ py_library( "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index 0f4cd8e20c5..ef9284456eb 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -17,6 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse @@ -68,3 +71,30 @@ def get_single_element(dataset): return sparse.deserialize_sparse_tensors( nested_ret, dataset.output_types, dataset.output_shapes, dataset.output_classes) + + +def reduce_dataset(dataset, reducer): + """Returns the result of reducing the `dataset` using `reducer`. + + Args: + dataset: A @{tf.data.Dataset} object. + reducer: A @{tf.contrib.data.Reducer} object representing the reduce logic. + + Returns: + A nested structure of @{tf.Tensor} objects, corresponding to the result + of reducing `dataset` using `reducer`. + + Raises: + TypeError: if `dataset` is not a `tf.data.Dataset` object. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + + # The sentinel dataset is used in case the reduced dataset is empty. + sentinel_dataset = dataset_ops.Dataset.from_tensors( + reducer.finalize_func(reducer.init_func(np.int64(0)))) + reduced_dataset = dataset.apply( + grouping.group_by_reducer(lambda x: np.int64(0), reducer)) + + return get_single_element( + reduced_dataset.concatenate(sentinel_dataset).take(1))