Make the individual call of Model.evaluate
use no cached data, while Model.fit
still uses cached data for evaluation.
PiperOrigin-RevId: 337194749 Change-Id: I1a99eadac1831918a21ab994a1d7a5e84b0623c6
This commit is contained in:
parent
2882df7cf2
commit
5cb9129e4b
@ -225,6 +225,9 @@
|
||||
argument.
|
||||
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
|
||||
with the same implementation as their `tf.losses` equivalent.
|
||||
* For Keras model, the individual call of `Model.evaluate` uses no cached
|
||||
data for evaluation, while `Model.fit` uses cached data when
|
||||
`validation_data` arg is provided for better performance.
|
||||
* `tf.function` / AutoGraph:
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
True, the function may use type annotations to optimize the tracing
|
||||
|
@ -58,6 +58,7 @@ from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.keras.utils import version_utils
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
@ -1099,6 +1100,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
if validation_data and self._should_eval(epoch, validation_freq):
|
||||
# Create data_handler for evaluation and cache it.
|
||||
if getattr(self, '_eval_data_handler', None) is None:
|
||||
self._fit_frame = tf_inspect.currentframe()
|
||||
self._eval_data_handler = data_adapter.DataHandler(
|
||||
x=val_x,
|
||||
y=val_y,
|
||||
@ -1134,6 +1136,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
# If eval data_hanlder exists, delete it after all epochs are done.
|
||||
if getattr(self, '_eval_data_handler', None) is not None:
|
||||
del self._eval_data_handler
|
||||
del self._fit_frame
|
||||
callbacks.on_train_end(logs=training_logs)
|
||||
return self.history
|
||||
|
||||
@ -1327,7 +1330,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
_disallow_inside_tf_function('evaluate')
|
||||
|
||||
with self.distribute_strategy.scope():
|
||||
if getattr(self, '_eval_data_handler', None) is not None:
|
||||
# Use cached evaluation data only when it's called in `Model.fit`
|
||||
if (getattr(self, '_fit_frame', None) is not None
|
||||
and tf_inspect.currentframe().f_back is self._fit_frame
|
||||
and getattr(self, '_eval_data_handler', None) is not None):
|
||||
data_handler = self._eval_data_handler
|
||||
else:
|
||||
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
||||
|
@ -3642,16 +3642,20 @@ class TestAutoUpdates(keras_parameterized.TestCase):
|
||||
|
||||
class TestFunctionTracing(keras_parameterized.TestCase):
|
||||
|
||||
def _seq_model_and_data(self):
|
||||
model = sequential.Sequential([layers_module.Dense(4, activation='relu')])
|
||||
model.compile(loss='mse', optimizer='rmsprop')
|
||||
x = np.random.random((10, 6))
|
||||
y = np.random.random((10, 4))
|
||||
return model, x, y
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(
|
||||
always_skip_v1=True, always_skip_eager=True)
|
||||
def test_no_tracing_between_epoch(self):
|
||||
if sys.version_info[0] < 3:
|
||||
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
||||
|
||||
model = sequential.Sequential([layers_module.Dense(4, activation='relu')])
|
||||
model.compile(loss='mse', optimizer='rmsprop')
|
||||
x = np.random.random((10, 6))
|
||||
y = np.random.random((10, 4))
|
||||
model, x, y = self._seq_model_and_data()
|
||||
|
||||
logging.set_verbosity(1)
|
||||
with self.assertLogs(level=1) as logs:
|
||||
@ -3660,6 +3664,21 @@ class TestFunctionTracing(keras_parameterized.TestCase):
|
||||
new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function'
|
||||
self.assertEqual(sum(new_func_graph in log for log in logs.output), 9)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(
|
||||
always_skip_v1=True, always_skip_eager=True)
|
||||
def test_evaluate_no_cached_data(self):
|
||||
if sys.version_info[0] < 3:
|
||||
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
||||
|
||||
model, x, y = self._seq_model_and_data()
|
||||
|
||||
new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function'
|
||||
logging.set_verbosity(1)
|
||||
with self.assertLogs(level=1) as eval_logs:
|
||||
for _ in range(6):
|
||||
model.evaluate(x, y, batch_size=5)
|
||||
self.assertEqual(sum(new_func_graph in log for log in eval_logs.output), 20)
|
||||
|
||||
|
||||
class TestBuildCustomModel(keras_parameterized.TestCase):
|
||||
|
||||
|
@ -90,6 +90,11 @@ else:
|
||||
return _convert_maybe_argspec_to_fullargspec(getargspec(target))
|
||||
|
||||
|
||||
def currentframe():
|
||||
"""TFDecorator-aware replacement for inspect.currentframe."""
|
||||
return _inspect.stack()[1][0]
|
||||
|
||||
|
||||
def getargspec(obj):
|
||||
"""TFDecorator-aware replacement for `inspect.getargspec`.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user