Fix op wrapper code to correctly intercept CompositeTensors when doing op layer wrapping.
PiperOrigin-RevId: 313821947 Change-Id: I8e37ab0b38b1f861ca36ba3017e0068ed961e206
This commit is contained in:
parent
90a005a4e9
commit
0ccc0ed961
@ -25,6 +25,7 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend
|
||||
@ -214,18 +215,17 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
||||
for tensor in tensor_list:
|
||||
if getattr(tensor, '_keras_history', None) is not None:
|
||||
continue
|
||||
if sparse_tensor.is_sparse(tensor) or ragged_tensor.is_ragged(tensor):
|
||||
example = """
|
||||
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
|
||||
output = tf.keras.layers.Lambda(weights_mult)(input)
|
||||
"""
|
||||
raise ValueError('Tensorflow ops that generate ragged or sparse tensor '
|
||||
'outputs are currently not supported by Keras automatic '
|
||||
'op wrapping. Please wrap these ops in a Lambda layer: '
|
||||
'\n\n```\n{example}\n```\n'.format(example=example))
|
||||
op = tensor.op # The Op that created this Tensor.
|
||||
if op not in processed_ops:
|
||||
if op.type.startswith('Sparse'):
|
||||
lambda_example = """
|
||||
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
|
||||
output = tf.keras.layers.Lambda(weights_mult)(input)
|
||||
"""
|
||||
raise ValueError(
|
||||
'Sparse ops are not supported with functional models with built-in '
|
||||
'layer wrapping. Please wrap the sparse ops in a Lambda layer like'
|
||||
': \n{lambda_example}\n'.format(lambda_example=lambda_example))
|
||||
|
||||
# Recursively set `_keras_history`.
|
||||
op_inputs = list(op.inputs)
|
||||
constants = {}
|
||||
|
@ -16,12 +16,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -66,5 +70,32 @@ class TrackableWeightHandlerTest(keras_parameterized.TestCase):
|
||||
_ = backend.batch_get_value(table_handler.get_tensors())
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class OpLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_tensor_op_layer(self):
|
||||
int_values = keras.Input(shape=(2,), dtype=dtypes.int32)
|
||||
float_values = math_ops.cast(int_values, dtypes.float32)
|
||||
model = keras.Model(int_values, float_values)
|
||||
model.compile(loss='mse')
|
||||
|
||||
input_data = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
||||
expected = [[1.0, 2.0], [3.0, 4.0]]
|
||||
output = model.predict(input_data)
|
||||
self.assertAllClose(expected, output)
|
||||
|
||||
def test_ragged_op_layer(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Keras automatic op wrapping'):
|
||||
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
|
||||
float_values = math_ops.cast(int_values, dtypes.float32)
|
||||
_ = keras.Model(int_values, float_values)
|
||||
|
||||
def test_sparse_op_layer(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Keras automatic op wrapping'):
|
||||
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, sparse=True)
|
||||
float_values = math_ops.cast(int_values, dtypes.float32)
|
||||
_ = keras.Model(int_values, float_values)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -1531,8 +1531,7 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||
output = sparse_ops.sparse_minimum(inputs, inputs)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'Sparse ops are not supported with functional models with built-in '
|
||||
'layer wrapping'
|
||||
'not supported by Keras automatic op wrapping'
|
||||
):
|
||||
training_module.Model([inputs], output)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user