Add internal registry for special-casing different ExtensionTypes with different KerasTensor subclasses, so that we can support unique instance methods/properties / shape/dtype inference of different (composite)tensor types when the generic approach is insufficient.
PiperOrigin-RevId: 335683002 Change-Id: Ic05cfc7c4109176d249455edc78db2a950575820
This commit is contained in:
parent
31029280a2
commit
95776148da
tensorflow/python/keras
@ -1207,11 +1207,10 @@ def placeholder(shape=None,
|
||||
if ndim:
|
||||
shape = (None,) * ndim
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
spec = tensor_spec.TensorSpec(
|
||||
shape=shape, dtype=dtype, name=name)
|
||||
if sparse:
|
||||
spec = sparse_tensor.SparseTensorSpec(
|
||||
shape=shape, dtype=dtype)
|
||||
x = keras_tensor.SparseKerasTensor(spec, name=name)
|
||||
elif ragged:
|
||||
ragged_rank = 0
|
||||
for i in range(1, len(shape)):
|
||||
@ -1224,7 +1223,11 @@ def placeholder(shape=None,
|
||||
spec = ragged_tensor.RaggedTensorSpec(
|
||||
shape=shape, dtype=dtype, ragged_rank=ragged_rank)
|
||||
|
||||
x = keras_tensor.KerasTensor(spec, name=name)
|
||||
x = keras_tensor.RaggedKerasTensor(spec, name=name)
|
||||
else:
|
||||
spec = tensor_spec.TensorSpec(
|
||||
shape=shape, dtype=dtype, name=name)
|
||||
x = keras_tensor.KerasTensor(spec, name=name)
|
||||
else:
|
||||
with get_graph().as_default():
|
||||
if sparse:
|
||||
|
@ -348,6 +348,22 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "ragged_keras_tensor_test",
|
||||
size = "small",
|
||||
srcs = ["ragged_keras_tensor_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "input_spec_test",
|
||||
size = "small",
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -25,6 +26,8 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec as type_spec_module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_operators # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
|
||||
@ -50,6 +53,13 @@ def keras_tensors_enabled():
|
||||
return _KERAS_TENSORS_ENABLED and ops.executing_eagerly_outside_functions()
|
||||
|
||||
|
||||
# Tensorflow tensors have a maximum rank of 254
|
||||
# (See `MaxDimensions()` in //tensorflow/core/framework/tensor_shape.h )
|
||||
# So we do not try to infer values for int32 tensors larger than this,
|
||||
# As they cannot represent shapes.
|
||||
_MAX_TENSOR_RANK = 254
|
||||
|
||||
|
||||
class KerasTensor(object):
|
||||
"""A representation of a Keras in/output during Functional API construction.
|
||||
|
||||
@ -153,6 +163,75 @@ class KerasTensor(object):
|
||||
# it can't access shape or dtype
|
||||
return self._type_spec._shape # pylint: disable=protected-access
|
||||
|
||||
@classmethod
|
||||
def from_tensor(cls, tensor):
|
||||
"""Convert a traced (composite)tensor to a representative KerasTensor."""
|
||||
if isinstance(tensor, ops.Tensor):
|
||||
name = getattr(tensor, 'name', None)
|
||||
type_spec = type_spec_module.type_spec_from_value(tensor)
|
||||
inferred_value = None
|
||||
if (type_spec.dtype == dtypes.int32 and type_spec.shape.rank < 2):
|
||||
# If this tensor might be representing shape information,
|
||||
# (dtype=int32, rank of 0 or 1, not too large to represent a shape)
|
||||
# we attempt to capture any value information tensorflow's
|
||||
# shape handling can extract from the current scratch graph.
|
||||
#
|
||||
# Even though keras layers each trace in their own scratch
|
||||
# graph, this shape value info extraction allows us to capture
|
||||
# a sizable and useful subset of the C++ shape value inference TF can do
|
||||
# if all tf ops appear in the same graph when using shape ops.
|
||||
#
|
||||
# Examples of things this cannot infer concrete dimensions for
|
||||
# that the full single-graph C++ shape inference sometimes can are:
|
||||
# * cases where the shape tensor is cast out of int32 before being
|
||||
# manipulated w/ floating point numbers then converted back
|
||||
# * cases where int32 tensors w/ rank >= 2 are manipulated before being
|
||||
# used as a shape tensor
|
||||
# * cases where int32 tensors too large to represent shapes are
|
||||
# manipulated to a smaller size before being used as a shape tensor
|
||||
inferred_value = array_ops.ones(shape=tensor).shape
|
||||
if inferred_value.dims:
|
||||
inferred_value = inferred_value.as_list()
|
||||
if len(inferred_value) > _MAX_TENSOR_RANK:
|
||||
inferred_value = None
|
||||
else:
|
||||
inferred_value = None
|
||||
|
||||
return KerasTensor(type_spec, inferred_value=inferred_value, name=name)
|
||||
else:
|
||||
# Fallback to the generic arbitrary-typespec KerasTensor
|
||||
name = getattr(tensor, 'name', None)
|
||||
type_spec = type_spec_module.type_spec_from_value(tensor)
|
||||
return cls(type_spec, name=name)
|
||||
|
||||
def _to_placeholder(self):
|
||||
"""Convert this KerasTensor to a placeholder in a graph."""
|
||||
# If there is an inferred value for this tensor, inject the inferred value
|
||||
if self._inferred_value is not None:
|
||||
# If we suspect this KerasTensor might be representing a shape tensor,
|
||||
# and we were able to extract value information with TensorFlow's shape
|
||||
# handling when making the KerasTensor, we construct the placeholder by
|
||||
# re-injecting the inferred value information into the graph. We
|
||||
# do this injection through the shape of a placeholder, because that
|
||||
# allows us to specify partially-unspecified shape values.
|
||||
#
|
||||
# See the comment on value extraction inside `from_tensor` for more info.
|
||||
inferred_value = array_ops.shape(
|
||||
array_ops.placeholder(
|
||||
shape=self._inferred_value, dtype=dtypes.int32))
|
||||
if self.type_spec.shape.rank == 0:
|
||||
# `tf.shape` always returns a rank-1, we may need to turn it back to a
|
||||
# scalar.
|
||||
inferred_value = inferred_value[0]
|
||||
return inferred_value
|
||||
|
||||
# Use the generic conversion from typespec to a placeholder.
|
||||
def component_to_placeholder(component):
|
||||
return array_ops.placeholder(component.dtype, component.shape)
|
||||
|
||||
return nest.map_structure(
|
||||
component_to_placeholder, self.type_spec, expand_composites=True)
|
||||
|
||||
def get_shape(self):
|
||||
return self.shape
|
||||
|
||||
@ -298,26 +377,27 @@ class KerasTensor(object):
|
||||
return self._name
|
||||
|
||||
@classmethod
|
||||
def _overload_all_operators(cls): # pylint: disable=invalid-name
|
||||
def _overload_all_operators(cls, tensor_class): # pylint: disable=invalid-name
|
||||
"""Register overloads for all operators."""
|
||||
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
|
||||
cls._overload_operator(operator)
|
||||
cls._overload_operator(tensor_class, operator)
|
||||
|
||||
# We include `experimental_ref` for versions of TensorFlow that
|
||||
# still include the deprecated method in Tensors.
|
||||
if hasattr(ops.Tensor, 'experimental_ref'):
|
||||
cls._overload_operator('experimental_ref')
|
||||
if hasattr(tensor_class, 'experimental_ref'):
|
||||
cls._overload_operator(tensor_class, 'experimental_ref')
|
||||
|
||||
@classmethod
|
||||
def _overload_operator(cls, operator): # pylint: disable=invalid-name
|
||||
"""Overload an operator with the same overloading as `ops.Tensor`.
|
||||
def _overload_operator(cls, tensor_class, operator): # pylint: disable=invalid-name
|
||||
"""Overload an operator with the same implementation as a base Tensor class.
|
||||
|
||||
We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
|
||||
We pull the operator out of the class dynamically to avoid ordering issues.
|
||||
|
||||
Args:
|
||||
tensor_class: The (Composite)Tensor to get the method from.
|
||||
operator: string. The operator name.
|
||||
"""
|
||||
tensor_oper = getattr(ops.Tensor, operator)
|
||||
tensor_oper = getattr(tensor_class, operator)
|
||||
|
||||
# Compatibility with Python 2:
|
||||
# Python 2 unbound methods have type checks for the first arg,
|
||||
@ -327,81 +407,91 @@ class KerasTensor(object):
|
||||
setattr(cls, operator, tensor_oper)
|
||||
|
||||
|
||||
KerasTensor._overload_all_operators() # pylint: disable=protected-access
|
||||
KerasTensor._overload_all_operators(ops.Tensor) # pylint: disable=protected-access
|
||||
|
||||
|
||||
class _KerasTensorIterator(object):
|
||||
"""Iterates over the leading dim of a KerasTensor. Performs 0 error checks."""
|
||||
class SparseKerasTensor(KerasTensor):
|
||||
"""A specialized KerasTensor representation for `tf.sparse.SparseTensor`s.
|
||||
|
||||
def __init__(self, tensor, dim0):
|
||||
self._tensor = tensor
|
||||
self._index = 0
|
||||
self._limit = dim0
|
||||
Specifically, it specializes the conversion to a placeholder in order
|
||||
to maintain dense shape information.
|
||||
"""
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
def _to_placeholder(self):
|
||||
spec = self.type_spec
|
||||
|
||||
def __next__(self):
|
||||
if self._index == self._limit:
|
||||
raise StopIteration
|
||||
result = self._tensor[self._index]
|
||||
self._index += 1
|
||||
# nest.map_structure loses dense shape information for sparse tensors.
|
||||
# So, we special-case sparse placeholder creation.
|
||||
# This only preserves shape information for top-level sparse tensors;
|
||||
# not for sparse tensors that are nested inside another composite
|
||||
# tensor.
|
||||
return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape)
|
||||
|
||||
|
||||
class RaggedKerasTensor(KerasTensor):
|
||||
"""A specialized KerasTensor representation for `tf.RaggedTensor`s.
|
||||
|
||||
Specifically, it:
|
||||
|
||||
1. Specializes the conversion to a placeholder in order
|
||||
to maintain shape information for non-ragged dimensions.
|
||||
2. Overloads the KerasTensor's operators with the RaggedTensor versions
|
||||
when they don't match the `tf.Tensor` versions
|
||||
3. Exposes some of the instance method/attribute that are unique to
|
||||
the RaggedTensor API (such as ragged_rank).
|
||||
"""
|
||||
|
||||
def _to_placeholder(self):
|
||||
ragged_spec = self.type_spec
|
||||
if ragged_spec.ragged_rank == 0 or ragged_spec.shape.rank is None:
|
||||
return super(RaggedKerasTensor, self)._to_placeholder()
|
||||
|
||||
flat_shape = ragged_spec.shape[ragged_spec.ragged_rank:]
|
||||
result = array_ops.placeholder(ragged_spec.dtype, flat_shape)
|
||||
|
||||
known_num_splits = []
|
||||
prod = 1
|
||||
for axis_size in ragged_spec.shape:
|
||||
if prod is not None:
|
||||
if axis_size is None or (
|
||||
getattr(axis_size, 'value', True) is None):
|
||||
prod = None
|
||||
else:
|
||||
prod = prod * axis_size
|
||||
known_num_splits.append(prod)
|
||||
|
||||
for axis in range(ragged_spec.ragged_rank, 0, -1):
|
||||
axis_size = ragged_spec.shape[axis]
|
||||
if axis_size is None or (getattr(axis_size, 'value', True) is None):
|
||||
num_splits = known_num_splits[axis-1]
|
||||
if num_splits is not None:
|
||||
num_splits = num_splits + 1
|
||||
splits = array_ops.placeholder(
|
||||
ragged_spec.row_splits_dtype, [num_splits])
|
||||
result = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
result, splits, validate=False)
|
||||
else:
|
||||
rowlen = constant_op.constant(axis_size, ragged_spec.row_splits_dtype)
|
||||
result = ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
result, rowlen, validate=False)
|
||||
return result
|
||||
|
||||
next = __next__ # python2.x compatibility.
|
||||
|
||||
|
||||
def keras_tensor_to_placeholder(x):
|
||||
"""Construct a graph placeholder to represent a KerasTensor when tracing."""
|
||||
if hasattr(x, '_user_registered_symbolic_object'):
|
||||
return x._user_registered_symbolic_object # pylint: disable=protected-access
|
||||
|
||||
if isinstance(x, KerasTensor):
|
||||
spec = x.type_spec
|
||||
|
||||
if x._inferred_value is not None: # pylint: disable=protected-access
|
||||
# If we suspect this KerasTensor might be representing a shape tensor,
|
||||
# and we were able to extract value information with TensorFlow's shape
|
||||
# handling when making the KerasTensor, we construct the placeholder by
|
||||
# re-injecting the inferred value information into the graph.
|
||||
# Even though keras layers each trace in their own scratch
|
||||
# graph, this shape value info injection allows us to capture
|
||||
# a sizable and useful subset of the C++ shape value inference TF can do
|
||||
# if all tf ops appear in the same graph when using shape ops.
|
||||
#
|
||||
# Examples of things this cannot infer concrete dimensions for
|
||||
# that the full single-graph C++ shape inference sometimes can are:
|
||||
# * cases where the shape tensor is cast out of int32 before being
|
||||
# manipulated w/ floating point numbers then converted back
|
||||
# * cases where int32 tensors w/ rank > 2 are manipulated before being
|
||||
# used as a shape tensor
|
||||
inferred_value = array_ops.shape(
|
||||
array_ops.placeholder(
|
||||
shape=x._inferred_value, dtype=dtypes.int32)) # pylint: disable=protected-access
|
||||
if spec.shape.rank == 0:
|
||||
# `tf.shape` always returns a rank-1, we may need to turn it back to a
|
||||
# scalar.
|
||||
inferred_value = inferred_value[0]
|
||||
return inferred_value # pylint: disable=protected-access
|
||||
|
||||
if isinstance(spec, sparse_tensor.SparseTensorSpec):
|
||||
# nest.map_structure loses dense shape information for sparse tensors.
|
||||
# So, we special-case sparse placeholder creation.
|
||||
# This only preserves shape information for top-level sparse tensors;
|
||||
# not for sparse tensors that are nested inside another composite
|
||||
# tensor.
|
||||
return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape)
|
||||
|
||||
def component_to_placeholder(component):
|
||||
return array_ops.placeholder(component.dtype, component.shape)
|
||||
|
||||
ph = nest.map_structure(
|
||||
component_to_placeholder, spec, expand_composites=True)
|
||||
return ph
|
||||
else:
|
||||
return x
|
||||
@property
|
||||
def ragged_rank(self):
|
||||
return self.type_spec.ragged_rank
|
||||
|
||||
RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__add__') # pylint: disable=protected-access
|
||||
RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__radd__') # pylint: disable=protected-access
|
||||
RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__mul__') # pylint: disable=protected-access
|
||||
RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__rmul__') # pylint: disable=protected-access
|
||||
|
||||
|
||||
# TODO(b/161487382):
|
||||
# Special-case user-registered symbolic objects (registered by the
|
||||
# private `register_symbolic_tensor_type` method) by passing them between
|
||||
# scratch graphs directly.
|
||||
# This is needed to not break Tensorflow probability
|
||||
# while they finish migrating to composite tensors.
|
||||
class UserRegisteredSpec(type_spec_module.TypeSpec):
|
||||
"""TypeSpec to represent user-registered symbolic objects."""
|
||||
|
||||
@ -425,72 +515,95 @@ class UserRegisteredSpec(type_spec_module.TypeSpec):
|
||||
def value_type(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# Tensorflow tensors have a maximum dimension of 254
|
||||
# (See //tensorflow/core/framework/tensor_shape.h )
|
||||
# So we do not try to infer values for int32 tensors larger than this,
|
||||
# As they cannot represent shapes.
|
||||
_MAX_TENSOR_DIMS = 254
|
||||
|
||||
# TODO(b/161487382):
|
||||
# Special-case user-registered symbolic objects (registered by the
|
||||
# private `register_symbolic_tensor_type` method) by passing them between
|
||||
# scratch graphs directly.
|
||||
# This is needed to not break Tensorflow probability
|
||||
# while they finish migrating to composite tensors.
|
||||
class UserRegisteredTypeKerasTensor(KerasTensor):
|
||||
"""KerasTensor that represents legacy register_symbolic_tensor_type."""
|
||||
|
||||
def keras_tensor_from_tensor(x):
|
||||
"""Convert a traced (composite)tensor to a representative KerasTensor."""
|
||||
name = getattr(x, 'name', None)
|
||||
inferred_value = None
|
||||
|
||||
# TODO(b/161487382):
|
||||
# Special-case user-registered symbolic objects (registered by the
|
||||
# private `register_symbolic_tensor_type` method) by passing them between
|
||||
# scratch graphs directly.
|
||||
# This is needed to not break Tensorflow probability
|
||||
# while they finish migrating to composite tensors.
|
||||
user_registered_symbolic = False
|
||||
try:
|
||||
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top to prevent circular imports
|
||||
if isinstance(x, tuple(tf_utils._user_convertible_tensor_types)): # pylint: disable=protected-access
|
||||
user_registered_symbolic = True
|
||||
except ImportError:
|
||||
pass
|
||||
if user_registered_symbolic:
|
||||
def __init__(self, user_registered_symbolic_object):
|
||||
x = user_registered_symbolic_object
|
||||
self._user_registered_symbolic_object = x
|
||||
type_spec = UserRegisteredSpec(x.shape, x.dtype)
|
||||
name = getattr(x, 'name', None)
|
||||
|
||||
super(UserRegisteredTypeKerasTensor, self).__init__(type_spec, name)
|
||||
|
||||
@classmethod
|
||||
def from_tensor(cls, tensor):
|
||||
return cls(tensor)
|
||||
|
||||
def _to_placeholder(self):
|
||||
return self._user_registered_symbolic_object
|
||||
|
||||
|
||||
class _KerasTensorIterator(object):
|
||||
"""Iterates over the leading dim of a KerasTensor. Performs 0 error checks."""
|
||||
|
||||
def __init__(self, tensor, dim0):
|
||||
self._tensor = tensor
|
||||
self._index = 0
|
||||
self._limit = dim0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._index == self._limit:
|
||||
raise StopIteration
|
||||
result = self._tensor[self._index]
|
||||
self._index += 1
|
||||
return result
|
||||
|
||||
next = __next__ # python2.x compatibility.
|
||||
|
||||
|
||||
# Specify the mappings of tensor class to KerasTensor class.
|
||||
# This is specifically a list instead of a dict for now because
|
||||
# 1. we do a check w/ isinstance because a key lookup based on class
|
||||
# would miss subclasses
|
||||
# 2. a list allows us to control lookup ordering
|
||||
# We include ops.Tensor -> KerasTensor in the first position as a fastpath,
|
||||
# *and* include object -> KerasTensor at the end as a catch-all.
|
||||
# We can re-visit these choices in the future as needed.
|
||||
keras_tensor_classes = [
|
||||
(ops.Tensor, KerasTensor),
|
||||
(sparse_tensor.SparseTensor, SparseKerasTensor),
|
||||
(ragged_tensor.RaggedTensor, RaggedKerasTensor),
|
||||
(object, KerasTensor)
|
||||
]
|
||||
|
||||
|
||||
def register_keras_tensor_specialization(cls, keras_tensor_subclass):
|
||||
"""Register a specialized KerasTensor subclass for a Tensor type."""
|
||||
# We always leave (object, KerasTensor) at the end as a generic fallback
|
||||
keras_tensor_classes.insert(-1, (cls, keras_tensor_subclass))
|
||||
|
||||
|
||||
def keras_tensor_to_placeholder(x):
|
||||
"""Construct a graph placeholder to represent a KerasTensor when tracing."""
|
||||
if isinstance(x, KerasTensor):
|
||||
return x._to_placeholder() # pylint: disable=protected-access
|
||||
else:
|
||||
type_spec = type_spec_module.type_spec_from_value(x)
|
||||
return x
|
||||
|
||||
if (isinstance(type_spec, tensor_spec.TensorSpec)
|
||||
and type_spec.dtype == dtypes.int32
|
||||
and type_spec.shape.rank < 2):
|
||||
# If this tensor might be representing shape information,
|
||||
# (dtype=int32, rank of 0 or 1, not too large to represent a shape)
|
||||
# we attempt to capture any value information tensorflow's
|
||||
# shape handling can extract from the current scratch graph.
|
||||
#
|
||||
# Even though keras layers each trace in their own scratch
|
||||
# graph, this shape value info extraction allows us to capture
|
||||
# a sizable and useful subset of the C++ shape value inference TF can do
|
||||
# if all tf ops appear in the same graph when using shape ops.
|
||||
#
|
||||
# Examples of things this cannot infer concrete dimensions for
|
||||
# that the full single-graph C++ shape inference sometimes can are:
|
||||
# * cases where the shape tensor is cast out of int32 before being
|
||||
# manipulated w/ floating point numbers then converted back
|
||||
# * cases where int32 tensors w/ rank > 2 are manipulated before being
|
||||
# used as a shape tensor
|
||||
# * cases where int32 tensors too large to represent shapes are manipulated
|
||||
# to a smaller size before being used as a shape tensor
|
||||
inferred_value = array_ops.ones(shape=x).shape
|
||||
if inferred_value.dims:
|
||||
inferred_value = inferred_value.as_list()
|
||||
if len(inferred_value) > _MAX_TENSOR_DIMS:
|
||||
inferred_value = None
|
||||
else:
|
||||
inferred_value = None
|
||||
|
||||
out = KerasTensor(type_spec,
|
||||
inferred_value=inferred_value, name=name)
|
||||
if user_registered_symbolic:
|
||||
out._user_registered_symbolic_object = x # pylint: disable=protected-access
|
||||
def keras_tensor_from_tensor(tensor):
|
||||
"""Convert a traced (composite)tensor to a representative KerasTensor."""
|
||||
# Create a specialized KerasTensor that supports instance methods,
|
||||
# operators, and additional value inference if possible
|
||||
keras_tensor_cls = None
|
||||
for tensor_type, cls in keras_tensor_classes:
|
||||
if isinstance(tensor, tensor_type):
|
||||
keras_tensor_cls = cls
|
||||
break
|
||||
|
||||
if hasattr(x, '_keras_mask'):
|
||||
out._keras_mask = KerasTensor( # pylint: disable=protected-access
|
||||
type_spec_module.type_spec_from_value(x._keras_mask)) # pylint: disable=protected-access
|
||||
out = keras_tensor_cls.from_tensor(tensor)
|
||||
|
||||
if hasattr(tensor, '_keras_mask'):
|
||||
out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask) # pylint: disable=protected-access
|
||||
return out
|
||||
|
96
tensorflow/python/keras/engine/ragged_keras_tensor_test.py
Normal file
96
tensorflow/python/keras/engine/ragged_keras_tensor_test.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""RaggedKerasTensor tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class RaggedKerasTensorTest(keras_parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
{'batch_size': None, 'shape': (None, 5), 'ragged_rank': 1},
|
||||
{'batch_size': None, 'shape': (None, 3, 5), 'ragged_rank': 1},
|
||||
{'batch_size': None, 'shape': (5, None), 'ragged_rank': 2},
|
||||
{'batch_size': None, 'shape': (3, 5, None), 'ragged_rank': 3},
|
||||
{'batch_size': None, 'shape': (None, 3, 5, None), 'ragged_rank': 4},
|
||||
{'batch_size': None, 'shape': (2, 3, None, 4, 5, None), 'ragged_rank': 6},
|
||||
{'batch_size': 8, 'shape': (None, 5), 'ragged_rank': 1},
|
||||
{'batch_size': 9, 'shape': (None, 3, 5), 'ragged_rank': 1},
|
||||
{'batch_size': 1, 'shape': (5, None), 'ragged_rank': 2},
|
||||
{'batch_size': 4, 'shape': (3, 5, None), 'ragged_rank': 3},
|
||||
{'batch_size': 7, 'shape': (None, 3, 5, None), 'ragged_rank': 4},
|
||||
{'batch_size': 12, 'shape': (2, 3, None, 4, 5, None), 'ragged_rank': 6},
|
||||
)
|
||||
def test_to_placeholder(self, shape, batch_size, ragged_rank):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
inp = layers.Input(shape=shape, batch_size=batch_size, ragged=True)
|
||||
self.assertEqual(inp.ragged_rank, ragged_rank)
|
||||
self.assertAllEqual(inp.shape, [batch_size] + list(shape))
|
||||
with func_graph.FuncGraph('test').as_default():
|
||||
placeholder = inp._to_placeholder()
|
||||
self.assertEqual(placeholder.ragged_rank, ragged_rank)
|
||||
self.assertAllEqual(placeholder.shape, [batch_size] + list(shape))
|
||||
|
||||
def test_add(self):
|
||||
inp = layers.Input(shape=[None], ragged=True)
|
||||
out = inp + inp
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
|
||||
self.assertAllEqual(model(x), x + x)
|
||||
|
||||
def test_mul(self):
|
||||
inp = layers.Input(shape=[None], ragged=True)
|
||||
out = inp * inp
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
|
||||
self.assertAllEqual(model(x), x * x)
|
||||
|
||||
def test_sub(self):
|
||||
inp = layers.Input(shape=[None], ragged=True)
|
||||
out = inp - inp
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
|
||||
self.assertAllEqual(model(x), x - x)
|
||||
|
||||
def test_div(self):
|
||||
inp = layers.Input(shape=[None], ragged=True)
|
||||
out = inp / inp
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
|
||||
self.assertAllEqual(model(x), x / x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
tensor_shape.enable_v2_tensorshape()
|
||||
test.main()
|
@ -158,15 +158,15 @@ def _int32_manipulation_at_max_shape_dims_limit():
|
||||
# of the max tensor size Keras can try inferring values for.
|
||||
inputs = keras.Input(batch_size=2, shape=(10,))
|
||||
batch_size = array_ops.shape(inputs)[0]
|
||||
num_features = int(keras_tensor._MAX_TENSOR_DIMS / int(inputs.shape[0]))
|
||||
num_features = int(keras_tensor._MAX_TENSOR_RANK / int(inputs.shape[0]))
|
||||
x = math_ops.range(batch_size * num_features, dtype='int32')
|
||||
assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_DIMS]
|
||||
assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_RANK]
|
||||
|
||||
# Verify that a value was actually inferred for a tensor that *might*
|
||||
# represent the shape, bying checking that a value in
|
||||
# the range appears in the printed inferred value
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
assert str(keras_tensor._MAX_TENSOR_DIMS - 1) in str(x)
|
||||
assert str(keras_tensor._MAX_TENSOR_RANK - 1) in str(x)
|
||||
|
||||
x = array_ops.reshape(x, (batch_size, num_features))
|
||||
x = math_ops.cast(x, dtype='float32')
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -363,6 +364,9 @@ def register_symbolic_tensor_type(cls):
|
||||
cls: A `class` type which shall be regarded as a symbolic `Tensor`.
|
||||
"""
|
||||
global _user_convertible_tensor_types
|
||||
if cls not in _user_convertible_tensor_types:
|
||||
keras_tensor.register_keras_tensor_specialization(
|
||||
cls, keras_tensor.UserRegisteredTypeKerasTensor)
|
||||
_user_convertible_tensor_types.add(cls)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user