diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index f83ed74c2f8..eccd3ba6ea9 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -240,14 +240,33 @@ def reset_uids(): @keras_export('keras.backend.clear_session') def clear_session(): - """Destroys the current TF graph and session, and creates a new one. + """Resets all state generated by Keras. - Calling clear_session() releases the global graph state that Keras is - holding on to; resets the counters used for naming layers and - variables in Keras; and resets the learning phase. This helps avoid clutter - from old models and layers, especially when memory is limited, and a - common use-case for clear_session is releasing memory when building models - and layers in a loop. + Keras manages a global state, which it uses to implement the Functional + model-building API and to uniquify autogenerated layer names. + + If you are creating many models in a loop, this global state will consume + an increasing amount of memory over time, and you may want to clear it. + Calling `clear_session()` releases the global state: this helps avoid clutter + from old models and layers, especially when memory is limited. + + Example 1: calling `clear_session()` when creating models in a loop + + ```python + for _ in range(100): + # Without `clear_session()`, each iteration of this loop will + # slightly increase the size of the global state managed by Keras + model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)]) + + for _ in range(100): + # With `clear_session()` called at the beginning, + # Keras starts with a blank state at each iteration + # and memory consumption is constant over time. + tf.keras.backend.clear_session() + model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)]) + ``` + + Example 2: resetting the layer name generation counter >>> import tensorflow as tf >>> layers = [tf.keras.layers.Dense(10) for _ in range(10)] @@ -261,8 +280,6 @@ def clear_session(): >>> new_layer = tf.keras.layers.Dense(10) >>> print(new_layer.name) dense - >>> print(tf.keras.backend.learning_phase()) - 0 """ global _SESSION global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned