Add internal registry for special-casing different ExtensionTypes with different KerasTensor subclasses, so that we can support unique instance methods/properties / shape/dtype inference of different (composite)tensor types when the generic approach is insufficient.

PiperOrigin-RevId: 335683002
Change-Id: Ic05cfc7c4109176d249455edc78db2a950575820
This commit is contained in:
Tomer Kaftan 2020-10-06 11:30:10 -07:00 committed by TensorFlower Gardener
parent 31029280a2
commit 95776148da
6 changed files with 372 additions and 140 deletions

View File

@ -1207,11 +1207,10 @@ def placeholder(shape=None,
if ndim:
shape = (None,) * ndim
if keras_tensor.keras_tensors_enabled():
spec = tensor_spec.TensorSpec(
shape=shape, dtype=dtype, name=name)
if sparse:
spec = sparse_tensor.SparseTensorSpec(
shape=shape, dtype=dtype)
x = keras_tensor.SparseKerasTensor(spec, name=name)
elif ragged:
ragged_rank = 0
for i in range(1, len(shape)):
@ -1224,7 +1223,11 @@ def placeholder(shape=None,
spec = ragged_tensor.RaggedTensorSpec(
shape=shape, dtype=dtype, ragged_rank=ragged_rank)
x = keras_tensor.KerasTensor(spec, name=name)
x = keras_tensor.RaggedKerasTensor(spec, name=name)
else:
spec = tensor_spec.TensorSpec(
shape=shape, dtype=dtype, name=name)
x = keras_tensor.KerasTensor(spec, name=name)
else:
with get_graph().as_default():
if sparse:

View File

@ -348,6 +348,22 @@ tf_py_test(
],
)
tf_py_test(
name = "ragged_keras_tensor_test",
size = "small",
srcs = ["ragged_keras_tensor_test.py"],
python_version = "PY3",
tags = [
"nomac", # TODO(mihaimaruseac): b/127695564
],
tfrt_enabled = True,
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "input_spec_test",
size = "small",

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@ -25,6 +26,8 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec as type_spec_module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_operators # pylint: disable=unused-import
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
@ -50,6 +53,13 @@ def keras_tensors_enabled():
return _KERAS_TENSORS_ENABLED and ops.executing_eagerly_outside_functions()
# Tensorflow tensors have a maximum rank of 254
# (See `MaxDimensions()` in //tensorflow/core/framework/tensor_shape.h )
# So we do not try to infer values for int32 tensors larger than this,
# As they cannot represent shapes.
_MAX_TENSOR_RANK = 254
class KerasTensor(object):
"""A representation of a Keras in/output during Functional API construction.
@ -153,6 +163,75 @@ class KerasTensor(object):
# it can't access shape or dtype
return self._type_spec._shape # pylint: disable=protected-access
@classmethod
def from_tensor(cls, tensor):
"""Convert a traced (composite)tensor to a representative KerasTensor."""
if isinstance(tensor, ops.Tensor):
name = getattr(tensor, 'name', None)
type_spec = type_spec_module.type_spec_from_value(tensor)
inferred_value = None
if (type_spec.dtype == dtypes.int32 and type_spec.shape.rank < 2):
# If this tensor might be representing shape information,
# (dtype=int32, rank of 0 or 1, not too large to represent a shape)
# we attempt to capture any value information tensorflow's
# shape handling can extract from the current scratch graph.
#
# Even though keras layers each trace in their own scratch
# graph, this shape value info extraction allows us to capture
# a sizable and useful subset of the C++ shape value inference TF can do
# if all tf ops appear in the same graph when using shape ops.
#
# Examples of things this cannot infer concrete dimensions for
# that the full single-graph C++ shape inference sometimes can are:
# * cases where the shape tensor is cast out of int32 before being
# manipulated w/ floating point numbers then converted back
# * cases where int32 tensors w/ rank >= 2 are manipulated before being
# used as a shape tensor
# * cases where int32 tensors too large to represent shapes are
# manipulated to a smaller size before being used as a shape tensor
inferred_value = array_ops.ones(shape=tensor).shape
if inferred_value.dims:
inferred_value = inferred_value.as_list()
if len(inferred_value) > _MAX_TENSOR_RANK:
inferred_value = None
else:
inferred_value = None
return KerasTensor(type_spec, inferred_value=inferred_value, name=name)
else:
# Fallback to the generic arbitrary-typespec KerasTensor
name = getattr(tensor, 'name', None)
type_spec = type_spec_module.type_spec_from_value(tensor)
return cls(type_spec, name=name)
def _to_placeholder(self):
"""Convert this KerasTensor to a placeholder in a graph."""
# If there is an inferred value for this tensor, inject the inferred value
if self._inferred_value is not None:
# If we suspect this KerasTensor might be representing a shape tensor,
# and we were able to extract value information with TensorFlow's shape
# handling when making the KerasTensor, we construct the placeholder by
# re-injecting the inferred value information into the graph. We
# do this injection through the shape of a placeholder, because that
# allows us to specify partially-unspecified shape values.
#
# See the comment on value extraction inside `from_tensor` for more info.
inferred_value = array_ops.shape(
array_ops.placeholder(
shape=self._inferred_value, dtype=dtypes.int32))
if self.type_spec.shape.rank == 0:
# `tf.shape` always returns a rank-1, we may need to turn it back to a
# scalar.
inferred_value = inferred_value[0]
return inferred_value
# Use the generic conversion from typespec to a placeholder.
def component_to_placeholder(component):
return array_ops.placeholder(component.dtype, component.shape)
return nest.map_structure(
component_to_placeholder, self.type_spec, expand_composites=True)
def get_shape(self):
return self.shape
@ -298,26 +377,27 @@ class KerasTensor(object):
return self._name
@classmethod
def _overload_all_operators(cls): # pylint: disable=invalid-name
def _overload_all_operators(cls, tensor_class): # pylint: disable=invalid-name
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
cls._overload_operator(operator)
cls._overload_operator(tensor_class, operator)
# We include `experimental_ref` for versions of TensorFlow that
# still include the deprecated method in Tensors.
if hasattr(ops.Tensor, 'experimental_ref'):
cls._overload_operator('experimental_ref')
if hasattr(tensor_class, 'experimental_ref'):
cls._overload_operator(tensor_class, 'experimental_ref')
@classmethod
def _overload_operator(cls, operator): # pylint: disable=invalid-name
"""Overload an operator with the same overloading as `ops.Tensor`.
def _overload_operator(cls, tensor_class, operator): # pylint: disable=invalid-name
"""Overload an operator with the same implementation as a base Tensor class.
We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
We pull the operator out of the class dynamically to avoid ordering issues.
Args:
tensor_class: The (Composite)Tensor to get the method from.
operator: string. The operator name.
"""
tensor_oper = getattr(ops.Tensor, operator)
tensor_oper = getattr(tensor_class, operator)
# Compatibility with Python 2:
# Python 2 unbound methods have type checks for the first arg,
@ -327,81 +407,91 @@ class KerasTensor(object):
setattr(cls, operator, tensor_oper)
KerasTensor._overload_all_operators() # pylint: disable=protected-access
KerasTensor._overload_all_operators(ops.Tensor) # pylint: disable=protected-access
class _KerasTensorIterator(object):
"""Iterates over the leading dim of a KerasTensor. Performs 0 error checks."""
class SparseKerasTensor(KerasTensor):
"""A specialized KerasTensor representation for `tf.sparse.SparseTensor`s.
def __init__(self, tensor, dim0):
self._tensor = tensor
self._index = 0
self._limit = dim0
Specifically, it specializes the conversion to a placeholder in order
to maintain dense shape information.
"""
def __iter__(self):
return self
def _to_placeholder(self):
spec = self.type_spec
def __next__(self):
if self._index == self._limit:
raise StopIteration
result = self._tensor[self._index]
self._index += 1
# nest.map_structure loses dense shape information for sparse tensors.
# So, we special-case sparse placeholder creation.
# This only preserves shape information for top-level sparse tensors;
# not for sparse tensors that are nested inside another composite
# tensor.
return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape)
class RaggedKerasTensor(KerasTensor):
"""A specialized KerasTensor representation for `tf.RaggedTensor`s.
Specifically, it:
1. Specializes the conversion to a placeholder in order
to maintain shape information for non-ragged dimensions.
2. Overloads the KerasTensor's operators with the RaggedTensor versions
when they don't match the `tf.Tensor` versions
3. Exposes some of the instance method/attribute that are unique to
the RaggedTensor API (such as ragged_rank).
"""
def _to_placeholder(self):
ragged_spec = self.type_spec
if ragged_spec.ragged_rank == 0 or ragged_spec.shape.rank is None:
return super(RaggedKerasTensor, self)._to_placeholder()
flat_shape = ragged_spec.shape[ragged_spec.ragged_rank:]
result = array_ops.placeholder(ragged_spec.dtype, flat_shape)
known_num_splits = []
prod = 1
for axis_size in ragged_spec.shape:
if prod is not None:
if axis_size is None or (
getattr(axis_size, 'value', True) is None):
prod = None
else:
prod = prod * axis_size
known_num_splits.append(prod)
for axis in range(ragged_spec.ragged_rank, 0, -1):
axis_size = ragged_spec.shape[axis]
if axis_size is None or (getattr(axis_size, 'value', True) is None):
num_splits = known_num_splits[axis-1]
if num_splits is not None:
num_splits = num_splits + 1
splits = array_ops.placeholder(
ragged_spec.row_splits_dtype, [num_splits])
result = ragged_tensor.RaggedTensor.from_row_splits(
result, splits, validate=False)
else:
rowlen = constant_op.constant(axis_size, ragged_spec.row_splits_dtype)
result = ragged_tensor.RaggedTensor.from_uniform_row_length(
result, rowlen, validate=False)
return result
next = __next__ # python2.x compatibility.
def keras_tensor_to_placeholder(x):
"""Construct a graph placeholder to represent a KerasTensor when tracing."""
if hasattr(x, '_user_registered_symbolic_object'):
return x._user_registered_symbolic_object # pylint: disable=protected-access
if isinstance(x, KerasTensor):
spec = x.type_spec
if x._inferred_value is not None: # pylint: disable=protected-access
# If we suspect this KerasTensor might be representing a shape tensor,
# and we were able to extract value information with TensorFlow's shape
# handling when making the KerasTensor, we construct the placeholder by
# re-injecting the inferred value information into the graph.
# Even though keras layers each trace in their own scratch
# graph, this shape value info injection allows us to capture
# a sizable and useful subset of the C++ shape value inference TF can do
# if all tf ops appear in the same graph when using shape ops.
#
# Examples of things this cannot infer concrete dimensions for
# that the full single-graph C++ shape inference sometimes can are:
# * cases where the shape tensor is cast out of int32 before being
# manipulated w/ floating point numbers then converted back
# * cases where int32 tensors w/ rank > 2 are manipulated before being
# used as a shape tensor
inferred_value = array_ops.shape(
array_ops.placeholder(
shape=x._inferred_value, dtype=dtypes.int32)) # pylint: disable=protected-access
if spec.shape.rank == 0:
# `tf.shape` always returns a rank-1, we may need to turn it back to a
# scalar.
inferred_value = inferred_value[0]
return inferred_value # pylint: disable=protected-access
if isinstance(spec, sparse_tensor.SparseTensorSpec):
# nest.map_structure loses dense shape information for sparse tensors.
# So, we special-case sparse placeholder creation.
# This only preserves shape information for top-level sparse tensors;
# not for sparse tensors that are nested inside another composite
# tensor.
return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape)
def component_to_placeholder(component):
return array_ops.placeholder(component.dtype, component.shape)
ph = nest.map_structure(
component_to_placeholder, spec, expand_composites=True)
return ph
else:
return x
@property
def ragged_rank(self):
return self.type_spec.ragged_rank
RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__add__') # pylint: disable=protected-access
RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__radd__') # pylint: disable=protected-access
RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__mul__') # pylint: disable=protected-access
RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__rmul__') # pylint: disable=protected-access
# TODO(b/161487382):
# Special-case user-registered symbolic objects (registered by the
# private `register_symbolic_tensor_type` method) by passing them between
# scratch graphs directly.
# This is needed to not break Tensorflow probability
# while they finish migrating to composite tensors.
class UserRegisteredSpec(type_spec_module.TypeSpec):
"""TypeSpec to represent user-registered symbolic objects."""
@ -425,72 +515,95 @@ class UserRegisteredSpec(type_spec_module.TypeSpec):
def value_type(self):
raise NotImplementedError
# Tensorflow tensors have a maximum dimension of 254
# (See //tensorflow/core/framework/tensor_shape.h )
# So we do not try to infer values for int32 tensors larger than this,
# As they cannot represent shapes.
_MAX_TENSOR_DIMS = 254
# TODO(b/161487382):
# Special-case user-registered symbolic objects (registered by the
# private `register_symbolic_tensor_type` method) by passing them between
# scratch graphs directly.
# This is needed to not break Tensorflow probability
# while they finish migrating to composite tensors.
class UserRegisteredTypeKerasTensor(KerasTensor):
"""KerasTensor that represents legacy register_symbolic_tensor_type."""
def keras_tensor_from_tensor(x):
"""Convert a traced (composite)tensor to a representative KerasTensor."""
name = getattr(x, 'name', None)
inferred_value = None
# TODO(b/161487382):
# Special-case user-registered symbolic objects (registered by the
# private `register_symbolic_tensor_type` method) by passing them between
# scratch graphs directly.
# This is needed to not break Tensorflow probability
# while they finish migrating to composite tensors.
user_registered_symbolic = False
try:
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top to prevent circular imports
if isinstance(x, tuple(tf_utils._user_convertible_tensor_types)): # pylint: disable=protected-access
user_registered_symbolic = True
except ImportError:
pass
if user_registered_symbolic:
def __init__(self, user_registered_symbolic_object):
x = user_registered_symbolic_object
self._user_registered_symbolic_object = x
type_spec = UserRegisteredSpec(x.shape, x.dtype)
name = getattr(x, 'name', None)
super(UserRegisteredTypeKerasTensor, self).__init__(type_spec, name)
@classmethod
def from_tensor(cls, tensor):
return cls(tensor)
def _to_placeholder(self):
return self._user_registered_symbolic_object
class _KerasTensorIterator(object):
"""Iterates over the leading dim of a KerasTensor. Performs 0 error checks."""
def __init__(self, tensor, dim0):
self._tensor = tensor
self._index = 0
self._limit = dim0
def __iter__(self):
return self
def __next__(self):
if self._index == self._limit:
raise StopIteration
result = self._tensor[self._index]
self._index += 1
return result
next = __next__ # python2.x compatibility.
# Specify the mappings of tensor class to KerasTensor class.
# This is specifically a list instead of a dict for now because
# 1. we do a check w/ isinstance because a key lookup based on class
# would miss subclasses
# 2. a list allows us to control lookup ordering
# We include ops.Tensor -> KerasTensor in the first position as a fastpath,
# *and* include object -> KerasTensor at the end as a catch-all.
# We can re-visit these choices in the future as needed.
keras_tensor_classes = [
(ops.Tensor, KerasTensor),
(sparse_tensor.SparseTensor, SparseKerasTensor),
(ragged_tensor.RaggedTensor, RaggedKerasTensor),
(object, KerasTensor)
]
def register_keras_tensor_specialization(cls, keras_tensor_subclass):
"""Register a specialized KerasTensor subclass for a Tensor type."""
# We always leave (object, KerasTensor) at the end as a generic fallback
keras_tensor_classes.insert(-1, (cls, keras_tensor_subclass))
def keras_tensor_to_placeholder(x):
"""Construct a graph placeholder to represent a KerasTensor when tracing."""
if isinstance(x, KerasTensor):
return x._to_placeholder() # pylint: disable=protected-access
else:
type_spec = type_spec_module.type_spec_from_value(x)
return x
if (isinstance(type_spec, tensor_spec.TensorSpec)
and type_spec.dtype == dtypes.int32
and type_spec.shape.rank < 2):
# If this tensor might be representing shape information,
# (dtype=int32, rank of 0 or 1, not too large to represent a shape)
# we attempt to capture any value information tensorflow's
# shape handling can extract from the current scratch graph.
#
# Even though keras layers each trace in their own scratch
# graph, this shape value info extraction allows us to capture
# a sizable and useful subset of the C++ shape value inference TF can do
# if all tf ops appear in the same graph when using shape ops.
#
# Examples of things this cannot infer concrete dimensions for
# that the full single-graph C++ shape inference sometimes can are:
# * cases where the shape tensor is cast out of int32 before being
# manipulated w/ floating point numbers then converted back
# * cases where int32 tensors w/ rank > 2 are manipulated before being
# used as a shape tensor
# * cases where int32 tensors too large to represent shapes are manipulated
# to a smaller size before being used as a shape tensor
inferred_value = array_ops.ones(shape=x).shape
if inferred_value.dims:
inferred_value = inferred_value.as_list()
if len(inferred_value) > _MAX_TENSOR_DIMS:
inferred_value = None
else:
inferred_value = None
out = KerasTensor(type_spec,
inferred_value=inferred_value, name=name)
if user_registered_symbolic:
out._user_registered_symbolic_object = x # pylint: disable=protected-access
def keras_tensor_from_tensor(tensor):
"""Convert a traced (composite)tensor to a representative KerasTensor."""
# Create a specialized KerasTensor that supports instance methods,
# operators, and additional value inference if possible
keras_tensor_cls = None
for tensor_type, cls in keras_tensor_classes:
if isinstance(tensor, tensor_type):
keras_tensor_cls = cls
break
if hasattr(x, '_keras_mask'):
out._keras_mask = KerasTensor( # pylint: disable=protected-access
type_spec_module.type_spec_from_value(x._keras_mask)) # pylint: disable=protected-access
out = keras_tensor_cls.from_tensor(tensor)
if hasattr(tensor, '_keras_mask'):
out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask) # pylint: disable=protected-access
return out

