Update Keras training logic for detecting whether an input dataset is shuffled.

It was previously using instance type checking, and heuristic is not fully accurate.
Update the logic to parse the dataset graph_def, which should be deterministic about the information we need.

This also allows keras to remove the private API symbol usage of tf.data.

PiperOrigin-RevId: 332087112
Change-Id: I6db0cfbe3ac00ace2d736b6a2b0a15cf04981185
This commit is contained in:
Scott Zhu 2020-09-16 14:20:52 -07:00 committed by TensorFlower Gardener
parent 0af5cc0c64
commit f96903abea
3 changed files with 40 additions and 87 deletions
tensorflow/python

View File

@ -105,7 +105,7 @@ EOF
# Test debugging of tf.keras, with non-debug runs included.
cat << EOF | ${DEBUG_KERAS_BIN} --debug --ui_type=readline --use_random_config_path
run -t 10
run -t 11
EOF
# Test offline_analyzer.

View File

@ -29,12 +29,12 @@ import numpy as np
import six
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.core.framework import graph_pb2
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor_utils
from tensorflow.python.framework import dtypes
@ -1567,60 +1567,12 @@ def is_eager_dataset_or_iterator(data):
# pylint: disable=protected-access
def assert_not_shuffled(dataset):
"""Asserts that `dataset` is not shuffled.
The algorithm used by this method is sound but not complete. In other words,
if the method fails to establish the assertion, it does not mean the dataset
is shuffled.
Example usage:
```python
try:
assert_not_shuffled(dataset)
# safe to assume `dataset` it not shuffled here
expect ValueError:
# make no assumptions about `dataset`
```
Args:
dataset: The dataset to analyze.
Raises:
ValueError: If the method cannot establish the assertion.
"""
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
return assert_not_shuffled(dataset._dataset)
def get_dataset_graph_def(dataset):
if context.executing_eagerly():
graph_def_str = dataset._as_serialized_graph().numpy()
else:
allowed_types = [
dataset_ops._OptionsDataset,
dataset_ops.BatchDataset,
dataset_ops.ConcatenateDataset,
dataset_ops.CacheDataset,
dataset_ops.FilterDataset,
dataset_ops.MapDataset,
dataset_ops.PaddedBatchDataset,
dataset_ops.ParallelMapDataset,
dataset_ops.PrefetchDataset,
dataset_ops.RangeDataset,
dataset_ops.RepeatDataset,
dataset_ops.SkipDataset,
dataset_ops.SparseTensorSliceDataset,
dataset_ops.TakeDataset,
dataset_ops.TensorDataset,
dataset_ops.TensorSliceDataset,
dataset_ops.WindowDataset,
dataset_ops.ZipDataset,
readers.FixedLengthRecordDatasetV2,
readers.TextLineDatasetV2,
readers.TFRecordDatasetV2,
]
for ty in allowed_types:
if isinstance(dataset, ty):
for input_dataset in dataset._inputs():
assert_not_shuffled(input_dataset)
return
raise ValueError('Could not assert that dataset is not shuffled.')
graph_def_str = K.get_value(dataset._as_serialized_graph())
return graph_pb2.GraphDef().FromString(graph_def_str)
def verify_dataset_shuffled(x):
@ -1629,18 +1581,22 @@ def verify_dataset_shuffled(x):
Args:
x: Dataset passed as an input to the model.
Raises:
ValueError: if the dataset is not already shuffled.
Returns:
boolean, whether the input dataset is shuffled or not.
"""
assert isinstance(x, dataset_ops.DatasetV2)
try:
assert_not_shuffled(x)
except ValueError:
# Dataset may or may not be shuffled.
return
else:
logging.warning('Expected a shuffled dataset but input dataset `x` is '
'not shuffled. Please invoke `shuffle()` on input dataset.')
graph_def = get_dataset_graph_def(x)
for node in graph_def.node:
if node.op.startswith('ShuffleDataset'):
return True
# Also check graph_def.library.function for ds.interleave or ds.flat_map
for function in graph_def.library.function:
for node in function.node_def:
if node.op.startswith('ShuffleDataset'):
return True
logging.warning('Expected a shuffled dataset but input dataset `x` is '
'not shuffled. Please invoke `shuffle()` on input dataset.')
return False
def is_dataset_or_iterator(data):

View File

@ -140,7 +140,9 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5))),
('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map(
lambda _: dataset_ops.Dataset.from_tensors(0)), ValueError),
lambda _: dataset_ops.Dataset.from_tensors(0))),
('FlatMap_Shuffle', lambda: dataset_ops.Dataset.range(5).flat_map(
lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1)), True),
('Filter', lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)),
('FixedLengthRecordDatasetV2',
lambda: readers.FixedLengthRecordDatasetV2([], 42)),
@ -148,8 +150,10 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
('FromTensorSlices',
lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])),
('Interleave', lambda: dataset_ops.Dataset.range(5).interleave(
lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1),
ValueError),
lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1)),
('Interleave_Shuffle', lambda: dataset_ops.Dataset.range(5).interleave(
lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1),
cycle_length=1), True),
('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)),
('Options',
lambda: dataset_ops.Dataset.range(5).with_options(dataset_ops.Options())
@ -158,13 +162,13 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave(
lambda _: dataset_ops.Dataset.from_tensors(0),
cycle_length=1,
num_parallel_calls=1), ValueError),
num_parallel_calls=1)),
('ParallelMap', lambda: dataset_ops.Dataset.range(5).map(
lambda x: x, num_parallel_calls=1)),
('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)),
('Range', lambda: dataset_ops.Dataset.range(0)),
('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)),
('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1), ValueError),
('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1), True),
('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)),
('Take', lambda: dataset_ops.Dataset.range(5).take(2)),
('TextLineDataset', lambda: readers.TextLineDatasetV2([])),
@ -173,24 +177,17 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))),
# pylint: enable=g-long-lambda
)
def test_assert_not_shuffled(self, dataset_fn, expected_error=None):
if expected_error is None:
training_utils.assert_not_shuffled(dataset_fn())
def test_verify_dataset_shuffled(self, dataset_fn, expect_shuffled=False):
dataset = dataset_fn()
if not expect_shuffled:
with test.mock.patch.object(logging, 'warning') as mock_log:
shuffled = training_utils.verify_dataset_shuffled(dataset)
self.assertRegex(
str(mock_log.call_args), 'input dataset `x` is not shuffled.')
self.assertFalse(shuffled)
else:
with self.assertRaises(expected_error):
training_utils.assert_not_shuffled(dataset_fn())
def test_verify_dataset_shuffled(self):
dataset = dataset_ops.Dataset.range(5)
training_utils.assert_not_shuffled(dataset)
with test.mock.patch.object(logging, 'warning') as mock_log:
training_utils.verify_dataset_shuffled(dataset)
self.assertRegex(
str(mock_log.call_args), 'input dataset `x` is not shuffled.')
shuffled_dataset = dataset.shuffle(10)
training_utils.verify_dataset_shuffled(shuffled_dataset)
self.assertTrue(training_utils.verify_dataset_shuffled(dataset))
class StandardizeWeightsTest(keras_parameterized.TestCase):