Update Keras training logic for detecting whether an input dataset is shuffled.
It was previously using instance type checking, and heuristic is not fully accurate. Update the logic to parse the dataset graph_def, which should be deterministic about the information we need. This also allows keras to remove the private API symbol usage of tf.data. PiperOrigin-RevId: 332087112 Change-Id: I6db0cfbe3ac00ace2d736b6a2b0a15cf04981185
This commit is contained in:
parent
0af5cc0c64
commit
f96903abea
tensorflow/python
@ -105,7 +105,7 @@ EOF
|
||||
|
||||
# Test debugging of tf.keras, with non-debug runs included.
|
||||
cat << EOF | ${DEBUG_KERAS_BIN} --debug --ui_type=readline --use_random_config_path
|
||||
run -t 10
|
||||
run -t 11
|
||||
EOF
|
||||
|
||||
# Test offline_analyzer.
|
||||
|
@ -29,12 +29,12 @@ import numpy as np
|
||||
import six
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor_utils
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -1567,60 +1567,12 @@ def is_eager_dataset_or_iterator(data):
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def assert_not_shuffled(dataset):
|
||||
"""Asserts that `dataset` is not shuffled.
|
||||
|
||||
The algorithm used by this method is sound but not complete. In other words,
|
||||
if the method fails to establish the assertion, it does not mean the dataset
|
||||
is shuffled.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
try:
|
||||
assert_not_shuffled(dataset)
|
||||
# safe to assume `dataset` it not shuffled here
|
||||
expect ValueError:
|
||||
# make no assumptions about `dataset`
|
||||
```
|
||||
|
||||
Args:
|
||||
dataset: The dataset to analyze.
|
||||
|
||||
Raises:
|
||||
ValueError: If the method cannot establish the assertion.
|
||||
"""
|
||||
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
|
||||
return assert_not_shuffled(dataset._dataset)
|
||||
def get_dataset_graph_def(dataset):
|
||||
if context.executing_eagerly():
|
||||
graph_def_str = dataset._as_serialized_graph().numpy()
|
||||
else:
|
||||
allowed_types = [
|
||||
dataset_ops._OptionsDataset,
|
||||
dataset_ops.BatchDataset,
|
||||
dataset_ops.ConcatenateDataset,
|
||||
dataset_ops.CacheDataset,
|
||||
dataset_ops.FilterDataset,
|
||||
dataset_ops.MapDataset,
|
||||
dataset_ops.PaddedBatchDataset,
|
||||
dataset_ops.ParallelMapDataset,
|
||||
dataset_ops.PrefetchDataset,
|
||||
dataset_ops.RangeDataset,
|
||||
dataset_ops.RepeatDataset,
|
||||
dataset_ops.SkipDataset,
|
||||
dataset_ops.SparseTensorSliceDataset,
|
||||
dataset_ops.TakeDataset,
|
||||
dataset_ops.TensorDataset,
|
||||
dataset_ops.TensorSliceDataset,
|
||||
dataset_ops.WindowDataset,
|
||||
dataset_ops.ZipDataset,
|
||||
readers.FixedLengthRecordDatasetV2,
|
||||
readers.TextLineDatasetV2,
|
||||
readers.TFRecordDatasetV2,
|
||||
]
|
||||
for ty in allowed_types:
|
||||
if isinstance(dataset, ty):
|
||||
for input_dataset in dataset._inputs():
|
||||
assert_not_shuffled(input_dataset)
|
||||
return
|
||||
raise ValueError('Could not assert that dataset is not shuffled.')
|
||||
graph_def_str = K.get_value(dataset._as_serialized_graph())
|
||||
return graph_pb2.GraphDef().FromString(graph_def_str)
|
||||
|
||||
|
||||
def verify_dataset_shuffled(x):
|
||||
@ -1629,18 +1581,22 @@ def verify_dataset_shuffled(x):
|
||||
Args:
|
||||
x: Dataset passed as an input to the model.
|
||||
|
||||
Raises:
|
||||
ValueError: if the dataset is not already shuffled.
|
||||
Returns:
|
||||
boolean, whether the input dataset is shuffled or not.
|
||||
"""
|
||||
assert isinstance(x, dataset_ops.DatasetV2)
|
||||
try:
|
||||
assert_not_shuffled(x)
|
||||
except ValueError:
|
||||
# Dataset may or may not be shuffled.
|
||||
return
|
||||
else:
|
||||
logging.warning('Expected a shuffled dataset but input dataset `x` is '
|
||||
'not shuffled. Please invoke `shuffle()` on input dataset.')
|
||||
graph_def = get_dataset_graph_def(x)
|
||||
for node in graph_def.node:
|
||||
if node.op.startswith('ShuffleDataset'):
|
||||
return True
|
||||
# Also check graph_def.library.function for ds.interleave or ds.flat_map
|
||||
for function in graph_def.library.function:
|
||||
for node in function.node_def:
|
||||
if node.op.startswith('ShuffleDataset'):
|
||||
return True
|
||||
logging.warning('Expected a shuffled dataset but input dataset `x` is '
|
||||
'not shuffled. Please invoke `shuffle()` on input dataset.')
|
||||
return False
|
||||
|
||||
|
||||
def is_dataset_or_iterator(data):
|
||||
|
@ -140,7 +140,9 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate(
|
||||
dataset_ops.Dataset.range(5))),
|
||||
('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map(
|
||||
lambda _: dataset_ops.Dataset.from_tensors(0)), ValueError),
|
||||
lambda _: dataset_ops.Dataset.from_tensors(0))),
|
||||
('FlatMap_Shuffle', lambda: dataset_ops.Dataset.range(5).flat_map(
|
||||
lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1)), True),
|
||||
('Filter', lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)),
|
||||
('FixedLengthRecordDatasetV2',
|
||||
lambda: readers.FixedLengthRecordDatasetV2([], 42)),
|
||||
@ -148,8 +150,10 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
('FromTensorSlices',
|
||||
lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])),
|
||||
('Interleave', lambda: dataset_ops.Dataset.range(5).interleave(
|
||||
lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1),
|
||||
ValueError),
|
||||
lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1)),
|
||||
('Interleave_Shuffle', lambda: dataset_ops.Dataset.range(5).interleave(
|
||||
lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1),
|
||||
cycle_length=1), True),
|
||||
('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)),
|
||||
('Options',
|
||||
lambda: dataset_ops.Dataset.range(5).with_options(dataset_ops.Options())
|
||||
@ -158,13 +162,13 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave(
|
||||
lambda _: dataset_ops.Dataset.from_tensors(0),
|
||||
cycle_length=1,
|
||||
num_parallel_calls=1), ValueError),
|
||||
num_parallel_calls=1)),
|
||||
('ParallelMap', lambda: dataset_ops.Dataset.range(5).map(
|
||||
lambda x: x, num_parallel_calls=1)),
|
||||
('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)),
|
||||
('Range', lambda: dataset_ops.Dataset.range(0)),
|
||||
('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)),
|
||||
('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1), ValueError),
|
||||
('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1), True),
|
||||
('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)),
|
||||
('Take', lambda: dataset_ops.Dataset.range(5).take(2)),
|
||||
('TextLineDataset', lambda: readers.TextLineDatasetV2([])),
|
||||
@ -173,24 +177,17 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))),
|
||||
# pylint: enable=g-long-lambda
|
||||
)
|
||||
def test_assert_not_shuffled(self, dataset_fn, expected_error=None):
|
||||
if expected_error is None:
|
||||
training_utils.assert_not_shuffled(dataset_fn())
|
||||
def test_verify_dataset_shuffled(self, dataset_fn, expect_shuffled=False):
|
||||
dataset = dataset_fn()
|
||||
|
||||
if not expect_shuffled:
|
||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
||||
shuffled = training_utils.verify_dataset_shuffled(dataset)
|
||||
self.assertRegex(
|
||||
str(mock_log.call_args), 'input dataset `x` is not shuffled.')
|
||||
self.assertFalse(shuffled)
|
||||
else:
|
||||
with self.assertRaises(expected_error):
|
||||
training_utils.assert_not_shuffled(dataset_fn())
|
||||
|
||||
def test_verify_dataset_shuffled(self):
|
||||
dataset = dataset_ops.Dataset.range(5)
|
||||
training_utils.assert_not_shuffled(dataset)
|
||||
|
||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
||||
training_utils.verify_dataset_shuffled(dataset)
|
||||
self.assertRegex(
|
||||
str(mock_log.call_args), 'input dataset `x` is not shuffled.')
|
||||
|
||||
shuffled_dataset = dataset.shuffle(10)
|
||||
training_utils.verify_dataset_shuffled(shuffled_dataset)
|
||||
self.assertTrue(training_utils.verify_dataset_shuffled(dataset))
|
||||
|
||||
|
||||
class StandardizeWeightsTest(keras_parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user