Give a clearer error message if the user passes a nested Dataset
Fixes #44160 PiperOrigin-RevId: 341076418 Change-Id: I6d0614a193a61ba39fbb880eaaccf50133984d21
This commit is contained in:
parent
173e9bd169
commit
e437569c7d
@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import type_spec
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras.engine import training_generator_v1
|
from tensorflow.python.keras.engine import training_generator_v1
|
||||||
from tensorflow.python.keras.engine.base_layer import Layer
|
from tensorflow.python.keras.engine.base_layer import Layer
|
||||||
@ -161,6 +162,12 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
|
|||||||
'got {}'.format(type(data)))
|
'got {}'.format(type(data)))
|
||||||
|
|
||||||
if isinstance(data, dataset_ops.DatasetV2):
|
if isinstance(data, dataset_ops.DatasetV2):
|
||||||
|
# Validate that the dataset only contains single-tensor elements.
|
||||||
|
if not isinstance(data.element_spec, type_spec.TypeSpec):
|
||||||
|
raise TypeError(
|
||||||
|
'The dataset should yield single-Tensor elements. Use `dataset.map`'
|
||||||
|
'to select the element of interest.\n'
|
||||||
|
'Got dataset.element_spec=' + str(data.element_spec))
|
||||||
# Validate the datasets to try and ensure we haven't been passed one with
|
# Validate the datasets to try and ensure we haven't been passed one with
|
||||||
# infinite size. That would cause an infinite loop here.
|
# infinite size. That would cause an infinite loop here.
|
||||||
if tf_utils.dataset_is_infinite(data):
|
if tf_utils.dataset_is_infinite(data):
|
||||||
|
@ -248,6 +248,17 @@ class PreprocessingLayerTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
|
self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
|
||||||
|
|
||||||
|
def test_adapt_dataset_of_tuples_fails(self):
|
||||||
|
"""Test that preproc layers can adapt() before build() is called."""
|
||||||
|
input_dataset = dataset_ops.Dataset.from_tensor_slices((
|
||||||
|
np.array([[1], [2], [3], [4], [5], [0]]),
|
||||||
|
np.array([[1], [2], [3], [4], [5], [0]])))
|
||||||
|
|
||||||
|
layer = get_layer()
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(TypeError, "single-Tensor elements"):
|
||||||
|
layer.adapt(input_dataset)
|
||||||
|
|
||||||
def test_post_build_adapt_update_dataset(self):
|
def test_post_build_adapt_update_dataset(self):
|
||||||
"""Test that preproc layers can adapt() after build() is called."""
|
"""Test that preproc layers can adapt() after build() is called."""
|
||||||
input_dataset = dataset_ops.Dataset.from_tensor_slices(
|
input_dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||||
|
@ -110,6 +110,7 @@ do_pylint() {
|
|||||||
"^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
|
"^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
|
||||||
"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
|
"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
|
||||||
"^tensorflow/python/keras/engine/base_layer.py.*\[E1102.*not-callable "\
|
"^tensorflow/python/keras/engine/base_layer.py.*\[E1102.*not-callable "\
|
||||||
|
"^tensorflow/python/keras/layers/preprocessing/.*\[E1102.*not-callable "\
|
||||||
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
|
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
|
||||||
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
|
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
|
||||||
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable "\
|
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable "\
|
||||||
|
Loading…
Reference in New Issue
Block a user