[tf.data] Adding tf.contrib.data.reduce_dataset
which can be used to reduce a dataset to a single element.
PiperOrigin-RevId: 205281140
This commit is contained in:
parent
bbe9364b22
commit
e9869ece18
@ -52,6 +52,7 @@ See @{$guide/datasets$Importing Data} for an overview.
|
||||
@@prefetch_to_device
|
||||
@@read_batch_features
|
||||
@@rejection_resample
|
||||
@@reduce_dataset
|
||||
@@sample_from_datasets
|
||||
@@scan
|
||||
@@shuffle_and_repeat
|
||||
@ -77,6 +78,7 @@ from tensorflow.contrib.data.python.ops.counter import Counter
|
||||
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
|
||||
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
|
||||
from tensorflow.contrib.data.python.ops.get_single_element import get_single_element
|
||||
from tensorflow.contrib.data.python.ops.get_single_element import reduce_dataset
|
||||
from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length
|
||||
from tensorflow.contrib.data.python.ops.grouping import group_by_reducer
|
||||
from tensorflow.contrib.data.python.ops.grouping import group_by_window
|
||||
|
@ -121,6 +121,7 @@ py_test(
|
||||
srcs = ["get_single_element_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:get_single_element",
|
||||
"//tensorflow/contrib/data/python/ops:grouping",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
@ -128,6 +129,7 @@ py_test(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,9 +17,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import get_single_element
|
||||
from tensorflow.contrib.data.python.ops import grouping
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -27,40 +30,69 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GetSingleElementTest(test.TestCase):
|
||||
class GetSingleElementTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testGetSingleElement(self):
|
||||
skip_value = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
take_value = array_ops.placeholder_with_default(
|
||||
constant_op.constant(1, dtype=dtypes.int64), shape=[])
|
||||
@parameterized.named_parameters(
|
||||
("Zero", 0, 1),
|
||||
("Five", 5, 1),
|
||||
("Ten", 10, 1),
|
||||
("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
|
||||
("MoreThanOne", 0, 2, errors.InvalidArgumentError,
|
||||
"Dataset had more than one element."),
|
||||
)
|
||||
def testGetSingleElement(self, skip, take, error=None, error_msg=None):
|
||||
skip_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
take_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
def make_sparse(x):
|
||||
x_1d = array_ops.reshape(x, [1])
|
||||
x_2d = array_ops.reshape(x, [1, 1])
|
||||
return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
|
||||
|
||||
dataset = (dataset_ops.Dataset.range(100)
|
||||
.skip(skip_value)
|
||||
.map(lambda x: (x * x, make_sparse(x)))
|
||||
.take(take_value))
|
||||
|
||||
dataset = dataset_ops.Dataset.range(100).skip(skip_t).map(
|
||||
lambda x: (x * x, make_sparse(x))).take(take_t)
|
||||
element = get_single_element.get_single_element(dataset)
|
||||
|
||||
with self.test_session() as sess:
|
||||
for x in [0, 5, 10]:
|
||||
dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x})
|
||||
self.assertEqual(x * x, dense_val)
|
||||
self.assertAllEqual([[x]], sparse_val.indices)
|
||||
self.assertAllEqual([x], sparse_val.values)
|
||||
self.assertAllEqual([x], sparse_val.dense_shape)
|
||||
if error is None:
|
||||
dense_val, sparse_val = sess.run(
|
||||
element, feed_dict={
|
||||
skip_t: skip,
|
||||
take_t: take
|
||||
})
|
||||
self.assertEqual(skip * skip, dense_val)
|
||||
self.assertAllEqual([[skip]], sparse_val.indices)
|
||||
self.assertAllEqual([skip], sparse_val.values)
|
||||
self.assertAllEqual([skip], sparse_val.dense_shape)
|
||||
else:
|
||||
with self.assertRaisesRegexp(error, error_msg):
|
||||
sess.run(element, feed_dict={skip_t: skip, take_t: take})
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Dataset was empty."):
|
||||
sess.run(element, feed_dict={skip_value: 100})
|
||||
@parameterized.named_parameters(
|
||||
("SumZero", 0),
|
||||
("SumOne", 1),
|
||||
("SumFive", 5),
|
||||
("SumTen", 10),
|
||||
)
|
||||
def testReduceDataset(self, stop):
|
||||
def init_fn(_):
|
||||
return np.int64(0)
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Dataset had more than one element."):
|
||||
sess.run(element, feed_dict={skip_value: 0, take_value: 2})
|
||||
def reduce_fn(state, value):
|
||||
return state + value
|
||||
|
||||
def finalize_fn(state):
|
||||
return state
|
||||
|
||||
sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
|
||||
|
||||
stop_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
dataset = dataset_ops.Dataset.range(stop_t)
|
||||
element = get_single_element.reduce_dataset(dataset, sum_reducer)
|
||||
|
||||
with self.test_session() as sess:
|
||||
value = sess.run(element, feed_dict={stop_t: stop})
|
||||
self.assertEqual(stop * (stop - 1) / 2, value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -28,10 +28,12 @@ py_library(
|
||||
srcs = ["get_single_element.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":grouping",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -129,6 +131,7 @@ py_library(
|
||||
"//tensorflow/python/data/util:convert",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,6 +17,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import grouping
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import sparse
|
||||
@ -68,3 +71,30 @@ def get_single_element(dataset):
|
||||
return sparse.deserialize_sparse_tensors(
|
||||
nested_ret, dataset.output_types, dataset.output_shapes,
|
||||
dataset.output_classes)
|
||||
|
||||
|
||||
def reduce_dataset(dataset, reducer):
|
||||
"""Returns the result of reducing the `dataset` using `reducer`.
|
||||
|
||||
Args:
|
||||
dataset: A @{tf.data.Dataset} object.
|
||||
reducer: A @{tf.contrib.data.Reducer} object representing the reduce logic.
|
||||
|
||||
Returns:
|
||||
A nested structure of @{tf.Tensor} objects, corresponding to the result
|
||||
of reducing `dataset` using `reducer`.
|
||||
|
||||
Raises:
|
||||
TypeError: if `dataset` is not a `tf.data.Dataset` object.
|
||||
"""
|
||||
if not isinstance(dataset, dataset_ops.Dataset):
|
||||
raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
|
||||
|
||||
# The sentinel dataset is used in case the reduced dataset is empty.
|
||||
sentinel_dataset = dataset_ops.Dataset.from_tensors(
|
||||
reducer.finalize_func(reducer.init_func(np.int64(0))))
|
||||
reduced_dataset = dataset.apply(
|
||||
grouping.group_by_reducer(lambda x: np.int64(0), reducer))
|
||||
|
||||
return get_single_element(
|
||||
reduced_dataset.concatenate(sentinel_dataset).take(1))
|
||||
|
Loading…
Reference in New Issue
Block a user