From 95776148da661b5e57c41cba3b9ef9546aab0ac2 Mon Sep 17 00:00:00 2001 From: Tomer Kaftan Date: Tue, 6 Oct 2020 11:30:10 -0700 Subject: [PATCH] Add internal registry for special-casing different ExtensionTypes with different KerasTensor subclasses, so that we can support unique instance methods/properties / shape/dtype inference of different (composite)tensor types when the generic approach is insufficient. PiperOrigin-RevId: 335683002 Change-Id: Ic05cfc7c4109176d249455edc78db2a950575820 --- tensorflow/python/keras/backend.py | 9 +- tensorflow/python/keras/engine/BUILD | 16 + .../python/keras/engine/keras_tensor.py | 381 ++++++++++++------ .../keras/engine/ragged_keras_tensor_test.py | 96 +++++ .../keras/layers/tensorflow_op_layer_test.py | 6 +- tensorflow/python/keras/utils/tf_utils.py | 4 + 6 files changed, 372 insertions(+), 140 deletions(-) create mode 100644 tensorflow/python/keras/engine/ragged_keras_tensor_test.py diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index b8ae91dff02..05391ea02fc 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -1207,11 +1207,10 @@ def placeholder(shape=None, if ndim: shape = (None,) * ndim if keras_tensor.keras_tensors_enabled(): - spec = tensor_spec.TensorSpec( - shape=shape, dtype=dtype, name=name) if sparse: spec = sparse_tensor.SparseTensorSpec( shape=shape, dtype=dtype) + x = keras_tensor.SparseKerasTensor(spec, name=name) elif ragged: ragged_rank = 0 for i in range(1, len(shape)): @@ -1224,7 +1223,11 @@ def placeholder(shape=None, spec = ragged_tensor.RaggedTensorSpec( shape=shape, dtype=dtype, ragged_rank=ragged_rank) - x = keras_tensor.KerasTensor(spec, name=name) + x = keras_tensor.RaggedKerasTensor(spec, name=name) + else: + spec = tensor_spec.TensorSpec( + shape=shape, dtype=dtype, name=name) + x = keras_tensor.KerasTensor(spec, name=name) else: with get_graph().as_default(): if sparse: diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 470404acd31..88c2023bd68 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -348,6 +348,22 @@ tf_py_test( ], ) +tf_py_test( + name = "ragged_keras_tensor_test", + size = "small", + srcs = ["ragged_keras_tensor_test.py"], + python_version = "PY3", + tags = [ + "nomac", # TODO(mihaimaruseac): b/127695564 + ], + tfrt_enabled = True, + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python/keras", + "@absl_py//absl/testing:parameterized", + ], +) + tf_py_test( name = "input_spec_test", size = "small", diff --git a/tensorflow/python/keras/engine/keras_tensor.py b/tensorflow/python/keras/engine/keras_tensor.py index 3aa9b595d4f..0c2c8bfc44d 100644 --- a/tensorflow/python/keras/engine/keras_tensor.py +++ b/tensorflow/python/keras/engine/keras_tensor.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -25,6 +26,8 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec as type_spec_module from tensorflow.python.ops import array_ops +from tensorflow.python.ops.ragged import ragged_operators # pylint: disable=unused-import +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import nest from tensorflow.python.util import object_identity @@ -50,6 +53,13 @@ def keras_tensors_enabled(): return _KERAS_TENSORS_ENABLED and ops.executing_eagerly_outside_functions() +# Tensorflow tensors have a maximum rank of 254 +# (See `MaxDimensions()` in //tensorflow/core/framework/tensor_shape.h ) +# So we do not try to infer values for int32 tensors larger than this, +# As they cannot represent shapes. +_MAX_TENSOR_RANK = 254 + + class KerasTensor(object): """A representation of a Keras in/output during Functional API construction. @@ -153,6 +163,75 @@ class KerasTensor(object): # it can't access shape or dtype return self._type_spec._shape # pylint: disable=protected-access + @classmethod + def from_tensor(cls, tensor): + """Convert a traced (composite)tensor to a representative KerasTensor.""" + if isinstance(tensor, ops.Tensor): + name = getattr(tensor, 'name', None) + type_spec = type_spec_module.type_spec_from_value(tensor) + inferred_value = None + if (type_spec.dtype == dtypes.int32 and type_spec.shape.rank < 2): + # If this tensor might be representing shape information, + # (dtype=int32, rank of 0 or 1, not too large to represent a shape) + # we attempt to capture any value information tensorflow's + # shape handling can extract from the current scratch graph. + # + # Even though keras layers each trace in their own scratch + # graph, this shape value info extraction allows us to capture + # a sizable and useful subset of the C++ shape value inference TF can do + # if all tf ops appear in the same graph when using shape ops. + # + # Examples of things this cannot infer concrete dimensions for + # that the full single-graph C++ shape inference sometimes can are: + # * cases where the shape tensor is cast out of int32 before being + # manipulated w/ floating point numbers then converted back + # * cases where int32 tensors w/ rank >= 2 are manipulated before being + # used as a shape tensor + # * cases where int32 tensors too large to represent shapes are + # manipulated to a smaller size before being used as a shape tensor + inferred_value = array_ops.ones(shape=tensor).shape + if inferred_value.dims: + inferred_value = inferred_value.as_list() + if len(inferred_value) > _MAX_TENSOR_RANK: + inferred_value = None + else: + inferred_value = None + + return KerasTensor(type_spec, inferred_value=inferred_value, name=name) + else: + # Fallback to the generic arbitrary-typespec KerasTensor + name = getattr(tensor, 'name', None) + type_spec = type_spec_module.type_spec_from_value(tensor) + return cls(type_spec, name=name) + + def _to_placeholder(self): + """Convert this KerasTensor to a placeholder in a graph.""" + # If there is an inferred value for this tensor, inject the inferred value + if self._inferred_value is not None: + # If we suspect this KerasTensor might be representing a shape tensor, + # and we were able to extract value information with TensorFlow's shape + # handling when making the KerasTensor, we construct the placeholder by + # re-injecting the inferred value information into the graph. We + # do this injection through the shape of a placeholder, because that + # allows us to specify partially-unspecified shape values. + # + # See the comment on value extraction inside `from_tensor` for more info. + inferred_value = array_ops.shape( + array_ops.placeholder( + shape=self._inferred_value, dtype=dtypes.int32)) + if self.type_spec.shape.rank == 0: + # `tf.shape` always returns a rank-1, we may need to turn it back to a + # scalar. + inferred_value = inferred_value[0] + return inferred_value + + # Use the generic conversion from typespec to a placeholder. + def component_to_placeholder(component): + return array_ops.placeholder(component.dtype, component.shape) + + return nest.map_structure( + component_to_placeholder, self.type_spec, expand_composites=True) + def get_shape(self): return self.shape @@ -298,26 +377,27 @@ class KerasTensor(object): return self._name @classmethod - def _overload_all_operators(cls): # pylint: disable=invalid-name + def _overload_all_operators(cls, tensor_class): # pylint: disable=invalid-name """Register overloads for all operators.""" for operator in ops.Tensor.OVERLOADABLE_OPERATORS: - cls._overload_operator(operator) + cls._overload_operator(tensor_class, operator) # We include `experimental_ref` for versions of TensorFlow that # still include the deprecated method in Tensors. - if hasattr(ops.Tensor, 'experimental_ref'): - cls._overload_operator('experimental_ref') + if hasattr(tensor_class, 'experimental_ref'): + cls._overload_operator(tensor_class, 'experimental_ref') @classmethod - def _overload_operator(cls, operator): # pylint: disable=invalid-name - """Overload an operator with the same overloading as `ops.Tensor`. + def _overload_operator(cls, tensor_class, operator): # pylint: disable=invalid-name + """Overload an operator with the same implementation as a base Tensor class. - We pull the operator out of ops.Tensor dynamically to avoid ordering issues. + We pull the operator out of the class dynamically to avoid ordering issues. Args: + tensor_class: The (Composite)Tensor to get the method from. operator: string. The operator name. """ - tensor_oper = getattr(ops.Tensor, operator) + tensor_oper = getattr(tensor_class, operator) # Compatibility with Python 2: # Python 2 unbound methods have type checks for the first arg, @@ -327,81 +407,91 @@ class KerasTensor(object): setattr(cls, operator, tensor_oper) -KerasTensor._overload_all_operators() # pylint: disable=protected-access +KerasTensor._overload_all_operators(ops.Tensor) # pylint: disable=protected-access -class _KerasTensorIterator(object): - """Iterates over the leading dim of a KerasTensor. Performs 0 error checks.""" +class SparseKerasTensor(KerasTensor): + """A specialized KerasTensor representation for `tf.sparse.SparseTensor`s. - def __init__(self, tensor, dim0): - self._tensor = tensor - self._index = 0 - self._limit = dim0 + Specifically, it specializes the conversion to a placeholder in order + to maintain dense shape information. + """ - def __iter__(self): - return self + def _to_placeholder(self): + spec = self.type_spec - def __next__(self): - if self._index == self._limit: - raise StopIteration - result = self._tensor[self._index] - self._index += 1 + # nest.map_structure loses dense shape information for sparse tensors. + # So, we special-case sparse placeholder creation. + # This only preserves shape information for top-level sparse tensors; + # not for sparse tensors that are nested inside another composite + # tensor. + return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape) + + +class RaggedKerasTensor(KerasTensor): + """A specialized KerasTensor representation for `tf.RaggedTensor`s. + + Specifically, it: + + 1. Specializes the conversion to a placeholder in order + to maintain shape information for non-ragged dimensions. + 2. Overloads the KerasTensor's operators with the RaggedTensor versions + when they don't match the `tf.Tensor` versions + 3. Exposes some of the instance method/attribute that are unique to + the RaggedTensor API (such as ragged_rank). + """ + + def _to_placeholder(self): + ragged_spec = self.type_spec + if ragged_spec.ragged_rank == 0 or ragged_spec.shape.rank is None: + return super(RaggedKerasTensor, self)._to_placeholder() + + flat_shape = ragged_spec.shape[ragged_spec.ragged_rank:] + result = array_ops.placeholder(ragged_spec.dtype, flat_shape) + + known_num_splits = [] + prod = 1 + for axis_size in ragged_spec.shape: + if prod is not None: + if axis_size is None or ( + getattr(axis_size, 'value', True) is None): + prod = None + else: + prod = prod * axis_size + known_num_splits.append(prod) + + for axis in range(ragged_spec.ragged_rank, 0, -1): + axis_size = ragged_spec.shape[axis] + if axis_size is None or (getattr(axis_size, 'value', True) is None): + num_splits = known_num_splits[axis-1] + if num_splits is not None: + num_splits = num_splits + 1 + splits = array_ops.placeholder( + ragged_spec.row_splits_dtype, [num_splits]) + result = ragged_tensor.RaggedTensor.from_row_splits( + result, splits, validate=False) + else: + rowlen = constant_op.constant(axis_size, ragged_spec.row_splits_dtype) + result = ragged_tensor.RaggedTensor.from_uniform_row_length( + result, rowlen, validate=False) return result - next = __next__ # python2.x compatibility. - - -def keras_tensor_to_placeholder(x): - """Construct a graph placeholder to represent a KerasTensor when tracing.""" - if hasattr(x, '_user_registered_symbolic_object'): - return x._user_registered_symbolic_object # pylint: disable=protected-access - - if isinstance(x, KerasTensor): - spec = x.type_spec - - if x._inferred_value is not None: # pylint: disable=protected-access - # If we suspect this KerasTensor might be representing a shape tensor, - # and we were able to extract value information with TensorFlow's shape - # handling when making the KerasTensor, we construct the placeholder by - # re-injecting the inferred value information into the graph. - # Even though keras layers each trace in their own scratch - # graph, this shape value info injection allows us to capture - # a sizable and useful subset of the C++ shape value inference TF can do - # if all tf ops appear in the same graph when using shape ops. - # - # Examples of things this cannot infer concrete dimensions for - # that the full single-graph C++ shape inference sometimes can are: - # * cases where the shape tensor is cast out of int32 before being - # manipulated w/ floating point numbers then converted back - # * cases where int32 tensors w/ rank > 2 are manipulated before being - # used as a shape tensor - inferred_value = array_ops.shape( - array_ops.placeholder( - shape=x._inferred_value, dtype=dtypes.int32)) # pylint: disable=protected-access - if spec.shape.rank == 0: - # `tf.shape` always returns a rank-1, we may need to turn it back to a - # scalar. - inferred_value = inferred_value[0] - return inferred_value # pylint: disable=protected-access - - if isinstance(spec, sparse_tensor.SparseTensorSpec): - # nest.map_structure loses dense shape information for sparse tensors. - # So, we special-case sparse placeholder creation. - # This only preserves shape information for top-level sparse tensors; - # not for sparse tensors that are nested inside another composite - # tensor. - return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape) - - def component_to_placeholder(component): - return array_ops.placeholder(component.dtype, component.shape) - - ph = nest.map_structure( - component_to_placeholder, spec, expand_composites=True) - return ph - else: - return x + @property + def ragged_rank(self): + return self.type_spec.ragged_rank + +RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__add__') # pylint: disable=protected-access +RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__radd__') # pylint: disable=protected-access +RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__mul__') # pylint: disable=protected-access +RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__rmul__') # pylint: disable=protected-access +# TODO(b/161487382): +# Special-case user-registered symbolic objects (registered by the +# private `register_symbolic_tensor_type` method) by passing them between +# scratch graphs directly. +# This is needed to not break Tensorflow probability +# while they finish migrating to composite tensors. class UserRegisteredSpec(type_spec_module.TypeSpec): """TypeSpec to represent user-registered symbolic objects.""" @@ -425,72 +515,95 @@ class UserRegisteredSpec(type_spec_module.TypeSpec): def value_type(self): raise NotImplementedError -# Tensorflow tensors have a maximum dimension of 254 -# (See //tensorflow/core/framework/tensor_shape.h ) -# So we do not try to infer values for int32 tensors larger than this, -# As they cannot represent shapes. -_MAX_TENSOR_DIMS = 254 +# TODO(b/161487382): +# Special-case user-registered symbolic objects (registered by the +# private `register_symbolic_tensor_type` method) by passing them between +# scratch graphs directly. +# This is needed to not break Tensorflow probability +# while they finish migrating to composite tensors. +class UserRegisteredTypeKerasTensor(KerasTensor): + """KerasTensor that represents legacy register_symbolic_tensor_type.""" -def keras_tensor_from_tensor(x): - """Convert a traced (composite)tensor to a representative KerasTensor.""" - name = getattr(x, 'name', None) - inferred_value = None - - # TODO(b/161487382): - # Special-case user-registered symbolic objects (registered by the - # private `register_symbolic_tensor_type` method) by passing them between - # scratch graphs directly. - # This is needed to not break Tensorflow probability - # while they finish migrating to composite tensors. - user_registered_symbolic = False - try: - from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top to prevent circular imports - if isinstance(x, tuple(tf_utils._user_convertible_tensor_types)): # pylint: disable=protected-access - user_registered_symbolic = True - except ImportError: - pass - if user_registered_symbolic: + def __init__(self, user_registered_symbolic_object): + x = user_registered_symbolic_object + self._user_registered_symbolic_object = x type_spec = UserRegisteredSpec(x.shape, x.dtype) + name = getattr(x, 'name', None) + + super(UserRegisteredTypeKerasTensor, self).__init__(type_spec, name) + + @classmethod + def from_tensor(cls, tensor): + return cls(tensor) + + def _to_placeholder(self): + return self._user_registered_symbolic_object + + +class _KerasTensorIterator(object): + """Iterates over the leading dim of a KerasTensor. Performs 0 error checks.""" + + def __init__(self, tensor, dim0): + self._tensor = tensor + self._index = 0 + self._limit = dim0 + + def __iter__(self): + return self + + def __next__(self): + if self._index == self._limit: + raise StopIteration + result = self._tensor[self._index] + self._index += 1 + return result + + next = __next__ # python2.x compatibility. + + +# Specify the mappings of tensor class to KerasTensor class. +# This is specifically a list instead of a dict for now because +# 1. we do a check w/ isinstance because a key lookup based on class +# would miss subclasses +# 2. a list allows us to control lookup ordering +# We include ops.Tensor -> KerasTensor in the first position as a fastpath, +# *and* include object -> KerasTensor at the end as a catch-all. +# We can re-visit these choices in the future as needed. +keras_tensor_classes = [ + (ops.Tensor, KerasTensor), + (sparse_tensor.SparseTensor, SparseKerasTensor), + (ragged_tensor.RaggedTensor, RaggedKerasTensor), + (object, KerasTensor) +] + + +def register_keras_tensor_specialization(cls, keras_tensor_subclass): + """Register a specialized KerasTensor subclass for a Tensor type.""" + # We always leave (object, KerasTensor) at the end as a generic fallback + keras_tensor_classes.insert(-1, (cls, keras_tensor_subclass)) + + +def keras_tensor_to_placeholder(x): + """Construct a graph placeholder to represent a KerasTensor when tracing.""" + if isinstance(x, KerasTensor): + return x._to_placeholder() # pylint: disable=protected-access else: - type_spec = type_spec_module.type_spec_from_value(x) + return x - if (isinstance(type_spec, tensor_spec.TensorSpec) - and type_spec.dtype == dtypes.int32 - and type_spec.shape.rank < 2): - # If this tensor might be representing shape information, - # (dtype=int32, rank of 0 or 1, not too large to represent a shape) - # we attempt to capture any value information tensorflow's - # shape handling can extract from the current scratch graph. - # - # Even though keras layers each trace in their own scratch - # graph, this shape value info extraction allows us to capture - # a sizable and useful subset of the C++ shape value inference TF can do - # if all tf ops appear in the same graph when using shape ops. - # - # Examples of things this cannot infer concrete dimensions for - # that the full single-graph C++ shape inference sometimes can are: - # * cases where the shape tensor is cast out of int32 before being - # manipulated w/ floating point numbers then converted back - # * cases where int32 tensors w/ rank > 2 are manipulated before being - # used as a shape tensor - # * cases where int32 tensors too large to represent shapes are manipulated - # to a smaller size before being used as a shape tensor - inferred_value = array_ops.ones(shape=x).shape - if inferred_value.dims: - inferred_value = inferred_value.as_list() - if len(inferred_value) > _MAX_TENSOR_DIMS: - inferred_value = None - else: - inferred_value = None - out = KerasTensor(type_spec, - inferred_value=inferred_value, name=name) - if user_registered_symbolic: - out._user_registered_symbolic_object = x # pylint: disable=protected-access +def keras_tensor_from_tensor(tensor): + """Convert a traced (composite)tensor to a representative KerasTensor.""" + # Create a specialized KerasTensor that supports instance methods, + # operators, and additional value inference if possible + keras_tensor_cls = None + for tensor_type, cls in keras_tensor_classes: + if isinstance(tensor, tensor_type): + keras_tensor_cls = cls + break - if hasattr(x, '_keras_mask'): - out._keras_mask = KerasTensor( # pylint: disable=protected-access - type_spec_module.type_spec_from_value(x._keras_mask)) # pylint: disable=protected-access + out = keras_tensor_cls.from_tensor(tensor) + if hasattr(tensor, '_keras_mask'): + out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask) # pylint: disable=protected-access return out diff --git a/tensorflow/python/keras/engine/ragged_keras_tensor_test.py b/tensorflow/python/keras/engine/ragged_keras_tensor_test.py new file mode 100644 index 00000000000..92abdc82240 --- /dev/null +++ b/tensorflow/python/keras/engine/ragged_keras_tensor_test.py @@ -0,0 +1,96 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RaggedKerasTensor tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.framework import func_graph +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras import layers +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import training +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.platform import test + + +class RaggedKerasTensorTest(keras_parameterized.TestCase): + + @parameterized.parameters( + {'batch_size': None, 'shape': (None, 5), 'ragged_rank': 1}, + {'batch_size': None, 'shape': (None, 3, 5), 'ragged_rank': 1}, + {'batch_size': None, 'shape': (5, None), 'ragged_rank': 2}, + {'batch_size': None, 'shape': (3, 5, None), 'ragged_rank': 3}, + {'batch_size': None, 'shape': (None, 3, 5, None), 'ragged_rank': 4}, + {'batch_size': None, 'shape': (2, 3, None, 4, 5, None), 'ragged_rank': 6}, + {'batch_size': 8, 'shape': (None, 5), 'ragged_rank': 1}, + {'batch_size': 9, 'shape': (None, 3, 5), 'ragged_rank': 1}, + {'batch_size': 1, 'shape': (5, None), 'ragged_rank': 2}, + {'batch_size': 4, 'shape': (3, 5, None), 'ragged_rank': 3}, + {'batch_size': 7, 'shape': (None, 3, 5, None), 'ragged_rank': 4}, + {'batch_size': 12, 'shape': (2, 3, None, 4, 5, None), 'ragged_rank': 6}, + ) + def test_to_placeholder(self, shape, batch_size, ragged_rank): + with testing_utils.use_keras_tensors_scope(True): + inp = layers.Input(shape=shape, batch_size=batch_size, ragged=True) + self.assertEqual(inp.ragged_rank, ragged_rank) + self.assertAllEqual(inp.shape, [batch_size] + list(shape)) + with func_graph.FuncGraph('test').as_default(): + placeholder = inp._to_placeholder() + self.assertEqual(placeholder.ragged_rank, ragged_rank) + self.assertAllEqual(placeholder.shape, [batch_size] + list(shape)) + + def test_add(self): + inp = layers.Input(shape=[None], ragged=True) + out = inp + inp + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + self.assertAllEqual(model(x), x + x) + + def test_mul(self): + inp = layers.Input(shape=[None], ragged=True) + out = inp * inp + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + self.assertAllEqual(model(x), x * x) + + def test_sub(self): + inp = layers.Input(shape=[None], ragged=True) + out = inp - inp + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + self.assertAllEqual(model(x), x - x) + + def test_div(self): + inp = layers.Input(shape=[None], ragged=True) + out = inp / inp + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + self.assertAllEqual(model(x), x / x) + + +if __name__ == '__main__': + ops.enable_eager_execution() + tensor_shape.enable_v2_tensorshape() + test.main() diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py index bde6dd137d7..b89bafde8d2 100644 --- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py +++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py @@ -158,15 +158,15 @@ def _int32_manipulation_at_max_shape_dims_limit(): # of the max tensor size Keras can try inferring values for. inputs = keras.Input(batch_size=2, shape=(10,)) batch_size = array_ops.shape(inputs)[0] - num_features = int(keras_tensor._MAX_TENSOR_DIMS / int(inputs.shape[0])) + num_features = int(keras_tensor._MAX_TENSOR_RANK / int(inputs.shape[0])) x = math_ops.range(batch_size * num_features, dtype='int32') - assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_DIMS] + assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_RANK] # Verify that a value was actually inferred for a tensor that *might* # represent the shape, bying checking that a value in # the range appears in the printed inferred value if keras_tensor.keras_tensors_enabled(): - assert str(keras_tensor._MAX_TENSOR_DIMS - 1) in str(x) + assert str(keras_tensor._MAX_TENSOR_RANK - 1) in str(x) x = array_ops.reshape(x, (batch_size, num_features)) x = math_ops.cast(x, dtype='float32') diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 98971b3d1c3..3515dcc87a1 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine import keras_tensor from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables @@ -363,6 +364,9 @@ def register_symbolic_tensor_type(cls): cls: A `class` type which shall be regarded as a symbolic `Tensor`. """ global _user_convertible_tensor_types + if cls not in _user_convertible_tensor_types: + keras_tensor.register_keras_tensor_specialization( + cls, keras_tensor.UserRegisteredTypeKerasTensor) _user_convertible_tensor_types.add(cls)