Fix the cache key problem when compute_output_shape().
This is a very tricky one wrt the id() of int in python. Under the hood, id returns memory address for the int, and python has a cache location for the ints, which result into different ints get same hash value. Changed to use tuples of shape itself as the dict key, since the tuple itself is immutable and hashable. Same tuple value will return the same hash value. Also remove the generic utils for that where network.py is only usage for that function. Fix #32029 PiperOrigin-RevId: 296302946 Change-Id: I865c9380a06ed6ee80fea7f942c21c4d102473c2
This commit is contained in:
parent
4030aa1fe5
commit
3ba8bd697f
@ -720,7 +720,9 @@ class Network(base_layer.Layer):
|
||||
': model has ' + str(len(self._input_layers)) +
|
||||
' tensor inputs.')
|
||||
|
||||
cache_key = generic_utils.object_list_uid(input_shape)
|
||||
# Use the tuple of TensorShape as the cache key, since tuple is hashable
|
||||
# and can be used as hash key.
|
||||
cache_key = tuple(tf_utils.convert_shapes(input_shape, to_tuples=True))
|
||||
if cache_key in self._output_shape_cache:
|
||||
# Cache hit. Return shapes as TensorShapes.
|
||||
return self._output_shape_cache[cache_key]
|
||||
@ -905,7 +907,7 @@ class Network(base_layer.Layer):
|
||||
|
||||
if output_shapes is not None:
|
||||
input_shapes = [x.shape for x in inputs]
|
||||
cache_key = generic_utils.object_list_uid(input_shapes)
|
||||
cache_key = tuple(tf_utils.convert_shapes(input_shapes, to_tuples=True))
|
||||
self._output_shape_cache[cache_key] = nest.pack_sequence_as(
|
||||
self._nested_outputs, output_shapes)
|
||||
|
||||
|
@ -1869,6 +1869,15 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(network.dynamic, False)
|
||||
self.assertEqual(network.stateful, False)
|
||||
|
||||
def test_compute_output_shape_cache(self):
|
||||
# See https://github.com/tensorflow/tensorflow/issues/32029.
|
||||
x = input_layer_lib.Input(shape=(None, 32))
|
||||
dense = keras.layers.Dense(2)
|
||||
y = dense(x)
|
||||
network = network_lib.Network(x, y, name='dense_network')
|
||||
|
||||
for i in range(999, 1024):
|
||||
self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -756,12 +756,6 @@ def to_list(x):
|
||||
return [x]
|
||||
|
||||
|
||||
def object_list_uid(object_list):
|
||||
"""Creates a single string from object ids."""
|
||||
object_list = nest.flatten(object_list)
|
||||
return ', '.join(str(abs(id(x))) for x in object_list)
|
||||
|
||||
|
||||
def to_snake_case(name):
|
||||
intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
|
||||
insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
|
||||
|
Loading…
Reference in New Issue
Block a user