Fix op wrapper code to correctly intercept CompositeTensors when doing op layer wrapping.

PiperOrigin-RevId: 313821947
Change-Id: I8e37ab0b38b1f861ca36ba3017e0068ed961e206
This commit is contained in:
A. Unique TensorFlower 2020-05-29 11:50:52 -07:00 committed by TensorFlower Gardener
parent 90a005a4e9
commit 0ccc0ed961
3 changed files with 42 additions and 12 deletions

View File

@ -25,6 +25,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
@ -214,18 +215,17 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
for tensor in tensor_list: for tensor in tensor_list:
if getattr(tensor, '_keras_history', None) is not None: if getattr(tensor, '_keras_history', None) is not None:
continue continue
op = tensor.op # The Op that created this Tensor. if sparse_tensor.is_sparse(tensor) or ragged_tensor.is_ragged(tensor):
if op not in processed_ops: example = """
if op.type.startswith('Sparse'):
lambda_example = """
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights) weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
output = tf.keras.layers.Lambda(weights_mult)(input) output = tf.keras.layers.Lambda(weights_mult)(input)
""" """
raise ValueError( raise ValueError('Tensorflow ops that generate ragged or sparse tensor '
'Sparse ops are not supported with functional models with built-in ' 'outputs are currently not supported by Keras automatic '
'layer wrapping. Please wrap the sparse ops in a Lambda layer like' 'op wrapping. Please wrap these ops in a Lambda layer: '
': \n{lambda_example}\n'.format(lambda_example=lambda_example)) '\n\n```\n{example}\n```\n'.format(example=example))
op = tensor.op # The Op that created this Tensor.
if op not in processed_ops:
# Recursively set `_keras_history`. # Recursively set `_keras_history`.
op_inputs = list(op.inputs) op_inputs = list(op.inputs)
constants = {} constants = {}

View File

@ -16,12 +16,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import combinations from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -66,5 +70,32 @@ class TrackableWeightHandlerTest(keras_parameterized.TestCase):
_ = backend.batch_get_value(table_handler.get_tensors()) _ = backend.batch_get_value(table_handler.get_tensors())
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class OpLayerTest(keras_parameterized.TestCase):
def test_tensor_op_layer(self):
int_values = keras.Input(shape=(2,), dtype=dtypes.int32)
float_values = math_ops.cast(int_values, dtypes.float32)
model = keras.Model(int_values, float_values)
model.compile(loss='mse')
input_data = np.array([[1, 2], [3, 4]], dtype=np.int32)
expected = [[1.0, 2.0], [3.0, 4.0]]
output = model.predict(input_data)
self.assertAllClose(expected, output)
def test_ragged_op_layer(self):
with self.assertRaisesRegexp(ValueError, 'Keras automatic op wrapping'):
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
float_values = math_ops.cast(int_values, dtypes.float32)
_ = keras.Model(int_values, float_values)
def test_sparse_op_layer(self):
with self.assertRaisesRegexp(ValueError, 'Keras automatic op wrapping'):
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, sparse=True)
float_values = math_ops.cast(int_values, dtypes.float32)
_ = keras.Model(int_values, float_values)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -1531,8 +1531,7 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
output = sparse_ops.sparse_minimum(inputs, inputs) output = sparse_ops.sparse_minimum(inputs, inputs)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
'Sparse ops are not supported with functional models with built-in ' 'not supported by Keras automatic op wrapping'
'layer wrapping'
): ):
training_module.Model([inputs], output) training_module.Model([inputs], output)