From 3ba8bd697faf4b831f78c3fa547d7956f1b1a0aa Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 20 Feb 2020 14:50:37 -0800 Subject: [PATCH] Fix the cache key problem when compute_output_shape(). This is a very tricky one wrt the id() of int in python. Under the hood, id returns memory address for the int, and python has a cache location for the ints, which result into different ints get same hash value. Changed to use tuples of shape itself as the dict key, since the tuple itself is immutable and hashable. Same tuple value will return the same hash value. Also remove the generic utils for that where network.py is only usage for that function. Fix #32029 PiperOrigin-RevId: 296302946 Change-Id: I865c9380a06ed6ee80fea7f942c21c4d102473c2 --- tensorflow/python/keras/engine/network.py | 6 ++++-- tensorflow/python/keras/engine/network_test.py | 9 +++++++++ tensorflow/python/keras/utils/generic_utils.py | 6 ------ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 166553a324b..79f15d9f3ae 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -720,7 +720,9 @@ class Network(base_layer.Layer): ': model has ' + str(len(self._input_layers)) + ' tensor inputs.') - cache_key = generic_utils.object_list_uid(input_shape) + # Use the tuple of TensorShape as the cache key, since tuple is hashable + # and can be used as hash key. + cache_key = tuple(tf_utils.convert_shapes(input_shape, to_tuples=True)) if cache_key in self._output_shape_cache: # Cache hit. Return shapes as TensorShapes. return self._output_shape_cache[cache_key] @@ -905,7 +907,7 @@ class Network(base_layer.Layer): if output_shapes is not None: input_shapes = [x.shape for x in inputs] - cache_key = generic_utils.object_list_uid(input_shapes) + cache_key = tuple(tf_utils.convert_shapes(input_shapes, to_tuples=True)) self._output_shape_cache[cache_key] = nest.pack_sequence_as( self._nested_outputs, output_shapes) diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py index b3e19f2a6ea..17f08889936 100644 --- a/tensorflow/python/keras/engine/network_test.py +++ b/tensorflow/python/keras/engine/network_test.py @@ -1869,6 +1869,15 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): self.assertEqual(network.dynamic, False) self.assertEqual(network.stateful, False) + def test_compute_output_shape_cache(self): + # See https://github.com/tensorflow/tensorflow/issues/32029. + x = input_layer_lib.Input(shape=(None, 32)) + dense = keras.layers.Dense(2) + y = dense(x) + network = network_lib.Network(x, y, name='dense_network') + + for i in range(999, 1024): + self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2)) if __name__ == '__main__': diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index edbfed6d776..9ee644bf8cd 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -756,12 +756,6 @@ def to_list(x): return [x] -def object_list_uid(object_list): - """Creates a single string from object ids.""" - object_list = nest.flatten(object_list) - return ', '.join(str(abs(id(x))) for x in object_list) - - def to_snake_case(name): intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()