Give a clearer error message if the user passes a nested Dataset
Fixes #44160 PiperOrigin-RevId: 341076418 Change-Id: I6d0614a193a61ba39fbb880eaaccf50133984d21
This commit is contained in:
parent
173e9bd169
commit
e437569c7d
@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import training_generator_v1
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
@ -161,6 +162,12 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
|
||||
'got {}'.format(type(data)))
|
||||
|
||||
if isinstance(data, dataset_ops.DatasetV2):
|
||||
# Validate that the dataset only contains single-tensor elements.
|
||||
if not isinstance(data.element_spec, type_spec.TypeSpec):
|
||||
raise TypeError(
|
||||
'The dataset should yield single-Tensor elements. Use `dataset.map`'
|
||||
'to select the element of interest.\n'
|
||||
'Got dataset.element_spec=' + str(data.element_spec))
|
||||
# Validate the datasets to try and ensure we haven't been passed one with
|
||||
# infinite size. That would cause an infinite loop here.
|
||||
if tf_utils.dataset_is_infinite(data):
|
||||
|
@ -248,6 +248,17 @@ class PreprocessingLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
|
||||
|
||||
def test_adapt_dataset_of_tuples_fails(self):
|
||||
"""Test that preproc layers can adapt() before build() is called."""
|
||||
input_dataset = dataset_ops.Dataset.from_tensor_slices((
|
||||
np.array([[1], [2], [3], [4], [5], [0]]),
|
||||
np.array([[1], [2], [3], [4], [5], [0]])))
|
||||
|
||||
layer = get_layer()
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "single-Tensor elements"):
|
||||
layer.adapt(input_dataset)
|
||||
|
||||
def test_post_build_adapt_update_dataset(self):
|
||||
"""Test that preproc layers can adapt() after build() is called."""
|
||||
input_dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
|
@ -110,6 +110,7 @@ do_pylint() {
|
||||
"^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
|
||||
"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
|
||||
"^tensorflow/python/keras/engine/base_layer.py.*\[E1102.*not-callable "\
|
||||
"^tensorflow/python/keras/layers/preprocessing/.*\[E1102.*not-callable "\
|
||||
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
|
||||
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
|
||||
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable "\
|
||||
|
Loading…
Reference in New Issue
Block a user