Make the sparse ops with Keras functional models error more verbose to aid debugging.
PiperOrigin-RevId: 321829617 Change-Id: Ie2e0c131ca7632b37eb16aaeccbcd4894ee6bbd4
This commit is contained in:
parent
7bf203f9d6
commit
a1bcafd8d8
tensorflow/python/keras/engine
@ -211,18 +211,18 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
||||
# TODO(omalleyt): Resolve circular dependency.
|
||||
from tensorflow.python.keras.engine import base_layer # pylint: disable=g-import-not-at-top
|
||||
tensor_list = nest.flatten(tensors)
|
||||
sparse_ops = []
|
||||
ragged_tensors = []
|
||||
for tensor in tensor_list:
|
||||
if getattr(tensor, '_keras_history', None) is not None:
|
||||
continue
|
||||
if sparse_tensor.is_sparse(tensor) or ragged_tensor.is_ragged(tensor):
|
||||
example = """
|
||||
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
|
||||
output = tf.keras.layers.Lambda(weights_mult)(input)
|
||||
"""
|
||||
raise ValueError('Tensorflow ops that generate ragged or sparse tensor '
|
||||
'outputs are currently not supported by Keras automatic '
|
||||
'op wrapping. Please wrap these ops in a Lambda layer: '
|
||||
'\n\n```\n{example}\n```\n'.format(example=example))
|
||||
if sparse_tensor.is_sparse(tensor):
|
||||
sparse_ops.append(tensor.op)
|
||||
continue
|
||||
if ragged_tensor.is_ragged(tensor):
|
||||
# Ragged tensors don't have an op property
|
||||
ragged_tensors.append(tensor)
|
||||
continue
|
||||
op = tensor.op # The Op that created this Tensor.
|
||||
if op not in processed_ops:
|
||||
# Recursively set `_keras_history`.
|
||||
@ -264,6 +264,21 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
||||
kwargs={},
|
||||
outputs=op.outputs)
|
||||
processed_ops.update([op])
|
||||
if sparse_ops or ragged_tensors:
|
||||
lambda_example = """
|
||||
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
|
||||
output = tf.keras.layers.Lambda(weights_mult)(input)
|
||||
"""
|
||||
raise ValueError(
|
||||
'Tensorflow ops that generate ragged or sparse tensor '
|
||||
'outputs are currently not supported by Keras automatic '
|
||||
'op wrapping. Please wrap these ops in a Lambda layer: '
|
||||
'\n\n```\n{example}\n```\n'
|
||||
'Sparse ops encountered: {sparse_ops}\n'
|
||||
'Ragged tensors encountered: {ragged_tensors}\n'.format(
|
||||
example=lambda_example,
|
||||
sparse_ops=str(sparse_ops),
|
||||
ragged_tensors=str(ragged_tensors)))
|
||||
return processed_ops, created_layers
|
||||
|
||||
|
||||
|
@ -91,14 +91,21 @@ class OpLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_ragged_op_layer(self):
|
||||
with testing_utils.use_keras_tensors_scope(False):
|
||||
with self.assertRaisesRegex(ValueError, 'Keras automatic op wrapping'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, '(?ms)Keras automatic op wrapping'
|
||||
'.*Ragged tensors encountered: '
|
||||
r'\[tf.RaggedTensor\(values=Tensor\("Cast:0", shape=\((\?|None),\), '
|
||||
r'dtype=float32\), row_splits=Tensor\("Placeholder_1:0", '
|
||||
r'shape=\((\?|None),\), dtype=int64\)\)\]'):
|
||||
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
|
||||
float_values = math_ops.cast(int_values, dtypes.float32)
|
||||
_ = keras.Model(int_values, float_values)
|
||||
|
||||
def test_sparse_op_layer(self):
|
||||
with testing_utils.use_keras_tensors_scope(False):
|
||||
with self.assertRaisesRegex(ValueError, 'Keras automatic op wrapping'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "(?ms)Keras automatic op wrapping"
|
||||
r".*Sparse ops encountered: \[\<tf\.Operation 'Cast' type=Cast\>\]"):
|
||||
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, sparse=True)
|
||||
float_values = math_ops.cast(int_values, dtypes.float32)
|
||||
_ = keras.Model(int_values, float_values)
|
||||
|
Loading…
Reference in New Issue
Block a user