Put a size limit on the int32 tensors KerasTensors will try to infer values for. This is needed because there is a maximum rank limit (of 254) for Tensors, so int32 tensors with more than 254 elements cannot represent shapes. (Which before this cl would cause KerasTensor shape inference to crash)
PiperOrigin-RevId: 323860520 Change-Id: Icbf10b8220739c5a9474f6f588174606402b9ca8
This commit is contained in:
parent
8295742853
commit
d6066885d7
@ -433,6 +433,12 @@ class UserRegisteredSpec(type_spec_module.TypeSpec):
|
|||||||
def value_type(self):
|
def value_type(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# Tensorflow tensors have a maximum dimension of 254
|
||||||
|
# (See //tensorflow/core/framework/tensor_shape.h )
|
||||||
|
# So we do not try to infer values for int32 tensors larger than this,
|
||||||
|
# As they cannot represent shapes.
|
||||||
|
_MAX_TENSOR_DIMS = 254
|
||||||
|
|
||||||
|
|
||||||
def keras_tensor_from_tensor(x):
|
def keras_tensor_from_tensor(x):
|
||||||
"""Convert a traced (composite)tensor to a representative KerasTensor."""
|
"""Convert a traced (composite)tensor to a representative KerasTensor."""
|
||||||
@ -461,7 +467,7 @@ def keras_tensor_from_tensor(x):
|
|||||||
and type_spec.dtype == dtypes.int32
|
and type_spec.dtype == dtypes.int32
|
||||||
and type_spec.shape.rank < 2):
|
and type_spec.shape.rank < 2):
|
||||||
# If this tensor might be representing shape information,
|
# If this tensor might be representing shape information,
|
||||||
# (dtype=int32, rank of 0 or 1)
|
# (dtype=int32, rank of 0 or 1, not too large to represent a shape)
|
||||||
# we attempt to capture any value information tensorflow's
|
# we attempt to capture any value information tensorflow's
|
||||||
# shape handling can extract from the current scratch graph.
|
# shape handling can extract from the current scratch graph.
|
||||||
#
|
#
|
||||||
@ -476,9 +482,13 @@ def keras_tensor_from_tensor(x):
|
|||||||
# manipulated w/ floating point numbers then converted back
|
# manipulated w/ floating point numbers then converted back
|
||||||
# * cases where int32 tensors w/ rank > 2 are manipulated before being
|
# * cases where int32 tensors w/ rank > 2 are manipulated before being
|
||||||
# used as a shape tensor
|
# used as a shape tensor
|
||||||
|
# * cases where int32 tensors too large to represent shapes are manipulated
|
||||||
|
# to a smaller size before being used as a shape tensor
|
||||||
inferred_value = array_ops.ones(shape=x).shape
|
inferred_value = array_ops.ones(shape=x).shape
|
||||||
if inferred_value.dims:
|
if inferred_value.dims:
|
||||||
inferred_value = inferred_value.as_list()
|
inferred_value = inferred_value.as_list()
|
||||||
|
if len(inferred_value) > _MAX_TENSOR_DIMS:
|
||||||
|
inferred_value = None
|
||||||
else:
|
else:
|
||||||
inferred_value = None
|
inferred_value = None
|
||||||
|
|
||||||
|
@ -131,6 +131,55 @@ def _shape_op_slice_and_range_known_dim():
|
|||||||
return keras.Model(inputs, inputs)
|
return keras.Model(inputs, inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def _int32_manipulation_too_big_for_shape():
|
||||||
|
# This test verifies that the Keras Functional API
|
||||||
|
# won't crash when manipulating int32 tensors that are too large
|
||||||
|
# to represent shapes.
|
||||||
|
inputs = keras.Input(batch_size=2, shape=(10,))
|
||||||
|
batch_size = array_ops.shape(inputs)[0]
|
||||||
|
num_features = 3 * 1024 * 16
|
||||||
|
x = math_ops.range(batch_size * num_features, dtype='int32')
|
||||||
|
assert x.shape.as_list() == [inputs.shape[0] * num_features]
|
||||||
|
x = array_ops.reshape(x, (batch_size, num_features))
|
||||||
|
x = math_ops.cast(x, dtype='float32')
|
||||||
|
outputs = keras.layers.Dense(10)(x)
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return keras.Model(inputs, outputs)
|
||||||
|
else:
|
||||||
|
# In V1 the op layer fails for some reason,
|
||||||
|
# but we don't have access to the test case to call
|
||||||
|
# self.skip_test in this util method
|
||||||
|
return keras.Model(inputs, inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def _int32_manipulation_at_max_shape_dims_limit():
|
||||||
|
# This test verifies that the Keras Functional API
|
||||||
|
# won't crash when manipulating int32 tensors that are at the limit
|
||||||
|
# of the max tensor size Keras can try inferring values for.
|
||||||
|
inputs = keras.Input(batch_size=2, shape=(10,))
|
||||||
|
batch_size = array_ops.shape(inputs)[0]
|
||||||
|
num_features = int(keras_tensor._MAX_TENSOR_DIMS / int(inputs.shape[0]))
|
||||||
|
x = math_ops.range(batch_size * num_features, dtype='int32')
|
||||||
|
assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_DIMS]
|
||||||
|
|
||||||
|
# Verify that a value was actually inferred for a tensor that *might*
|
||||||
|
# represent the shape, bying checking that a value in
|
||||||
|
# the range appears in the printed inferred value
|
||||||
|
if keras_tensor.keras_tensors_enabled():
|
||||||
|
assert str(keras_tensor._MAX_TENSOR_DIMS - 1) in str(x)
|
||||||
|
|
||||||
|
x = array_ops.reshape(x, (batch_size, num_features))
|
||||||
|
x = math_ops.cast(x, dtype='float32')
|
||||||
|
outputs = keras.layers.Dense(10)(x)
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return keras.Model(inputs, outputs)
|
||||||
|
else:
|
||||||
|
# In V1 the op layer fails for some reason,
|
||||||
|
# but we don't have access to the test case to call
|
||||||
|
# self.skip_test in this util method
|
||||||
|
return keras.Model(inputs, inputs)
|
||||||
|
|
||||||
|
|
||||||
def _single_standalone_branch():
|
def _single_standalone_branch():
|
||||||
inputs = keras.Input(shape=(10,))
|
inputs = keras.Input(shape=(10,))
|
||||||
x = keras.layers.Dense(10)(inputs)
|
x = keras.layers.Dense(10)(inputs)
|
||||||
@ -252,6 +301,10 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
|||||||
('shape_op_slice_and_range', _shape_op_slice_and_range),
|
('shape_op_slice_and_range', _shape_op_slice_and_range),
|
||||||
('shape_op_slice_and_range_known_dim',
|
('shape_op_slice_and_range_known_dim',
|
||||||
_shape_op_slice_and_range_known_dim),
|
_shape_op_slice_and_range_known_dim),
|
||||||
|
('int32_manipulation_too_big_for_shape',
|
||||||
|
_int32_manipulation_too_big_for_shape),
|
||||||
|
('int32_manipulation_at_max_shape_dims_limit',
|
||||||
|
_int32_manipulation_at_max_shape_dims_limit),
|
||||||
('single_standalone_branch', _single_standalone_branch),
|
('single_standalone_branch', _single_standalone_branch),
|
||||||
('single_op_with_attrs', _single_op_with_attrs),
|
('single_op_with_attrs', _single_op_with_attrs),
|
||||||
('multiple_uses', _multiple_uses),
|
('multiple_uses', _multiple_uses),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user