Fix op wrapper code to correctly intercept CompositeTensors when doing op layer wrapping.

PiperOrigin-RevId: 313821947
Change-Id: I8e37ab0b38b1f861ca36ba3017e0068ed961e206
This commit is contained in:
A. Unique TensorFlower 2020-05-29 11:50:52 -07:00 committed by TensorFlower Gardener
parent 90a005a4e9
commit 0ccc0ed961
3 changed files with 42 additions and 12 deletions

View File

@ -25,6 +25,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
@ -214,18 +215,17 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
for tensor in tensor_list:
if getattr(tensor, '_keras_history', None) is not None:
continue
if sparse_tensor.is_sparse(tensor) or ragged_tensor.is_ragged(tensor):
example = """
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
output = tf.keras.layers.Lambda(weights_mult)(input)
"""
raise ValueError('Tensorflow ops that generate ragged or sparse tensor '
'outputs are currently not supported by Keras automatic '
'op wrapping. Please wrap these ops in a Lambda layer: '
'\n\n```\n{example}\n```\n'.format(example=example))
op = tensor.op # The Op that created this Tensor.
if op not in processed_ops:
if op.type.startswith('Sparse'):
lambda_example = """
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
output = tf.keras.layers.Lambda(weights_mult)(input)
"""
raise ValueError(
'Sparse ops are not supported with functional models with built-in '
'layer wrapping. Please wrap the sparse ops in a Lambda layer like'
': \n{lambda_example}\n'.format(lambda_example=lambda_example))
# Recursively set `_keras_history`.
op_inputs = list(op.inputs)
constants = {}

View File

@ -16,12 +16,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -66,5 +70,32 @@ class TrackableWeightHandlerTest(keras_parameterized.TestCase):
_ = backend.batch_get_value(table_handler.get_tensors())
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class OpLayerTest(keras_parameterized.TestCase):
def test_tensor_op_layer(self):
int_values = keras.Input(shape=(2,), dtype=dtypes.int32)
float_values = math_ops.cast(int_values, dtypes.float32)
model = keras.Model(int_values, float_values)
model.compile(loss='mse')
input_data = np.array([[1, 2], [3, 4]], dtype=np.int32)
expected = [[1.0, 2.0], [3.0, 4.0]]
output = model.predict(input_data)
self.assertAllClose(expected, output)
def test_ragged_op_layer(self):
with self.assertRaisesRegexp(ValueError, 'Keras automatic op wrapping'):
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
float_values = math_ops.cast(int_values, dtypes.float32)
_ = keras.Model(int_values, float_values)
def test_sparse_op_layer(self):
with self.assertRaisesRegexp(ValueError, 'Keras automatic op wrapping'):
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, sparse=True)
float_values = math_ops.cast(int_values, dtypes.float32)
_ = keras.Model(int_values, float_values)
if __name__ == '__main__':
test.main()

View File

@ -1531,8 +1531,7 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
output = sparse_ops.sparse_minimum(inputs, inputs)
with self.assertRaisesRegexp(
ValueError,
'Sparse ops are not supported with functional models with built-in '
'layer wrapping'
'not supported by Keras automatic op wrapping'
):
training_module.Model([inputs], output)