Make the individual call of Model.evaluate use no cached data, while Model.fit still uses cached data for evaluation.
PiperOrigin-RevId: 337194749 Change-Id: I1a99eadac1831918a21ab994a1d7a5e84b0623c6
This commit is contained in:
parent
2882df7cf2
commit
5cb9129e4b
@ -225,6 +225,9 @@
|
|||||||
argument.
|
argument.
|
||||||
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
|
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
|
||||||
with the same implementation as their `tf.losses` equivalent.
|
with the same implementation as their `tf.losses` equivalent.
|
||||||
|
* For Keras model, the individual call of `Model.evaluate` uses no cached
|
||||||
|
data for evaluation, while `Model.fit` uses cached data when
|
||||||
|
`validation_data` arg is provided for better performance.
|
||||||
* `tf.function` / AutoGraph:
|
* `tf.function` / AutoGraph:
|
||||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||||
True, the function may use type annotations to optimize the tracing
|
True, the function may use type annotations to optimize the tracing
|
||||||
|
|||||||
@ -58,6 +58,7 @@ from tensorflow.python.keras.saving.saved_model import json_utils
|
|||||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils import layer_utils
|
from tensorflow.python.keras.utils import layer_utils
|
||||||
|
from tensorflow.python.keras.utils import tf_inspect
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.keras.utils import version_utils
|
from tensorflow.python.keras.utils import version_utils
|
||||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||||
@ -1099,6 +1100,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||||||
if validation_data and self._should_eval(epoch, validation_freq):
|
if validation_data and self._should_eval(epoch, validation_freq):
|
||||||
# Create data_handler for evaluation and cache it.
|
# Create data_handler for evaluation and cache it.
|
||||||
if getattr(self, '_eval_data_handler', None) is None:
|
if getattr(self, '_eval_data_handler', None) is None:
|
||||||
|
self._fit_frame = tf_inspect.currentframe()
|
||||||
self._eval_data_handler = data_adapter.DataHandler(
|
self._eval_data_handler = data_adapter.DataHandler(
|
||||||
x=val_x,
|
x=val_x,
|
||||||
y=val_y,
|
y=val_y,
|
||||||
@ -1134,6 +1136,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||||||
# If eval data_hanlder exists, delete it after all epochs are done.
|
# If eval data_hanlder exists, delete it after all epochs are done.
|
||||||
if getattr(self, '_eval_data_handler', None) is not None:
|
if getattr(self, '_eval_data_handler', None) is not None:
|
||||||
del self._eval_data_handler
|
del self._eval_data_handler
|
||||||
|
del self._fit_frame
|
||||||
callbacks.on_train_end(logs=training_logs)
|
callbacks.on_train_end(logs=training_logs)
|
||||||
return self.history
|
return self.history
|
||||||
|
|
||||||
@ -1327,7 +1330,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||||||
_disallow_inside_tf_function('evaluate')
|
_disallow_inside_tf_function('evaluate')
|
||||||
|
|
||||||
with self.distribute_strategy.scope():
|
with self.distribute_strategy.scope():
|
||||||
if getattr(self, '_eval_data_handler', None) is not None:
|
# Use cached evaluation data only when it's called in `Model.fit`
|
||||||
|
if (getattr(self, '_fit_frame', None) is not None
|
||||||
|
and tf_inspect.currentframe().f_back is self._fit_frame
|
||||||
|
and getattr(self, '_eval_data_handler', None) is not None):
|
||||||
data_handler = self._eval_data_handler
|
data_handler = self._eval_data_handler
|
||||||
else:
|
else:
|
||||||
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
||||||
|
|||||||
@ -3642,16 +3642,20 @@ class TestAutoUpdates(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
class TestFunctionTracing(keras_parameterized.TestCase):
|
class TestFunctionTracing(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
def _seq_model_and_data(self):
|
||||||
|
model = sequential.Sequential([layers_module.Dense(4, activation='relu')])
|
||||||
|
model.compile(loss='mse', optimizer='rmsprop')
|
||||||
|
x = np.random.random((10, 6))
|
||||||
|
y = np.random.random((10, 4))
|
||||||
|
return model, x, y
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(
|
@keras_parameterized.run_all_keras_modes(
|
||||||
always_skip_v1=True, always_skip_eager=True)
|
always_skip_v1=True, always_skip_eager=True)
|
||||||
def test_no_tracing_between_epoch(self):
|
def test_no_tracing_between_epoch(self):
|
||||||
if sys.version_info[0] < 3:
|
if sys.version_info[0] < 3:
|
||||||
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
||||||
|
|
||||||
model = sequential.Sequential([layers_module.Dense(4, activation='relu')])
|
model, x, y = self._seq_model_and_data()
|
||||||
model.compile(loss='mse', optimizer='rmsprop')
|
|
||||||
x = np.random.random((10, 6))
|
|
||||||
y = np.random.random((10, 4))
|
|
||||||
|
|
||||||
logging.set_verbosity(1)
|
logging.set_verbosity(1)
|
||||||
with self.assertLogs(level=1) as logs:
|
with self.assertLogs(level=1) as logs:
|
||||||
@ -3660,6 +3664,21 @@ class TestFunctionTracing(keras_parameterized.TestCase):
|
|||||||
new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function'
|
new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function'
|
||||||
self.assertEqual(sum(new_func_graph in log for log in logs.output), 9)
|
self.assertEqual(sum(new_func_graph in log for log in logs.output), 9)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes(
|
||||||
|
always_skip_v1=True, always_skip_eager=True)
|
||||||
|
def test_evaluate_no_cached_data(self):
|
||||||
|
if sys.version_info[0] < 3:
|
||||||
|
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
||||||
|
|
||||||
|
model, x, y = self._seq_model_and_data()
|
||||||
|
|
||||||
|
new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function'
|
||||||
|
logging.set_verbosity(1)
|
||||||
|
with self.assertLogs(level=1) as eval_logs:
|
||||||
|
for _ in range(6):
|
||||||
|
model.evaluate(x, y, batch_size=5)
|
||||||
|
self.assertEqual(sum(new_func_graph in log for log in eval_logs.output), 20)
|
||||||
|
|
||||||
|
|
||||||
class TestBuildCustomModel(keras_parameterized.TestCase):
|
class TestBuildCustomModel(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
|||||||
@ -90,6 +90,11 @@ else:
|
|||||||
return _convert_maybe_argspec_to_fullargspec(getargspec(target))
|
return _convert_maybe_argspec_to_fullargspec(getargspec(target))
|
||||||
|
|
||||||
|
|
||||||
|
def currentframe():
|
||||||
|
"""TFDecorator-aware replacement for inspect.currentframe."""
|
||||||
|
return _inspect.stack()[1][0]
|
||||||
|
|
||||||
|
|
||||||
def getargspec(obj):
|
def getargspec(obj):
|
||||||
"""TFDecorator-aware replacement for `inspect.getargspec`.
|
"""TFDecorator-aware replacement for `inspect.getargspec`.
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user