Give a clearer error message if the user passes a nested Dataset

Fixes #44160

PiperOrigin-RevId: 341076418
Change-Id: I6d0614a193a61ba39fbb880eaaccf50133984d21
This commit is contained in:
Mark Daoust 2020-11-06 10:41:02 -08:00 committed by TensorFlower Gardener
parent 173e9bd169
commit e437569c7d
3 changed files with 19 additions and 0 deletions

View File

@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import type_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import training_generator_v1
from tensorflow.python.keras.engine.base_layer import Layer
@ -161,6 +162,12 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
'got {}'.format(type(data)))
if isinstance(data, dataset_ops.DatasetV2):
# Validate that the dataset only contains single-tensor elements.
if not isinstance(data.element_spec, type_spec.TypeSpec):
raise TypeError(
'The dataset should yield single-Tensor elements. Use `dataset.map`'
'to select the element of interest.\n'
'Got dataset.element_spec=' + str(data.element_spec))
# Validate the datasets to try and ensure we haven't been passed one with
# infinite size. That would cause an infinite loop here.
if tf_utils.dataset_is_infinite(data):

View File

@ -248,6 +248,17 @@ class PreprocessingLayerTest(keras_parameterized.TestCase):
self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
def test_adapt_dataset_of_tuples_fails(self):
"""Test that preproc layers can adapt() before build() is called."""
input_dataset = dataset_ops.Dataset.from_tensor_slices((
np.array([[1], [2], [3], [4], [5], [0]]),
np.array([[1], [2], [3], [4], [5], [0]])))
layer = get_layer()
with self.assertRaisesRegex(TypeError, "single-Tensor elements"):
layer.adapt(input_dataset)
def test_post_build_adapt_update_dataset(self):
"""Test that preproc layers can adapt() after build() is called."""
input_dataset = dataset_ops.Dataset.from_tensor_slices(

View File

@ -110,6 +110,7 @@ do_pylint() {
"^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/keras/engine/base_layer.py.*\[E1102.*not-callable "\
"^tensorflow/python/keras/layers/preprocessing/.*\[E1102.*not-callable "\
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable "\