diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index 6d25995e4c2..d7bd3d5d372 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -25,6 +25,7 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend @@ -214,18 +215,17 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers): for tensor in tensor_list: if getattr(tensor, '_keras_history', None) is not None: continue + if sparse_tensor.is_sparse(tensor) or ragged_tensor.is_ragged(tensor): + example = """ + weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights) + output = tf.keras.layers.Lambda(weights_mult)(input) + """ + raise ValueError('Tensorflow ops that generate ragged or sparse tensor ' + 'outputs are currently not supported by Keras automatic ' + 'op wrapping. Please wrap these ops in a Lambda layer: ' + '\n\n```\n{example}\n```\n'.format(example=example)) op = tensor.op # The Op that created this Tensor. if op not in processed_ops: - if op.type.startswith('Sparse'): - lambda_example = """ - weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights) - output = tf.keras.layers.Lambda(weights_mult)(input) - """ - raise ValueError( - 'Sparse ops are not supported with functional models with built-in ' - 'layer wrapping. Please wrap the sparse ops in a Lambda layer like' - ': \n{lambda_example}\n'.format(lambda_example=lambda_example)) - # Recursively set `_keras_history`. op_inputs = list(op.inputs) constants = {} diff --git a/tensorflow/python/keras/engine/base_layer_utils_test.py b/tensorflow/python/keras/engine/base_layer_utils_test.py index 21f539d89c5..d27a4cd3297 100644 --- a/tensorflow/python/keras/engine/base_layer_utils_test.py +++ b/tensorflow/python/keras/engine/base_layer_utils_test.py @@ -16,12 +16,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.python import keras from tensorflow.python.framework import dtypes from tensorflow.python.keras import backend from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -66,5 +70,32 @@ class TrackableWeightHandlerTest(keras_parameterized.TestCase): _ = backend.batch_get_value(table_handler.get_tensors()) +@combinations.generate(combinations.combine(mode=['graph', 'eager'])) +class OpLayerTest(keras_parameterized.TestCase): + + def test_tensor_op_layer(self): + int_values = keras.Input(shape=(2,), dtype=dtypes.int32) + float_values = math_ops.cast(int_values, dtypes.float32) + model = keras.Model(int_values, float_values) + model.compile(loss='mse') + + input_data = np.array([[1, 2], [3, 4]], dtype=np.int32) + expected = [[1.0, 2.0], [3.0, 4.0]] + output = model.predict(input_data) + self.assertAllClose(expected, output) + + def test_ragged_op_layer(self): + with self.assertRaisesRegexp(ValueError, 'Keras automatic op wrapping'): + int_values = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True) + float_values = math_ops.cast(int_values, dtypes.float32) + _ = keras.Model(int_values, float_values) + + def test_sparse_op_layer(self): + with self.assertRaisesRegexp(ValueError, 'Keras automatic op wrapping'): + int_values = keras.Input(shape=(None,), dtype=dtypes.int32, sparse=True) + float_values = math_ops.cast(int_values, dtypes.float32) + _ = keras.Model(int_values, float_values) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index c1c498b207b..e1180b5234b 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1531,8 +1531,7 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase): output = sparse_ops.sparse_minimum(inputs, inputs) with self.assertRaisesRegexp( ValueError, - 'Sparse ops are not supported with functional models with built-in ' - 'layer wrapping' + 'not supported by Keras automatic op wrapping' ): training_module.Model([inputs], output)