View File

@ -0,0 +1,96 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RaggedKerasTensor tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import layers
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
class RaggedKerasTensorTest(keras_parameterized.TestCase):
@parameterized.parameters(
{'batch_size': None, 'shape': (None, 5), 'ragged_rank': 1},
{'batch_size': None, 'shape': (None, 3, 5), 'ragged_rank': 1},
{'batch_size': None, 'shape': (5, None), 'ragged_rank': 2},
{'batch_size': None, 'shape': (3, 5, None), 'ragged_rank': 3},
{'batch_size': None, 'shape': (None, 3, 5, None), 'ragged_rank': 4},
{'batch_size': None, 'shape': (2, 3, None, 4, 5, None), 'ragged_rank': 6},
{'batch_size': 8, 'shape': (None, 5), 'ragged_rank': 1},
{'batch_size': 9, 'shape': (None, 3, 5), 'ragged_rank': 1},
{'batch_size': 1, 'shape': (5, None), 'ragged_rank': 2},
{'batch_size': 4, 'shape': (3, 5, None), 'ragged_rank': 3},
{'batch_size': 7, 'shape': (None, 3, 5, None), 'ragged_rank': 4},
{'batch_size': 12, 'shape': (2, 3, None, 4, 5, None), 'ragged_rank': 6},
)
def test_to_placeholder(self, shape, batch_size, ragged_rank):
with testing_utils.use_keras_tensors_scope(True):
inp = layers.Input(shape=shape, batch_size=batch_size, ragged=True)
self.assertEqual(inp.ragged_rank, ragged_rank)
self.assertAllEqual(inp.shape, [batch_size] + list(shape))
with func_graph.FuncGraph('test').as_default():
placeholder = inp._to_placeholder()
self.assertEqual(placeholder.ragged_rank, ragged_rank)
self.assertAllEqual(placeholder.shape, [batch_size] + list(shape))
def test_add(self):
inp = layers.Input(shape=[None], ragged=True)
out = inp + inp
model = training.Model(inp, out)
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
self.assertAllEqual(model(x), x + x)
def test_mul(self):
inp = layers.Input(shape=[None], ragged=True)
out = inp * inp
model = training.Model(inp, out)
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
self.assertAllEqual(model(x), x * x)
def test_sub(self):
inp = layers.Input(shape=[None], ragged=True)
out = inp - inp
model = training.Model(inp, out)
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
self.assertAllEqual(model(x), x - x)
def test_div(self):
inp = layers.Input(shape=[None], ragged=True)
out = inp / inp
model = training.Model(inp, out)
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
self.assertAllEqual(model(x), x / x)
if __name__ == '__main__':
ops.enable_eager_execution()
tensor_shape.enable_v2_tensorshape()
test.main()

