STT-tensorflow/tensorflow/python/data/kernel_tests/dataset_test.py
Andrew Audibert db6c219b8a Prevent segmentation faults when datasets produce unexpected numbers of components.
PiperOrigin-RevId: 320990394
Change-Id: Ibd117c7e75a08ed2857028061b97c1d6fd7a84c8
2020-07-13 11:04:04 -07:00

562 lines
22 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
from absl.testing import parameterized
import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
graph = graph_pb2.GraphDef().FromString(
self.evaluate(dataset._as_serialized_graph()))
self.assertTrue(any(node.op == "RangeDataset" for node in graph.node))
def testAsSerializedGraphStateful(self):
dataset = dataset_ops.Dataset.range(10).map(
lambda _: random_ops.random_uniform(()))
with self.assertRaises(errors.FailedPreconditionError):
self.evaluate(
dataset._as_serialized_graph(external_state_policy=distribute_options
.ExternalStatePolicy.FAIL))
@combinations.generate(test_base.default_test_combinations())
def testAsFunctionWithMap(self):
if not context.executing_eagerly():
self.skipTest("Only works executing eagerly")
with ops.device("CPU"):
original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
fn = original_dataset._trace_variant_creation()
variant = fn()
revived_dataset = dataset_ops._VariantDataset(
variant, original_dataset.element_spec)
self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
@combinations.generate(test_base.default_test_combinations())
def testAsFunctionWithMapInFlatMap(self):
if not context.executing_eagerly():
self.skipTest("Only works executing eagerly")
with ops.device("CPU"):
original_dataset = dataset_ops.Dataset.range(5).flat_map(
lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
fn = original_dataset._trace_variant_creation()
variant = fn()
revived_dataset = dataset_ops._VariantDataset(
variant, original_dataset.element_spec)
self.assertDatasetProduces(revived_dataset, list(original_dataset))
def _testNumInputs(self, dataset, num_inputs):
self.assertLen(dataset._inputs(), num_inputs)
@combinations.generate(test_base.default_test_combinations())
def testFixedLengthRecordInputs(self):
dataset = readers.FixedLengthRecordDataset("", 42)
self._testNumInputs(dataset, 0)
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorInputs(self):
def gen():
yield 42
dataset = dataset_ops.Dataset.from_generator(gen, dtypes.int32)
self._testNumInputs(dataset, 1)
@combinations.generate(test_base.default_test_combinations())
def testFromTensorsInputs(self):
dataset = dataset_ops.Dataset.from_tensors([42])
self._testNumInputs(dataset, 0)
@combinations.generate(test_base.default_test_combinations())
def testRangeInputs(self):
dataset = dataset_ops.Dataset.range(10)
self._testNumInputs(dataset, 0)
@combinations.generate(test_base.default_test_combinations())
def testTextLineInputs(self):
dataset = readers.TextLineDataset("")
self._testNumInputs(dataset, 0)
@combinations.generate(test_base.default_test_combinations())
def testTFRecordInputs(self):
dataset = readers.TFRecordDataset("")
self._testNumInputs(dataset, 1)
@combinations.generate(
combinations.combine(tf_api_version=1, mode=["eager", "graph"]))
def testDatasetComplexSourceInputs(self):
dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
sparse_tensor.SparseTensor(
indices=np.array([[0, 0], [1, 0], [2, 0]]),
values=np.array([0, 0, 0]),
dense_shape=np.array([3, 1])))
self.assertEmpty(dataset_fn._inputs())
def _testUnaryInputs(self, dataset_fn):
input_dataset = dataset_ops.Dataset.range(0)
self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
@combinations.generate(test_base.default_test_combinations())
def testBatchInputs(self):
self._testUnaryInputs(lambda x: x.batch(10))
@combinations.generate(test_base.default_test_combinations())
def testCacheInputs(self):
self._testUnaryInputs(lambda x: x.cache())
@combinations.generate(test_base.default_test_combinations())
def testFilterInputs(self):
self._testUnaryInputs(lambda x: x.filter(lambda x: True))
@combinations.generate(test_base.default_test_combinations())
def testFlatMapInputs(self):
self._testUnaryInputs(
lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)))
@combinations.generate(test_base.default_test_combinations())
def testMapInputs(self):
self._testUnaryInputs(lambda x: x.map(lambda x: x))
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchInputs(self):
self._testUnaryInputs(lambda x: x.padded_batch(10, []))
@combinations.generate(test_base.default_test_combinations())
def testParallelMapInputs(self):
self._testUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2))
@combinations.generate(test_base.default_test_combinations())
def testRepeatInputs(self):
self._testUnaryInputs(lambda x: x.repeat())
@combinations.generate(test_base.default_test_combinations())
def testShuffleInputs(self):
self._testUnaryInputs(lambda x: x.shuffle(10))
@combinations.generate(test_base.default_test_combinations())
def testSkipInputs(self):
self._testUnaryInputs(lambda x: x.skip(1))
@combinations.generate(test_base.default_test_combinations())
def testTakeInputs(self):
self._testUnaryInputs(lambda x: x.take(1))
@combinations.generate(test_base.default_test_combinations())
def testWindowInputs(self):
self._testUnaryInputs(lambda x: x.window(10))
@combinations.generate(test_base.default_test_combinations())
def testUnaryTransformationInputsApply(self):
input_dataset = dataset_ops.Dataset.range(0)
dataset = input_dataset.apply(lambda dataset: dataset.cache())
self.assertEqual([input_dataset], dataset._inputs())
def _testInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism):
input_dataset = dataset_ops.Dataset.range(0)
dataset = input_dataset.interleave(
lambda x: dataset_ops.Dataset.range(0),
cycle_length=2,
num_parallel_calls=interleave_parallelism)
self.assertEqual([input_dataset], dataset._inputs())
@combinations.generate(test_base.default_test_combinations())
def testParallelInterleaveInputs(self):
self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2)
@combinations.generate(test_base.default_test_combinations())
def testInterleaveInputs(self):
self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), None)
@combinations.generate(test_base.default_test_combinations())
def testNoWarnings(self):
with test.mock.patch.object(warnings, "warn") as mock_log:
dataset_ops.Dataset.range(0).interleave(
lambda x: dataset_ops.Dataset.range(0), cycle_length=2)
self.assertEmpty(mock_log.call_args_list)
def _testBinaryInputs(self, dataset_fn):
input1 = dataset_ops.Dataset.range(0)
input2 = dataset_ops.Dataset.range(1)
self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
@combinations.generate(test_base.default_test_combinations())
def testConcatenateInputs(self):
self._testBinaryInputs(lambda x, y: x.concatenate(y))
def _testVariadicInputs(self, dataset_fn, input_datasets):
self.assertEqual(
nest.flatten(input_datasets),
dataset_fn(input_datasets)._inputs())
@combinations.generate(test_base.default_test_combinations())
def testZipOneInputs(self):
input_datasets = dataset_ops.Dataset.range(0)
self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
@combinations.generate(test_base.default_test_combinations())
def testZipNestInputs(self):
input_datasets = (dataset_ops.Dataset.range(0),
(dataset_ops.Dataset.range(1),
dataset_ops.Dataset.range(2)))
self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
@combinations.generate(test_base.default_test_combinations())
def testZipTupleInputs(self):
input_datasets = (dataset_ops.Dataset.range(0),
dataset_ops.Dataset.range(1))
self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
@combinations.generate(test_base.default_test_combinations())
def testFunctions(self):
dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
self.assertLen(dataset._functions(), 1)
@combinations.generate(test_base.default_test_combinations())
def testCollectInputs(self):
ds1 = dataset_ops.Dataset.range(0)
ds2 = ds1.concatenate(ds1)
ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
inputs = []
queue = [ds3]
while queue:
ds = queue[0]
queue = queue[1:]
queue.extend(ds._inputs())
inputs.append(ds)
self.assertEqual(5, inputs.count(ds1))
self.assertEqual(2, inputs.count(ds2))
self.assertEqual(1, inputs.count(ds3))
def _testDatasetSpec(self, tf_value, expected_element_structure):
dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value)
dataset_structure = structure.type_spec_from_value(dataset)
self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)
self.assertTrue(
structure.are_compatible(
dataset_ops.get_structure(dataset), expected_element_structure))
self.assertEqual([dtypes.variant],
structure.get_flat_tensor_types(dataset_structure))
self.assertEqual([tensor_shape.TensorShape([])],
structure.get_flat_tensor_shapes(dataset_structure))
# Assert that the `Dataset` survives a round-trip via _from_tensor_list()
# and _to_tensor_list().
round_trip_dataset = dataset_structure._from_tensor_list(
dataset_structure._to_tensor_list(dataset))
value = tf_value
if isinstance(value, dataset_ops.Dataset):
self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
elif isinstance(value, optional_ops.Optional):
self.assertDatasetProduces(
round_trip_dataset.map(lambda opt: opt.get_value()),
[self.evaluate(value.get_value())],
requires_initialization=True)
else:
self.assertDatasetProduces(
round_trip_dataset, [self.evaluate(tf_value)],
requires_initialization=True)
@combinations.generate(test_base.default_test_combinations())
def testTensorDatasetSpec(self):
self._testDatasetSpec(
constant_op.constant(37.0), tensor_spec.TensorSpec([], dtypes.float32))
@combinations.generate(test_base.default_test_combinations())
def testSparseTensorDatasetSpec(self):
self._testDatasetSpec(
sparse_tensor.SparseTensor(
indices=[[0]],
values=constant_op.constant([0], dtype=dtypes.int32),
dense_shape=[1]), sparse_tensor.SparseTensorSpec([1], dtypes.int32))
@combinations.generate(test_base.default_test_combinations())
def testNestDatasetSpec(self):
self._testDatasetSpec(
{
"a": constant_op.constant(37.0),
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
}, {
"a":
tensor_spec.TensorSpec([], dtypes.float32),
"b": (
tensor_spec.TensorSpec([1], dtypes.string),
tensor_spec.TensorSpec([], dtypes.string),
)
})
@combinations.generate(test_base.default_test_combinations())
def testDatasetDatasetSpec(self):
self._testDatasetSpec(
dataset_ops.Dataset.from_tensor_slices(
constant_op.constant([1, 2, 3])),
dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32)))
@combinations.generate(test_base.default_test_combinations())
def testOptionalDatasetSpec(self):
self._testDatasetSpec(
optional_ops.Optional.from_value(37.0),
optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32)))
@combinations.generate(test_base.graph_only_combinations())
def testSameGraphError(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with self.assertRaisesRegex(ValueError, "must be from the same graph"):
dataset = dataset.batch(2)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testSameGraphErrorOneShot(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with self.assertRaisesRegex(
ValueError, "Please ensure that all datasets in the pipeline are "
"created in the same graph as the iterator."):
_ = dataset_ops.make_one_shot_iterator(dataset)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testSameGraphErrorInitializable(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with self.assertRaisesRegex(
ValueError, "Please ensure that all datasets in the pipeline are "
"created in the same graph as the iterator."):
_ = dataset_ops.make_initializable_iterator(dataset)
@combinations.generate(
combinations.times(
test_base.eager_only_combinations(),
combinations.combine(execution_mode=[context.ASYNC, context.SYNC])))
def testEagerIteration(self, execution_mode):
with context.execution_mode(execution_mode):
val = 0
dataset = dataset_ops.Dataset.range(10)
for foo in dataset:
self.assertEqual(val, foo.numpy())
val += 1
@combinations.generate(test_base.default_test_combinations())
def testDatasetAsFunctionArgument(self):
@def_function.function
def _uses_dataset(d):
accumulator = array_ops.zeros([], dtype=dtypes.int64)
for value in d:
accumulator += value
return accumulator
with ops.device("CPU"):
first_dataset = dataset_ops.Dataset.range(10)
self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
second_dataset = dataset_ops.Dataset.range(11)
self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
first_concrete = _uses_dataset.get_concrete_function(first_dataset)
# The dataset should not be a captured input
self.assertEmpty(first_concrete.graph.captures)
# The two datasets have the same structure and so should re-use a trace.
self.assertIs(first_concrete,
_uses_dataset.get_concrete_function(second_dataset))
# With a different structure we should use a different trace.
self.assertIsNot(
first_concrete,
_uses_dataset.get_concrete_function(
dataset_ops.Dataset.zip((first_dataset, second_dataset))))
@combinations.generate(test_base.default_test_combinations())
def testLimitedRetracing(self):
trace_count = [0]
@def_function.function
def f(ds):
trace_count[0] += 1
counter = np.int64(0)
for elem in ds:
counter += elem
return counter
dataset = dataset_ops.Dataset.range(5)
dataset2 = dataset_ops.Dataset.range(10)
for _ in range(10):
self.assertEqual(self.evaluate(f(dataset)), 10)
self.assertEqual(self.evaluate(f(dataset2)), 45)
self.assertEqual(trace_count[0], 1)
# pylint: disable=g-long-lambda,unnecessary-lambda
@combinations.generate(test_base.default_test_combinations())
def testLegacyStructureAPI(self):
components = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5.]),
np.array([6., 7.])),
np.array([8, 9, 10], dtype=np.int64))
dataset = dataset_ops.Dataset.from_tensors(components)
self.assertEqual(
(dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
dataset_ops.get_legacy_output_types(dataset))
self.assertEqual(([3], ([2], [2]), [3]),
dataset_ops.get_legacy_output_shapes(dataset))
dataset = dataset.shuffle(10, 10)
self.assertEqual(
(dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
dataset_ops.get_legacy_output_types(dataset))
self.assertEqual(([3], ([2], [2]), [3]),
dataset_ops.get_legacy_output_shapes(dataset))
dataset = dataset.repeat(-1)
self.assertEqual(
(dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
dataset_ops.get_legacy_output_types(dataset))
self.assertEqual(([3], ([2], [2]), [3]),
dataset_ops.get_legacy_output_shapes(dataset))
dataset = dataset.filter(lambda x, y, z: True)
self.assertEqual(
(dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
dataset_ops.get_legacy_output_types(dataset))
self.assertEqual(([3], ([2], [2]), [3]),
dataset_ops.get_legacy_output_shapes(dataset))
dataset = dataset.take(5)
self.assertEqual(
(dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
dataset_ops.get_legacy_output_types(dataset))
self.assertEqual(([3], ([2], [2]), [3]),
dataset_ops.get_legacy_output_shapes(dataset))
dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
self.assertEqual(
((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
dataset_ops.get_legacy_output_types(dataset))
self.assertEqual((([3], [3]), ([2], [2])),
dataset_ops.get_legacy_output_shapes(dataset))
dataset = dataset.flat_map(lambda x, y: dataset_ops.Dataset.from_tensors(
((x[0], x[1]), (y[0], y[1]))))
self.assertEqual(
((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
dataset_ops.get_legacy_output_types(dataset))
self.assertEqual((([3], [3]), ([2], [2])),
dataset_ops.get_legacy_output_shapes(dataset))
dataset = dataset.batch(32)
self.assertEqual(
((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
dataset_ops.get_legacy_output_types(dataset))
dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
self.assertEqual(
(([None, 3], [None, 3]), ([None, 2], [None, 2])),
nest.pack_sequence_as(
dataset_output_shapes,
[s.as_list() for s in nest.flatten(dataset_output_shapes)]))
# Define a separate set of components with matching leading
# dimension for the from-slices constructor.
components_for_slices = (np.array([1, 2, 3],
dtype=np.int64), (np.array([4., 5., 6.]),
np.array([7., 8., 9.])),
np.array([10, 11, 12], dtype=np.int64))
dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
self.assertEqual(
(dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
dataset_ops.get_legacy_output_types(dataset))
self.assertEqual(([], ([], []), []),
dataset_ops.get_legacy_output_shapes(dataset))
@combinations.generate(test_base.default_test_combinations())
def testNoneComponent(self):
dataset = dataset_ops.Dataset.from_tensors((42, None))
if context.executing_eagerly():
self.assertDatasetProduces(dataset, expected_output=[(42, None)])
else:
iterator = dataset_ops.make_one_shot_iterator(dataset)
next_first, next_second = iterator.get_next()
self.assertEqual(next_second, None)
with self.cached_session() as sess:
self.assertEqual(sess.run(next_first), 42)
@combinations.generate(test_base.default_test_combinations())
def testNoneComponentInFunction(self):
@def_function.function
def fn(ds):
total = 0
it = iter(ds)
for elem in it:
x, _ = elem
total += x
return total
dataset = dataset_ops.Dataset.range(
10, output_type=dtypes.int32).map(lambda x: (x, None))
self.assertEqual(self.evaluate(fn(dataset)), 45)
@combinations.generate(test_base.default_test_combinations())
def testIncorrectPythonStructure(self):
# Tests that an exception is raised (as opposed to a segfault) when the
# Python structure assigned to a dataset is incorrect.
dataset = dataset_ops.Dataset.range(10)
spec = tensor_spec.TensorSpec([], dtypes.int64)
new_structure = (spec, spec)
dataset = dataset_ops._RestructuredDataset(dataset, new_structure)
dataset = dataset.map(lambda x, y: y)
with self.assertRaisesOpError(""):
self.getDatasetOutput(dataset)
if __name__ == "__main__":
test.main()