Fix the cache key problem when compute_output_shape().

This is a very tricky one wrt the id() of int in python. Under the hood,
id returns memory address for the int, and python has a cache location
for the ints, which result into different ints get same hash value.

Changed to use tuples of shape itself as the dict key, since the tuple
itself is immutable and hashable. Same tuple value will return the same
hash value.

Also remove the generic utils for that where network.py is only usage
for that function.

Fix #32029

PiperOrigin-RevId: 296302946
Change-Id: I865c9380a06ed6ee80fea7f942c21c4d102473c2
This commit is contained in:
Scott Zhu 2020-02-20 14:50:37 -08:00 committed by TensorFlower Gardener
parent 4030aa1fe5
commit 3ba8bd697f
3 changed files with 13 additions and 8 deletions

View File

@ -720,7 +720,9 @@ class Network(base_layer.Layer):
': model has ' + str(len(self._input_layers)) + ': model has ' + str(len(self._input_layers)) +
' tensor inputs.') ' tensor inputs.')
cache_key = generic_utils.object_list_uid(input_shape) # Use the tuple of TensorShape as the cache key, since tuple is hashable
# and can be used as hash key.
cache_key = tuple(tf_utils.convert_shapes(input_shape, to_tuples=True))
if cache_key in self._output_shape_cache: if cache_key in self._output_shape_cache:
# Cache hit. Return shapes as TensorShapes. # Cache hit. Return shapes as TensorShapes.
return self._output_shape_cache[cache_key] return self._output_shape_cache[cache_key]
@ -905,7 +907,7 @@ class Network(base_layer.Layer):
if output_shapes is not None: if output_shapes is not None:
input_shapes = [x.shape for x in inputs] input_shapes = [x.shape for x in inputs]
cache_key = generic_utils.object_list_uid(input_shapes) cache_key = tuple(tf_utils.convert_shapes(input_shapes, to_tuples=True))
self._output_shape_cache[cache_key] = nest.pack_sequence_as( self._output_shape_cache[cache_key] = nest.pack_sequence_as(
self._nested_outputs, output_shapes) self._nested_outputs, output_shapes)

View File

@ -1869,6 +1869,15 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
self.assertEqual(network.dynamic, False) self.assertEqual(network.dynamic, False)
self.assertEqual(network.stateful, False) self.assertEqual(network.stateful, False)
def test_compute_output_shape_cache(self):
# See https://github.com/tensorflow/tensorflow/issues/32029.
x = input_layer_lib.Input(shape=(None, 32))
dense = keras.layers.Dense(2)
y = dense(x)
network = network_lib.Network(x, y, name='dense_network')
for i in range(999, 1024):
self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -756,12 +756,6 @@ def to_list(x):
return [x] return [x]
def object_list_uid(object_list):
"""Creates a single string from object ids."""
object_list = nest.flatten(object_list)
return ', '.join(str(abs(id(x))) for x in object_list)
def to_snake_case(name): def to_snake_case(name):
intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()