View File

@ -158,15 +158,15 @@ def _int32_manipulation_at_max_shape_dims_limit():
# of the max tensor size Keras can try inferring values for.
inputs = keras.Input(batch_size=2, shape=(10,))
batch_size = array_ops.shape(inputs)[0]
num_features = int(keras_tensor._MAX_TENSOR_DIMS / int(inputs.shape[0]))
num_features = int(keras_tensor._MAX_TENSOR_RANK / int(inputs.shape[0]))
x = math_ops.range(batch_size * num_features, dtype='int32')
assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_DIMS]
assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_RANK]
# Verify that a value was actually inferred for a tensor that *might*
# represent the shape, bying checking that a value in
# the range appears in the printed inferred value
if keras_tensor.keras_tensors_enabled():
assert str(keras_tensor._MAX_TENSOR_DIMS - 1) in str(x)
assert str(keras_tensor._MAX_TENSOR_RANK - 1) in str(x)
x = array_ops.reshape(x, (batch_size, num_features))
x = math_ops.cast(x, dtype='float32')

View File

@ -30,6 +30,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
@ -363,6 +364,9 @@ def register_symbolic_tensor_type(cls):
cls: A `class` type which shall be regarded as a symbolic `Tensor`.
"""
global _user_convertible_tensor_types
if cls not in _user_convertible_tensor_types:
keras_tensor.register_keras_tensor_specialization(
cls, keras_tensor.UserRegisteredTypeKerasTensor)
_user_convertible_tensor_types.add(cls)