Make the individual call of Model.evaluate use no cached data, while Model.fit still uses cached data for evaluation.

PiperOrigin-RevId: 337194749
Change-Id: I1a99eadac1831918a21ab994a1d7a5e84b0623c6
This commit is contained in:
Yanhui Liang 2020-10-14 16:15:12 -07:00 committed by TensorFlower Gardener
parent 2882df7cf2
commit 5cb9129e4b
4 changed files with 38 additions and 5 deletions
RELEASE.md
tensorflow/python/keras

View File

@ -225,6 +225,9 @@
argument.
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
with the same implementation as their `tf.losses` equivalent.
* For Keras model, the individual call of `Model.evaluate` uses no cached
data for evaluation, while `Model.fit` uses cached data when
`validation_data` arg is provided for better performance.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing

View File

@ -58,6 +58,7 @@ from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
@ -1099,6 +1100,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
if validation_data and self._should_eval(epoch, validation_freq):
# Create data_handler for evaluation and cache it.
if getattr(self, '_eval_data_handler', None) is None:
self._fit_frame = tf_inspect.currentframe()
self._eval_data_handler = data_adapter.DataHandler(
x=val_x,
y=val_y,
@ -1134,6 +1136,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
# If eval data_hanlder exists, delete it after all epochs are done.
if getattr(self, '_eval_data_handler', None) is not None:
del self._eval_data_handler
del self._fit_frame
callbacks.on_train_end(logs=training_logs)
return self.history
@ -1327,7 +1330,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
_disallow_inside_tf_function('evaluate')
with self.distribute_strategy.scope():
if getattr(self, '_eval_data_handler', None) is not None:
# Use cached evaluation data only when it's called in `Model.fit`
if (getattr(self, '_fit_frame', None) is not None
and tf_inspect.currentframe().f_back is self._fit_frame
and getattr(self, '_eval_data_handler', None) is not None):
data_handler = self._eval_data_handler
else:
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.

View File

@ -3642,16 +3642,20 @@ class TestAutoUpdates(keras_parameterized.TestCase):
class TestFunctionTracing(keras_parameterized.TestCase):
def _seq_model_and_data(self):
model = sequential.Sequential([layers_module.Dense(4, activation='relu')])
model.compile(loss='mse', optimizer='rmsprop')
x = np.random.random((10, 6))
y = np.random.random((10, 4))
return model, x, y
@keras_parameterized.run_all_keras_modes(
always_skip_v1=True, always_skip_eager=True)
def test_no_tracing_between_epoch(self):
if sys.version_info[0] < 3:
self.skipTest('self.assertLogs() call is not available in Python 2.')
model = sequential.Sequential([layers_module.Dense(4, activation='relu')])
model.compile(loss='mse', optimizer='rmsprop')
x = np.random.random((10, 6))
y = np.random.random((10, 4))
model, x, y = self._seq_model_and_data()
logging.set_verbosity(1)
with self.assertLogs(level=1) as logs:
@ -3660,6 +3664,21 @@ class TestFunctionTracing(keras_parameterized.TestCase):
new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function'
self.assertEqual(sum(new_func_graph in log for log in logs.output), 9)
@keras_parameterized.run_all_keras_modes(
always_skip_v1=True, always_skip_eager=True)
def test_evaluate_no_cached_data(self):
if sys.version_info[0] < 3:
self.skipTest('self.assertLogs() call is not available in Python 2.')
model, x, y = self._seq_model_and_data()
new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function'
logging.set_verbosity(1)
with self.assertLogs(level=1) as eval_logs:
for _ in range(6):
model.evaluate(x, y, batch_size=5)
self.assertEqual(sum(new_func_graph in log for log in eval_logs.output), 20)
class TestBuildCustomModel(keras_parameterized.TestCase):

View File

@ -90,6 +90,11 @@ else:
return _convert_maybe_argspec_to_fullargspec(getargspec(target))
def currentframe():
"""TFDecorator-aware replacement for inspect.currentframe."""
return _inspect.stack()[1][0]
def getargspec(obj):
"""TFDecorator-aware replacement for `inspect.getargspec`.