[tf.data] Adding tf.contrib.data.reduce_dataset
which can be used to reduce a dataset to a single element.
PiperOrigin-RevId: 205281140
This commit is contained in:
parent
bbe9364b22
commit
e9869ece18
@ -52,6 +52,7 @@ See @{$guide/datasets$Importing Data} for an overview.
|
|||||||
@@prefetch_to_device
|
@@prefetch_to_device
|
||||||
@@read_batch_features
|
@@read_batch_features
|
||||||
@@rejection_resample
|
@@rejection_resample
|
||||||
|
@@reduce_dataset
|
||||||
@@sample_from_datasets
|
@@sample_from_datasets
|
||||||
@@scan
|
@@scan
|
||||||
@@shuffle_and_repeat
|
@@shuffle_and_repeat
|
||||||
@ -77,6 +78,7 @@ from tensorflow.contrib.data.python.ops.counter import Counter
|
|||||||
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
|
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
|
||||||
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
|
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
|
||||||
from tensorflow.contrib.data.python.ops.get_single_element import get_single_element
|
from tensorflow.contrib.data.python.ops.get_single_element import get_single_element
|
||||||
|
from tensorflow.contrib.data.python.ops.get_single_element import reduce_dataset
|
||||||
from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length
|
from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length
|
||||||
from tensorflow.contrib.data.python.ops.grouping import group_by_reducer
|
from tensorflow.contrib.data.python.ops.grouping import group_by_reducer
|
||||||
from tensorflow.contrib.data.python.ops.grouping import group_by_window
|
from tensorflow.contrib.data.python.ops.grouping import group_by_window
|
||||||
|
@ -121,6 +121,7 @@ py_test(
|
|||||||
srcs = ["get_single_element_test.py"],
|
srcs = ["get_single_element_test.py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/contrib/data/python/ops:get_single_element",
|
"//tensorflow/contrib/data/python/ops:get_single_element",
|
||||||
|
"//tensorflow/contrib/data/python/ops:grouping",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
@ -128,6 +129,7 @@ py_test(
|
|||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:sparse_tensor",
|
"//tensorflow/python:sparse_tensor",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,9 +17,12 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.data.python.ops import get_single_element
|
from tensorflow.contrib.data.python.ops import get_single_element
|
||||||
|
from tensorflow.contrib.data.python.ops import grouping
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -27,40 +30,69 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class GetSingleElementTest(test.TestCase):
|
class GetSingleElementTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def testGetSingleElement(self):
|
@parameterized.named_parameters(
|
||||||
skip_value = array_ops.placeholder(dtypes.int64, shape=[])
|
("Zero", 0, 1),
|
||||||
take_value = array_ops.placeholder_with_default(
|
("Five", 5, 1),
|
||||||
constant_op.constant(1, dtype=dtypes.int64), shape=[])
|
("Ten", 10, 1),
|
||||||
|
("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
|
||||||
|
("MoreThanOne", 0, 2, errors.InvalidArgumentError,
|
||||||
|
"Dataset had more than one element."),
|
||||||
|
)
|
||||||
|
def testGetSingleElement(self, skip, take, error=None, error_msg=None):
|
||||||
|
skip_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
|
take_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
|
|
||||||
def make_sparse(x):
|
def make_sparse(x):
|
||||||
x_1d = array_ops.reshape(x, [1])
|
x_1d = array_ops.reshape(x, [1])
|
||||||
x_2d = array_ops.reshape(x, [1, 1])
|
x_2d = array_ops.reshape(x, [1, 1])
|
||||||
return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
|
return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
|
||||||
|
|
||||||
dataset = (dataset_ops.Dataset.range(100)
|
dataset = dataset_ops.Dataset.range(100).skip(skip_t).map(
|
||||||
.skip(skip_value)
|
lambda x: (x * x, make_sparse(x))).take(take_t)
|
||||||
.map(lambda x: (x * x, make_sparse(x)))
|
|
||||||
.take(take_value))
|
|
||||||
|
|
||||||
element = get_single_element.get_single_element(dataset)
|
element = get_single_element.get_single_element(dataset)
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
for x in [0, 5, 10]:
|
if error is None:
|
||||||
dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x})
|
dense_val, sparse_val = sess.run(
|
||||||
self.assertEqual(x * x, dense_val)
|
element, feed_dict={
|
||||||
self.assertAllEqual([[x]], sparse_val.indices)
|
skip_t: skip,
|
||||||
self.assertAllEqual([x], sparse_val.values)
|
take_t: take
|
||||||
self.assertAllEqual([x], sparse_val.dense_shape)
|
})
|
||||||
|
self.assertEqual(skip * skip, dense_val)
|
||||||
|
self.assertAllEqual([[skip]], sparse_val.indices)
|
||||||
|
self.assertAllEqual([skip], sparse_val.values)
|
||||||
|
self.assertAllEqual([skip], sparse_val.dense_shape)
|
||||||
|
else:
|
||||||
|
with self.assertRaisesRegexp(error, error_msg):
|
||||||
|
sess.run(element, feed_dict={skip_t: skip, take_t: take})
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
@parameterized.named_parameters(
|
||||||
"Dataset was empty."):
|
("SumZero", 0),
|
||||||
sess.run(element, feed_dict={skip_value: 100})
|
("SumOne", 1),
|
||||||
|
("SumFive", 5),
|
||||||
|
("SumTen", 10),
|
||||||
|
)
|
||||||
|
def testReduceDataset(self, stop):
|
||||||
|
def init_fn(_):
|
||||||
|
return np.int64(0)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
def reduce_fn(state, value):
|
||||||
"Dataset had more than one element."):
|
return state + value
|
||||||
sess.run(element, feed_dict={skip_value: 0, take_value: 2})
|
|
||||||
|
def finalize_fn(state):
|
||||||
|
return state
|
||||||
|
|
||||||
|
sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
|
||||||
|
|
||||||
|
stop_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
|
dataset = dataset_ops.Dataset.range(stop_t)
|
||||||
|
element = get_single_element.reduce_dataset(dataset, sum_reducer)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
value = sess.run(element, feed_dict={stop_t: stop})
|
||||||
|
self.assertEqual(stop * (stop - 1) / 2, value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -28,10 +28,12 @@ py_library(
|
|||||||
srcs = ["get_single_element.py"],
|
srcs = ["get_single_element.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":grouping",
|
||||||
"//tensorflow/python:dataset_ops_gen",
|
"//tensorflow/python:dataset_ops_gen",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
"//tensorflow/python/data/util:nest",
|
"//tensorflow/python/data/util:nest",
|
||||||
"//tensorflow/python/data/util:sparse",
|
"//tensorflow/python/data/util:sparse",
|
||||||
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -129,6 +131,7 @@ py_library(
|
|||||||
"//tensorflow/python/data/util:convert",
|
"//tensorflow/python/data/util:convert",
|
||||||
"//tensorflow/python/data/util:nest",
|
"//tensorflow/python/data/util:nest",
|
||||||
"//tensorflow/python/data/util:sparse",
|
"//tensorflow/python/data/util:sparse",
|
||||||
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,6 +17,9 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.data.python.ops import grouping
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.util import nest
|
from tensorflow.python.data.util import nest
|
||||||
from tensorflow.python.data.util import sparse
|
from tensorflow.python.data.util import sparse
|
||||||
@ -68,3 +71,30 @@ def get_single_element(dataset):
|
|||||||
return sparse.deserialize_sparse_tensors(
|
return sparse.deserialize_sparse_tensors(
|
||||||
nested_ret, dataset.output_types, dataset.output_shapes,
|
nested_ret, dataset.output_types, dataset.output_shapes,
|
||||||
dataset.output_classes)
|
dataset.output_classes)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_dataset(dataset, reducer):
|
||||||
|
"""Returns the result of reducing the `dataset` using `reducer`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: A @{tf.data.Dataset} object.
|
||||||
|
reducer: A @{tf.contrib.data.Reducer} object representing the reduce logic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A nested structure of @{tf.Tensor} objects, corresponding to the result
|
||||||
|
of reducing `dataset` using `reducer`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `dataset` is not a `tf.data.Dataset` object.
|
||||||
|
"""
|
||||||
|
if not isinstance(dataset, dataset_ops.Dataset):
|
||||||
|
raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
|
||||||
|
|
||||||
|
# The sentinel dataset is used in case the reduced dataset is empty.
|
||||||
|
sentinel_dataset = dataset_ops.Dataset.from_tensors(
|
||||||
|
reducer.finalize_func(reducer.init_func(np.int64(0))))
|
||||||
|
reduced_dataset = dataset.apply(
|
||||||
|
grouping.group_by_reducer(lambda x: np.int64(0), reducer))
|
||||||
|
|
||||||
|
return get_single_element(
|
||||||
|
reduced_dataset.concatenate(sentinel_dataset).take(1))
|
||||||
|
Loading…
Reference in New Issue
Block a user