Make the sparse ops with Keras functional models error more verbose to aid debugging.

PiperOrigin-RevId: 321829617
Change-Id: Ie2e0c131ca7632b37eb16aaeccbcd4894ee6bbd4
This commit is contained in:
R. Alex hofer 2020-07-17 12:19:08 -07:00 committed by TensorFlower Gardener
parent 7bf203f9d6
commit a1bcafd8d8
2 changed files with 33 additions and 11 deletions
tensorflow/python/keras/engine

View File

@ -211,18 +211,18 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
# TODO(omalleyt): Resolve circular dependency.
from tensorflow.python.keras.engine import base_layer # pylint: disable=g-import-not-at-top
tensor_list = nest.flatten(tensors)
sparse_ops = []
ragged_tensors = []
for tensor in tensor_list:
if getattr(tensor, '_keras_history', None) is not None:
continue
if sparse_tensor.is_sparse(tensor) or ragged_tensor.is_ragged(tensor):
example = """
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
output = tf.keras.layers.Lambda(weights_mult)(input)
"""
raise ValueError('Tensorflow ops that generate ragged or sparse tensor '
'outputs are currently not supported by Keras automatic '
'op wrapping. Please wrap these ops in a Lambda layer: '
'\n\n```\n{example}\n```\n'.format(example=example))
if sparse_tensor.is_sparse(tensor):
sparse_ops.append(tensor.op)
continue
if ragged_tensor.is_ragged(tensor):
# Ragged tensors don't have an op property
ragged_tensors.append(tensor)
continue
op = tensor.op # The Op that created this Tensor.
if op not in processed_ops:
# Recursively set `_keras_history`.
@ -264,6 +264,21 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
kwargs={},
outputs=op.outputs)
processed_ops.update([op])
if sparse_ops or ragged_tensors:
lambda_example = """
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
output = tf.keras.layers.Lambda(weights_mult)(input)
"""
raise ValueError(
'Tensorflow ops that generate ragged or sparse tensor '
'outputs are currently not supported by Keras automatic '
'op wrapping. Please wrap these ops in a Lambda layer: '
'\n\n```\n{example}\n```\n'
'Sparse ops encountered: {sparse_ops}\n'
'Ragged tensors encountered: {ragged_tensors}\n'.format(
example=lambda_example,
sparse_ops=str(sparse_ops),
ragged_tensors=str(ragged_tensors)))
return processed_ops, created_layers

View File

@ -91,14 +91,21 @@ class OpLayerTest(keras_parameterized.TestCase):
def test_ragged_op_layer(self):
with testing_utils.use_keras_tensors_scope(False):
with self.assertRaisesRegex(ValueError, 'Keras automatic op wrapping'):
with self.assertRaisesRegex(
ValueError, '(?ms)Keras automatic op wrapping'
'.*Ragged tensors encountered: '
r'\[tf.RaggedTensor\(values=Tensor\("Cast:0", shape=\((\?|None),\), '
r'dtype=float32\), row_splits=Tensor\("Placeholder_1:0", '
r'shape=\((\?|None),\), dtype=int64\)\)\]'):
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
float_values = math_ops.cast(int_values, dtypes.float32)
_ = keras.Model(int_values, float_values)
def test_sparse_op_layer(self):
with testing_utils.use_keras_tensors_scope(False):
with self.assertRaisesRegex(ValueError, 'Keras automatic op wrapping'):
with self.assertRaisesRegex(
ValueError, "(?ms)Keras automatic op wrapping"
r".*Sparse ops encountered: \[\<tf\.Operation 'Cast' type=Cast\>\]"):
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, sparse=True)
float_values = math_ops.cast(int_values, dtypes.float32)
_ = keras.Model(int_values, float_values)