Put a size limit on the int32 tensors KerasTensors will try to infer values for. This is needed because there is a maximum rank limit (of 254) for Tensors, so int32 tensors with more than 254 elements cannot represent shapes. (Which before this cl would cause KerasTensor shape inference to crash)
PiperOrigin-RevId: 323860520 Change-Id: Icbf10b8220739c5a9474f6f588174606402b9ca8
This commit is contained in:
parent
8295742853
commit
d6066885d7
@ -433,6 +433,12 @@ class UserRegisteredSpec(type_spec_module.TypeSpec):
|
||||
def value_type(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# Tensorflow tensors have a maximum dimension of 254
|
||||
# (See //tensorflow/core/framework/tensor_shape.h )
|
||||
# So we do not try to infer values for int32 tensors larger than this,
|
||||
# As they cannot represent shapes.
|
||||
_MAX_TENSOR_DIMS = 254
|
||||
|
||||
|
||||
def keras_tensor_from_tensor(x):
|
||||
"""Convert a traced (composite)tensor to a representative KerasTensor."""
|
||||
@ -461,7 +467,7 @@ def keras_tensor_from_tensor(x):
|
||||
and type_spec.dtype == dtypes.int32
|
||||
and type_spec.shape.rank < 2):
|
||||
# If this tensor might be representing shape information,
|
||||
# (dtype=int32, rank of 0 or 1)
|
||||
# (dtype=int32, rank of 0 or 1, not too large to represent a shape)
|
||||
# we attempt to capture any value information tensorflow's
|
||||
# shape handling can extract from the current scratch graph.
|
||||
#
|
||||
@ -476,9 +482,13 @@ def keras_tensor_from_tensor(x):
|
||||
# manipulated w/ floating point numbers then converted back
|
||||
# * cases where int32 tensors w/ rank > 2 are manipulated before being
|
||||
# used as a shape tensor
|
||||
# * cases where int32 tensors too large to represent shapes are manipulated
|
||||
# to a smaller size before being used as a shape tensor
|
||||
inferred_value = array_ops.ones(shape=x).shape
|
||||
if inferred_value.dims:
|
||||
inferred_value = inferred_value.as_list()
|
||||
if len(inferred_value) > _MAX_TENSOR_DIMS:
|
||||
inferred_value = None
|
||||
else:
|
||||
inferred_value = None
|
||||
|
||||
|
@ -131,6 +131,55 @@ def _shape_op_slice_and_range_known_dim():
|
||||
return keras.Model(inputs, inputs)
|
||||
|
||||
|
||||
def _int32_manipulation_too_big_for_shape():
|
||||
# This test verifies that the Keras Functional API
|
||||
# won't crash when manipulating int32 tensors that are too large
|
||||
# to represent shapes.
|
||||
inputs = keras.Input(batch_size=2, shape=(10,))
|
||||
batch_size = array_ops.shape(inputs)[0]
|
||||
num_features = 3 * 1024 * 16
|
||||
x = math_ops.range(batch_size * num_features, dtype='int32')
|
||||
assert x.shape.as_list() == [inputs.shape[0] * num_features]
|
||||
x = array_ops.reshape(x, (batch_size, num_features))
|
||||
x = math_ops.cast(x, dtype='float32')
|
||||
outputs = keras.layers.Dense(10)(x)
|
||||
if context.executing_eagerly():
|
||||
return keras.Model(inputs, outputs)
|
||||
else:
|
||||
# In V1 the op layer fails for some reason,
|
||||
# but we don't have access to the test case to call
|
||||
# self.skip_test in this util method
|
||||
return keras.Model(inputs, inputs)
|
||||
|
||||
|
||||
def _int32_manipulation_at_max_shape_dims_limit():
|
||||
# This test verifies that the Keras Functional API
|
||||
# won't crash when manipulating int32 tensors that are at the limit
|
||||
# of the max tensor size Keras can try inferring values for.
|
||||
inputs = keras.Input(batch_size=2, shape=(10,))
|
||||
batch_size = array_ops.shape(inputs)[0]
|
||||
num_features = int(keras_tensor._MAX_TENSOR_DIMS / int(inputs.shape[0]))
|
||||
x = math_ops.range(batch_size * num_features, dtype='int32')
|
||||
assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_DIMS]
|
||||
|
||||
# Verify that a value was actually inferred for a tensor that *might*
|
||||
# represent the shape, bying checking that a value in
|
||||
# the range appears in the printed inferred value
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
assert str(keras_tensor._MAX_TENSOR_DIMS - 1) in str(x)
|
||||
|
||||
x = array_ops.reshape(x, (batch_size, num_features))
|
||||
x = math_ops.cast(x, dtype='float32')
|
||||
outputs = keras.layers.Dense(10)(x)
|
||||
if context.executing_eagerly():
|
||||
return keras.Model(inputs, outputs)
|
||||
else:
|
||||
# In V1 the op layer fails for some reason,
|
||||
# but we don't have access to the test case to call
|
||||
# self.skip_test in this util method
|
||||
return keras.Model(inputs, inputs)
|
||||
|
||||
|
||||
def _single_standalone_branch():
|
||||
inputs = keras.Input(shape=(10,))
|
||||
x = keras.layers.Dense(10)(inputs)
|
||||
@ -252,6 +301,10 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
('shape_op_slice_and_range', _shape_op_slice_and_range),
|
||||
('shape_op_slice_and_range_known_dim',
|
||||
_shape_op_slice_and_range_known_dim),
|
||||
('int32_manipulation_too_big_for_shape',
|
||||
_int32_manipulation_too_big_for_shape),
|
||||
('int32_manipulation_at_max_shape_dims_limit',
|
||||
_int32_manipulation_at_max_shape_dims_limit),
|
||||
('single_standalone_branch', _single_standalone_branch),
|
||||
('single_op_with_attrs', _single_op_with_attrs),
|
||||
('multiple_uses', _multiple_uses),
|
||||
|
Loading…
x
Reference in New Issue
Block a user