[tf.data] Adding tf.contrib.data.reduce_dataset which can be used to reduce a dataset to a single element.

PiperOrigin-RevId: 205281140
This commit is contained in:
Jiri Simsa 2018-07-19 12:19:37 -07:00 committed by TensorFlower Gardener
parent bbe9364b22
commit e9869ece18
5 changed files with 92 additions and 23 deletions

View File

@ -52,6 +52,7 @@ See @{$guide/datasets$Importing Data} for an overview.
@@prefetch_to_device
@@read_batch_features
@@rejection_resample
@@reduce_dataset
@@sample_from_datasets
@@scan
@@shuffle_and_repeat
@ -77,6 +78,7 @@ from tensorflow.contrib.data.python.ops.counter import Counter
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
from tensorflow.contrib.data.python.ops.get_single_element import get_single_element
from tensorflow.contrib.data.python.ops.get_single_element import reduce_dataset
from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length
from tensorflow.contrib.data.python.ops.grouping import group_by_reducer
from tensorflow.contrib.data.python.ops.grouping import group_by_window

View File

@ -121,6 +121,7 @@ py_test(
srcs = ["get_single_element_test.py"],
deps = [
"//tensorflow/contrib/data/python/ops:get_single_element",
"//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@ -128,6 +129,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -17,9 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import get_single_element
from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@ -27,40 +30,69 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class GetSingleElementTest(test.TestCase):
class GetSingleElementTest(test.TestCase, parameterized.TestCase):
def testGetSingleElement(self):
skip_value = array_ops.placeholder(dtypes.int64, shape=[])
take_value = array_ops.placeholder_with_default(
constant_op.constant(1, dtype=dtypes.int64), shape=[])
@parameterized.named_parameters(
("Zero", 0, 1),
("Five", 5, 1),
("Ten", 10, 1),
("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
("MoreThanOne", 0, 2, errors.InvalidArgumentError,
"Dataset had more than one element."),
)
def testGetSingleElement(self, skip, take, error=None, error_msg=None):
skip_t = array_ops.placeholder(dtypes.int64, shape=[])
take_t = array_ops.placeholder(dtypes.int64, shape=[])
def make_sparse(x):
x_1d = array_ops.reshape(x, [1])
x_2d = array_ops.reshape(x, [1, 1])
return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
dataset = (dataset_ops.Dataset.range(100)
.skip(skip_value)
.map(lambda x: (x * x, make_sparse(x)))
.take(take_value))
dataset = dataset_ops.Dataset.range(100).skip(skip_t).map(
lambda x: (x * x, make_sparse(x))).take(take_t)
element = get_single_element.get_single_element(dataset)
with self.test_session() as sess:
for x in [0, 5, 10]:
dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x})
self.assertEqual(x * x, dense_val)
self.assertAllEqual([[x]], sparse_val.indices)
self.assertAllEqual([x], sparse_val.values)
self.assertAllEqual([x], sparse_val.dense_shape)
if error is None:
dense_val, sparse_val = sess.run(
element, feed_dict={
skip_t: skip,
take_t: take
})
self.assertEqual(skip * skip, dense_val)
self.assertAllEqual([[skip]], sparse_val.indices)
self.assertAllEqual([skip], sparse_val.values)
self.assertAllEqual([skip], sparse_val.dense_shape)
else:
with self.assertRaisesRegexp(error, error_msg):
sess.run(element, feed_dict={skip_t: skip, take_t: take})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Dataset was empty."):
sess.run(element, feed_dict={skip_value: 100})
@parameterized.named_parameters(
("SumZero", 0),
("SumOne", 1),
("SumFive", 5),
("SumTen", 10),
)
def testReduceDataset(self, stop):
def init_fn(_):
return np.int64(0)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Dataset had more than one element."):
sess.run(element, feed_dict={skip_value: 0, take_value: 2})
def reduce_fn(state, value):
return state + value
def finalize_fn(state):
return state
sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
stop_t = array_ops.placeholder(dtypes.int64, shape=[])
dataset = dataset_ops.Dataset.range(stop_t)
element = get_single_element.reduce_dataset(dataset, sum_reducer)
with self.test_session() as sess:
value = sess.run(element, feed_dict={stop_t: stop})
self.assertEqual(stop * (stop - 1) / 2, value)
if __name__ == "__main__":

View File

@ -28,10 +28,12 @@ py_library(
srcs = ["get_single_element.py"],
srcs_version = "PY2AND3",
deps = [
":grouping",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//third_party/py/numpy",
],
)
@ -129,6 +131,7 @@ py_library(
"//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//third_party/py/numpy",
],
)

View File

@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
@ -68,3 +71,30 @@ def get_single_element(dataset):
return sparse.deserialize_sparse_tensors(
nested_ret, dataset.output_types, dataset.output_shapes,
dataset.output_classes)
def reduce_dataset(dataset, reducer):
"""Returns the result of reducing the `dataset` using `reducer`.
Args:
dataset: A @{tf.data.Dataset} object.
reducer: A @{tf.contrib.data.Reducer} object representing the reduce logic.
Returns:
A nested structure of @{tf.Tensor} objects, corresponding to the result
of reducing `dataset` using `reducer`.
Raises:
TypeError: if `dataset` is not a `tf.data.Dataset` object.
"""
if not isinstance(dataset, dataset_ops.Dataset):
raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
# The sentinel dataset is used in case the reduced dataset is empty.
sentinel_dataset = dataset_ops.Dataset.from_tensors(
reducer.finalize_func(reducer.init_func(np.int64(0))))
reduced_dataset = dataset.apply(
grouping.group_by_reducer(lambda x: np.int64(0), reducer))
return get_single_element(
reduced_dataset.concatenate(sentinel_dataset).take(1))