From 10666c59dd4858645d1b03ce01f4450da80710ec Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Wed, 19 Feb 2020 17:02:39 -0800 Subject: [PATCH] Keras ideal fit and compile. Kept all new abstractions private for now. In a few weeks, if we're comfortable that these abstractions are working and stable, we should expose many of them publicly. Capabilites added by this CL: (1) Easy to create a custom training step via overriding Model._train_step (2) Easy to create custom tf.function / DistStrat logic via overriding Model._make_train_function (3) Advanced users can override Model.compile and Model.fit (4) Full support for dicts, nested structures, etc with Subclassed Models. (5) "Power user" path (tf.data inputs) only modifies data in Model._train_step, where this behavior is easy to override and disable. This applies even to Keras's assumption that data is passed in (x, y, sample_weight) format. Behavior changes: (1) "loss" passed to Callbacks is now stateful (like all other metrics in Callbacks). This greatly simplifies the training step logic and callback logic. (2) ProgbarLogger always uses steps. If steps is not available, the ProgbarLogger handles inferring the steps after the first epoch. (3) validation_batch_size added in `fit`, rather than inferring from generator. (4) Model.inputs, Model.outputs, Model.input_names, and Model.output_names are no longer populated for subclassed Models. Instead, "pseudo" output names are created for subclassed Models, which are only used for metrics names and SavedModel's signature. (5) Cast NumPy floats to backend.floatx(), otherwise leave unchanged (this is likely not a change, we did something like this in our old version but the logic was scattered in many places) PiperOrigin-RevId: 296090972 Change-Id: Ia5ac833fd39085bddb016833bd338083d0dc5fc2 --- .../debug/lib/distributed_callbacks_test.py | 4 +- .../python/distribute/keras_save_load_test.py | 8 +- .../model_collection/simple_models.py | 6 +- .../distribute/saved_model_mixed_api_test.py | 8 +- .../distribute/saved_model_save_load_test.py | 16 +- .../distribute/saved_model_test_base.py | 18 +- tensorflow/python/eager/forwardprop.py | 4 +- tensorflow/python/eager/forwardprop_test.py | 2 +- tensorflow/python/eager/function.py | 3 +- tensorflow/python/keras/backend.py | 4 + tensorflow/python/keras/callbacks.py | 246 +- tensorflow/python/keras/callbacks_test.py | 107 +- .../distribute/distribute_strategy_test.py | 28 +- .../keras/distribute/keras_utils_test.py | 70 +- tensorflow/python/keras/engine/BUILD | 20 - tensorflow/python/keras/engine/base_layer.py | 51 +- .../python/keras/engine/base_layer_test.py | 44 +- .../python/keras/engine/compile_utils.py | 269 +- .../python/keras/engine/compile_utils_test.py | 65 +- .../python/keras/engine/data_adapter.py | 448 ++- .../python/keras/engine/data_adapter_test.py | 59 +- tensorflow/python/keras/engine/network.py | 68 +- tensorflow/python/keras/engine/sequential.py | 18 +- .../python/keras/engine/sequential_test.py | 64 +- tensorflow/python/keras/engine/training.py | 2712 ++++------------- .../python/keras/engine/training_arrays.py | 18 +- .../keras/engine/training_dataset_test.py | 43 +- .../keras/engine/training_eager_test.py | 9 +- .../python/keras/engine/training_generator.py | 18 +- .../keras/engine/training_generator_test.py | 38 +- .../python/keras/engine/training_test.py | 999 +----- tensorflow/python/keras/engine/training_v1.py | 69 +- tensorflow/python/keras/engine/training_v2.py | 778 ----- .../python/keras/engine/training_v2_utils.py | 556 ---- .../keras/engine/training_v2_utils_test.py | 160 - tensorflow/python/keras/layers/core.py | 19 +- tensorflow/python/keras/layers/merge.py | 20 +- .../python/keras/layers/normalization_test.py | 4 +- .../preprocessing/normalization_test.py | 32 +- .../python/keras/layers/wrappers_test.py | 47 +- tensorflow/python/keras/losses.py | 14 +- tensorflow/python/keras/metrics.py | 13 +- .../python/keras/metrics_correctness_test.py | 99 +- tensorflow/python/keras/models.py | 50 +- tensorflow/python/keras/models_test.py | 8 +- tensorflow/python/keras/premade/linear.py | 2 +- tensorflow/python/keras/premade/wide_deep.py | 56 +- .../python/keras/premade/wide_deep_test.py | 2 - .../python/keras/saving/hdf5_format_test.py | 26 +- .../keras/saving/losses_serialization_test.py | 16 +- .../saving/metrics_serialization_test.py | 11 - .../python/keras/saving/saved_model/load.py | 7 +- .../keras/saving/saved_model/revive_test.py | 26 +- .../keras/saving/saved_model/save_impl.py | 29 +- .../saving/saved_model/saved_model_test.py | 34 +- .../saving/saved_model_experimental_test.py | 21 +- .../python/keras/saving/saving_utils.py | 216 +- .../python/keras/saving/saving_utils_test.py | 58 +- tensorflow/python/keras/testing_utils.py | 3 + .../tests/model_subclassing_compiled_test.py | 2 - .../keras/tests/model_subclassing_test.py | 7 +- ...emporal_sample_weights_correctness_test.py | 45 +- .../utils/composite_tensor_support_test.py | 113 +- .../python/keras/utils/generic_utils.py | 34 +- tensorflow/python/keras/utils/layer_utils.py | 1 - tensorflow/python/keras/utils/tf_utils.py | 25 + .../python/keras/utils/tf_utils_test.py | 2 + tensorflow/python/layers/base.py | 2 +- .../golden/v1/tensorflow.keras.-model.pbtxt | 8 +- .../v1/tensorflow.keras.-sequential.pbtxt | 8 +- ...low.keras.experimental.-linear-model.pbtxt | 8 +- ....keras.experimental.-wide-deep-model.pbtxt | 8 +- .../v1/tensorflow.keras.models.-model.pbtxt | 8 +- .../tensorflow.keras.models.-sequential.pbtxt | 8 +- .../v1/tensorflow.keras.utils.-progbar.pbtxt | 2 +- .../golden/v2/tensorflow.keras.-model.pbtxt | 8 +- .../v2/tensorflow.keras.-sequential.pbtxt | 8 +- ...low.keras.experimental.-linear-model.pbtxt | 8 +- ....keras.experimental.-wide-deep-model.pbtxt | 8 +- .../v2/tensorflow.keras.models.-model.pbtxt | 8 +- .../tensorflow.keras.models.-sequential.pbtxt | 8 +- .../v2/tensorflow.keras.utils.-progbar.pbtxt | 2 +- 82 files changed, 2215 insertions(+), 5959 deletions(-) delete mode 100644 tensorflow/python/keras/engine/training_v2.py delete mode 100644 tensorflow/python/keras/engine/training_v2_utils.py delete mode 100644 tensorflow/python/keras/engine/training_v2_utils_test.py diff --git a/tensorflow/python/debug/lib/distributed_callbacks_test.py b/tensorflow/python/debug/lib/distributed_callbacks_test.py index 4b1eb3e498a..606f14b3230 100644 --- a/tensorflow/python/debug/lib/distributed_callbacks_test.py +++ b/tensorflow/python/debug/lib/distributed_callbacks_test.py @@ -195,6 +195,7 @@ class DistributedDumpingCallbackTest( self.assertAllClose(device_1_matmul_values[0], [[10.0]]) self.assertAllClose(device_1_bias_add_values[0], [[11.0]]) + # TODO(b/148461691): Fix for new Keras internals. @combinations.generate( combinations.combine( distribution=[ @@ -206,7 +207,8 @@ class DistributedDumpingCallbackTest( mode=["eager"], tensor_debug_mode=["NO_TENSOR", "FULL_TENSOR"], )) - def testKerasModelFitOnOneOrTwoDevices(self, distribution, tensor_debug_mode): + def DISABLED_testKerasModelFitOnOneOrTwoDevices(self, distribution, + tensor_debug_mode): writer = dumping_callback.enable_dump_debug_info( self.dump_root, tensor_debug_mode=tensor_debug_mode) diff --git a/tensorflow/python/distribute/keras_save_load_test.py b/tensorflow/python/distribute/keras_save_load_test.py index 494a348d050..6475406eb4b 100644 --- a/tensorflow/python/distribute/keras_save_load_test.py +++ b/tensorflow/python/distribute/keras_save_load_test.py @@ -33,8 +33,12 @@ class KerasSaveLoadTest(test_base.TestSavedModelBase): def _save_model(self, model, saved_dir): model.save(saved_dir, save_format='tf') - def _load_and_run_model(self, distribution, saved_dir, predict_dataset, - output_name, experimental_run_tf_function): + def _load_and_run_model(self, + distribution, + saved_dir, + predict_dataset, + experimental_run_tf_function, + output_name='output_1'): restored_keras_model = save.load_model(saved_dir) restored_keras_model._experimental_run_tf_function = ( experimental_run_tf_function) diff --git a/tensorflow/python/distribute/model_collection/simple_models.py b/tensorflow/python/distribute/model_collection/simple_models.py index 63a2bfcb520..ededb0a7f59 100644 --- a/tensorflow/python/distribute/model_collection/simple_models.py +++ b/tensorflow/python/distribute/model_collection/simple_models.py @@ -45,7 +45,7 @@ class SimpleFunctionalModel(model_collection_base.ModelAndInput): """A simple functional model and its inputs.""" def get_model(self, **kwargs): - output_name = 'output_layer' + output_name = 'output_1' x = keras.layers.Input(shape=(3,), dtype=dtypes.float32) y = keras.layers.Dense(5, dtype=dtypes.float32, name=output_name)(x) @@ -74,7 +74,7 @@ class SimpleSequentialModel(model_collection_base.ModelAndInput): """A simple sequential model and its inputs.""" def get_model(self, **kwargs): - output_name = 'output_layer' + output_name = 'output_1' model = keras.Sequential() y = keras.layers.Dense( @@ -106,7 +106,7 @@ class _SimpleModel(keras.Model): self._dense_layer = keras.layers.Dense(5, dtype=dtypes.float32) def call(self, inputs): - return {'output_layer': self._dense_layer(inputs)} + return self._dense_layer(inputs) class SimpleSubclassModel(model_collection_base.ModelAndInput): diff --git a/tensorflow/python/distribute/saved_model_mixed_api_test.py b/tensorflow/python/distribute/saved_model_mixed_api_test.py index 2b0e5e9e899..240f5f45f9f 100644 --- a/tensorflow/python/distribute/saved_model_mixed_api_test.py +++ b/tensorflow/python/distribute/saved_model_mixed_api_test.py @@ -41,8 +41,12 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase): def _save_model(self, model, saved_dir): keras_saved_model.export_saved_model(model, saved_dir, serving_only=True) - def _load_and_run_model(self, distribution, saved_dir, predict_dataset, - output_name, experimental_run_tf_function): + def _load_and_run_model(self, + distribution, + saved_dir, + predict_dataset, + experimental_run_tf_function, + output_name='output_1'): return test_base.load_and_run_with_saved_model_api(distribution, saved_dir, predict_dataset, output_name) diff --git a/tensorflow/python/distribute/saved_model_save_load_test.py b/tensorflow/python/distribute/saved_model_save_load_test.py index 5380d6f9d1f..10dae8065bb 100644 --- a/tensorflow/python/distribute/saved_model_save_load_test.py +++ b/tensorflow/python/distribute/saved_model_save_load_test.py @@ -35,8 +35,12 @@ class SavedModelKerasModelTest(test_base.TestSavedModelBase): def _save_model(self, model, saved_dir): saved_model.save(model, saved_dir) - def _load_and_run_model(self, distribution, saved_dir, predict_dataset, - output_name, experimental_run_tf_function): + def _load_and_run_model(self, + distribution, + saved_dir, + predict_dataset, + experimental_run_tf_function, + output_name='output_1'): return test_base.load_and_run_with_saved_model_api(distribution, saved_dir, predict_dataset, output_name) @@ -100,8 +104,12 @@ class SavedModelTFModuleTest(test_base.TestSavedModelBase): call = model.__call__.get_concrete_function(tensor_spec.TensorSpec(None)) saved_model.save(model, saved_dir, signatures=call) - def _load_and_run_model(self, distribution, saved_dir, predict_dataset, - output_name, experimental_run_tf_function): + def _load_and_run_model(self, + distribution, + saved_dir, + predict_dataset, + experimental_run_tf_function, + output_name='output_1'): del output_name, experimental_run_tf_function model = saved_model.load(saved_dir) return self._predict_with_model(distribution, model, predict_dataset) diff --git a/tensorflow/python/distribute/saved_model_test_base.py b/tensorflow/python/distribute/saved_model_test_base.py index 832bb4f1dbd..5d3511c6cde 100644 --- a/tensorflow/python/distribute/saved_model_test_base.py +++ b/tensorflow/python/distribute/saved_model_test_base.py @@ -150,8 +150,12 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase): """ raise NotImplementedError('must be implemented in descendants') - def _load_and_run_model(self, distribution, saved_dir, predict_dataset, - output_name, experimental_run_tf_function): + def _load_and_run_model(self, + distribution, + saved_dir, + predict_dataset, + experimental_run_tf_function, + output_name='output_1'): """Load the model and run 1 step of predict with it. This method must be implemented by the subclasses. @@ -162,10 +166,10 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase): saved_dir: the string representing the path where the model is saved. predict_dataset: the data used to do the predict on the model for cross_replica context. - output_name: the string representing the name of the output layer of the - model. experimental_run_tf_function: Whether to use the single execution path for models. + output_name: the string representing the name of the output layer of the + model. """ raise NotImplementedError('must be implemented in descendants') @@ -211,10 +215,6 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase): distribution=distribution, saved_dir=saved_dir, predict_dataset=predict_dataset, - # Note that subclassed model's output names aren't defined until after - # the model is built (in these tests, this occurs when the model is - # trained). - output_name=getattr(model, 'output_names', [None])[0], experimental_run_tf_function=experimental_run_tf_function) tolerance = get_tolerance(None, distribution) @@ -248,7 +248,6 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase): distribution=None, saved_dir=saved_dir, predict_dataset=predict_dataset, - output_name=getattr(model, 'output_names', [None])[0], experimental_run_tf_function=experimental_run_tf_function) tolerance = get_tolerance(distribution, None) @@ -285,7 +284,6 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase): distribution=distribution_for_restoring, saved_dir=saved_dir, predict_dataset=predict_dataset, - output_name=getattr(model, 'output_names', [None])[0], experimental_run_tf_function=experimental_run_tf_function) tolerance = get_tolerance(distribution_for_saving, diff --git a/tensorflow/python/eager/forwardprop.py b/tensorflow/python/eager/forwardprop.py index 973e130ef0f..0bb1e89e4a3 100644 --- a/tensorflow/python/eager/forwardprop.py +++ b/tensorflow/python/eager/forwardprop.py @@ -186,7 +186,7 @@ class ForwardAccumulator(object): >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]]) >>> dense = tf.keras.layers.Dense(1) - >>> dense.build([2]) + >>> dense.build([None, 2]) >>> with tf.autodiff.ForwardAccumulator( ... primals=dense.kernel, ... tangents=tf.constant([[1.], [0.]])) as acc: @@ -210,7 +210,7 @@ class ForwardAccumulator(object): >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]]) >>> dense = tf.keras.layers.Dense(1) - >>> dense.build([2]) + >>> dense.build([None, 2]) >>> loss_fn = lambda: tf.reduce_sum((dense(x) - tf.constant([1., -1.])) ** 2.) >>> kernel_fprop = [] >>> with tf.autodiff.ForwardAccumulator( diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index 79c0714c720..fed04aec270 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -1067,7 +1067,7 @@ class HessianTests(test.TestCase, parameterized.TestCase): ("MapFn", False)]) def testHessianOfVariables(self, use_pfor): model = core.Dense(1) - model.build([2]) + model.build([None, 2]) def _loss(*unused_args): input_value = constant_op.constant([[-0.5, 1.], [0.5, -1.]]) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 76e036da74e..895a5de7765 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2271,7 +2271,8 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): flatten_inputs = nest.flatten_up_to( input_signature, inputs[:len(input_signature)], - expand_composites=True) + expand_composites=True, + check_types=False) # lists are convert to tuples for `tf.data`. except ValueError: raise ValueError("Structure of Python function inputs does not match " "input_signature:\n%s" % diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 81323613231..50856e1f173 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -4347,6 +4347,10 @@ def in_train_phase(x, alt, training=None): Either `x` or `alt` based on the `training` flag. the `training` flag defaults to `K.learning_phase()`. """ + from tensorflow.python.keras.engine import base_layer_utils # pylint: disable=g-import-not-at-top + if training is None: + training = base_layer_utils.call_context().training + if training is None: training = learning_phase() diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 6fd3e0e902d..5fae5eb9218 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -49,6 +49,7 @@ from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management +from tensorflow.python.util import nest from tensorflow.python.util.compat import collections_abc from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls @@ -187,26 +188,67 @@ def make_logs(model, logs, outputs, mode, prefix=''): class CallbackList(object): - """Container abstracting a list of callbacks. + """Container abstracting a list of callbacks.""" - Arguments: + def __init__(self, + callbacks=None, + add_history=False, + add_progbar=False, + model=None, + **params): + """Creates a container for `Callbacks`. + + Arguments: callbacks: List of `Callback` instances. - queue_length: Queue length for keeping - running statistics over callback execution time. - """ + add_history: Whether a `History` callback should be added, if one does not + already exist in `callback`s. + add_progbar: Whether a `ProgbarLogger` callback should be added, if one + does not already exist in `callback`s. + model: The `Model` these `Callback`s are used with.` + **params: If provided, parameters will be passed to each `Callback` via + `Callback.set_params`. + """ + self.callbacks = nest.flatten(callbacks) if callbacks else [] + self._add_default_callbacks(add_history, add_progbar) - def __init__(self, callbacks=None, queue_length=10): - callbacks = callbacks or [] - self.callbacks = [c for c in callbacks] - self.queue_length = queue_length - self.params = {} - self.model = None + if model: + self.set_model(model) + if params: + self.set_params(params) + + self._queue_length = 10 self._reset_batch_timing() + def _add_default_callbacks(self, add_history, add_progbar): + """Adds `Callback`s that are always present.""" + self._progbar = None + self._history = None + + for cb in self.callbacks: + if isinstance(cb, ProgbarLogger): + self._progbar = cb + elif isinstance(cb, History): + self._history = cb + + if self._progbar is None and add_progbar: + self._progbar = ProgbarLogger(count_mode='steps') + self.callbacks.append(self._progbar) + + if self._history is None and add_history: + self._history = History() + self.callbacks.append(self._history) + def _reset_batch_timing(self): self._delta_t_batch = 0. self._delta_ts = collections.defaultdict( - lambda: collections.deque([], maxlen=self.queue_length)) + lambda: collections.deque([], maxlen=self._queue_length)) + + def _process_logs(self, logs): + if logs: + return { + k: v.numpy() if hasattr(v, 'numpy') else v for k, v in logs.items() + } + return {} def append(self, callback): self.callbacks.append(callback) @@ -218,6 +260,8 @@ class CallbackList(object): def set_model(self, model): self.model = model + if self._history: + model.history = self._history for callback in self.callbacks: callback.set_model(model) @@ -266,9 +310,11 @@ class CallbackList(object): self.on_predict_end() def on_batch_begin(self, batch, logs=None): + logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) def on_batch_end(self, batch, logs=None): + logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) def on_epoch_begin(self, epoch, logs=None): @@ -281,7 +327,7 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ - logs = logs or {} + logs = self._process_logs(logs) for callback in self.callbacks: callback.on_epoch_begin(epoch, logs) self._reset_batch_timing() @@ -297,7 +343,7 @@ class CallbackList(object): validation epoch if validation is performed. Validation result keys are prefixed with `val_`. """ - logs = logs or {} + logs = self._process_logs(logs) for callback in self.callbacks: callback.on_epoch_end(epoch, logs) @@ -309,6 +355,7 @@ class CallbackList(object): logs: dict. Has keys `batch` and `size` representing the current batch number and the size of the batch. """ + logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) def on_train_batch_end(self, batch, logs=None): @@ -318,6 +365,7 @@ class CallbackList(object): batch: integer, index of batch within the current epoch. logs: dict. Metric results for this batch. """ + logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) def on_test_batch_begin(self, batch, logs=None): @@ -328,6 +376,7 @@ class CallbackList(object): logs: dict. Has keys `batch` and `size` representing the current batch number and the size of the batch. """ + logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs) def on_test_batch_end(self, batch, logs=None): @@ -347,6 +396,7 @@ class CallbackList(object): logs: dict. Has keys `batch` and `size` representing the current batch number and the size of the batch. """ + logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs) def on_predict_batch_end(self, batch, logs=None): @@ -356,6 +406,7 @@ class CallbackList(object): batch: integer, index of batch within the current epoch. logs: dict. Metric results for this batch. """ + logs = self._process_logs(logs) self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs) def on_train_begin(self, logs=None): @@ -365,6 +416,7 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ + logs = self._process_logs(logs) for callback in self.callbacks: callback.on_train_begin(logs) @@ -375,6 +427,7 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ + logs = self._process_logs(logs) for callback in self.callbacks: callback.on_train_end(logs) @@ -385,6 +438,7 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ + logs = self._process_logs(logs) for callback in self.callbacks: callback.on_test_begin(logs) @@ -395,6 +449,7 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ + logs = self._process_logs(logs) for callback in self.callbacks: callback.on_test_end(logs) @@ -405,6 +460,7 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ + logs = self._process_logs(logs) for callback in self.callbacks: callback.on_predict_begin(logs) @@ -415,6 +471,7 @@ class CallbackList(object): logs: dict. Currently no data is passed to this argument for this method but that may change in the future. """ + logs = self._process_logs(logs) for callback in self.callbacks: callback.on_predict_end(logs) @@ -721,6 +778,7 @@ class ProgbarLogger(Callback): should *not* be averaged over an epoch. Metrics in this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). + If not provided, defaults to the `Model`'s metrics. Raises: ValueError: In case of invalid `count_mode`. @@ -734,59 +792,96 @@ class ProgbarLogger(Callback): self.use_steps = True else: raise ValueError('Unknown `count_mode`: ' + str(count_mode)) - self.stateful_metrics = set(stateful_metrics or []) - self.log_values = None + # Defaults to all Model's metrics except for loss. + self.stateful_metrics = set(stateful_metrics) if stateful_metrics else None + + self.seen = 0 + self.progbar = None + self.target = None + self.verbose = 1 + self.epochs = 1 + + self._called_in_fit = False + + def set_params(self, params): + self.verbose = params['verbose'] + self.epochs = params['epochs'] + if self.use_steps and 'steps' in params: + self.target = params['steps'] + elif not self.use_steps and 'samples' in params: + self.target = params['samples'] + else: + self.target = None # Will be inferred at the end of the first epoch. def on_train_begin(self, logs=None): - self.verbose = self.params['verbose'] - self.epochs = self.params['epochs'] + # When this logger is called inside `fit`, validation is silent. + self._called_in_fit = True + + def on_test_begin(self, logs=None): + if not self._called_in_fit: + self._reset_progbar() + + def on_predict_begin(self, logs=None): + self._reset_progbar() def on_epoch_begin(self, epoch, logs=None): - self.seen = 0 - if self.use_steps: - self.target = self.params['steps'] - else: - self.target = self.params['samples'] + self._reset_progbar() + if self.verbose and self.epochs > 1: + print('Epoch %d/%d' % (epoch + 1, self.epochs)) - if self.verbose: - if self.epochs > 1: - print('Epoch %d/%d' % (epoch + 1, self.epochs)) - self.progbar = Progbar( - target=self.target, - verbose=self.verbose, - stateful_metrics=self.stateful_metrics, - unit_name='step' if self.use_steps else 'sample') + def on_train_batch_end(self, batch, logs=None): + self._batch_update_progbar(logs) - def on_batch_begin(self, batch, logs=None): - self.log_values = [] + def on_test_batch_end(self, batch, logs=None): + if not self._called_in_fit: + self._batch_update_progbar(logs) - def on_batch_end(self, batch, logs=None): - logs = logs or {} - batch_size = logs.get('size', 0) - # In case of distribution strategy we can potentially run multiple steps - # at the same time, we should account for that in the `seen` calculation. - num_steps = logs.get('num_steps', 1) - if self.use_steps: - self.seen += num_steps - else: - self.seen += batch_size * num_steps - - for k in self.params['metrics']: - if k in logs: - self.log_values.append((k, logs[k])) - - # Skip progbar update for the last batch; - # will be handled by on_epoch_end. - if self.verbose and (self.target is None or self.seen < self.target): - self.progbar.update(self.seen, self.log_values) + def on_predict_batch_end(self, batch, logs=None): + self._batch_update_progbar(None) # Don't pass prediction results. def on_epoch_end(self, epoch, logs=None): + self._finalize_progbar(logs) + + def on_test_end(self, logs=None): + if not self._called_in_fit: + self._finalize_progbar(logs) + + def on_predict_end(self, logs=None): + self._finalize_progbar(logs) + + def _reset_progbar(self): + self.seen = 0 + self.progbar = None + + def _batch_update_progbar(self, logs=None): + """Updates the progbar.""" + if self.stateful_metrics is None: + if self.model: + self.stateful_metrics = (set(m.name for m in self.model.metrics)) + else: + self.stateful_metrics = set() + + if self.progbar is None: + self.progbar = Progbar( + target=self.target, + verbose=self.verbose, + stateful_metrics=self.stateful_metrics, + unit_name='step' if self.use_steps else 'sample') + + logs = copy.copy(logs) if logs else {} + batch_size = logs.pop('size', 0) + num_steps = logs.pop('num_steps', 1) # DistStrat can run >1 steps. + logs.pop('batch', None) + add_seen = num_steps if self.use_steps else num_steps * batch_size + self.seen += add_seen + self.progbar.update(self.seen, list(logs.items()), finalize=False) + + def _finalize_progbar(self, logs): + if self.target is None: + self.target = self.seen + self.progbar.target = self.seen logs = logs or {} - for k in self.params['metrics']: - if k in logs: - self.log_values.append((k, logs[k])) - if self.verbose: - self.progbar.update(self.seen, self.log_values) + self.progbar.update(self.seen, list(logs.items()), finalize=True) @keras_export('keras.callbacks.History') @@ -826,7 +921,7 @@ class ModelCheckpoint(Callback): - Definition of 'best'; which quantity to monitor and whether it should be maximized or minimized. - The frequency it should save at. Currently, the callback supports saving at - the end of every epoch, or after a fixed number of training samples. + the end of every epoch, or after a fixed number of training batches. - Whether only weights are saved, or the whole model is saved. Example: @@ -873,11 +968,10 @@ class ModelCheckpoint(Callback): (`model.save(filepath)`). save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves the model after each epoch. When using integer, the callback saves the - model at end of a batch at which this many samples have been seen since - last saving. Note that if the saving isn't aligned to epochs, the - monitored metric may potentially be less reliable (it could reflect as - little as 1 batch, since the metrics get reset every epoch). Defaults to - `'epoch'` + model at end of this many batches. Note that if the saving isn't aligned + to epochs, the monitored metric may potentially be less reliable (it + could reflect as little as 1 batch, since the metrics get reset every + epoch). Defaults to `'epoch'` **kwargs: Additional arguments for backwards compatibility. Possible key is `period`. """ @@ -899,7 +993,7 @@ class ModelCheckpoint(Callback): self.save_weights_only = save_weights_only self.save_freq = save_freq self.epochs_since_last_save = 0 - self._samples_seen_since_last_saving = 0 + self._batches_seen_since_last_saving = 0 # Deprecated field `load_weights_on_restart` is for loading the checkpoint # file from `filepath` at the start of `model.fit()` @@ -917,7 +1011,7 @@ class ModelCheckpoint(Callback): if 'period' in kwargs: self.period = kwargs['period'] logging.warning('`period` argument is deprecated. Please use `save_freq` ' - 'to specify the frequency in number of samples seen.') + 'to specify the frequency in number of batches seen.') else: self.period = 1 @@ -1000,15 +1094,15 @@ class ModelCheckpoint(Callback): # Restore the training state so the model is ready for next (possible) # multi worker training. del self._training_state - del self.model._training_state + self.model._training_state = None def on_batch_end(self, batch, logs=None): logs = logs or {} if isinstance(self.save_freq, int): - self._samples_seen_since_last_saving += logs.get('size', 1) - if self._samples_seen_since_last_saving >= self.save_freq: + self._batches_seen_since_last_saving += 1 + if self._batches_seen_since_last_saving >= self.save_freq: self._save_model(epoch=self._current_epoch, logs=logs) - self._samples_seen_since_last_saving = 0 + self._batches_seen_since_last_saving = 0 def on_epoch_begin(self, epoch, logs=None): self._current_epoch = epoch @@ -1228,16 +1322,10 @@ class EarlyStopping(Callback): >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), - ... epochs=10, callbacks=[callback]) - Train on 5 samples - Epoch 1/10 - 5/5 [==============================] - ... loss: 6533.1904 - Epoch 2/10 - 5/5 [==============================] - ... loss: 110183360.0000 - Epoch 3/10 - 5/5 [==============================] - ... loss: 1862575718400.0000 - Epoch 4/10 - 5/5 [==============================] - ... loss: 31485597793124352.0000 + ... epochs=10, batch_size=1, callbacks=[callback], + ... verbose=0) + >>> len(history.history['loss']) # Only 4 epochs are run. + 4 """ def __init__(self, diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 6e5066e19ed..bf6d8cda6f2 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -35,6 +35,7 @@ import numpy as np from tensorflow.core.framework import summary_pb2 from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import random_seed from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils @@ -146,9 +147,10 @@ class CallbackCountsTest(keras_parameterized.TestCase): @parameterized.named_parameters(('with_numpy', _get_numpy()), ('with_sequence', _get_sequence())) def test_callback_hooks_are_called_in_fit(self, data): + if not context.executing_eagerly(): + self.skipTest('Behavior changed in v2.') x, y = data val_x, val_y = np.ones((4, 10)), np.ones((4, 1)) - is_sequence = isinstance(x, keras.utils.data_utils.Sequence) model = self._get_model() counter = Counter() @@ -156,8 +158,8 @@ class CallbackCountsTest(keras_parameterized.TestCase): x, y, validation_data=(val_x, val_y), - batch_size=2 if not is_sequence else None, - steps_per_epoch=5 if is_sequence else None, + batch_size=2, + steps_per_epoch=5, epochs=5, callbacks=[counter]) @@ -264,8 +266,8 @@ class KerasCallbacksTest(keras_parameterized.TestCase): def test_progbar_logging(self): model = self._get_model(input_shape=(3,)) - x = array_ops.ones((50, 3)) - y = array_ops.zeros((50, 2)) + x = array_ops.ones((200, 3)) + y = array_ops.zeros((200, 2)) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10) expected_log = r'(.*- loss:.*- my_acc:.*)+' @@ -279,8 +281,8 @@ class KerasCallbacksTest(keras_parameterized.TestCase): model = self._get_model() self.assertFalse(model.built) - x = array_ops.ones((50, 3)) - y = array_ops.zeros((50, 2)) + x = array_ops.ones((200, 3)) + y = array_ops.zeros((200, 2)) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10) expected_log = r'(.*- loss:.*- my_acc:.*)+' @@ -304,15 +306,15 @@ class KerasCallbacksTest(keras_parameterized.TestCase): self.assertRegexpMatches(printed.contents(), expected_log) @keras_parameterized.run_with_all_model_types - @keras_parameterized.run_all_keras_modes + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_progbar_logging_validation_split(self): model = self._get_model(input_shape=(3,)) x = np.ones((100, 3)) y = np.zeros((100, 2)) expected_log = ( - r'(?s).*1/2.*80/80.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:' - r'.*2/2.*80/80.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*') + r'(?s).*1/2.*8/8.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:' + r'.*2/2.*8/8.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*') with self.captureWritesToStream(sys.stdout) as printed: model.fit(x, y, batch_size=10, epochs=2, validation_split=0.2) @@ -587,7 +589,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase): monitor=monitor, save_best_only=save_best_only, mode=mode, - save_freq=30, + save_freq=15, period=100) # The period should be ignored (this test tests this). ] assert not os.path.exists(filepath.format(epoch=3)) @@ -638,8 +640,8 @@ class KerasCallbacksTest(keras_parameterized.TestCase): def get_input_datasets(): # Simple training input. - train_input = [[1]] * 16 - train_label = [[0]] * 16 + train_input = [[1.]] * 16 + train_label = [[0.]] * 16 ds = dataset_ops.Dataset.from_tensor_slices((train_input, train_label)) return ds.batch(8, drop_remainder=True) @@ -1268,40 +1270,40 @@ class KerasCallbacksTest(keras_parameterized.TestCase): values.append(x) assert 'nan' in values[-1], 'The last epoch was not logged.' + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_TerminateOnNaN(self): - with self.cached_session(): - np.random.seed(1337) - (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( - train_samples=TRAIN_SAMPLES, - test_samples=TEST_SAMPLES, - input_shape=(INPUT_DIM,), - num_classes=NUM_CLASSES) + np.random.seed(1337) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=TRAIN_SAMPLES, + test_samples=TEST_SAMPLES, + input_shape=(INPUT_DIM,), + num_classes=NUM_CLASSES) - y_test = np_utils.to_categorical(y_test) - y_train = np_utils.to_categorical(y_train) - cbks = [keras.callbacks.TerminateOnNaN()] - model = keras.models.Sequential() - initializer = keras.initializers.Constant(value=1e5) - for _ in range(5): - model.add( - keras.layers.Dense( - 2, - input_dim=INPUT_DIM, - activation='relu', - kernel_initializer=initializer)) - model.add(keras.layers.Dense(NUM_CLASSES)) - model.compile(loss='mean_squared_error', optimizer='rmsprop') + y_test = np_utils.to_categorical(y_test) + y_train = np_utils.to_categorical(y_train) + cbks = [keras.callbacks.TerminateOnNaN()] + model = keras.models.Sequential() + initializer = keras.initializers.Constant(value=1e5) + for _ in range(5): + model.add( + keras.layers.Dense( + 2, + input_dim=INPUT_DIM, + activation='relu', + kernel_initializer=initializer)) + model.add(keras.layers.Dense(NUM_CLASSES)) + model.compile(loss='mean_squared_error', optimizer='rmsprop') - history = model.fit( - x_train, - y_train, - batch_size=BATCH_SIZE, - validation_data=(x_test, y_test), - callbacks=cbks, - epochs=20) - loss = history.history['loss'] - self.assertEqual(len(loss), 1) - self.assertEqual(loss[0], np.inf) + history = model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=20) + loss = history.history['loss'] + self.assertEqual(len(loss), 1) + self.assertTrue(np.isnan(loss[0])) @unittest.skipIf( os.name == 'nt', @@ -1406,14 +1408,17 @@ class KerasCallbacksTest(keras_parameterized.TestCase): callbacks=cbks, epochs=1) - def test_callback_params_samples(self): - x, y = np.ones((64, 3)), np.ones((64, 2)) - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=2, input_dim=3) + def test_progbar_infers_steps(self): + x, y = np.ones((10, 1)), np.ones((10, 1)) + data = dataset_ops.DatasetV2.from_tensor_slices((x, y)).batch(2) + data = data.filter(lambda x, y: True) # Unknown cardinality. + + progbar = keras.callbacks.ProgbarLogger('steps') + model = keras.Sequential([keras.layers.Dense(1)]) model.compile('sgd', 'mse') - callback = keras.callbacks.Callback() - model.evaluate(x, y, callbacks=[callback]) - self.assertEqual(callback.params['samples'], 64) + self.assertIsNone(progbar.target) + model.fit(data, epochs=2, callbacks=[progbar]) + self.assertEqual(progbar.target, 5) # A summary that was emitted during a test. Fields: diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 16f69a4410f..81609d7092c 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -950,10 +950,16 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer='adam', experimental_run_tf_function=experimental_run_tf_function) - def map_fn(img, lbl, weight): - inputs = {'img': img, 'lbl': lbl, 'weight': weight} - targets = {} - return inputs, targets + if context.executing_eagerly(): + + def map_fn(img, lbl, weight): + inputs = {'img': img, 'lbl': lbl, 'weight': weight} + return (inputs,) + else: + + def map_fn(img, lbl, weight): + inputs = {'img': img, 'lbl': lbl, 'weight': weight} + return inputs, {} fake_imgs = np.ones([50, 64, 64, 3], dtype=np.float32) fake_lbls = np.ones([50, 64, 64, 1], dtype=np.float32) @@ -1178,7 +1184,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, dataset = dataset.repeat(100) dataset = dataset.batch(10) - with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): + with self.assertRaisesRegexp(ValueError, 'incompatible with the layer'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) @combinations.generate( @@ -1776,7 +1782,9 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, experimental_run_tf_function=experimental_run_tf_function) ds_history = ds_model.fit( x, y, validation_data=(x, y), validation_steps=2, epochs=2) - self.assertLen(ds_model.metrics, 1) + # includes stateful loss metric in eager. + metrics_len = 2 if context.executing_eagerly() else 1 + self.assertLen(ds_model.metrics, metrics_len) self.assertAllClose(history.history, ds_history.history) @@ -1830,7 +1838,9 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, experimental_run_tf_function=experimental_run_tf_function) ds_history = ds_model.fit( x, y, validation_data=(x, y), validation_steps=2, epochs=2) - self.assertLen(ds_model.metrics, 1) + # includes stateful loss metric in eager. + metrics_len = 2 if context.executing_eagerly() else 1 + self.assertLen(ds_model.metrics, metrics_len) self.assertAllClose(history.history, ds_history.history) @@ -1870,7 +1880,9 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, experimental_run_tf_function=experimental_run_tf_function) ds_history = ds_model.fit( x, y, validation_data=(x, y), validation_steps=2, epochs=2) - self.assertLen(ds_model.metrics, 1) + # includes stateful loss metric in eager. + metrics_len = 2 if context.executing_eagerly() else 1 + self.assertLen(ds_model.metrics, metrics_len) self.assertAllClose(history.history, ds_history.history) diff --git a/tensorflow/python/keras/distribute/keras_utils_test.py b/tensorflow/python/keras/distribute/keras_utils_test.py index 2454b9cdee6..20a4f98d881 100644 --- a/tensorflow/python/keras/distribute/keras_utils_test.py +++ b/tensorflow/python/keras/distribute/keras_utils_test.py @@ -257,11 +257,8 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): experimental_run_tf_function=experimental_run_tf_function) dataset = keras_test_lib.get_dataset(distribution) - exception_error_message = ( - '`validation_split` argument is not supported when ') - # Test with validation split - with self.assertRaisesRegexp(ValueError, exception_error_message): + with self.assertRaises(ValueError): model.fit( dataset, epochs=1, @@ -272,9 +269,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): # Test with sample weight. sample_weight = np.random.random((10,)) - with self.assertRaisesRegexp( - ValueError, '`sample_weight` argument is not supported when.*' - 'dataset'): + with self.assertRaises(ValueError): model.fit( dataset, epochs=1, @@ -285,69 +280,14 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): # Test with not specifying the `steps` argument for dataset with infinite # cardinality. dataset = dataset.repeat() - with self.assertRaisesRegexp( - ValueError, 'When passing an infinitely ' - 'repeating dataset, you must specify the ' - '`steps_per_epoch` argument'): + with self.assertRaises(ValueError): model.fit(dataset, epochs=1, verbose=0) - with self.assertRaisesRegexp( - ValueError, 'When passing an infinitely ' - 'repeating dataset, you must specify the ' - '`steps` argument'): + with self.assertRaises(ValueError): model.evaluate(dataset, verbose=0) - with self.assertRaisesRegexp( - ValueError, 'When passing an infinitely ' - 'repeating dataset, you must specify the ' - '`steps` argument'): + with self.assertRaises(ValueError): model.predict(dataset, verbose=0) - @combinations.generate( - combinations.combine( - distribution=[ - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - ], - mode=['graph', 'eager'], - experimental_run_tf_function=[True, False])) - def test_calling_with_unsupported_predefined_callbacks( - self, distribution, experimental_run_tf_function): - with self.cached_session(): - with distribution.scope(): - model = keras_test_lib.get_model() - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile( - optimizer, - loss, - metrics=metrics, - experimental_run_tf_function=experimental_run_tf_function) - - dataset = keras_test_lib.get_dataset(distribution) - - def schedule(_): - return 0.001 - - with self.assertRaisesRegexp( - ValueError, 'You must specify a Keras Optimizer V2 when ' - 'using'): - model.fit( - dataset, - epochs=1, - steps_per_epoch=2, - verbose=0, - callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) - - with self.assertRaisesRegexp( - ValueError, 'You must specify a Keras Optimizer V2 when ' - 'using'): - model.fit( - dataset, - epochs=1, - steps_per_epoch=2, - verbose=0, - callbacks=[keras.callbacks.ReduceLROnPlateau()]) - @combinations.generate( combinations.combine( distribution=[ diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 3ecc31905ba..47765190ff6 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -29,8 +29,6 @@ py_library( "training_generator.py", "training_utils.py", "training_v1.py", - "training_v2.py", - "training_v2_utils.py", ], srcs_version = "PY2AND3", deps = [ @@ -428,24 +426,6 @@ tf_py_test( ], ) -tf_py_test( - name = "training_v2_utils_test", - size = "medium", - srcs = ["training_v2_utils_test.py"], - python_version = "PY3", - tags = [ - "no_oss", # TODO(b/135021748) reenable - "notsan", - ], - deps = [ - "//tensorflow/python:client_testlib", - "//tensorflow/python/distribute:strategy_combinations", - "//tensorflow/python/keras", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - tf_py_test( name = "network_test", size = "medium", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 24d3432fb8e..c097398d90d 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -22,6 +22,7 @@ import collections import functools import itertools import threading +import weakref import numpy as np import six @@ -230,6 +231,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # A list of metric instances corresponding to the symbolic metric tensors # added using the `add_metric` API. self._metrics = [] + # Ensures the same metric is not added multiple times in `MirroredStrategy`. + self._metrics_lock = threading.Lock() # Both graph and subclassed networks have a dtype policy. For graph # networks, the policy's compute and variable dtypes are ignored, but other @@ -849,10 +852,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): if hasattr(self, '_set_inputs') and not self.inputs: # Subclassed network: explicitly set metadata normally set by # a call to self._set_inputs(). - # TODO(b/120997007): This should be done in Eager as well, but - # causes garbage collection issues because of the placeholders - # created on the default Keras graph. - self._set_inputs(inputs, outputs) + self._set_inputs(cast_inputs, outputs) else: # Eager execution on data tensors. with backend.name_scope(self._name_scope()): @@ -863,6 +863,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): outputs = self.call(cast_inputs, *args, **kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) + if hasattr(self, '_set_save_spec'): + self._set_save_spec(cast_inputs) return outputs @@ -1146,7 +1148,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): collected_metrics = [] all_layers = self._gather_unique_layers() for layer in all_layers: - collected_metrics.extend(layer._metrics) + with layer._metrics_lock: + collected_metrics.extend(layer._metrics) return collected_metrics @doc_controls.for_subclass_implementers @@ -1938,20 +1941,29 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # on it, otherwise we create a new metric instance and # add it to the `metrics` list. metric_obj = getattr(value, '_metric_obj', None) - if metric_obj: - name = metric_obj.name + # Tensors that come from a Metric object already updated the Metric state. + should_update_state = not metric_obj + name = metric_obj.name if metric_obj else name - match = self._get_existing_metric(name) - if match: - # Tensors that come from a Metric object already updated the Metric state. - if not metric_obj: - match(value) - return + with self._metrics_lock: + match = self._get_existing_metric(name) + if match: + metric_obj = match + elif metric_obj: + self._metrics.append(metric_obj) + else: + from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top + if aggregation is None: + raise ValueError( + '`aggregation` must be specified when passing a `Tensor` ' + 'to `add_metric`.') + assert aggregation is not None + metric_obj = metrics_mod.Mean(name=name, dtype=value.dtype) + self._metrics.append(metric_obj) - if not metric_obj: - assert aggregation is not None - metric_obj, _ = base_layer_utils.create_mean_metric(value, name) - self._metrics.append(metric_obj) + if should_update_state: + metric_obj(value) + return def _symbolic_add_metric(self, value, aggregation=None, name=None): base_layer_utils.check_graph_consistency(value, method='add_metric') @@ -2259,7 +2271,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): layers = trackable_layer_utils.filter_empty_layer_containers(self._layers) # Keep track of each top-level layers' `trainable` as well as the # state of all of its sublayers. - trainable_state = {self: self.trainable} + trainable_state = weakref.WeakKeyDictionary() + trainable_state[self] = self.trainable for layer in layers: trainable_state.update(layer._get_trainable_state()) return trainable_state @@ -2565,10 +2578,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # so shouldn't be copied. state = self.__dict__.copy() state.pop('_thread_local', None) + state.pop('_metrics_lock', None) return state def __setstate__(self, state): state['_thread_local'] = threading.local() + state['_metrics_lock'] = threading.Lock() # Bypass Trackable logic as `__dict__` already contains this info. object.__setattr__(self, '__dict__', state) diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 5e07f77265e..86b0689d026 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -187,7 +187,7 @@ class BaseLayerTest(keras_parameterized.TestCase): model.compile(rmsprop.RMSprop(0.001), loss='mse') self.assertEqual(model.run_eagerly, True) model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) - self.assertEqual(model.outputs, [None]) + self.assertEqual(model.outputs, None) def test_dynamic_subclassed_model_with_shape_inference(self): @@ -210,8 +210,10 @@ class BaseLayerTest(keras_parameterized.TestCase): model = MyModel() self.assertEqual(model.dynamic, True) model.compile(rmsprop.RMSprop(0.001), loss='mse') - model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) - self.assertEqual(model.outputs[0].shape.as_list(), [None, 3]) + x, y = np.random.random((2, 3)), np.random.random((2, 3)) + model.train_on_batch(x, y) + outputs = model(x) + self.assertEqual(outputs.shape.as_list(), [2, 3]) def test_deepcopy(self): with context.eager_mode(): @@ -331,42 +333,6 @@ class BaseLayerTest(keras_parameterized.TestCase): keras.backend.set_learning_phase(0) self.assertEqual(get_learning_phase_value(), 0) - @keras_parameterized.run_all_keras_modes - def test_learning_phase_freezing_for_layers_in_predict(self): - if not (testing_utils.should_run_eagerly() or - testing_utils.should_run_tf_function()): - self.skipTest('Predict fails to override the outer learning phase in' - 'the FuncGraph path.') - - class LearningPhaseLayer(keras.layers.Layer): - - def call(self, inputs): - return keras.backend.in_train_phase( - lambda: array_ops.ones_like(inputs), - lambda: array_ops.zeros_like(inputs)) - - def get_learning_phase_value(): - model = keras.models.Sequential([LearningPhaseLayer(input_shape=(1,))]) - model._run_eagerly = testing_utils.should_run_eagerly() - model._experimental_run_tf_function = ( - testing_utils.should_run_tf_function()) - return np.sum(model.predict(np.ones((1, 1)))) - - self.assertEqual(get_learning_phase_value(), 0) - - # Test scope. - with keras.backend.learning_phase_scope(1): - self.assertEqual(get_learning_phase_value(), 0) - - # The effects of the scope end after exiting it. - self.assertEqual(get_learning_phase_value(), 0) - - # Test setting. - keras.backend.set_learning_phase(1) - self.assertEqual(get_learning_phase_value(), 0) - keras.backend.set_learning_phase(0) - self.assertEqual(get_learning_phase_value(), 0) - # Cannot be enabled with `run_eagerly=True`, see b/123904578 @test_util.run_all_in_graph_and_eager_modes def test_layer_can_return_variable(self): diff --git a/tensorflow/python/keras/engine/compile_utils.py b/tensorflow/python/keras/engine/compile_utils.py index b9241280d0f..74c6370fce6 100644 --- a/tensorflow/python/keras/engine/compile_utils.py +++ b/tensorflow/python/keras/engine/compile_utils.py @@ -21,9 +21,9 @@ import copy import six +from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.keras import losses as losses_mod from tensorflow.python.keras import metrics as metrics_mod -from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import losses_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -35,6 +35,10 @@ class LossesContainer(object): """A container class for losses passed to `Model.compile`.""" def __init__(self, losses, loss_weights=None, output_names=None): + # Keep user-supplied values untouched for recompiling and serialization. + self._user_losses = losses + self._user_loss_weights = loss_weights + self._losses = losses self._loss_weights = loss_weights self._output_names = output_names @@ -59,7 +63,7 @@ class LossesContainer(object): if self._output_names is None: # In Subclass API, output names like 'output_1' are used for # `Metric` names. - self._output_names = create_output_names(y_pred) + self._output_names = create_pseudo_output_names(y_pred) # Accept a dict of losses keyed by output_name when outputs are a flat # list. @@ -94,7 +98,11 @@ class LossesContainer(object): self._built = True - def __call__(self, y_true, y_pred, sample_weight=None): + def __call__(self, + y_true, + y_pred, + sample_weight=None, + regularization_losses=None): """Computes the overall loss. Arguments: @@ -104,14 +112,19 @@ class LossesContainer(object): per-sample loss weights. If one Tensor is passed, it is used for all losses. If multiple Tensors are passed, the structure should match `y_pred`. + regularization_losses: Additional losses to be added to the total loss. Returns: Tuple of `(total_loss, per_output_loss_list)` """ + y_true = map_to_output_names(y_pred, self._output_names, y_true) + sample_weight = map_to_output_names(y_pred, self._output_names, + sample_weight) + if not self._built: self._build(y_pred) - y_true = nest.flatten(y_true) + y_true = nest.flatten(y_true) if y_true is not None else [] y_pred = nest.flatten(y_pred) # TODO(omalleyt): Remove ambiguity here. @@ -127,45 +140,47 @@ class LossesContainer(object): if len(sample_weight) == 1 and len(y_pred) > 1: sample_weight = sample_weight * len(y_pred) - loss_values = [] + loss_values = [] # Used for gradient calculation. + loss_metric_values = [] # Used for loss metric calculation. zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights, self._per_output_metrics) for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args): if loss_obj is None: # Ok to have no loss for an output. continue - y_t = math_ops.cast(y_t, y_p.dtype) - if sw is not None: - sw = math_ops.cast(sw, y_p.dtype) - - # Handle Keras mask on outputs. - mask = getattr(y_p, '_keras_mask', None) - if mask is not None: - mask = math_ops.cast(mask, y_p.dtype) - if sw is not None: - mask, _, sw = ( - tf_losses_utils.squeeze_or_expand_dimensions( - mask, sample_weight=sw)) - sw *= mask - else: - sw = mask + y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw) + sw = apply_mask(y_p, sw) loss_value = loss_obj(y_t, y_p, sample_weight=sw) + loss_metric_value = loss_value + # Correct for the `Mean` loss metrics counting each replica as a batch. + if loss_obj.reduction == losses_utils.ReductionV2.SUM: + loss_metric_value *= ds_context.get_strategy().num_replicas_in_sync if metric_obj is not None: - metric_obj.update_state(loss_value) + metric_obj.update_state(loss_metric_value) if loss_weight is not None: loss_value *= loss_weight + loss_metric_value *= loss_weight if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or loss_obj.reduction == losses_utils.ReductionV2.AUTO): loss_value = losses_utils.scale_loss_for_distribution(loss_value) + loss_values.append(loss_value) + loss_metric_values.append(loss_metric_value) + + if regularization_losses: + reg_loss = math_ops.add_n(regularization_losses) + loss_metric_values.append(reg_loss) + loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss)) if loss_values: + total_loss_metric_value = math_ops.add_n(loss_metric_values) + self._loss_metric.update_state(total_loss_metric_value) + total_loss = math_ops.add_n(loss_values) - self._loss_metric.update_state(total_loss) return total_loss else: # Ok for a model to have no compiled loss. @@ -188,7 +203,8 @@ class LossesContainer(object): loss = losses_mod.get(loss) if not isinstance(loss, losses_mod.Loss): - loss = losses_mod.LossFunctionWrapper(loss, name=loss.__name__) + loss_name = loss.__name__ + loss = losses_mod.LossFunctionWrapper(loss, name=loss_name) loss._allow_sum_over_batch_size = True # pylint: disable=protected-access return loss @@ -197,6 +213,10 @@ class MetricsContainer(object): """A container class for metrics passed to `Model.compile`.""" def __init__(self, metrics=None, weighted_metrics=None, output_names=None): + # Keep user-supplied values untouched for recompiling and serialization. + self._user_metrics = metrics + self._user_weighted_metrics = weighted_metrics + self._metrics = metrics self._weighted_metrics = weighted_metrics self._output_names = output_names @@ -207,22 +227,19 @@ class MetricsContainer(object): """Metrics created by this container.""" if not self._built: return [] - metrics = [ - metric_obj for metric_obj in nest.flatten(self._metrics) - if metric_obj is not None - ] - weighted_metrics = [ - metric_obj for metric_obj in nest.flatten(self._weighted_metrics) - if metric_obj is not None - ] - return metrics + weighted_metrics + return self._metrics_in_order def _build(self, y_pred, y_true): """One-time setup of metric objects.""" if self._output_names is None: # Subclass output names like 'output_1' are used for `Metric` names. - self._output_names = create_output_names(y_pred) + self._output_names = create_pseudo_output_names(y_pred) + + # If a single metric or flat list of metrics, apply to all outputs. + self._metrics = self._maybe_broadcast(self._metrics, y_pred) + self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics, + y_pred) # Accept a dict of metrics keyed by output_name when outputs are a flat # list. @@ -231,10 +248,13 @@ class MetricsContainer(object): self._weighted_metrics = map_to_output_names(y_pred, self._output_names, self._weighted_metrics) - # If a single metric is supplied, apply to all outputs. - self._metrics = self._maybe_broadcast(self._metrics, y_pred) - self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics, - y_pred) + # Standardize on tuple since `tf.data` turns lists into `Tensor`s. + # pylint: disable=protected-access + y_pred = nest._list_to_tuple(y_pred) + y_true = nest._list_to_tuple(y_true) + self._metrics = nest._list_to_tuple(self._metrics) + self._weighted_metrics = nest._list_to_tuple(self._weighted_metrics) + # pylint: enable=protected-access # Convert to `Metric` objects, potentially disambiguating based on output # properties. @@ -252,6 +272,17 @@ class MetricsContainer(object): # Assumes metrics, weighted_metrics have been flattened up to outputs. self._set_metric_names() + # Cache the flat order needed when returning metrics, for backwards compat. + self._metrics_in_order = [] + for output_metrics, output_weighted_metrics in zip(self._metrics, + self._weighted_metrics): + for m in nest.flatten(output_metrics): + if m is not None: + self._metrics_in_order.append(m) + for wm in nest.flatten(output_weighted_metrics): + if wm is not None: + self._metrics_in_order.append(wm) + self._built = True def _set_metric_names(self): @@ -277,9 +308,13 @@ class MetricsContainer(object): if wm is None: continue if is_multi_output: - wm._name = output_name + '_' + wm._name - if wm._name in metric_names: + if output_name + '_' + wm._name in metric_names: + wm._name = output_name + '_weighted_' + wm._name + else: + wm._name = output_name + '_' + wm._name + elif wm._name in metric_names: wm._name = 'weighted_' + wm._name + if wm._name in metric_names: raise ValueError('Found two metrics with the same name: {}'.format( wm._name)) @@ -288,9 +323,16 @@ class MetricsContainer(object): def update_state(self, y_true, y_pred, sample_weight=None): """Updates the state of per-output metrics.""" - flat_y_true = nest.flatten(y_true) + y_true = map_to_output_names(y_pred, self._output_names, y_true) + sample_weight = map_to_output_names(y_pred, self._output_names, + sample_weight) + + flat_y_true = nest.flatten(y_true) if y_true is not None else [] flat_y_pred = nest.flatten(y_pred) + if not flat_y_true: + return # Handle case where no targets are passed. + # TODO(omalleyt): Remove ambiguity here (see LossesContainer). if len(flat_y_true) == 1 and len(flat_y_pred) > 1: y_true = nest.map_structure(lambda _: flat_y_true[0], y_pred) @@ -311,21 +353,8 @@ class MetricsContainer(object): zip_args = (y_true, y_pred, sample_weight, self._metrics, self._weighted_metrics) for y_t, y_p, sw, metric_objs, weighted_metric_objs in zip(*zip_args): - y_t = math_ops.cast(y_t, y_p.dtype) - if sw is not None: - sw = math_ops.cast(sw, y_p.dtype) - - # Handle Keras mask on outputs. - mask = getattr(y_p, '_keras_mask', None) - if mask is not None: - mask = math_ops.cast(mask, y_p.dtype) - if sw is not None: - mask, _, sw = ( - tf_losses_utils.squeeze_or_expand_dimensions( - mask, sample_weight=sw)) - sw *= mask - else: - sw = mask + y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw) + sw = apply_mask(y_p, sw) for metric_obj in metric_objs: if metric_obj is None: @@ -339,7 +368,7 @@ class MetricsContainer(object): def _get_metric_objects(self, metrics, y_t, y_p): """Convert user-supplied metrics to `Metric` objects.""" - metrics = generic_utils.to_list(metrics) + metrics = nest.flatten(metrics) return [self._get_metric_object(m, y_t, y_p) for m in metrics] def _get_metric_object(self, metric, y_t, y_p): @@ -399,31 +428,47 @@ class MetricsContainer(object): return metric_obj def _maybe_broadcast(self, metrics, y_pred): - """If a single Metric is supplied, applies it to all outputs.""" + """If a flat list of Metrics is supplied, apply them to all outputs.""" def _should_broadcast(metrics): - single_valued_list = ( - isinstance(metrics, list) and len(metrics) == 1 and - not nest.is_sequence(metrics[0])) - # I.e. `metrics=['accuracy']` or `metrics='accuracy'`. - # In this special case we apply the metric to each output. - return not nest.is_sequence(metrics) or single_valued_list - - def _copy(metric): - if isinstance(metric, metrics_mod.Metric): - return metrics_mod.Metric.from_config(metric.get_config()) - return metric + # e.g. 'mse'. + if not nest.is_sequence(metrics): + return True + # e.g. ['mse'] or ['mse', 'mae']. + return (isinstance(metrics, (list, tuple)) and + not any(nest.is_sequence(m) for m in metrics)) if _should_broadcast(metrics): - metric = metrics[0] if isinstance(metrics, list) else metrics - return nest.map_structure(lambda _: _copy(metric), y_pred) + copy_metrics = len(nest.flatten(y_pred)) > 1 + + def _maybe_copy(m): + if copy_metrics and isinstance(m, metrics_mod.Metric): + return m.__class__.from_config(m.get_config()) + return m + + metrics = nest.flatten(metrics) + return nest.map_structure(lambda _: [_maybe_copy(m) for m in metrics], + y_pred) + return metrics -def create_output_names(y_pred): - """Creates output names for subclassed Model outputs. +def create_pseudo_output_names(outputs): + """Create pseudo output names for a subclassed Model.""" + return _create_pseudo_names(outputs, prefix='output_') - These names are used for naming `Metric`s. + +def create_pseudo_input_names(inputs): + """Create pseudo input names for a subclassed Model.""" + return _create_pseudo_names(inputs, prefix='input_') + + +def _create_pseudo_names(tensors, prefix): + """Creates pseudo {input | output} names for subclassed Models. + + Warning: this function should only be used to define default + names for `Metics` and `SavedModel`. No other use cases should + rely on a `Model`'s input or output names. Example with dict: @@ -436,10 +481,11 @@ def create_output_names(y_pred): `['output_1', 'output_2']` Arguments: - y_pred: `Model`'s outputs. + tensors: `Model`'s outputs or inputs. + prefix: 'output_' for outputs, 'input_' for inputs. Returns: - Flattened list of output names. + Flattened list of pseudo names. """ def one_index(ele): @@ -448,18 +494,18 @@ def create_output_names(y_pred): return ele + 1 return ele - flat_paths = list(nest.yield_flat_paths(y_pred)) + flat_paths = list(nest.yield_flat_paths(tensors)) flat_paths = nest.map_structure(one_index, flat_paths) - output_names = [] + names = [] for path in flat_paths: if not path: - output_name = 'output_1' + name = prefix + '1' # Single output. else: - output_name = '_'.join(str(p) for p in path) + name = '_'.join(str(p) for p in path) if isinstance(path[0], int): - output_name = 'output_' + output_name - output_names.append(output_name) - return output_names + name = prefix + name + names.append(name) + return names def map_to_output_names(y_pred, output_names, struct): @@ -473,7 +519,7 @@ def map_to_output_names(y_pred, output_names, struct): For the Functional API, the output names are the names of the last layer of each output. For the Subclass API, the output names - are determined by `create_output_names` (For example: + are determined by `create_pseudo_output_names` (For example: `['output_1', 'output_2']` for a list of outputs). This mapping preserves backwards compatibility for `compile` and @@ -492,17 +538,52 @@ def map_to_output_names(y_pred, output_names, struct): outputs_are_flat_list = ( isinstance(y_pred, (list, tuple)) and not any(nest.is_sequence(y_p) for y_p in y_pred)) - if not outputs_are_flat_list: - # In this case, `y_pred` and `struct` must have the same structure. + single_output = not nest.is_sequence(y_pred) + + if (single_output or outputs_are_flat_list) and isinstance(struct, dict): + output_names = output_names or create_pseudo_output_names(y_pred) + struct = copy.copy(struct) + new_struct = [struct.pop(name, None) for name in output_names] + if struct: + raise ValueError('Found unexpected keys that do not correspond ' + 'to any Model output: {}. Expected: {}'.format( + struct.keys(), output_names)) + if len(new_struct) == 1: + return new_struct[0] + return new_struct + else: return struct - if not isinstance(struct, dict): - return struct - struct = copy.copy(struct) - new_struct = [struct.pop(name, None) for name in output_names] - if struct: - raise ValueError('Found unexpected keys that do not correspond ' - 'to any Model output: {}. Expected: {}'.format( - struct.keys(), output_names)) - return new_struct +def match_dtype_and_rank(y_t, y_p, sw): + """Match dtype and rank of predictions.""" + # Rank. + y_t_rank = len(y_t.shape) + y_p_rank = len(y_p.shape) + if y_t_rank == 1 and y_p_rank == 2: + y_t = array_ops.expand_dims_v2(y_t, axis=-1) + if sw is not None: + sw_rank = len(sw.shape) + if sw_rank == 1 and y_p_rank == 2: + sw = array_ops.expand_dims_v2(sw, axis=-1) + + # Dtype. + y_t = math_ops.cast(y_t, y_p.dtype) + if sw is not None: + sw = math_ops.cast(sw, y_p.dtype) + return y_t, y_p, sw + + +def apply_mask(y_p, sw): + """Applies any mask on predictions to sample weights.""" + # Handle Keras mask on outputs. + mask = getattr(y_p, '_keras_mask', None) + if mask is not None: + mask = math_ops.cast(mask, y_p.dtype) + if sw is not None: + mask, _, sw = ( + tf_losses_utils.squeeze_or_expand_dimensions(mask, sample_weight=sw)) + sw *= mask + else: + sw = mask + return sw diff --git a/tensorflow/python/keras/engine/compile_utils_test.py b/tensorflow/python/keras/engine/compile_utils_test.py index 58d92d41e1f..f888797746d 100644 --- a/tensorflow/python/keras/engine/compile_utils_test.py +++ b/tensorflow/python/keras/engine/compile_utils_test.py @@ -234,29 +234,37 @@ class MetricsContainerTest(keras_parameterized.TestCase): def test_list_of_metrics_list_of_outputs(self): metric_container = compile_utils.MetricsContainer( - metrics=['mse', 'mae'], + metrics=['mse', 'mae'], # Should broadcast to both outputs. weighted_metrics=['accuracy']) # Should broadcast to both outputs. y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] y_p = [array_ops.ones((10, 1)), 2 * array_ops.ones((10, 1))] sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) metric_container.update_state(y_t, y_p, sample_weight=sw) - self.assertLen(metric_container.metrics, 4) + self.assertLen(metric_container.metrics, 6) mse_metric = metric_container.metrics[0] self.assertEqual(mse_metric.name, 'output_1_mse') self.assertEqual(mse_metric.result().numpy(), 0.) - mae_metric = metric_container.metrics[1] - self.assertEqual(mae_metric.name, 'output_2_mae') - self.assertEqual(mae_metric.result().numpy(), 2.) + mse_metric = metric_container.metrics[1] + self.assertEqual(mse_metric.name, 'output_1_mae') + self.assertEqual(mse_metric.result().numpy(), 0.) acc_metric_1 = metric_container.metrics[2] self.assertEqual(acc_metric_1.name, 'output_1_accuracy') self.assertEqual(acc_metric_1.result().numpy(), 1.) self.assertEqual(acc_metric_1._fn, metrics_mod.binary_accuracy) - acc_metric_2 = metric_container.metrics[3] + mae_metric = metric_container.metrics[3] + self.assertEqual(mae_metric.name, 'output_2_mse') + self.assertEqual(mae_metric.result().numpy(), 4.) + + mae_metric = metric_container.metrics[4] + self.assertEqual(mae_metric.name, 'output_2_mae') + self.assertEqual(mae_metric.result().numpy(), 2.) + + acc_metric_2 = metric_container.metrics[5] self.assertEqual(acc_metric_2.name, 'output_2_accuracy') self.assertEqual(acc_metric_2.result().numpy(), 0.) self.assertEqual(acc_metric_2._fn, metrics_mod.binary_accuracy) @@ -281,16 +289,16 @@ class MetricsContainerTest(keras_parameterized.TestCase): self.assertEqual(mse_metric.name, 'out1_mse') self.assertEqual(mse_metric.result().numpy(), 0.) - mae_metric = metric_container.metrics[1] + weighted_mse_metric = metric_container.metrics[1] + self.assertEqual(weighted_mse_metric.name, 'out1_weighted_mse') + self.assertEqual(weighted_mse_metric.result().numpy(), 0.) + + mae_metric = metric_container.metrics[2] self.assertEqual(mae_metric.name, 'out2_mae') self.assertEqual(mae_metric.result().numpy(), 2.) - weighted_mse_metric = metric_container.metrics[2] - self.assertEqual(weighted_mse_metric.name, 'weighted_out1_mse') - self.assertEqual(weighted_mse_metric.result().numpy(), 0.) - weighted_mae_metric = metric_container.metrics[3] - self.assertEqual(weighted_mae_metric.name, 'weighted_out2_mae') + self.assertEqual(weighted_mae_metric.name, 'out2_weighted_mae') self.assertEqual(weighted_mae_metric.result().numpy(), 2.) def test_metric_partial_dict_with_output_names(self): @@ -355,14 +363,14 @@ class MetricsContainerTest(keras_parameterized.TestCase): self.assertEqual(a_mae_metric.name, 'a_mae') self.assertEqual(a_mae_metric.result().numpy(), 1.) - b_1_mse_metric = metric_container.metrics[1] - self.assertEqual(b_1_mse_metric.name, 'b_1_mse') - self.assertEqual(b_1_mse_metric.result().numpy(), 4.) - - weighted_a_mae_metric = metric_container.metrics[2] + weighted_a_mae_metric = metric_container.metrics[1] self.assertEqual(weighted_a_mae_metric.name, 'a_mse') self.assertEqual(weighted_a_mae_metric.result().numpy(), 1.) + b_1_mse_metric = metric_container.metrics[2] + self.assertEqual(b_1_mse_metric.name, 'b_1_mse') + self.assertEqual(b_1_mse_metric.result().numpy(), 4.) + def test_crossentropy(self): metric_container = compile_utils.MetricsContainer('crossentropy') y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 1)) @@ -422,6 +430,29 @@ class MetricsContainerTest(keras_parameterized.TestCase): self.assertEqual(weighted_mae_metric.name, 'weighted_mae') self.assertEqual(weighted_mae_metric.result().numpy(), 0.) + def test_broadcast_metrics_to_dict(self): + metric_container = compile_utils.MetricsContainer(metrics=['mae']) + + y_p = {'output': ops.convert_to_tensor([[0], [1], [2]])} + y_t = {'output': ops.convert_to_tensor([[1], [2], [3]])} + metric_container.update_state(y_t, y_p) + + mae_metric = metric_container.metrics[0] + self.assertEqual(mae_metric.name, 'mae') + self.assertEqual(mae_metric.result().numpy(), 1.) + + def test_broadcast_metrics_to_dict_with_output_names(self): + metric_container = compile_utils.MetricsContainer( + metrics=['mae'], output_names=['output']) + + y_p = ops.convert_to_tensor([[0], [1], [2]]) + y_t = {'output': ops.convert_to_tensor([[1], [2], [3]])} + metric_container.update_state(y_t, y_p) + + mae_metric = metric_container.metrics[0] + self.assertEqual(mae_metric.name, 'mae') + self.assertEqual(mae_metric.result().numpy(), 1.) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index d040a1fbdaa..3fc66d05b6f 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -36,6 +36,9 @@ from tensorflow.python.distribute import distribution_strategy_context as ds_con from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework.ops import composite_tensor from tensorflow.python.keras import backend from tensorflow.python.keras.engine import training_utils @@ -211,6 +214,15 @@ class DataAdapter(object): """Returns whether a new iterator should be created every epoch.""" raise NotImplementedError + def get_samples(self): + """Returns number of samples in the data, or `None`.""" + if not self.get_size() or not self.batch_size(): + return None + total_sample = self.get_size() * self.batch_size() + if self.has_partial_batch(): + total_sample -= (self.batch_size() - self.partial_batch_size()) + return total_sample + class TensorLikeDataAdapter(DataAdapter): """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.""" @@ -245,25 +257,15 @@ class TensorLikeDataAdapter(DataAdapter): shuffle=False, **kwargs): super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs) - x = _process_numpy_inputs(x) - y = _process_numpy_inputs(y) - sample_weights = _process_numpy_inputs(sample_weights) + x, y, sample_weights = _process_tensorlike((x, y, sample_weights)) sample_weight_modes = broadcast_sample_weight_modes( sample_weights, sample_weight_modes) # If sample_weights are not specified for an output use 1.0 as weights. - (sample_weights, any_sample_weight, _ - ) = training_utils.handle_partial_sample_weights( + (sample_weights, _, _) = training_utils.handle_partial_sample_weights( y, sample_weights, sample_weight_modes, check_all_flat=True) - if y is not None and any_sample_weight: - inputs = (x, y, sample_weights) - elif y is not None: - # Sample weight is only needed for training, so if y is None, then - # sample_weight is ignored. - inputs = (x, y) - else: - inputs = (x,) + inputs = pack_x_y_sample_weight(x, y, sample_weights) num_samples = set(int(i.shape[0]) for i in nest.flatten(inputs)) if len(num_samples) > 1: @@ -276,13 +278,9 @@ class TensorLikeDataAdapter(DataAdapter): num_samples = num_samples.pop() # If batch_size is not passed but steps is, calculate from the input data. - if steps and not batch_size: - batch_size = int(math.ceil(num_samples / steps)) - + # Default to 32 for backwards compat. if not batch_size: - raise ValueError( - "`batch_size` or `steps` is required for `Tensor` or `NumPy`" - " input data.") + batch_size = int(math.ceil(num_samples / steps)) if steps else 32 self._size = int(math.ceil(num_samples / batch_size)) self._batch_size = batch_size @@ -557,25 +555,15 @@ class CompositeTensorDataAdapter(DataAdapter): shuffle=False, **kwargs): super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs) - x = _process_numpy_inputs(x) - y = _process_numpy_inputs(y) - sample_weights = _process_numpy_inputs(sample_weights) + x, y, sample_weights = _process_tensorlike((x, y, sample_weights)) sample_weight_modes = broadcast_sample_weight_modes( sample_weights, sample_weight_modes) # If sample_weights are not specified for an output use 1.0 as weights. - (sample_weights, any_sample_weight, _ - ) = training_utils.handle_partial_sample_weights( + (sample_weights, _, _) = training_utils.handle_partial_sample_weights( y, sample_weights, sample_weight_modes, check_all_flat=True) - if y is not None and any_sample_weight: - inputs = (x, y, sample_weights) - elif y is not None: - # Sample weight is only needed for training, so if y is None, then - # sample_weight is ignored. - inputs = (x, y) - else: - inputs = (x,) + inputs = pack_x_y_sample_weight(x, y, sample_weights) dataset = dataset_ops.DatasetV2.from_tensor_slices(inputs) num_samples = int(nest.flatten(x)[0].shape[0]) @@ -583,13 +571,9 @@ class CompositeTensorDataAdapter(DataAdapter): dataset = dataset.shuffle(num_samples) # If batch_size is not passed but steps is, calculate from the input data. - if steps and not batch_size: - batch_size = int(math.ceil(num_samples/steps)) - + # Default to 32 for backwards compat. if not batch_size: - raise ValueError( - "`batch_size` or `steps` is required for `Tensor` or `NumPy`" - " input data.") + batch_size = int(math.ceil(num_samples / steps)) if steps else 32 dataset = dataset.batch(batch_size) self._size = int(math.ceil(num_samples / batch_size)) @@ -648,7 +632,6 @@ class ListsOfScalarsDataAdapter(DataAdapter): sample_weight_modes=None, batch_size=None, shuffle=False, - standardize_function=None, **kwargs): super(ListsOfScalarsDataAdapter, self).__init__(x, y, **kwargs) x = np.asarray(x) @@ -659,10 +642,6 @@ class ListsOfScalarsDataAdapter(DataAdapter): sample_weight_modes = broadcast_sample_weight_modes( sample_weights, sample_weight_modes) - if standardize_function is not None: - x, y, sample_weights = standardize_function( - x=x, y=y, sample_weight=sample_weights) - self._internal_adapter = TensorLikeDataAdapter( x, y=y, @@ -703,32 +682,22 @@ class DatasetAdapter(DataAdapter): y=None, sample_weights=None, steps=None, - standardize_function=None, **kwargs): super(DatasetAdapter, self).__init__(x, y, **kwargs) - if not is_none_or_empty(y): - raise ValueError("`y` argument is not supported when using " - "dataset as input.") - if not is_none_or_empty(sample_weights): - raise ValueError("`sample_weight` argument is not supported when using " - "dataset as input.") - - if standardize_function is not None: - x = standardize_function(x) - - # Note that the dataset instance is immutable, its fine to reusing the user + # Note that the dataset instance is immutable, its fine to reuse the user # provided dataset. self._dataset = x # The user-provided steps. self._user_steps = steps + self._validate_args(y, sample_weights, steps) + def get_dataset(self): return self._dataset def get_size(self): - # The size of dataset is unknown, unless its fully consumed. - return None + return # Inferred in `DataHandler`. def batch_size(self): return None @@ -746,6 +715,21 @@ class DatasetAdapter(DataAdapter): return (self._user_steps is None or cardinality.cardinality(self._dataset).numpy() == self._user_steps) + def _validate_args(self, y, sample_weights, steps): + """Validates `__init__` arguments.""" + # Arguments that shouldn't be passed. + if not is_none_or_empty(y): + raise ValueError("`y` argument is not supported when using " + "dataset as input.") + if not is_none_or_empty(sample_weights): + raise ValueError("`sample_weight` argument is not supported when using " + "dataset as input.") + + size = cardinality.cardinality(self._dataset).numpy() + if size == cardinality.INFINITE and steps is None: + raise ValueError("When providing an infinite dataset, you must specify " + "the number of steps to run.") + class GeneratorDataAdapter(DataAdapter): """Adapter that handles python generators and iterators.""" @@ -756,8 +740,14 @@ class GeneratorDataAdapter(DataAdapter): and hasattr(x, "__iter__") and not isinstance(x, data_utils.Sequence)) - def __init__(self, x, y=None, sample_weights=None, standardize_function=None, - workers=1, use_multiprocessing=False, max_queue_size=10, + def __init__(self, + x, + y=None, + sample_weights=None, + workers=1, + use_multiprocessing=False, + max_queue_size=10, + model=None, **kwargs): # Generators should never shuffle as exhausting the generator in order to # shuffle the batches is inefficient. @@ -769,115 +759,75 @@ class GeneratorDataAdapter(DataAdapter): if not is_none_or_empty(sample_weights): raise ValueError("`sample_weight` argument is not supported when using " "python generator as input.") + super(GeneratorDataAdapter, self).__init__(x, y, **kwargs) # Since we have to know the dtype of the python generator when we build the # dataset, we have to look at a batch to infer the structure. peek, x = self._peek_and_restore(x) assert_not_namedtuple(peek) + peek = self._standardize_batch(peek) + peek = _process_tensorlike(peek) - (peek, wrap_in_tuple, elements_to_keep, partial_sample_weight, - sample_weight_modes, nested_shape, nested_dtypes - ) = self._canonicalize_peek(peek, kwargs.get("sample_weight_modes")) + # Need to build the Model on concrete input shapes. + if model is not None and not model.built: + concrete_x, _, _ = unpack_x_y_sample_weight(peek) + model.distribute_strategy.experimental_run_v2( + lambda x: model(x, training=False), args=(concrete_x,)) + + self._first_batch_size = int(nest.flatten(peek)[0].shape[0]) + + def _get_dynamic_shape(t): + shape = t.shape + # Unknown number of dimensions, `as_list` cannot be called. + if shape.rank is None: + return shape + return tensor_shape.TensorShape([None for _ in shape.as_list()]) + + output_shapes = nest.map_structure(_get_dynamic_shape, peek) + output_types = nest.map_structure(lambda t: t.dtype, peek) # Note that dataset API takes a callable that creates a generator object, # rather than generator itself, which is why we define a function here. - generator_fn = self._make_callable(x, workers, use_multiprocessing, - max_queue_size) + generator_fn = self._handle_multiprocessing(x, workers, use_multiprocessing, + max_queue_size) - generator_fn = self._make_bridging_callable( - generator_fn, wrap_in_tuple, peek, elements_to_keep, - partial_sample_weight, sample_weight_modes) + def wrapped_generator(): + for data in generator_fn(): + yield self._standardize_batch(data) dataset = dataset_ops.DatasetV2.from_generator( - generator_fn, nested_dtypes, output_shapes=nested_shape) - - if standardize_function is not None: - dataset = standardize_function(dataset) + wrapped_generator, output_types, output_shapes=output_shapes) if workers == 1 and not use_multiprocessing: dataset = dataset.prefetch(1) self._dataset = dataset - def _canonicalize_peek(self, peek, sample_weight_modes): - """Map the peeked batch into a regular form. + def _standardize_batch(self, data): + """Standardizes a batch output by a generator.""" + # Removes `None`s. + x, y, sample_weight = unpack_x_y_sample_weight(data) + data = pack_x_y_sample_weight(x, y, sample_weight) - This function serves two purposes. First, it determines if per-batch - transformations are needed. Second, it extracts the structure to be used - by Dataset.from_generator. + data = nest._list_to_tuple(data) # pylint: disable=protected-access - Args: - peek: The first batch of the user's data - sample_weight_modes: Optional structure indicating how to handle sample - weights. If it is a string, it will be mapped to match the target - structure. + def _convert_dtype(t): + if (isinstance(t, np.ndarray) and issubclass(t.dtype.type, np.floating)): + return np.array(t, dtype=backend.floatx()) + return t - Returns: - An updated peek and various inspection results. - """ - wrap_in_tuple = False - if not isinstance(peek, tuple): - peek, wrap_in_tuple = (peek,), True - - if len(peek) not in (1, 2, 3): - raise ValueError( - "Output of generator should be a tuple of 1 or 2 or 3 elements: " - "(input,) or (input, target) or (input, target, sample_weights). " - "Received {}".format(peek)) - - x_peek, y_peek, sample_weights_peek = list(peek) + [None] * (3 - len(peek)) - - any_sample_weight, partial_sample_weight = False, False - sample_weight_modes = broadcast_sample_weight_modes( - sample_weights_peek if sample_weights_peek is not None else y_peek, - sample_weight_modes) - - if len(peek) == 3: - (sample_weights_peek, any_sample_weight, partial_sample_weight - ) = training_utils.handle_partial_sample_weights( - y_peek, sample_weights_peek, sample_weight_modes, check_all_flat=True) - peek = (x_peek, y_peek, sample_weights_peek) - - # Users often return None for fields which are not used. For instance: - # (x, y, None) to indicate no sample weights. - if len(peek) >= 2 and y_peek is None: - if any_sample_weight: - raise ValueError("Found sample weights but no targets\n{}".format(peek)) - elements_to_keep = 1 - elif len(peek) == 3 and not any_sample_weight: - elements_to_keep = 2 - else: - elements_to_keep = len(peek) - - def dynamic_shape_like(t): - return tuple(None for _ in t.shape) - - def convert_for_inspection(t): - if getattr(t, "shape", None) and getattr(t, "dtype", None): - return t - return np.array(t, dtype=backend.floatx()) - - canonicalized_peek = nest._list_to_tuple( # pylint: disable=protected-access - nest.map_structure(convert_for_inspection, peek[:elements_to_keep])) - nested_dtypes = nest.map_structure(lambda t: t.dtype, canonicalized_peek) - nested_shape = nest.map_structure(dynamic_shape_like, canonicalized_peek) - - try: - self._first_batch_size = int(nest.flatten(canonicalized_peek)[0].shape[0]) - except IndexError: - raise IndexError("Could not infer batch size from: {}".format(peek)) - - return (peek, wrap_in_tuple, elements_to_keep, partial_sample_weight, - sample_weight_modes, nested_shape, nested_dtypes) + data = nest.map_structure(_convert_dtype, data) + return data @staticmethod def _peek_and_restore(x): peek = next(x) return peek, itertools.chain([peek], x) - def _make_callable(self, x, workers, use_multiprocessing, max_queue_size): - """Create a callable, and possibly include an Enqueuer.""" + def _handle_multiprocessing(self, x, workers, use_multiprocessing, + max_queue_size): + """Create a callable, possibly including an Enqueuer.""" if workers > 1 or (workers > 0 and use_multiprocessing): if use_multiprocessing: logging.warning( @@ -893,44 +843,6 @@ class GeneratorDataAdapter(DataAdapter): generator_fn = lambda: x return generator_fn - @staticmethod - def _make_bridging_callable( - generator_fn, wrap_in_tuple, peek, elements_to_keep, - partial_sample_weight, sample_weight_modes): - """Optional compatibility layer between user's data and Dataset.""" - must_prune_nones = (elements_to_keep != len(peek)) - try: - nest.assert_same_structure(peek, nest._list_to_tuple(peek)) # pylint: disable=protected-access - must_extract_lists = False - except TypeError: - must_extract_lists = True - - # No additional transformations are needed. - if not (wrap_in_tuple or must_extract_lists or must_prune_nones or - partial_sample_weight): - return generator_fn - - def wrapped_generator(): - """Remove Nones and lists before invoking Dataset.from_generator.""" - for batch in generator_fn(): - if wrap_in_tuple: - batch = (batch,) - - if must_extract_lists: - batch = nest._list_to_tuple(batch) # pylint: disable=protected-access - - if must_prune_nones: - batch = batch[:elements_to_keep] - - if partial_sample_weight: - sample_weights, _, _ = training_utils.handle_partial_sample_weights( - batch[1], batch[2], sample_weight_modes, check_all_flat=False) - batch = batch[:2] + (sample_weights,) - - yield batch - - return wrapped_generator - def get_dataset(self): return self._dataset @@ -960,31 +872,40 @@ class KerasSequenceAdapter(GeneratorDataAdapter): def can_handle(x, y=None): return isinstance(x, data_utils.Sequence) - def __init__(self, x, y=None, sample_weights=None, standardize_function=None, - shuffle=False, workers=1, use_multiprocessing=False, - max_queue_size=10, **kwargs): + def __init__(self, + x, + y=None, + sample_weights=None, + shuffle=False, + workers=1, + use_multiprocessing=False, + max_queue_size=10, + model=None, + **kwargs): if not is_none_or_empty(y): raise ValueError("`y` argument is not supported when using " "`keras.utils.Sequence` as input.") if not is_none_or_empty(sample_weights): raise ValueError("`sample_weight` argument is not supported when using " "`keras.utils.Sequence` as input.") + self._size = len(x) self._shuffle_sequence = shuffle super(KerasSequenceAdapter, self).__init__( x, - standardize_function=standardize_function, shuffle=False, # Shuffle is handed in the _make_callable override. workers=workers, use_multiprocessing=use_multiprocessing, max_queue_size=max_queue_size, + model=model, **kwargs) @staticmethod def _peek_and_restore(x): return x[0], x - def _make_callable(self, x, workers, use_multiprocessing, max_queue_size): + def _handle_multiprocessing(self, x, workers, use_multiprocessing, + max_queue_size): if workers > 1 or (workers > 0 and use_multiprocessing): def generator_fn(): enqueuer = data_utils.OrderedEnqueuer( @@ -1051,37 +972,34 @@ def _type_name(x): return str(type(x)) -def _process_numpy_inputs(inputs): - """Process numpy array inputs. +def _process_tensorlike(inputs): + """Process tensor-like inputs. - For numpy inputs, it is possible to be single numpy array, or list/dict of - them. They could also be preprocessed by other lib to match with the order - of position for the model. The result here should be something that can be - used to build dataset. + This function: + + (1) Converts `Numpy` arrays to `Tensor`s. + (2) Converts `Scipy` sparse matrices to `SparseTensor`s. + (2) Converts `list`s to `tuple`s (for `tf.data` support). Args: - inputs: single or list/tuple/dict of numpy array. - Returns: - numpy arrays can be used to build dataset. - """ - if is_none_or_empty(inputs): - return None - flat_inputs = nest.flatten(inputs) - if len(flat_inputs) == 1: - return flat_inputs[0] + inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like. - def _convert_non_tensor(x): - # Don't call `ops.convert_to_tensor_v2` on all `inputs` because - # `SparseTensors` can't be converted to `Tensor`. + Returns: + Structure of `Tensor`s or tensor-like. + """ + + def _convert_numpy_and_scipy(x): if isinstance(x, np.ndarray): - return ops.convert_to_tensor_v2(x) + dtype = None + if issubclass(x.dtype.type, np.floating): + dtype = backend.floatx() + return ops.convert_to_tensor(x, dtype=dtype) + elif scipy_sparse and scipy_sparse.issparse(x): + return _scipy_sparse_to_sparse_tensor(x) return x - inputs = nest.map_structure(_convert_non_tensor, inputs) - # For more complicated structure, we only convert the out most list to tuple - # since dataset will stack the list, but treat elements in the tuple as - # individual element. - return training_utils.list_to_tuple(inputs) + inputs = nest.map_structure(_convert_numpy_and_scipy, inputs) + return nest._list_to_tuple(inputs) # pylint: disable=protected-access def is_none_or_empty(inputs): @@ -1147,8 +1065,6 @@ def assert_not_namedtuple(x): class DataHandler(object): """Handles iterating over epoch-level `tf.data.Iterator` objects.""" - # TODO(omalleyt): Handle `validation_split` with separate utility. - # TODO(omalleyt): Handle `validation_data` batch size when `x` is a gen. def __init__(self, x, y=None, @@ -1161,7 +1077,8 @@ class DataHandler(object): class_weight=None, max_queue_size=10, workers=1, - use_multiprocessing=False): + use_multiprocessing=False, + model=None): self._initial_epoch = initial_epoch self._epochs = epochs @@ -1173,20 +1090,21 @@ class DataHandler(object): y, batch_size=batch_size, steps=steps_per_epoch, - epochs=epochs, + epochs=epochs - initial_epoch, sample_weights=sample_weight, shuffle=shuffle, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, - distribution_strategy=ds_context.get_strategy()) + distribution_strategy=ds_context.get_strategy(), + model=model) strategy = ds_context.get_strategy() dataset = self._train_adapter.get_dataset() if class_weight: dataset = dataset.map(_make_class_weight_map_fn(class_weight)) + self._steps_per_epoch = self._infer_steps(steps_per_epoch, dataset) self._train_dataset = strategy.experimental_distribute_dataset(dataset) - self._steps_per_epoch = self._infer_steps(steps_per_epoch) def enumerate_epochs(self): """Yields `(epoch, tf.data.Iterator)`.""" @@ -1231,7 +1149,7 @@ class DataHandler(object): yield self._current_step self._current_step += 1 - def _infer_steps(self, steps): + def _infer_steps(self, steps, dataset): """Infers steps_per_epoch needed to loop through a dataset.""" if steps is not None: return steps @@ -1240,7 +1158,6 @@ class DataHandler(object): if adapter_steps is not None: return adapter_steps - dataset = self._train_dataset if (ds_context.get_strategy().extended._in_multi_worker_mode() and # pylint: disable=protected-access (dataset.options().experimental_distribute.auto_shard_policy != distribute_options.AutoShardPolicy.OFF)): @@ -1256,6 +1173,14 @@ class DataHandler(object): return size return None + @property + def _samples(self): + return self._train_adapter.get_samples() + + @property + def _steps(self): + return self._train_adapter.get_size() + def _make_class_weight_map_fn(class_weight): """Applies class weighting to a `Dataset`. @@ -1280,25 +1205,29 @@ def _make_class_weight_map_fn(class_weight): raise ValueError(error_msg) class_weight_tensor = ops.convert_to_tensor_v2( - [class_weight[c] for c in class_ids]) + [int(class_weight[c]) for c in class_ids], dtype="int64") def _class_weights_map_fn(*data): """Convert `class_weight` to `sample_weight`.""" - if len(data) == 2: - x, y = data - sw = None - else: - x, y, sw = data + x, y, sw = unpack_x_y_sample_weight(data) if nest.is_sequence(y): raise ValueError( - "`class_weight` is only supported for `Model`s with a single output.") + "`class_weight` is only supported for Models with a single output.") - cw = array_ops.gather_v2(class_weight_tensor, y) + if y.shape.rank > 2: + raise ValueError("`class_weight` not supported for " + "3+ dimensional targets.") + + y_classes = smart_cond.smart_cond( + y.shape.rank == 2 and backend.shape(y)[1] > 1, + lambda: backend.argmax(y, axis=1), + lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64)) + + cw = array_ops.gather_v2(class_weight_tensor, y_classes) if sw is not None: cw = math_ops.cast(cw, sw.dtype) - if len(cw.shape.as_list()) > len(sw.shape.as_list()): - cw = array_ops.squeeze(cw) + sw, cw = expand_1d((sw, cw)) # `class_weight` and `sample_weight` are multiplicative. sw = sw * cw else: @@ -1309,6 +1238,18 @@ def _make_class_weight_map_fn(class_weight): return _class_weights_map_fn +def expand_1d(data): + """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.""" + + def _expand_single_1d_tensor(t): + if (hasattr(t, "shape") and + isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1): + return array_ops.expand_dims_v2(t, axis=-1) + return t + + return nest.map_structure(_expand_single_1d_tensor, data) + + def train_validation_split(arrays, validation_split, shuffle=True): """Split arrays into random train and validation subsets. @@ -1368,3 +1309,60 @@ def train_validation_split(arrays, validation_split, shuffle=True): functools.partial(_split, indices=val_indices), arrays) return train_arrays, val_arrays + + +def unpack_x_y_sample_weight(data): + """Unpacks user-provided data tuple.""" + if not isinstance(data, tuple): + return (data, None, None) + elif len(data) == 1: + return (data[0], None, None) + elif len(data) == 2: + return (data[0], data[1], None) + elif len(data) == 3: + return (data[0], data[1], data[2]) + + raise ValueError("Data not understood.") + + +def pack_x_y_sample_weight(x, y=None, sample_weight=None): + """Packs user-provided data into a tuple.""" + if y is None: + return (x,) + elif sample_weight is None: + return (x, y) + else: + return (x, y, sample_weight) + + +def single_batch_iterator(strategy, + x, + y=None, + sample_weight=None, + class_weight=None): + """Creates a single-batch dataset.""" + x, y, sample_weight = _process_tensorlike((x, y, sample_weight)) + if y is None: + data = (x,) + elif sample_weight is None: + data = (x, y) + else: + data = (x, y, sample_weight) + + dataset = dataset_ops.DatasetV2.from_tensors(data) + if class_weight: + dataset = dataset.map(_make_class_weight_map_fn(class_weight)) + dataset = strategy.experimental_distribute_dataset(dataset) + return iter(dataset) + + +def _scipy_sparse_to_sparse_tensor(t): + """Converts a SciPy sparse matrix to a SparseTensor.""" + sparse_coo = t.tocoo() + row, col = sparse_coo.row, sparse_coo.col + data, shape = sparse_coo.data, sparse_coo.shape + if issubclass(data.dtype.type, np.floating): + data = data.astype(backend.floatx()) + indices = np.concatenate( + (np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1) + return sparse_tensor.SparseTensor(indices, data, shape) diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py index 1bb91303aa8..75ddf0f7d6e 100644 --- a/tensorflow/python/keras/engine/data_adapter_test.py +++ b/tensorflow/python/keras/engine/data_adapter_test.py @@ -124,11 +124,6 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase): self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) - def test_iterator_expect_batch_size_numpy(self): - with self.assertRaisesRegexp( - ValueError, r'`batch_size` or `steps` is required'): - self.adapter_cls(self.numpy_input, self.numpy_target) - def test_size_numpy(self): adapter = self.adapter_cls( self.numpy_input, self.numpy_target, batch_size=5) @@ -428,12 +423,6 @@ class GenericArrayLikeDataAdapterTest(DataAdapterTestBase): self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) - def test_iterator_expect_batch_size_generic_arraylike(self): - with self.assertRaisesRegexp( - ValueError, r'`batch_size` or `steps` is required'): - self.adapter_cls(self.arraylike_input, - self.arraylike_target) - def test_size(self): adapter = self.adapter_cls( self.arraylike_input, @@ -885,6 +874,7 @@ class DataHandlerTest(keras_parameterized.TestCase): def test_insufficient_data(self): ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1]) + ds = ds.filter(lambda *args, **kwargs: True) data_handler = data_adapter.DataHandler( ds, initial_epoch=0, epochs=2, steps_per_epoch=3) returned_data = [] @@ -963,53 +953,6 @@ class DataHandlerTest(keras_parameterized.TestCase): self.assertEqual(returned_data, [[([0],), ([1],), ([2],)], [([0],), ([1],), ([2],)]]) - def test_class_weight(self): - data_handler = data_adapter.DataHandler( - x=[[0], [1], [2]], - y=[[2], [1], [0]], - class_weight={ - 0: 0.5, - 1: 1., - 2: 1.5 - }, - epochs=2, - steps_per_epoch=3) - returned_data = [] - for _, iterator in data_handler.enumerate_epochs(): - epoch_data = [] - for _ in data_handler.steps(): - epoch_data.append(next(iterator)) - returned_data.append(epoch_data) - returned_data = self.evaluate(returned_data) - self.assertEqual(returned_data, [[([0], [2], [1.5]), ([1], [1], [1.]), - ([2], [0], [0.5])], - [([0], [2], [1.5]), ([1], [1], [1.]), - ([2], [0], [0.5])]]) - - def test_class_weight_and_sample_weight(self): - data_handler = data_adapter.DataHandler( - x=[[0], [1], [2]], - y=[[2], [1], [0]], - sample_weight=[[1.], [2.], [4.]], - class_weight={ - 0: 0.5, - 1: 1., - 2: 1.5 - }, - epochs=2, - steps_per_epoch=3) - returned_data = [] - for _, iterator in data_handler.enumerate_epochs(): - epoch_data = [] - for _ in data_handler.steps(): - epoch_data.append(next(iterator)) - returned_data.append(epoch_data) - returned_data = self.evaluate(returned_data) - self.assertEqual(returned_data, [[([0], [2], [1.5]), ([1], [1], [2.]), - ([2], [0], [2.])], - [([0], [2], [1.5]), ([1], [1], [2.]), - ([2], [0], [2.])]]) - def test_class_weight_user_errors(self): with self.assertRaisesRegexp(ValueError, 'to be a dict with keys'): data_adapter.DataHandler( diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index deb3bd27928..166553a324b 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -40,6 +40,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.keras.engine import compile_utils from tensorflow.python.keras.engine import input_layer as input_layer_module from tensorflow.python.keras.engine import node as node_module from tensorflow.python.keras.engine import training_utils @@ -50,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management @@ -200,7 +202,10 @@ class Network(base_layer.Layer): super(Network, self).__init__(name=name, **kwargs) + self.output_names = None + self.input_names = None self._is_compiled = False + self._saved_model_inputs_spec = None # This is True for Sequential networks and Functional networks. self._compute_output_and_mask_jointly = False @@ -326,6 +331,7 @@ class Network(base_layer.Layer): self._feed_inputs.append(layer.input) self._compute_tensor_usage_count() + self._set_save_spec(self._nested_inputs) def _set_output_names(self): """Assigns unique names to the Network's outputs. @@ -354,8 +360,8 @@ class Network(base_layer.Layer): self._autocast = kwargs.get('autocast', base_layer_utils.v2_dtype_behavior_enabled()) self._supports_ragged_inputs = None - self.outputs = [] - self.inputs = [] + self.outputs = None + self.inputs = None self.built = False self._build_input_shape = None @@ -573,24 +579,7 @@ class Network(base_layer.Layer): A list of `InputSpec` instances (one per input to the model) or a single instance if the model has only one input. """ - # If subclassed model, can't assume anything. - if not self._is_graph_network: - return None - - specs = [] - for layer in self._input_layers: - if layer.input_spec is None: - specs.append(None) - else: - if not isinstance(layer.input_spec, list): - raise TypeError('Layer ' + layer.name + - ' has an input_spec attribute that ' - 'is not a list. We expect a list. ' - 'Found input_spec = ' + str(layer.input_spec)) - specs += layer.input_spec - if len(specs) == 1: - return specs[0] - return specs + return @base_layer_utils.default def build(self, input_shape): @@ -648,6 +637,11 @@ class Network(base_layer.Layer): if isinstance(input_shape, list): x = [base_layer_utils.generate_placeholders_from_shape(shape) for shape in input_shape] + elif isinstance(input_shape, dict): + x = { + k: base_layer_utils.generate_placeholders_from_shape(shape) + for k, shape in input_shape.items() + } else: x = base_layer_utils.generate_placeholders_from_shape(input_shape) @@ -834,8 +828,7 @@ class Network(base_layer.Layer): tensor_dict = {} for x, y in zip(self.inputs, inputs): - x_id = str(id(x)) - tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] + # Set shape and dtype based on `keras.Input`s. if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor): try: y.set_shape(y.shape.merge_with(x.shape)) @@ -844,6 +837,11 @@ class Network(base_layer.Layer): 'Model was constructed with shape {} for input {}, but it was ' 're-called on a Tensor with incompatible shape {}.' .format(x, x.shape, y.shape)) + if isinstance(x, (ops.Tensor, composite_tensor.CompositeTensor)): + y = math_ops.cast(y, dtype=x.dtype) + + x_id = str(id(x)) + tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] depth_keys = list(self._nodes_by_depth.keys()) depth_keys.sort(reverse=True) @@ -1533,6 +1531,32 @@ class Network(base_layer.Layer): new_layers.append(add_metric_layer) self._insert_layers(new_layers, new_nodes) + @trackable.no_automatic_dependency_tracking + def _set_save_spec(self, inputs): + if self._saved_model_inputs_spec is not None: + return # Already set. + + input_names = self.input_names + if not input_names: + input_names = compile_utils.create_pseudo_input_names(inputs) + + flat_inputs = nest.flatten(inputs) + specs = [] + for name, tensor in zip(input_names, flat_inputs): + specs.append( + tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name)) + specs = nest.pack_sequence_as(inputs, specs) + + self._saved_model_inputs_spec = specs + + def _get_save_spec(self, dynamic_batch=True): + if self._saved_model_inputs_spec is None: + return None + + return nest.map_structure( + lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch), + self._saved_model_inputs_spec) + @property def _trackable_saved_model_saver(self): return network_serialization.NetworkSavedModelSaver(self) diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index a86084f1a35..4ae06bc46e1 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -266,6 +266,10 @@ class Sequential(training.Model): self.built = True def call(self, inputs, training=None, mask=None): # pylint: disable=redefined-outer-name + if self._build_input_shape is None: + input_shapes = nest.map_structure(_get_shape_tuple, inputs) + self._build_input_shape = input_shapes + if self._is_graph_network: if not self.built: self._init_graph_network(self.inputs, self.outputs, name=self.name) @@ -364,7 +368,7 @@ class Sequential(training.Model): 'name': self.name, 'layers': copy.deepcopy(layer_configs) } - if self._build_input_shape: + if self._build_input_shape is not None: config['build_input_shape'] = self._build_input_shape return config @@ -383,7 +387,8 @@ class Sequential(training.Model): layer = layer_module.deserialize(layer_config, custom_objects=custom_objects) model.add(layer) - if not model.inputs and build_input_shape: + if (not model.inputs and build_input_shape and + isinstance(build_input_shape, (tuple, list))): model.build(build_input_shape) return model @@ -396,3 +401,12 @@ class Sequential(training.Model): @property def _trackable_saved_model_saver(self): return model_serialization.SequentialSavedModelSaver(self) + + +def _get_shape_tuple(t): + if hasattr(t, 'shape'): + shape = t.shape + if shape.rank is not None: + return tuple(shape.as_list()) + return None + return None diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 65e58fd82cd..b5f24674b06 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -286,9 +286,16 @@ class TestSequential(keras_parameterized.TestCase): self.assertTrue(model.built) config = model.get_config() - self.assertIn('build_input_shape', config) - new_model = keras.models.Sequential.from_config(config) + new_model.compile( + loss='mse', + optimizer='rmsprop', + metrics=[keras.metrics.CategoricalAccuracy()], + run_eagerly=testing_utils.should_run_eagerly(), + experimental_run_tf_function=testing_utils.should_run_tf_function()) + x = np.random.random((batch_size, input_dim)) + y = np.random.random((batch_size, num_classes)) + new_model.train_on_batch(x, y) self.assertEqual(len(new_model.layers), 2) self.assertEqual(len(new_model.weights), 4) @@ -321,15 +328,12 @@ class TestSequential(keras_parameterized.TestCase): self.assertFalse(model.built) model(array_ops.zeros([1, 2])) self.assertTrue(model.built) - self.assertEqual(len(model.outputs), 0) model.compile( 'rmsprop', loss='mse', run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) - self.assertEqual(len(model.outputs), 0) model.train_on_batch(np.zeros((1, 2)), np.zeros((1, 5))) - self.assertEqual(len(model.outputs), 1) @keras_parameterized.run_all_keras_modes def test_sequential_nesting(self): @@ -399,29 +403,21 @@ class TestSequential(keras_parameterized.TestCase): ValueError, 'should have a single output tensor'): keras.Sequential([MultiOutputLayer()])(np.zeros((10, 10))) - @keras_parameterized.run_all_keras_modes + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_layer_add_after_compile_deferred(self): model = keras.Sequential([keras.layers.Dense(3)]) - self.assertFalse(model.built) - self.assertFalse(model.inputs) - self.assertFalse(model.outputs) model.compile('adam', loss='mse') model.fit(np.random.random((1, 3)), np.random.random((1, 3))) - self.assertTrue(model.built) - self.assertTrue(model.inputs) - self.assertTrue(model.outputs) model.add(keras.layers.Dense(3)) - - self.assertTrue(model.built) - self.assertTrue(model.inputs) - self.assertTrue(model.outputs) + self.assertFalse(model.built) model.compile('adam', loss='mse') model.fit(np.random.random((1, 3)), np.random.random((1, 3))) + self.assertTrue(model.built) def test_sequential_layer_tracking(self): """Test that Sequential only tracks layers added in init or `.add`.""" @@ -442,21 +438,6 @@ class TestSequential(keras_parameterized.TestCase): model.pop() self.assertEqual(model._layers[-1], layer) - @testing_utils.enable_v2_dtype_behavior - def test_sequential_does_not_autocast(self): - - class AssertFloat64InputLayer(keras.layers.Layer): - - def __init__(self): - super(AssertFloat64InputLayer, self).__init__(autocast=False) - - def call(self, inputs): - assert inputs.dtype == 'float64', 'inputs are %s' % inputs.dtype - return array_ops.identity(inputs) - - model = keras.Sequential([AssertFloat64InputLayer(), keras.layers.Dense(4)]) - model(np.random.random((4, 4))) - class TestSequentialEagerIntegration(keras_parameterized.TestCase): @@ -500,27 +481,6 @@ class TestSequentialEagerIntegration(keras_parameterized.TestCase): y = np.random.random((2, 5)) model.fit(x, y, epochs=1) - @keras_parameterized.run_all_keras_modes - def test_sequential_model_fails_with_dict_inputs(self): - num_classes = 5 - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=num_classes) - model.compile( - 'rmsprop', - metrics=['acc'], - weighted_metrics=['mae'], - loss='categorical_crossentropy', - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - x = {'dense_input': np.random.random((10, 1))} - y = np.random.randint(num_classes, size=(10, 1)) - - with self.assertRaisesRegexp( - ValueError, 'Passing a dictionary input to a Sequential Model which ' - 'doesn\'t have FeatureLayer as the first layer is an error'): - model.fit(x, y, batch_size=5, epochs=1) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 298c09a0f12..7e86d9e2d8b 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -18,61 +18,73 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections - -import numpy as np - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import values as ds_values +from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import monitoring -from tensorflow.python.framework import composite_tensor_utils -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec -from tensorflow.python.framework import tensor_util -from tensorflow.python.keras import backend as K -from tensorflow.python.keras import losses -from tensorflow.python.keras import metrics as metrics_module +from tensorflow.python.keras import callbacks as callbacks_module from tensorflow.python.keras import optimizers -from tensorflow.python.keras.distribute import distributed_training_utils +from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils +from tensorflow.python.keras.engine import compile_utils +from tensorflow.python.keras.engine import data_adapter from tensorflow.python.keras.engine import network -from tensorflow.python.keras.engine import training_distributed from tensorflow.python.keras.engine import training_utils -from tensorflow.python.keras.engine import training_v2 -from tensorflow.python.keras.engine import training_v2_utils -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer -from tensorflow.python.keras.optimizer_v2 import optimizer_v2 +from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso from tensorflow.python.keras.saving.saved_model import model_serialization -from tensorflow.python.keras.utils import data_utils -from tensorflow.python.keras.utils import losses_utils -from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.losses import util as tf_losses_utils -from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops.ragged import ragged_concat_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.training.tracking import base as trackable -from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils from tensorflow.python.util import deprecation from tensorflow.python.util import nest -from tensorflow.python.util import tf_inspect +from tensorflow.python.util import tf_decorator from tensorflow.python.util.tf_export import keras_export -try: - from scipy.sparse import issparse # pylint: disable=g-import-not-at-top -except ImportError: - issparse = None _keras_api_gauge = monitoring.BoolGauge('/tensorflow/api/keras', 'keras api usage', 'method') +def enable_multi_worker(method): + """Decorator that handles running `method` with multi-worker strategy.""" + + def _method_wrapper(self, *args, **kwargs): + if not self._in_multi_worker_mode(): # pylint: disable=protected-access + return method(self, *args, **kwargs) + + return dc.run_distribute_coordinator( + lambda _: method(self, *args, **kwargs), + self.distribute_strategy, + mode=dc.CoordinatorMode.INDEPENDENT_WORKER) + + return tf_decorator.make_decorator( + target=method, decorator_func=_method_wrapper) + + +def disable_multi_worker(method): + """Decorator that disallows multi-worker use of `method`.""" + + def _method_wrapper(self, *args, **kwargs): + strategy = self.distribute_strategy + if (self._in_multi_worker_mode() or dist_utils.is_tpu_strategy(strategy) and # pylint: disable=protected-access + strategy.extended.num_hosts > 1): + raise ValueError('{} is not supported in multi-worker mode.'.format( + method.__name__)) + + return method(self, *args, **kwargs) + + return tf_decorator.make_decorator( + target=method, decorator_func=_method_wrapper) + + @keras_export('keras.Model', 'keras.models.Model') class Model(network.Network, version_utils.ModelVersionSelector): """`Model` groups layers into an object with training and inference features. @@ -148,7 +160,6 @@ class Model(network.Network, version_utils.ModelVersionSelector): def __init__(self, *args, **kwargs): super(Model, self).__init__(*args, **kwargs) _keras_api_gauge.get_cell('model').set(True) - # Model must be created under scope of DistStrat it will be trained with. if ds_context.has_strategy(): self._distribution_strategy = ds_context.get_strategy() @@ -156,6 +167,12 @@ class Model(network.Network, version_utils.ModelVersionSelector): self._distribution_strategy = None # Defaults to value of `tf.config.experimental_functions_run_eagerly`. self._run_eagerly = None + self.stop_training = False + # Initialize cache attrs. + self._reset_compile_cache() + + # Fault-tolerance handler. Set in `ModelCheckpoint`. + self._training_state = None def get_weights(self): """Retrieves the weights of the model. @@ -212,14 +229,13 @@ class Model(network.Network, version_utils.ModelVersionSelector): ValueError: If `skip_mismatch` is set to `True` when `by_name` is `False`. """ - if distributed_training_utils.is_tpu_strategy(self._distribution_strategy): + if dist_utils.is_tpu_strategy(self._distribution_strategy): if (self._distribution_strategy.extended.steps_per_run > 1 and (not network._is_hdf5_filepath(filepath))): # pylint: disable=protected-access raise ValueError('Load weights is not yet supported with TPUStrategy ' 'with steps_per_run greater than 1.') return super(Model, self).load_weights(filepath, by_name, skip_mismatch) - @trackable.no_automatic_dependency_tracking def compile(self, optimizer='rmsprop', loss=None, @@ -291,105 +307,52 @@ class Model(network.Network, version_utils.ModelVersionSelector): ValueError: In case of invalid arguments for `optimizer`, `loss`, `metrics` or `sample_weight_mode`. """ + _keras_api_gauge.get_cell('compile').set(True) self._validate_compile(optimizer, **kwargs) self._run_eagerly = kwargs.pop('run_eagerly', None) - self._set_optimizer(optimizer) - # We've disabled automatic dependency tracking for this method, but do want - # to add a checkpoint dependency on the optimizer if it's trackable. - if isinstance(self.optimizer, trackable.Trackable): - self._track_trackable( - self.optimizer, name='optimizer', overwrite=True) - self.loss = loss or {} - self.loss_weights = loss_weights - self.sample_weight_mode = sample_weight_mode - self._compile_metrics = metrics or [] - self._compile_weighted_metrics = weighted_metrics - # _training_endpoints contains a list of _TrainingEndpoint object, which has - # all the model output/target/loss and related metadata. - self._training_endpoints = [] + self.optimizer = self._get_optimizer(optimizer) + self.compiled_loss = compile_utils.LossesContainer( + loss, loss_weights, output_names=self.output_names) + self.compiled_metrics = compile_utils.MetricsContainer( + metrics, weighted_metrics, output_names=self.output_names) - # Used to freeze the behavior of the Model once `compile` has been called. - self._compiled_trainable_state = self._get_trainable_state() - - # Set tf.distribute.Strategy specific parameters. - self._distributed_model_cache = {} - self._distributed_function_cache = {} - - # Clear any `_eager_losses` cached from a previous `Model.__call__`. - self._clear_losses() - - # Initialize model metric attributes. - self._init_metric_attributes() - if not self.built or not self.inputs or not self.outputs: - # Model is not compilable because it does not know its number of inputs - # and outputs, nor their shapes and names. We will compile after the first - # time the model gets called on training data. - return + # Initializes attrs that are reset each time `compile` is called. + self._reset_compile_cache() self._is_compiled = True - _keras_api_gauge.get_cell('compile').set(True) - # Prepare list of loss functions, same size of model outputs. - self.loss_functions = training_utils.prepare_loss_functions( - self.loss, self.output_names) + self.loss = loss or {} # Backwards compat. - target_tensors = self._process_target_tensor_for_compile(None) - for o, n, l, t in zip(self.outputs, self.output_names, - self.loss_functions, target_tensors): - endpoint = _TrainingEndpoint(o, n, l) - endpoint.create_training_target(t, run_eagerly=self.run_eagerly) - self._training_endpoints.append(endpoint) + def _get_optimizer(self, optimizer): + """Wraps `optimizer` in `LossScaleOptimizer` if necessary.""" - # Prepare list loss weights, same size of model outputs. - training_utils.prepare_loss_weights(self._training_endpoints, loss_weights) + def _get_single_optimizer(opt): + opt = optimizers.get(opt) + if (self._dtype_policy.loss_scale is not None and + not isinstance(opt, lso.LossScaleOptimizer)): + opt = lso.LossScaleOptimizer(opt, self._dtype_policy.loss_scale) + return opt - # Initialization for Eager mode execution. - if self.run_eagerly: - self._compile_eagerly(metrics, weighted_metrics, sample_weight_mode) - return - - with K.get_graph().as_default(): - # Save all metric attributes per output of the model. - self._cache_output_metric_attributes(metrics, weighted_metrics) - - # Set metric attributes on model. - self._set_metric_attributes() - - # Invoke metric functions (unweighted) for all the outputs. - self._handle_metrics( - self.outputs, - targets=self._targets, - skip_target_masks=self._prepare_skip_target_masks(), - masks=self._prepare_output_masks()) - - # Prepare sample weight modes. List with the same length as model outputs. - training_utils.prepare_sample_weight_modes( - self._training_endpoints, sample_weight_mode) - - # Creates the model loss and weighted metrics sub-graphs. - self._compile_weights_loss_and_weighted_metrics() - - # Functions for train, test and predict will - # be compiled lazily when required. - # This saves time when the user is not using all functions. - self.train_function = None - self.test_function = None - self.predict_function = None - - # Collected trainable weights, sorted in topological order. - self._collected_trainable_weights = self.trainable_weights + return nest.map_structure(_get_single_optimizer, optimizer) @trackable.no_automatic_dependency_tracking - def _init_distributed_function_cache_if_not_compiled(self): - if not hasattr(self, '_distributed_function_cache'): - self._distributed_function_cache = {} + def _reset_compile_cache(self): + self.train_function = None + self.test_function = None + self.predict_function = None + + # Used to cache `trainable` attr of `Layer`s for `fit`. + self._compiled_trainable_state = self._get_trainable_state() @property def metrics(self): """Returns the model's metrics added using `compile`, `add_metric` APIs.""" metrics = [] if self._is_compiled: - metrics += self._compile_metric_functions + # TODO(omalleyt): Track `CompiledLoss` and `CompiledMetrics` objects + # so that attr names are not load-bearing. + metrics = self.compiled_loss.metrics + self.compiled_metrics.metrics + all_layers = self._gather_unique_layers() for l in all_layers: metrics.extend(l._metrics) # pylint: disable=protected-access @@ -401,26 +364,12 @@ class Model(network.Network, version_utils.ModelVersionSelector): # This property includes all output names including `loss` and per-output # losses for backward compatibility. - metrics_names = ['loss'] - if self._is_compiled: - # Add output loss metric names to the metric names list. - if len(self._training_endpoints) > 1: - metrics_names.extend([ - e.loss_name() - for e in self._training_endpoints - if not e.should_skip_target() - ]) - - # Add all metric names. - metrics_names += [m.name for m in self.metrics] - return metrics_names + return [m.name for m in self.metrics] @property def distribute_strategy(self): """The `tf.distribute.Strategy` this model was created under.""" - if self._distribution_strategy is None: - return ds_context._get_default_strategy() # pylint: disable=protected-access - return self._distribution_strategy + return self._distribution_strategy or ds_context.get_strategy() @property def run_eagerly(self): @@ -465,26 +414,93 @@ class Model(network.Network, version_utils.ModelVersionSelector): def run_eagerly(self, value): self._run_eagerly = value - def _select_training_loop(self, inputs): - """Select training loop for fit/eval/predict based on the inputs.""" - # TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely - # integrated into the data adapters in the v2 loop. We can't do this yet - # because we currently have to fall back for unhandled data types. - if isinstance(inputs, (iterator_ops.Iterator, - iterator_ops.OwnedIterator)): - raise ValueError('For performance reasons Keras `fit`, `evaluate` and' - '`predict` accept tf.data `Datasets` as input but not ' - 'iterators that have been manually generated from ' - 'Datasets by users. Please directly pass in the ' - 'original `Dataset` object instead of passing in ' - '`iter(dataset)`.') + def _train_step(self, data): + """The logic for one training step. - if self._in_multi_worker_mode(): - return training_distributed.DistributionMultiWorkerTrainingLoop( - training_v2.Loop()) - else: - return training_v2.Loop() + This method can be overridden to support custom training logic. + This method is called by `Model._make_train_function`. + This method should contain the mathemetical logic for one step of training. + This typically includes the forward pass, loss calculation, backpropagation, + and metric updates. + + Configuration details for *how* this logic is run (e.g. `tf.function` and + `tf.distribute.Strategy` settings), should be left to + `Model._make_train_function`, which can also be overridden. + + Arguments: + data: A nested structure of `Tensor`s. + + Returns: + A `dict` containing values that will be passed to + `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the + values of the `Model`'s metrics are returned. Example: + `{'loss': 0.2, 'accuracy': 0.7}`. + + """ + # These are the only transformations `Model.fit` applies to user-input + # data when a `tf.data.Dataset` is provided. These utilities will be exposed + # publicly. + data = data_adapter.expand_1d(data) + x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) + + with backprop.GradientTape() as tape: + y_pred = self(x, training=True) + loss = self.compiled_loss( + y, y_pred, sample_weight, regularization_losses=self.losses) + if isinstance(self.optimizer, lso.LossScaleOptimizer): + loss = self.optimizer.get_scaled_loss(loss) + + trainable_variables = self.trainable_variables + gradients = tape.gradient(loss, trainable_variables) + if isinstance(self.optimizer, lso.LossScaleOptimizer): + gradients = self.optimizer.get_unscaled_gradients(gradients) + gradients = self.optimizer._clip_gradients(gradients) # pylint: disable=protected-access + if trainable_variables: + self.optimizer.apply_gradients(zip(gradients, trainable_variables)) + + self.compiled_metrics.update_state(y, y_pred, sample_weight) + return {m.name: m.result() for m in self.metrics} + + def _make_train_function(self): + """Creates a function that executes one step of training. + + This method can be overridden to support custom training logic. + This method is called by `Model.fit` and `Model.train_on_batch`. + + Typically, this method directly controls `tf.function` and + `tf.distribute.Strategy` settings, and delegates the actual training + logic to `Model._train_step`. + + This function is cached the first time `Model.fit` or + `Model.train_on_batch` is called. The cache is cleared whenever + `Model.compile` is called. + + Returns: + Function. The function created by this method should accept a + `tf.data.Iterator`, and return a `dict` containing values that will + be passed to `tf.keras.Callbacks.on_train_batch_end`, such as + `{'loss': 0.2, 'accuracy': 0.7}`. + """ + if self.train_function is not None: + return self.train_function + + def train_function(iterator): + data = next(iterator) + outputs = self.distribute_strategy.experimental_run_v2( + self._train_step, args=(data,)) + outputs = reduce_per_replica( + outputs, self.distribute_strategy, reduction='first') + return outputs + + if not self.run_eagerly: + train_function = def_function.function( + train_function, experimental_relax_shapes=True) + + self.train_function = train_function + return self.train_function + + @enable_multi_worker def fit(self, x=None, y=None, @@ -500,6 +516,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): initial_epoch=0, steps_per_epoch=None, validation_steps=None, + validation_batch_size=None, validation_freq=1, max_queue_size=10, workers=1, @@ -532,9 +549,8 @@ class Model(network.Network, version_utils.ModelVersionSelector): Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the - form of symbolic tensors, datasets, - generators, or `keras.utils.Sequence` instances (since they generate - batches). + form of datasets, generators, or `keras.utils.Sequence` instances + (since they generate batches). epochs: Integer. Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided. @@ -624,6 +640,12 @@ class Model(network.Network, version_utils.ModelVersionSelector): the dataset will be consumed, the evaluation will start from the beginning of the dataset at each epoch. This ensures that the same validation samples are used every time. + validation_batch_size: Integer or `None`. + Number of samples per validation batch. + If unspecified, will default to `batch_size`. + Do not specify the `validation_batch_size` if your data is in the + form of datasets, generators, or `keras.utils.Sequence` instances + (since they generate batches). validation_freq: Only relevant if validation data is provided. Integer or `collections_abc.Container` instance (e.g. list, tuple, etc.). If an integer, specifies how many training epochs to run before a @@ -685,38 +707,160 @@ class Model(network.Network, version_utils.ModelVersionSelector): _keras_api_gauge.get_cell('fit').set(True) # Legacy graph support is contained in `training_v1.Model`. version_utils.disallow_legacy_graph('Model', 'fit') - # Legacy support - if 'nb_epoch' in kwargs: - logging.warning( - 'The `nb_epoch` argument in `fit` has been renamed `epochs`.') - epochs = kwargs.pop('nb_epoch') - if kwargs: - raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) self._assert_compile_was_called() self._check_call_args('fit') - func = self._select_training_loop(x) - return func.fit( - self, - x=x, - y=y, - batch_size=batch_size, - epochs=epochs, - verbose=verbose, - callbacks=callbacks, - validation_split=validation_split, - validation_data=validation_data, - shuffle=shuffle, - class_weight=class_weight, - sample_weight=sample_weight, - initial_epoch=initial_epoch, - steps_per_epoch=steps_per_epoch, - validation_steps=validation_steps, - validation_freq=validation_freq, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing) + if validation_split: + # Create the validation data using the training data. Only supported for + # `Tensor` and `NumPy` input. + (x, y, sample_weight), validation_data = ( + data_adapter.train_validation_split((x, y, sample_weight), + validation_split=validation_split, + shuffle=False)) + with self.distribute_strategy.scope(), \ + training_utils.RespectCompiledTrainableState(self): + # Creates a `tf.data.Dataset` and handles batch and epoch iteration. + data_handler = data_adapter.DataHandler( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + initial_epoch=initial_epoch, + epochs=epochs, + shuffle=shuffle, + class_weight=class_weight, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + model=self) + + # Container that configures and calls `tf.keras.Callback`s. + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=True, + model=self, + verbose=verbose, + epochs=epochs, + steps=data_handler._steps) # pylint: disable=protected-access + + self.stop_training = False + train_function = self._make_train_function() + callbacks.on_train_begin() + # Handle fault-tolerance for multi-worker. + # TODO(omalleyt): Fix the ordering issues that mean this has to + # happen after `callbacks.on_train_begin`. + data_handler._initial_epoch = ( # pylint: disable=protected-access + self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) + for epoch, iterator in data_handler.enumerate_epochs(): + self.reset_metrics() + callbacks.on_epoch_begin(epoch) + with data_handler.catch_stop_iteration(): + for step in data_handler.steps(): + callbacks.on_train_batch_begin(step) + logs = train_function(iterator) + callbacks.on_train_batch_end(step, logs) + epoch_logs = {m.name: m.result() for m in self.metrics} + + # Run validation. + if validation_data and self._should_eval(epoch, validation_freq): + val_x, val_y, val_sample_weight = ( + data_adapter.unpack_x_y_sample_weight(validation_data)) + val_logs = self.evaluate( + x=val_x, + y=val_y, + sample_weight=val_sample_weight, + batch_size=validation_batch_size or batch_size, + steps=validation_steps, + callbacks=callbacks, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + return_dict=True) + val_logs = {'val_' + name: val for name, val in val_logs.items()} + epoch_logs.update(val_logs) + + callbacks.on_epoch_end(epoch, epoch_logs) + if self.stop_training: + break + + callbacks.on_train_end() + return self.history + + def _test_step(self, data): + """The logic for one evaluation step. + + This method can be overridden to support custom evaluation logic. + This method is called by `Model._make_test_function`. + + This function should contain the mathemetical logic for one step of + evaluation. + This typically includes the forward pass, loss calculation, and metrics + updates. + + Configuration details for *how* this logic is run (e.g. `tf.function` and + `tf.distribute.Strategy` settings), should be left to + `Model._make_test_function`, which can also be overridden. + + Arguments: + data: A nested structure of `Tensor`s. + + Returns: + A `dict` containing values that will be passed to + `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the + values of the `Model`'s metrics are returned. + """ + data = data_adapter.expand_1d(data) + x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) + + y_pred = self(x, training=False) + # Updates stateful loss metrics. + self.compiled_loss( + y, y_pred, sample_weight, regularization_losses=self.losses) + + self.compiled_metrics.update_state(y, y_pred, sample_weight) + return {m.name: m.result() for m in self.metrics} + + def _make_test_function(self): + """Creates a function that executes one step of evaluation. + + This method can be overridden to support custom evaluation logic. + This method is called by `Model.evaluate` and `Model.test_on_batch`. + + Typically, this method directly controls `tf.function` and + `tf.distribute.Strategy` settings, and delegates the actual evaluation + logic to `Model._test_step`. + + This function is cached the first time `Model.evaluate` or + `Model.test_on_batch` is called. The cache is cleared whenever + `Model.compile` is called. + + Returns: + Function. The function created by this method should accept a + `tf.data.Iterator`, and return a `dict` containing values that will + be passed to `tf.keras.Callbacks.on_test_batch_end`. + """ + if self.test_function is not None: + return self.test_function + + def test_function(iterator): + data = next(iterator) + outputs = self.distribute_strategy.experimental_run_v2( + self._test_step, args=(data,)) + outputs = reduce_per_replica( + outputs, self.distribute_strategy, reduction='first') + return outputs + + if not self.run_eagerly: + test_function = def_function.function( + test_function, experimental_relax_shapes=True) + + self.test_function = test_function + return self.test_function + + @enable_multi_worker def evaluate(self, x=None, y=None, @@ -727,76 +871,67 @@ class Model(network.Network, version_utils.ModelVersionSelector): callbacks=None, max_queue_size=10, workers=1, - use_multiprocessing=False): + use_multiprocessing=False, + return_dict=False): """Returns the loss value & metrics values for the model in test mode. Computation is done in batches. Arguments: - x: Input data. It could be: - - A Numpy array (or array-like), or a list of arrays - (in case the model has multiple inputs). - - A TensorFlow tensor, or a list of tensors - (in case the model has multiple inputs). - - A dict mapping input names to the corresponding array/tensors, - if the model has named inputs. - - A `tf.data` dataset. - - A generator or `keras.utils.Sequence` instance. - A more detailed description of unpacking behavior for iterator types - (Dataset, generator, Sequence) is given in the `Unpacking behavior - for iterator-like inputs` section of `Model.fit`. - y: Target data. Like the input data `x`, - it could be either Numpy array(s) or TensorFlow tensor(s). - It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). - If `x` is a dataset, generator or - `keras.utils.Sequence` instance, `y` should not be specified (since - targets will be obtained from the iterator/dataset). - batch_size: Integer or `None`. - Number of samples per gradient update. - If unspecified, `batch_size` will default to 32. - Do not specify the `batch_size` if your data is in the - form of symbolic tensors, dataset, - generators, or `keras.utils.Sequence` instances (since they generate - batches). - verbose: 0 or 1. Verbosity mode. - 0 = silent, 1 = progress bar. - sample_weight: Optional Numpy array of weights for - the test samples, used for weighting the loss function. - You can either pass a flat (1D) - Numpy array with the same length as the input samples - (1:1 mapping between weights and samples), - or in the case of temporal data, - you can pass a 2D array with shape - `(samples, sequence_length)`, - to apply a different weight to every timestep of every sample. - In this case you should make sure to specify - `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset, instead pass - sample weights as the third element of `x`. - steps: Integer or `None`. - Total number of steps (batches of samples) - before declaring the evaluation round finished. - Ignored with the default value of `None`. - If x is a `tf.data` dataset and `steps` is - None, 'evaluate' will run until the dataset is exhausted. - This argument is not supported with array inputs. - callbacks: List of `keras.callbacks.Callback` instances. - List of callbacks to apply during evaluation. - See [callbacks](/api_docs/python/tf/keras/callbacks). + x: Input data. It could be: - A Numpy array (or array-like), or a list + of arrays (in case the model has multiple inputs). - A TensorFlow + tensor, or a list of tensors (in case the model has multiple inputs). + - A dict mapping input names to the corresponding array/tensors, if + the model has named inputs. - A `tf.data` dataset. - A generator or + `keras.utils.Sequence` instance. A more detailed description of + unpacking behavior for iterator types (Dataset, generator, Sequence) + is given in the `Unpacking behavior for iterator-like inputs` section + of `Model.fit`. + y: Target data. Like the input data `x`, it could be either Numpy + array(s) or TensorFlow tensor(s). It should be consistent with `x` + (you cannot have Numpy inputs and tensor targets, or inversely). If + `x` is a dataset, generator or `keras.utils.Sequence` instance, `y` + should not be specified (since targets will be obtained from the + iterator/dataset). + batch_size: Integer or `None`. Number of samples per gradient update. If + unspecified, `batch_size` will default to 32. Do not specify the + `batch_size` if your data is in the form of a dataset, generators, + or `keras.utils.Sequence` instances (since they generate batches). + verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar. + sample_weight: Optional Numpy array of weights for the test samples, + used for weighting the loss function. You can either pass a flat (1D) + Numpy array with the same length as the input samples + (1:1 mapping between weights and samples), or in the case of + temporal data, you can pass a 2D array with shape `(samples, + sequence_length)`, to apply a different weight to every timestep + of every sample. In this case you should make sure to specify + `sample_weight_mode="temporal"` in `compile()`. This argument is + not supported when `x` is a dataset, instead pass sample weights + as the third element of `x`. + steps: Integer or `None`. Total number of steps (batches of samples) + before declaring the evaluation round finished. Ignored with the + default value of `None`. If x is a `tf.data` dataset and `steps` is + None, 'evaluate' will run until the dataset is exhausted. This + argument is not supported with array inputs. + callbacks: List of `keras.callbacks.Callback` instances. List of + callbacks to apply during evaluation. See + [callbacks](/api_docs/python/tf/keras/callbacks). max_queue_size: Integer. Used for generator or `keras.utils.Sequence` - input only. Maximum size for the generator queue. - If unspecified, `max_queue_size` will default to 10. + input only. Maximum size for the generator queue. If unspecified, + `max_queue_size` will default to 10. workers: Integer. Used for generator or `keras.utils.Sequence` input - only. Maximum number of processes to spin up when using - process-based threading. If unspecified, `workers` will default - to 1. If 0, will execute the generator on the main thread. + only. Maximum number of processes to spin up when using process-based + threading. If unspecified, `workers` will default to 1. If 0, will + execute the generator on the main thread. use_multiprocessing: Boolean. Used for generator or - `keras.utils.Sequence` input only. If `True`, use process-based - threading. If unspecified, `use_multiprocessing` will default to - `False`. Note that because this implementation relies on - multiprocessing, you should not pass non-picklable arguments to - the generator as they can't be passed easily to children processes. + `keras.utils.Sequence` input only. If `True`, use process-based + threading. If unspecified, `use_multiprocessing` will default to + `False`. Note that because this implementation relies on + multiprocessing, you should not pass non-picklable arguments to the + generator as they can't be passed easily to children processes. + return_dict: If `True`, loss and metric results are returned as a dict, + with each key being the name of the metric. If `False`, they are + returned as a list. See the discussion of `Unpacking behavior for iterator-like inputs` for `Model.fit`. @@ -815,20 +950,112 @@ class Model(network.Network, version_utils.ModelVersionSelector): self._assert_compile_was_called() self._check_call_args('evaluate') - func = self._select_training_loop(x) - return func.evaluate( - self, - x=x, - y=y, - batch_size=batch_size, - verbose=verbose, - sample_weight=sample_weight, - steps=steps, - callbacks=callbacks, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing) + with self.distribute_strategy.scope(): + # Creates a `tf.data.Dataset` and handles batch and epoch iteration. + data_handler = data_adapter.DataHandler( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + initial_epoch=0, + epochs=1, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + model=self) + # Container that configures and calls `tf.keras.Callback`s. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=True, + model=self, + verbose=verbose, + epochs=1, + steps=data_handler._steps) # pylint: disable=protected-access + + test_function = self._make_test_function() + callbacks.on_test_begin() + for _, iterator in data_handler.enumerate_epochs(): # Single epoch. + self.reset_metrics() + with data_handler.catch_stop_iteration(): + for step in data_handler.steps(): + callbacks.on_test_batch_begin(step) + logs = test_function(iterator) + callbacks.on_test_batch_end(step, logs) + callbacks.on_test_end() + + if return_dict: + return {m.name: m.result().numpy() for m in self.metrics} + else: + results = [m.result().numpy() for m in self.metrics] + if len(results) == 1: + return results[0] + return results + + def _predict_step(self, data): + """The logic for one inference step. + + This method can be overridden to support custom inference logic. + This method is called by `Model._make_predict_function`. + + This method should contain the mathemetical logic for one step of inference. + This typically includes the forward pass. + + Configuration details for *how* this logic is run (e.g. `tf.function` and + `tf.distribute.Strategy` settings), should be left to + `Model._make_predict_function`, which can also be overridden. + + Arguments: + data: A nested structure of `Tensor`s. + + Returns: + The result of one inference step, typically the output of calling the + `Model` on data. + """ + data = data_adapter.expand_1d(data) + x, _, _ = data_adapter.unpack_x_y_sample_weight(data) + return self(x, training=False) + + def _make_predict_function(self): + """Creates a function that executes one step of inference. + + This method can be overridden to support custom inference logic. + This method is called by `Model.predict` and `Model.predict_on_batch`. + + Typically, this method directly controls `tf.function` and + `tf.distribute.Strategy` settings, and delegates the actual evaluation + logic to `Model._predict_step`. + + This function is cached the first time `Model.predict` or + `Model.predict_on_batch` is called. The cache is cleared whenever + `Model.compile` is called. + + Returns: + Function. The function created by this method should accept a + `tf.data.Iterator`, and return the outputs of the `Model`. + """ + if self.predict_function is not None: + return self.predict_function + + def predict_function(iterator): + data = next(iterator) + outputs = self.distribute_strategy.experimental_run_v2( + self._predict_step, args=(data,)) + outputs = reduce_per_replica( + outputs, self.distribute_strategy, reduction='concat') + return outputs + + if not self.run_eagerly: + predict_function = def_function.function( + predict_function, experimental_relax_shapes=True) + + self.predict_function = predict_function + return self.predict_function + + @disable_multi_worker def predict(self, x, batch_size=None, @@ -862,9 +1089,8 @@ class Model(network.Network, version_utils.ModelVersionSelector): Number of samples per batch. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the - form of symbolic tensors, dataset, - generators, or `keras.utils.Sequence` instances (since they generate - batches). + form of dataset, generators, or `keras.utils.Sequence` instances + (since they generate batches). verbose: Verbosity mode, 0 or 1. steps: Total number of steps (batches of samples) before declaring the prediction round finished. @@ -906,22 +1132,53 @@ class Model(network.Network, version_utils.ModelVersionSelector): version_utils.disallow_legacy_graph('Model', 'predict') self._check_call_args('predict') - func = self._select_training_loop(x) - return func.predict( - self, - x=x, - batch_size=batch_size, - verbose=verbose, - steps=steps, - callbacks=callbacks, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing) + outputs = None + with self.distribute_strategy.scope(): + # Creates a `tf.data.Dataset` and handles batch and epoch iteration. + data_handler = data_adapter.DataHandler( + x=x, + batch_size=batch_size, + steps_per_epoch=steps, + initial_epoch=0, + epochs=1, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + model=self) + + # Container that configures and calls `tf.keras.Callback`s. + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=True, + model=self, + verbose=verbose, + epochs=1, + steps=data_handler._steps) # pylint: disable=protected-access + + predict_function = self._make_predict_function() + callbacks.on_predict_begin() + for _, iterator in data_handler.enumerate_epochs(): # Single epoch. + with data_handler.catch_stop_iteration(): + for step in data_handler.steps(): + callbacks.on_predict_batch_begin(step) + batch_outputs = predict_function(iterator) + if outputs is None: + outputs = nest.map_structure(lambda batch_output: [batch_output], + batch_outputs) + else: + nest.map_structure_up_to( + batch_outputs, + lambda output, batch_output: output.append(batch_output), + outputs, batch_outputs) + callbacks.on_predict_batch_end(step, {'outputs': batch_outputs}) + callbacks.on_predict_end() + all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs) + return to_numpy(all_outputs) def reset_metrics(self): """Resets the state of metrics.""" - metrics = self._get_training_eval_metrics() - for m in metrics: + for m in self.metrics: m.reset_states() def train_on_batch(self, @@ -940,19 +1197,15 @@ class Model(network.Network, version_utils.ModelVersionSelector): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` - (you cannot have Numpy inputs and tensor targets, or inversely). If - `x` is a dataset, `y` should not be specified - (since targets will be obtained from the iterator). + (you cannot have Numpy inputs and tensor targets, or inversely). sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. In the case of temporal data, you can pass a 2D array with shape (samples, sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify - sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset. + sample_weight_mode="temporal" in compile(). class_weight: Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model's loss for the samples from this class during training. This can be useful to tell the model to "pay @@ -973,46 +1226,38 @@ class Model(network.Network, version_utils.ModelVersionSelector): """ self._assert_compile_was_called() self._check_call_args('train_on_batch') - outputs = training_v2_utils.train_on_batch( - self, - x, - y=y, - sample_weight=sample_weight, - class_weight=class_weight, - reset_metrics=reset_metrics, - standalone=True) - outputs = ( - outputs['total_loss'] + outputs['output_losses'] + outputs['metrics']) - outputs = [training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access - if len(outputs) == 1: - outputs = outputs[0] - return outputs + with self.distribute_strategy.scope(), \ + training_utils.RespectCompiledTrainableState(self): + iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x, + y, sample_weight, + class_weight) + train_function = self._make_train_function() + train_function(iterator) + metrics = [m.result().numpy() for m in self.metrics] + if reset_metrics: + self.reset_metrics() + if len(metrics) == 1: + return metrics[0] + return metrics def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True): """Test the model on a single batch of samples. Arguments: - x: Input data. It could be: - - A Numpy array (or array-like), or a list of arrays - (in case the model has multiple inputs). - - A TensorFlow tensor, or a list of tensors - (in case the model has multiple inputs). - - A dict mapping input names to the corresponding array/tensors, - if the model has named inputs. - - A `tf.data` dataset. - y: Target data. Like the input data `x`, - it could be either Numpy array(s) or TensorFlow tensor(s). - It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset `y` should - not be specified (since targets will be obtained from the iterator). + x: Input data. It could be: - A Numpy array (or array-like), or a list + of arrays (in case the model has multiple inputs). - A TensorFlow + tensor, or a list of tensors (in case the model has multiple inputs). + - A dict mapping input names to the corresponding array/tensors, if + the model has named inputs. + y: Target data. Like the input data `x`, it could be either Numpy + array(s) or TensorFlow tensor(s). It should be consistent with `x` + (you cannot have Numpy inputs and tensor targets, or inversely). sample_weight: Optional array of the same length as x, containing - weights to apply to the model's loss for each sample. - In the case of temporal data, you can pass a 2D array - with shape (samples, sequence_length), - to apply a different weight to every timestep of every sample. - In this case you should make sure to specify - sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset. + weights to apply to the model's loss for each sample. In the case of + temporal data, you can pass a 2D array with shape (samples, + sequence_length), to apply a different weight to every timestep of + every sample. In this case you should make sure to specify + sample_weight_mode="temporal" in compile(). reset_metrics: If `True`, the metrics returned will be only for this batch. If `False`, the metrics will be statefully accumulated across batches. @@ -1028,30 +1273,25 @@ class Model(network.Network, version_utils.ModelVersionSelector): """ self._assert_compile_was_called() self._check_call_args('test_on_batch') - outputs = training_v2_utils.test_on_batch( - self, - x, - y=y, - sample_weight=sample_weight, - reset_metrics=reset_metrics, - standalone=True) - outputs = ( - outputs['total_loss'] + outputs['output_losses'] + outputs['metrics']) - outputs = [training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access - if len(outputs) == 1: - outputs = outputs[0] - return outputs + with self.distribute_strategy.scope(): + iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x, + y, sample_weight) + test_function = self._make_test_function() + test_function(iterator) + metrics = [m.result().numpy() for m in self.metrics] + if reset_metrics: + self.reset_metrics() + if len(metrics) == 1: + return metrics[0] + return metrics def predict_on_batch(self, x): """Returns predictions for a single batch of samples. Arguments: - x: Input data. It could be: - - A Numpy array (or array-like), or a list of arrays - (in case the model has multiple inputs). - - A TensorFlow tensor, or a list of tensors - (in case the model has multiple inputs). - - A `tf.data` dataset. + x: Input data. It could be: - A Numpy array (or array-like), or a list + of arrays (in case the model has multiple inputs). - A TensorFlow + tensor, or a list of tensors (in case the model has multiple inputs). Returns: Numpy array(s) of predictions. @@ -1061,7 +1301,11 @@ class Model(network.Network, version_utils.ModelVersionSelector): expectations of the model. """ self._check_call_args('predict_on_batch') - return training_v2_utils.predict_on_batch(self, x, standalone=True) + with self.distribute_strategy.scope(): + iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x) + predict_function = self._make_predict_function() + outputs = predict_function(iterator) + return to_numpy(outputs) @deprecation.deprecated( None, 'Please use Model.fit, which supports generators.') @@ -1176,54 +1420,11 @@ class Model(network.Network, version_utils.ModelVersionSelector): 'and the first argument in `call` as positional arguments, ' 'found: ' + str(extra_args) + '.') - def _set_optimizer(self, optimizer): - """Sets self.optimizer. - - Sets self.optimizer to `optimizer`, potentially wrapping it with a - LossScaleOptimizer. - - Args: - optimizer: The optimizer(s) to assign to self.optimizer. - """ - if isinstance(optimizer, (list, tuple)): - self.optimizer = [optimizers.get(opt) for opt in optimizer] - else: - self.optimizer = optimizers.get(optimizer) - - if (self._dtype_policy.loss_scale is not None and - not isinstance(self.optimizer, - loss_scale_optimizer.LossScaleOptimizer)): - if isinstance(self.optimizer, list): - raise ValueError('When a dtype policy with a loss scale is used, you ' - 'can only pass a single optimizer. Using policy %s ' - 'and got optimizers: %s' % - self._dtype_policy, self.optimizer) - if not isinstance(self.optimizer, optimizer_v2.OptimizerV2): - raise ValueError('"optimizer" must be an instance of ' - 'tf.keras.optimizers.Optimizer when a dype policy ' - 'with a loss scale used, but got: %s. Using policy: ' - '%s' % - (self.optimizer, self._dtype_policy)) - self.optimizer = loss_scale_optimizer.LossScaleOptimizer( - self.optimizer, self._dtype_policy.loss_scale) - if (isinstance(self.optimizer, loss_scale_optimizer.LossScaleOptimizer) and - self._dtype_policy.loss_scale and - self.optimizer.loss_scale != self._dtype_policy.loss_scale): - logging.warning('LossScale of LossScaleOptimizer passed to compile (%s) ' - 'is not the same as the dtype policy\'s loss scale (%s). ' - 'Because the dtype policy has a loss scale, you should ' - 'pass an optimizer that is not wrapped with a ' - 'LossScaleOptimizer,' - % (self.optimizer.loss_scale, - self._dtype_policy.loss_scale)) - def _validate_compile(self, optimizer, **kwargs): """Performs validation checks for the default `compile`.""" - is_any_keras_optimizer_v1 = any( - (isinstance(opt, optimizers.Optimizer) and - not isinstance(opt, optimizers.TFOptimizer)) - for opt in nest.flatten(optimizer)) - if is_any_keras_optimizer_v1: + if any( + isinstance(opt, optimizers.Optimizer) + for opt in nest.flatten(optimizer)): raise ValueError( '`tf.compat.v1.keras` Optimizer (', optimizer, ') is ' 'not supported when eager execution is enabled. Use a ' @@ -1259,1331 +1460,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): ' model=_create_model()\n' ' model.compile(...)' % (v, strategy)) - def _prepare_validation_data(self, validation_data, batch_size, - validation_steps): - """Unpack and check the validation data.""" - val_x, val_y, val_sample_weights = training_utils.unpack_validation_data( - validation_data) - return self._standardize_user_data( - val_x, - val_y, - sample_weight=val_sample_weights, - batch_size=batch_size, - steps=validation_steps, - steps_name='validation_steps') - - def _process_target_tensor_for_compile(self, target_tensors): - if self.run_eagerly: - # target tensor is not supported with run_eagerly. Create a list with None - # as placeholder for each output. - return [None for _ in self.output_names] - - if target_tensors is not None and not (isinstance(target_tensors, list) and - target_tensors == []): # pylint: disable=g-explicit-bool-comparison - if isinstance(target_tensors, list): - if len(target_tensors) != len(self.outputs): - raise ValueError( - 'When passing a list as `target_tensors`, ' - 'it should have one entry per model output. ' - 'The model has %s outputs, but you passed target_tensors=%s' % - (len(self.outputs), target_tensors)) - elif isinstance(target_tensors, dict): - unexpected_target_tensor_names = set(target_tensors.keys()).difference( - self.output_names) - if unexpected_target_tensor_names: - raise ValueError( - 'Unknown entry in `target_tensors` dictionary: "{name}". ' - 'Only expected the following keys: {keys}'.format( - name=unexpected_target_tensor_names, - keys=str(self.output_names))) - tmp_target_tensors = [] - for name in self.output_names: - tmp_target_tensors.append(target_tensors.get(name, None)) - target_tensors = tmp_target_tensors - elif tensor_util.is_tensor(target_tensors): - target_tensors = [target_tensors] - else: - raise TypeError('Expected `target_tensors` to be a list or tuple or ' - 'dict or a single tensor, but got:', target_tensors) - else: - # In case target tensor is empty or None, create a list with Nones - # that has same length as self.output_names. With that, the None check of - # target tensor can be skipped downstream. - target_tensors = [None for _ in self.output_names] - return target_tensors - - def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode): - # Prepare sample weight modes. List with the same length as model outputs. - training_utils.prepare_sample_weight_modes( - self._training_endpoints, sample_weight_mode) - # Prepare sample weights. - self._prepare_sample_weights() - # Save all metric attributes per output of the model. - self._cache_output_metric_attributes(metrics, weighted_metrics) - self.total_loss = None - # Set metric attributes on model. - self._set_metric_attributes() - - self._collected_trainable_weights = self.trainable_weights - - def _update_sample_weight_modes(self, sample_weights=None): - """Updates sample weight modes based on training/eval inputs. - - Sample weight placeholders will be created for all or no outputs - based on whether sample_weight is provided for any output. - - If model contains `_sample_weight_modes` we check if the input - `sample_weights` corresponds to the sample weight modes. - 1. Set sample weight mode to be 'temporal' for output i, if `compile` - sample_weight_mode was set to `temporal` and sample weight inputs - are given for one or more outputs. - 2. Set sample weight mode to be 'samplewise' for output i, if `compile` - sample_weight_mode was not set and sample weight inputs are given for - one or more outputs. - 3. Reset sample weight mode to None for output i if sample weight mode - was set but there is no sample weight input. - - Args: - sample_weights: List of sample weights of the same length as model outputs - or None. - """ - if not self._is_compiled: - return - if sample_weights and any(s is not None for s in sample_weights): - for endpoint in self._training_endpoints: - endpoint.sample_weight_mode = ( - endpoint.sample_weight_mode or 'samplewise') - else: - for endpoint in self._training_endpoints: - endpoint.sample_weight_mode = None - - def _recompile_weights_loss_and_weighted_metrics(self): - if not self._is_compiled: - return False - recompile = any( - e.sample_weights_mismatch() for e in self._training_endpoints) - - if recompile: - self._compile_weights_loss_and_weighted_metrics() - return recompile - - @trackable.no_automatic_dependency_tracking - def _compile_weights_loss_and_weighted_metrics(self, sample_weights=None): - """Compiles the model loss and weighted metric sub-graphs. - - This may be used to set graph tensors as sample weights (instead of creating - placeholders). This functionality is necessary for - `tf.keras.estimator.model_to_estimator`, which calls Keras models in a v1 - graph, and creates iterator tensors for inputs, targets, and sample weights. - - Args: - sample_weights: List of tensors to use as the sample weights. Must be the - same length as the number of outputs. If left as `None`, placeholders - are used instead. - """ - with K.get_graph().as_default(): - if sample_weights is not None: - self._update_sample_weight_modes(sample_weights) - self._prepare_sample_weights(sample_weights) - - masks = self._prepare_output_masks() - - # Compute weighted metrics. - self._handle_metrics( - self.outputs, - targets=self._targets, - skip_target_masks=self._prepare_skip_target_masks(), - sample_weights=self.sample_weights, - masks=masks, - return_weighted_metrics=True) - - # Compute total loss. - # Used to keep track of the total loss value (stateless). - # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + - # loss_weight_2 * output_2_loss_fn(...) + - # layer losses. - self.total_loss = self._prepare_total_loss(masks) - - def _prepare_skip_target_masks(self): - """Boolean mask for whether the target in the output list should be skipped. - - If the loss function corresponding to a model output is None, then this - output will be skipped during total loss calculation and feed targets - preparation. - - Returns: - A boolean list for whether the corresponding target in the output list - should be skipped during loss calculation. - """ - return [l is None for l in self.loss_functions] - - def _prepare_output_masks(self): - """Returns masks corresponding to model outputs.""" - return [getattr(x, '_keras_mask', None) for x in self.outputs] - - def _prepare_total_loss(self, masks): - """Computes total loss from loss functions. - - Arguments: - masks: List of mask values corresponding to each model output. - - Returns: - A list of loss weights of python floats. - - Raises: - TypeError: If model run_eagerly is True. - """ - if self.run_eagerly: - raise TypeError('total loss can not be computed when compiled with ' - 'run_eagerly = True.') - total_loss = None - with K.name_scope('loss'): - for endpoint, mask in zip(self._training_endpoints, masks): - if endpoint.should_skip_target(): - continue - y_true = endpoint.training_target.target - y_pred = endpoint.output - loss_fn = endpoint.loss_fn - loss_weight = endpoint.loss_weight - loss_name = endpoint.loss_name() - sample_weight = endpoint.sample_weight - - with K.name_scope(loss_name): - if mask is not None: - mask = math_ops.cast(mask, y_pred.dtype) - # Update weights with mask. - if sample_weight is None: - sample_weight = mask - else: - # Update dimensions of weights to match with mask if possible. - mask, _, sample_weight = ( - tf_losses_utils.squeeze_or_expand_dimensions( - mask, sample_weight=sample_weight)) - sample_weight *= mask - - if hasattr(loss_fn, 'reduction'): - per_sample_losses = loss_fn.call(y_true, y_pred) - weighted_losses = losses_utils.compute_weighted_loss( - per_sample_losses, - sample_weight=sample_weight, - reduction=losses_utils.ReductionV2.NONE) - loss_reduction = loss_fn.reduction - - # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all - # compile use cases. - if loss_reduction == losses_utils.ReductionV2.AUTO: - loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE - - # Compute the stateless loss value. - output_loss = losses_utils.reduce_weighted_loss( - weighted_losses, reduction=loss_reduction) - else: - # Compute the stateless loss value for a custom loss class. - # Here we assume that the class takes care of loss reduction - # because if this class returns a vector value we cannot - # differentiate between use case where a custom optimizer - # expects a vector loss value vs unreduced per-sample loss value. - output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) - loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE - - if len(self.outputs) > 1: - # Keep track of stateful result tensor for the loss. - endpoint.output_loss_metric(output_loss) - - # Scale output loss for distribution. For custom losses we assume - # reduction was mean. - if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE: - output_loss = losses_utils.scale_loss_for_distribution(output_loss) - - if total_loss is None: - total_loss = loss_weight * output_loss - else: - total_loss += loss_weight * output_loss - if total_loss is None: - if not self.losses: - raise ValueError('The model cannot be compiled ' - 'because it has no loss to optimize.') - else: - total_loss = 0. - - # Add regularization penalties and other layer-specific losses. - custom_losses = self.get_losses_for(None) + self.get_losses_for( - self.inputs) - if custom_losses: - total_loss += losses_utils.scale_loss_for_distribution( - math_ops.add_n(custom_losses)) - return total_loss - - def _get_callback_model(self): - """Returns the Callback Model for this Model.""" - - if hasattr(self, '_replicated_model') and self._replicated_model: - # When using training_distributed, we set the callback model - # to an instance of the `DistributedModel` that we create in - # the `compile` call. The `DistributedModel` is initialized - # with the first replicated model. We need to set the callback - # model to a DistributedModel to allow us to override saving - # and loading weights when we checkpoint the model during training. - return self._replicated_model - if hasattr(self, 'callback_model') and self.callback_model: - return self.callback_model - return self - - def _validate_or_infer_batch_size(self, batch_size, steps, x): - """Validates that the `batch_size` provided is consistent with InputLayer. - - It's possible that the user specified a static batch size in their - InputLayer. If so, this method checks the provided `batch_size` and `x` - arguments are consistent with this static batch size. Also, if - `batch_size` is `None`, this method will attempt to infer the batch size - from the static batch size of the InputLayer. Lastly, ValueError will be - raised if `x` is a tf.data.Dataset and `batch_size` is specified as we - expect users to provide batched datasets. - - Arguments: - batch_size: The batch_size provided as an argument to - fit/evaluate/predict. - steps: The steps provided as an argument to fit/evaluate/predict. - x: The data passed as `x` to fit/evaluate/predict. - - Returns: - The validated batch_size, auto-inferred from the first layer if not - provided. - """ - if (isinstance(x, (dataset_ops.DatasetV1, - dataset_ops.DatasetV2, - data_utils.Sequence)) or - tf_inspect.isgenerator(x)): - if batch_size is not None: - raise ValueError( - 'The `batch_size` argument must not be specified for the given ' - 'input type. Received input: {}, batch_size: {}'.format( - x, batch_size)) - return - - # Avoids the override in Sequential.layers which filters Input layers. - # (Which are often the very layers that we're after.) - layers = trackable_layer_utils.filter_empty_layer_containers(self._layers) - first_layer = next(layers, None) - if first_layer: - # The per-replica static batch size. - static_batch_size = training_utils.get_static_batch_size(first_layer) - if static_batch_size is not None: - - # Determine number of times the user-supplied batch size will be split. - if (self._distribution_strategy and - distributed_training_utils.global_batch_size_supported( - self._distribution_strategy)): - num_splits_for_ds = self._distribution_strategy.num_replicas_in_sync - else: - num_splits_for_ds = 1 - - # Check `batch_size` argument is consistent with InputLayer. - if batch_size is not None: - if batch_size % num_splits_for_ds != 0: - raise ValueError('The `batch_size` argument ({}) must be divisible ' - 'the by number of replicas ({})'.format( - batch_size, num_splits_for_ds)) - per_replica_batch_size = batch_size // num_splits_for_ds - - if per_replica_batch_size != static_batch_size: - raise ValueError('The `batch_size` argument value {} is ' - 'incompatible with the specified batch size of ' - 'your Input Layer: {}'.format( - per_replica_batch_size, static_batch_size)) - - # Check Dataset/Iterator batch size is consistent with InputLayer. - if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator, - iterator_ops.OwnedIterator)): - ds_batch_size = tensor_shape.as_dimension( - nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value - if ds_batch_size is not None: - if ds_batch_size % num_splits_for_ds != 0: - raise ValueError( - 'The batch output shape of your `Dataset` {} ' - 'cannot be divisible by number of replicas {}'.format( - ds_batch_size, num_splits_for_ds)) - - ds_per_replica_batch_size = ds_batch_size // num_splits_for_ds - if ds_per_replica_batch_size != static_batch_size: - raise ValueError('The batch output shape of your `Dataset` is ' - '{}, which is incompatible with the specified ' - 'batch size of your Input Layer: {}'.format( - ds_per_replica_batch_size, - static_batch_size)) - - # Set inferred batch size from the InputLayer. - if steps is None: - batch_size = static_batch_size * num_splits_for_ds - - if batch_size is None and steps is None: - # Backwards compatibility - batch_size = 32 - return batch_size - - def _prepare_sample_weights(self, sample_weights=None): - """Sets sample weight attribute on the model.""" - # List with the same length as model outputs. - if sample_weights is not None: - if len(sample_weights) != len(self._training_endpoints): - raise ValueError('Provided sample weights must have same length as the ' - 'number of outputs. Expected: {}, got: {}.'.format( - len(self._training_endpoints), - len(sample_weights))) - else: - sample_weights = [None] * len(self._training_endpoints) - for endpoint, weight in zip(self._training_endpoints, sample_weights): - endpoint.populate_sample_weight(weight, endpoint.sample_weight_mode) - - def _cache_output_metric_attributes(self, metrics, weighted_metrics): - """Caches metric name and function attributes for every model output.""" - output_shapes = [] - for output in self.outputs: - if output is None or output.shape.rank is None: - output_shapes.append(None) - else: - output_shapes.append(output.shape.as_list()) - self._per_output_metrics = training_utils.collect_per_output_metric_info( - metrics, self.output_names, output_shapes, self.loss_functions) - self._per_output_weighted_metrics = ( - training_utils.collect_per_output_metric_info( - weighted_metrics, - self.output_names, - output_shapes, - self.loss_functions, - is_weighted=True)) - - def _add_unique_metric_name(self, metric_name, output_index): - """Makes the metric name unique and adds it to the model's metric name list. - - If there are multiple outputs for which the metrics are calculated, the - metric names have to be made unique by appending an integer. - - Arguments: - metric_name: Metric name that corresponds to the metric specified by the - user. For example: 'acc'. - output_index: The index of the model output for which the metric name is - being added. - - Returns: - string, name of the model's unique metric name - """ - if len(self.output_names) > 1: - metric_name = '%s_%s' % (self.output_names[output_index], metric_name) - j = 1 - base_metric_name = metric_name - while metric_name in self.metrics_names: - metric_name = '%s_%d' % (base_metric_name, j) - j += 1 - - return metric_name - - def _init_metric_attributes(self): - """Initialized model metric attributes.""" - # List of stateful metric functions. Used for resetting metric state during - # training/eval. - self._compile_metric_functions = [] - - def _set_per_output_metric_attributes(self, metrics_dict, output_index): - """Sets the metric attributes on the model for the given output. - - Arguments: - metrics_dict: A dict with metric names as keys and metric fns as values. - output_index: The index of the model output for which the metric - attributes are added. - - Returns: - Metrics dict updated with unique metric names as keys. - """ - updated_metrics_dict = collections.OrderedDict() - for metric_name, metric_fn in metrics_dict.items(): - metric_name = self._add_unique_metric_name(metric_name, output_index) - - # Update the name on the metric class to be the unique generated name. - metric_fn._name = metric_name # pylint: disable=protected-access - updated_metrics_dict[metric_name] = metric_fn - # Keep track of metric name and function. - self._compile_metric_functions.append(metric_fn) - return updated_metrics_dict - - def _set_metric_attributes(self): - """Sets the metric attributes on the model for all the model outputs.""" - updated_per_output_metrics = [] - updated_per_output_weighted_metrics = [] - for i, endpoint in enumerate(self._training_endpoints): - if endpoint.should_skip_target(): - updated_per_output_metrics.append(self._per_output_metrics[i]) - updated_per_output_weighted_metrics.append( - self._per_output_weighted_metrics[i]) - continue - updated_per_output_metrics.append( - self._set_per_output_metric_attributes(self._per_output_metrics[i], - i)) - updated_per_output_weighted_metrics.append( - self._set_per_output_metric_attributes( - self._per_output_weighted_metrics[i], i)) - - # Create a metric wrapper for each output loss. This computes mean of an - # output loss across mini-batches (irrespective of how we reduce within a - # batch). - if len(self._training_endpoints) > 1: - for endpoint in self._training_endpoints: - if not endpoint.should_skip_target(): - endpoint.output_loss_metric = metrics_module.Mean( - name=endpoint.loss_name()) - - self._per_output_metrics = updated_per_output_metrics - self._per_output_weighted_metrics = updated_per_output_weighted_metrics - - def _handle_per_output_metrics(self, - metrics_dict, - y_true, - y_pred, - mask, - weights=None): - """Calls metric functions for a single output. - - Arguments: - metrics_dict: A dict with metric names as keys and metric fns as values. - y_true: Target output. - y_pred: Predicted output. - mask: Computed mask value for the current output. - weights: Weights to be applied on the current output. - - Returns: - A list of metric result tensors. - """ - metric_results = [] - for metric_name, metric_fn in metrics_dict.items(): - with K.name_scope(metric_name): - metric_result = training_utils.call_metric_function( - metric_fn, y_true, y_pred, weights=weights, mask=mask) - metric_results.append(metric_result) - return metric_results - - def _handle_metrics(self, - outputs, - targets=None, - skip_target_masks=None, - sample_weights=None, - masks=None, - return_weighted_metrics=False, - return_weighted_and_unweighted_metrics=False): - """Handles calling metric functions. - - Arguments: - outputs: List of outputs (predictions). - targets: List of targets. - skip_target_masks: Optional. List of boolean for whether the corresponding - target should be ignored or not. - sample_weights: Optional list of sample weight arrays. - masks: List of computed output mask values. - return_weighted_metrics: Flag that indicates whether weighted metrics - should be computed instead of unweighted metrics. This flag is ignored - when `return_weighted_and_unweighted_metrics` is enabled. - return_weighted_and_unweighted_metrics: Flag that is used to indicate - whether both weighted and unweighted metrics should be computed. When - this is not enabled, we use `return_weighted_metrics` param to indicate - whether weighted or unweighted metrics should be returned. - - Returns: - A list of metric result tensors. - """ - # TODO(scottzhu): Update this to use the new training_endpoints. Currently - # the eager and graph logic is bit different. - skip_target_masks = skip_target_masks or [False] * len(outputs) - metric_results = [] - with K.name_scope('metrics'): - # Invoke all metrics added using `compile`. - for i in range(len(outputs)): - if skip_target_masks[i]: - continue - output = outputs[i] if outputs else None - target = targets[i] if targets else None - output_mask = masks[i] if masks else None - - if (return_weighted_and_unweighted_metrics or - not return_weighted_metrics): - metric_results.extend( - self._handle_per_output_metrics(self._per_output_metrics[i], - target, output, output_mask)) - if return_weighted_and_unweighted_metrics or return_weighted_metrics: - metric_results.extend( - self._handle_per_output_metrics( - self._per_output_weighted_metrics[i], - target, - output, - output_mask, - weights=sample_weights[i] if sample_weights else None)) - return metric_results - - def _check_trainable_weights_consistency(self): - """Check trainable weights count consistency. - - This will raise a warning if `trainable_weights` and - `_collected_trainable_weights` are inconsistent (i.e. have different - number of parameters). - Inconsistency will typically arise when one modifies `model.trainable` - without calling `model.compile` again. - """ - if not hasattr(self, '_collected_trainable_weights'): - return - - if len(self.trainable_weights) != len(self._collected_trainable_weights): - logging.log_first_n( - logging.WARN, 'Discrepancy between trainable weights and collected' - ' trainable weights, did you set `model.trainable`' - ' without calling `model.compile` after ?', 1) - - def _make_train_function(self): - has_recompiled = self._recompile_weights_loss_and_weighted_metrics() - self._check_trainable_weights_consistency() - if isinstance(self.optimizer, list): - raise ValueError('The `optimizer` in `compile` should be a single ' - 'optimizer.') - # If we have re-compiled the loss/weighted metric sub-graphs then create - # train function even if one exists already. This is because - # `_feed_sample_weights` list has been updated on re-compile. - if getattr(self, 'train_function', None) is None or has_recompiled: - # Restore the compiled trainable state. - current_trainable_state = self._get_trainable_state() - self._set_trainable_state(self._compiled_trainable_state) - - inputs = (self._feed_inputs + - self._feed_targets + - self._feed_sample_weights) - if not isinstance(K.symbolic_learning_phase(), int): - inputs += [K.symbolic_learning_phase()] - - with K.get_graph().as_default(): - with K.name_scope('training'): - # Training updates - updates = self.optimizer.get_updates( - params=self._collected_trainable_weights, loss=self.total_loss) - # Unconditional updates - updates += self.get_updates_for(None) - # Conditional updates relevant to this model - updates += self.get_updates_for(self.inputs) - - metrics = self._get_training_eval_metrics() - metrics_tensors = [ - m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access - ] - - with K.name_scope('training'): - # Gets loss and metrics. Updates weights at each call. - fn = K.function( - inputs, [self.total_loss] + metrics_tensors, - updates=updates, - name='train_function') - setattr(self, 'train_function', fn) - - # Restore the current trainable state - self._set_trainable_state(current_trainable_state) - - def _make_test_function(self): - has_recompiled = self._recompile_weights_loss_and_weighted_metrics() - # If we have re-compiled the loss/weighted metric sub-graphs then create - # test function even if one exists already. This is because - # `_feed_sample_weights` list has been updated on re-compile. - if getattr(self, 'test_function', None) is None or has_recompiled: - inputs = (self._feed_inputs + - self._feed_targets + - self._feed_sample_weights) - - with K.get_graph().as_default(): - metrics = self._get_training_eval_metrics() - metrics_tensors = [ - m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access - ] - - with K.name_scope('evaluation'): - updates = self.state_updates - # Return loss and metrics, no gradient updates. - # Does update the network states. - fn = K.function( - inputs, [self.total_loss] + metrics_tensors, - updates=updates, - name='test_function') - setattr(self, 'test_function', fn) - - def _make_predict_function(self): - if not hasattr(self, 'predict_function'): - self.predict_function = None - if self.predict_function is None: - inputs = self._feed_inputs - # Gets network outputs. Does not update weights. - # Does update the network states. - kwargs = getattr(self, '_function_kwargs', {}) - with K.name_scope(ModeKeys.PREDICT): - self.predict_function = K.function( - inputs, - self.outputs, - updates=self.state_updates, - name='predict_function', - **kwargs) - - def _make_execution_function(self, mode): - if mode == ModeKeys.TRAIN: - self._make_train_function() - return self.train_function - if mode == ModeKeys.TEST: - self._make_test_function() - return self.test_function - if mode == ModeKeys.PREDICT: - self._make_predict_function() - return self.predict_function - - def _distribution_standardize_user_data(self, - x, - y=None, - sample_weight=None, - class_weight=None, - batch_size=None, - validation_split=0, - shuffle=False, - epochs=1, - allow_partial_batch=False): - """Runs validation checks on input and target data passed by the user. - - This is called when using tf.distribute.Strategy to train, evaluate or serve - the model. - - Args: - x: Input data. A numpy array or `tf.data` dataset. - y: Target data. A numpy array or None if x is a `tf.data` dataset. - sample_weight: An optional sample-weight array passed by the user to - weight the importance of each sample in `x`. - class_weight: An optional class-weight array by the user to - weight the importance of samples in `x` based on the class they belong - to, as conveyed by `y`. - batch_size: Integer batch size. If provided, it is used to run additional - validation checks on stateful models. - validation_split: Float between 0 and 1. - Fraction of the training data to be used as validation data. - shuffle: Boolean whether to shuffle the training data before each epoch. - epochs: Integer epochs. If > 1, repeat the numpy training data epochs - times when converting to training dataset. - allow_partial_batch: Boolean whether to enforce that all batches have the - same size. - - Returns: - Dataset instance. - - Raises: - ValueError: In case of invalid user-provided data. - RuntimeError: If the model was never compiled. - """ - if class_weight: - raise NotImplementedError('`class_weight` is currently not supported ' - 'when using tf.distribute.Strategy.') - - if (sample_weight is not None and sample_weight.all() and - distributed_training_utils.is_tpu_strategy( - self._distribution_strategy)): - raise NotImplementedError('`sample_weight` is currently not supported ' - 'when using TPUStrategy.') - - # Validates `steps` and `shuffle` arguments right at the beginning - # since we use it to construct the dataset object. - # TODO(anjalisridhar): Remove this check once we refactor the - # _standardize_user_data code path. This check is already present elsewhere - # in the codebase. - if isinstance(x, dataset_ops.DatasetV2): - if shuffle: - training_utils.verify_dataset_shuffled(x) - - strategy = self._distribution_strategy - with strategy.scope(): - # We should be sure to call get_session() inside the strategy.scope() - # so the strategy can affect the session options. - if ops.executing_eagerly_outside_functions(): - session = None - else: - session = K.get_session() - - first_x_value = nest.flatten(x)[0] - if isinstance(first_x_value, np.ndarray): - x = training_utils.list_to_tuple(x) - if y is not None: - y = training_utils.list_to_tuple(y) - if sample_weight is not None: - sample_weight = training_utils.list_to_tuple(sample_weight) - in_tuple = (x, y, sample_weight) - else: - in_tuple = (x, y) - else: - in_tuple = x - - ds = strategy.extended.experimental_make_numpy_dataset(in_tuple, - session=session) - if shuffle: - # We want a buffer size that is larger than the batch size provided by - # the user and provides sufficient randomness. Note that larger - # numbers introduce more memory usage based on the size of each - # sample. - ds = ds.shuffle(max(1024, batch_size * 8)) - if epochs > 1: - ds = ds.repeat(epochs) - - # We need to use the drop_remainder argument to get a known static - # input shape which is required for TPUs. - drop_remainder = (not allow_partial_batch and - strategy.extended.experimental_require_static_shapes) - - # TODO(b/131720208): We still drop remainder here if number of examples - # is divisible by batch size, as sometimes dynamic padder will time out - # with keras.metrics.CategoricalAccuracy() metric. - if distributed_training_utils.is_tpu_strategy( - strategy) and not drop_remainder: - dataset_size = first_x_value.shape[0] - if dataset_size % batch_size == 0: - drop_remainder = True - - x = ds.batch(batch_size, drop_remainder=drop_remainder) - else: - assert isinstance(x, dataset_ops.DatasetV2) - training_utils.validate_dataset_input(x, y, sample_weight, - validation_split) - return x - - def _standardize_user_data(self, - x, - y=None, - sample_weight=None, - class_weight=None, - batch_size=None, - check_steps=False, - steps_name='steps', - steps=None, - validation_split=0, - shuffle=False, - extract_tensors_from_dataset=False): - """Runs validation checks on input and target data passed by the user. - - Also standardizes the data to lists of arrays, in order. - - Also builds and compiles the model on the fly if it is a subclassed model - that has never been called before (and thus has no inputs/outputs). - - This is a purely internal method, subject to refactoring at any time. - - Args: - x: Input data. It could be: - - A Numpy array (or array-like), or a list of arrays - (in case the model has multiple inputs). - - A TensorFlow tensor, or a list of tensors - (in case the model has multiple inputs). - - A dict mapping input names to the corresponding array/tensors, - if the model has named inputs. - - A `tf.data` dataset. - y: Target data. Like the input data `x`, - it could be either Numpy array(s) or TensorFlow tensor(s). - It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset, `y` should not be - specified (since targets will be obtained from the iterator). - sample_weight: An optional sample-weight array passed by the user to - weight the importance of each sample in `x`. - class_weight: An optional class-weight array by the user to - weight the importance of samples in `x` based on the class they belong - to, as conveyed by `y`. If both `sample_weight` and `class_weight` are - provided, the weights are multiplied. - batch_size: Integer batch size. If provided, it is used to run additional - validation checks on stateful models. - check_steps: boolean, True if we want to check for validity of `steps` and - False, otherwise. For example, when we are standardizing one batch of - data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps` - value is not required and we should not check for its validity in these - cases. - steps_name: The public API's parameter name for `steps`. - steps: Integer or `None`. Total number of steps (batches of samples) to - execute. - validation_split: Float between 0 and 1. - Fraction of the training data to be used as validation data. - shuffle: Boolean whether to shuffle the training data before each epoch. - extract_tensors_from_dataset: Boolean. When `x` is a dataset instance, - this indicates whether to extract actual tensors from the dataset or - instead output the dataset instance itself. - Set to True when calling from `train_on_batch`/etc. - - Returns: - A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict - or not), target arrays, sample-weight arrays. - If the model's input and targets are symbolic, these lists are empty - (since the model takes no user-provided data, instead the data comes - from the symbolic inputs/targets). - - Raises: - ValueError: In case of invalid user-provided data. - RuntimeError: If the model was never compiled. - """ - if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): - # Graph mode dataset. We'll pass the dataset as-is (unless - # `extract_tensors_from_dataset` is True, in which case we extract - # the tensors from the dataset and we output them. - training_utils.validate_dataset_input(x, y, sample_weight, - validation_split) - if shuffle: - training_utils.verify_dataset_shuffled(x) - - is_dataset = True - if extract_tensors_from_dataset: - # We do this for `train_on_batch`/etc. - x, y, sample_weight = training_utils.extract_tensors_from_dataset(x) - elif isinstance(x, iterator_ops.Iterator): - # Graph mode iterator. We extract the symbolic tensors. - training_utils.validate_dataset_input(x, y, sample_weight, - validation_split) - iterator = x - x, y, sample_weight = training_utils.unpack_iterator_input(iterator) - is_dataset = True - else: - is_dataset = False - - # Validates `steps` argument based on x's type. - if check_steps: - training_utils.check_steps_argument(x, steps, steps_name) - - # First, we build the model on the fly if necessary. - if not self.inputs: - all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y) - is_build_called = True - else: - all_inputs = [] - # Whether this is a subclassed model that expects dictionary inputs - # rather than list inputs (e.g. FeatureColumn-based models). - dict_inputs = isinstance(self.inputs, dict) - is_build_called = False - y_input = y - - # Second, we compile the model on the fly if necessary, mostly for subclass - # models. - is_compile_called = False - if not self._is_compiled and self.optimizer: - self._compile_from_inputs(all_inputs, y_input, x, y) - is_compile_called = True - - # In graph mode, if we had just set inputs and targets as symbolic tensors - # by invoking build and compile on the model respectively, we do not have to - # feed anything to the model. Model already has input and target data as - # part of the graph. - # Note: in this case, `any` and `all` are equivalent since we disallow - # mixed symbolic/value inputs. - - # self.run_eagerly is not free to compute, so we want to reuse the value. - run_eagerly = self.run_eagerly - - if (not run_eagerly and is_build_called and is_compile_called and - not is_dataset and any(_is_symbolic_tensor(v) for v in all_inputs)): - return [], [], None - - return self._standardize_tensors( - x, y, sample_weight, - run_eagerly=run_eagerly, - dict_inputs=dict_inputs, - is_dataset=is_dataset, - class_weight=class_weight, - batch_size=batch_size) - - def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, - is_dataset, class_weight=None, batch_size=None): - if run_eagerly: - # In eager mode, do not do shape validation - # since the network has no input nodes (placeholders) to be fed. - feed_input_names = self.input_names - feed_input_shapes = None - elif not self._is_graph_network: - # Case: symbolic-mode subclassed network. Do not do shape validation. - feed_input_names = self._feed_input_names - feed_input_shapes = None - else: - # Case: symbolic-mode graph network. - # In this case, we run extensive shape validation checks. - feed_input_names = self._feed_input_names - feed_input_shapes = self._feed_input_shapes - - # Standardize the inputs. - if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): - # TODO(fchollet): run static checks with dataset output shape(s). - x = training_utils.standardize_input_data( - x, - feed_input_names, - feed_input_shapes, - check_batch_axis=False, # Don't enforce the batch size. - exception_prefix='input') - - # Get typespecs for the input data and sanitize it if necessary. - # TODO(momernick): This should be capable of doing full input validation - # at all times - validate that this is so and refactor the standardization - # code. - if isinstance(x, dataset_ops.DatasetV2): - x_shapes = dataset_ops.get_structure(x) - if isinstance(x_shapes, tuple): - # If the output of a Dataset is a tuple, we assume it's either of the - # form (x_data, y_data) or (x_data, y_data, sample_weights). In either - # case, we only care about x_data here. - x_shapes = x_shapes[0] - else: - flat_inputs = nest.flatten(x, expand_composites=False) - flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False) - converted_x = [] - for (a, b) in zip(flat_inputs, flat_expected_inputs): - converted_x.append(_convert_scipy_sparse_tensor(a, b)) - x = nest.pack_sequence_as(x, converted_x, expand_composites=False) - - x_shapes = nest.map_structure(tf_utils.type_spec_from_value, x) - - flat_inputs = nest.flatten(x_shapes, expand_composites=False) - - x_expected_shapes = nest.map_structure(tf_utils.type_spec_from_value, - self.inputs) - flat_expected_inputs = nest.flatten( - x_expected_shapes, expand_composites=False) - for (a, b) in zip(flat_inputs, flat_expected_inputs): - nest.assert_same_structure(a, b, expand_composites=True) - - if y is not None: - # Prepare self._sample_weight_modes. List with the same length as - # model outputs. - training_utils.prepare_sample_weight_modes(self._training_endpoints, - self.sample_weight_mode) - feed_output_names = self._feed_output_names - feed_sample_weight_modes = self._sample_weight_modes - if not self._is_graph_network: - feed_output_shapes = None - else: - feed_output_shapes = self._feed_output_shapes - - # Standardize the outputs. - y = training_utils.standardize_input_data( - y, - feed_output_names, - # Don't enforce target shapes to match output shapes. - # Precise checks will be run in `check_loss_and_target_compatibility`. - shapes=None, - check_batch_axis=False, # Don't enforce the batch size. - exception_prefix='target') - - # Generate sample-wise weight values given the `sample_weight` and - # `class_weight` arguments. - sample_weights = training_utils.standardize_sample_weights( - sample_weight, feed_output_names) - class_weights = training_utils.standardize_class_weights( - class_weight, feed_output_names) - - sample_weights = [ - training_utils.standardize_weights(ref, sw, cw, mode) - for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, - feed_sample_weight_modes) - ] - # Check that all arrays have the same length. - if not self._distribution_strategy: - training_utils.check_array_lengths(x, y, sample_weights) - if self._is_graph_network and not run_eagerly: - # Additional checks to avoid users mistakenly using improper loss fns. - training_utils.check_loss_and_target_compatibility( - y, self._feed_loss_fns, feed_output_shapes) - - sample_weights, _, _ = training_utils.handle_partial_sample_weights( - y, sample_weights, feed_sample_weight_modes, check_all_flat=True) - else: - y = [] - sample_weights = None - - if self.stateful and batch_size and not is_dataset: - # Check that for stateful networks, number of samples is a multiple - # of the static batch size. - if x[0].shape[0] % batch_size != 0: - raise ValueError('In a stateful network, ' - 'you should only pass inputs with ' - 'a number of samples that can be ' - 'divided by the batch size. Found: ' + - str(x[0].shape[0]) + ' samples') - - # If dictionary inputs were provided, we return a dictionary as well. - if dict_inputs and not isinstance(x, (dataset_ops.DatasetV1, - dataset_ops.DatasetV2)): - x = dict(zip(feed_input_names, x)) - return x, y, sample_weights - - def _build_model_with_inputs(self, inputs, targets): - """Build the model (set model inputs/outputs), mainly for subclass model.""" - processed_inputs = [] - is_dict_inputs = False - orig_inputs = inputs - # We need to use `inputs` to set the model inputs. - # If input data is a dataset iterator in graph mode or if it is an eager - # iterator and only one batch of samples is required, we fetch the data - # tensors from the iterator and then standardize them. - if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): - inputs, targets, _ = training_utils.extract_tensors_from_dataset(inputs) - # We type-check that `inputs` and `targets` are either single arrays - # or lists of arrays, and extract a flat list of inputs from the passed - # structure. - training_utils.validate_input_types(inputs, orig_inputs) - - if isinstance(inputs, (list, tuple)): - processed_inputs += list(inputs) - elif isinstance(inputs, dict): - is_dict_inputs = True - keys = sorted(inputs.keys()) - processed_inputs = [inputs[k] for k in keys] - else: - processed_inputs.append(inputs) - # Now that we have a flat set of inputs, we make sure that none of them - # are CompositeTensors or CompositeTensorValues of any type (or scipy - # sparse arrays, which we treat as SparseTensor values). We cannot safely - # infer input data from an arbitrary composite tensor, so we don't try - - # users should explicitly add composite tensor inputs to their subclassed - # models. - for input_tensor in processed_inputs: - if composite_tensor_utils.is_composite_or_composite_value(input_tensor): - # TODO(b/132691975): Document subclass-model CT input handling. - raise ValueError( - 'All SparseTensor and RaggedTensor inputs must be explicitly ' - 'declared using a keras.Input() with sparse=True or ragged=True. ' - 'We found an undeclared input %s. For Sequential models, please ' - 'add a keras.Input() as your first Layer. For subclassed models, ' - 'please call self._set_inputs() on your input set, which you can ' - 'create using keras.Input() for each input to your model.' % - (input_tensor,)) - # Build the model using the retrieved inputs (value or symbolic). - # If values are generated from a dataset, then in symbolic-mode - # placeholders will be created to match the value shapes. - if isinstance(orig_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, - iterator_ops.Iterator)): - if not self.inputs: - # For subclassed models, a robust input spec is not available so we - # must cast to the model dtype. - inputs = training_utils.cast_if_floating_dtype(inputs, self.dtype) - - def create_tensor_spec(t): - return tensor_spec.TensorSpec(t.shape, t.dtype) - - cast_inputs = nest.map_structure(create_tensor_spec, inputs) - elif training_utils.has_tensors(inputs): - cast_inputs = training_utils.cast_if_floating_dtype(inputs) - else: - cast_inputs = inputs - self._set_inputs(cast_inputs) - return processed_inputs, targets, is_dict_inputs - - def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target): - if target is not None: - # We need to use `y` to set the model targets. - if training_utils.has_tensors(target): - target = training_utils.cast_if_floating_dtype_and_mismatch( - target, self.outputs) - training_utils.validate_input_types(target, orig_target, - allow_dict=False, field_name='target') - if isinstance(target, (list, tuple)): - all_inputs += list(target) - else: - all_inputs.append(target) - # Type check that all inputs are *either* value *or* symbolic. - # TODO(fchollet): this check could be removed in Eager mode? - if any(tensor_util.is_tensor(v) for v in all_inputs): - if not all(tensor_util.is_tensor(v) for v in all_inputs): - raise ValueError('Do not pass inputs that mix Numpy arrays and ' - 'TensorFlow tensors. ' - 'You passed: x=' + str(orig_inputs) + - '; y=' + str(orig_target)) - is_dataset = isinstance(orig_inputs, (dataset_ops.DatasetV1, - dataset_ops.DatasetV2, - iterator_ops.Iterator)) - if is_dataset or context.executing_eagerly(): - target_tensors = None - else: - # Handle target tensors if any passed. - if target is not None: - if not isinstance(target, (list, tuple)): - target = [target] - target_tensors = [v for v in target if _is_symbolic_tensor(v)] - else: - target_tensors = None - - self.compile( - optimizer=self.optimizer, - loss=self.loss, - metrics=self._compile_metrics, - weighted_metrics=self._compile_weighted_metrics, - loss_weights=self.loss_weights, - target_tensors=target_tensors, - sample_weight_mode=self.sample_weight_mode, - run_eagerly=self.run_eagerly) - - # TODO(omalleyt): Consider changing to a more descriptive function name. - def _set_inputs(self, inputs, outputs=None, training=None): - """Set model's input and output specs based on the input data received. - - This is to be used for Model subclasses, which do not know at instantiation - time what their inputs look like. - - Args: - inputs: Single array, or list of arrays. The arrays could be placeholders, - Numpy arrays, data tensors, or TensorSpecs. - - if placeholders: the model is built on top of these placeholders, - and we expect Numpy data to be fed for them when calling `fit`/etc. - - if Numpy data or TensorShapes: we create placeholders matching the - TensorShapes or shapes of the Numpy arrays. We expect Numpy data to be - fed for these placeholders when calling `fit`/etc. - - if data tensors: the model is built on top of these tensors. - We do not expect any Numpy data to be provided when calling `fit`/etc. - outputs: None, a data tensor, or a list of tensors. If None, the - outputs will be determined by invoking `self.call()`, otherwise the - provided value will be used. - training: Boolean or None. Only relevant in symbolic mode. Specifies - whether to build the model's graph in inference mode (False), training - mode (True), or using the Keras learning phase (None). - Raises: - ValueError: If dict inputs are passed to a Sequential Model where the - first layer isn't FeatureLayer. - """ - inputs = self._set_input_attrs(inputs) - - if outputs is None: - kwargs = {} - if self._expects_training_arg: - # In V2 mode, feeding `training=None` is not allowed because any value - # explicitly passed by the user is respected, even `None`.` - if training is None and not ops.executing_eagerly_outside_functions(): - training = K.learning_phase() - if training is not None: - kwargs['training'] = training - try: - outputs = self(inputs, **kwargs) - except NotImplementedError: - # This Model or a submodel is dynamic and hasn't overridden - # `compute_output_shape`. - outputs = None - - self._set_output_attrs(outputs) - - @trackable.no_automatic_dependency_tracking - def _set_input_attrs(self, inputs): - """Sets attributes related to the inputs of the Model.""" - if self.inputs: - raise ValueError('Model inputs are already set.') - - if self.__class__.__name__ == 'Sequential' and not self.built: - if tensor_util.is_tensor(inputs): - input_shape = (None,) + tuple(inputs.shape.as_list()[1:]) - elif isinstance(inputs, tensor_shape.TensorShape): - input_shape = (None,) + tuple(inputs.as_list()[1:]) - elif isinstance(inputs, dict): - # We assert that the first layer is a FeatureLayer. - if not training_utils.is_feature_layer(self.layers[0]): - raise ValueError('Passing a dictionary input to a Sequential Model ' - 'which doesn\'t have FeatureLayer as the first layer' - ' is an error.') - input_shape = (None,) - else: - input_shape = (None,) + tuple(inputs.shape[1:]) - self._build_input_shape = input_shape - - # Cast inputs to the compute dtype. This is primarily used - # when saving to determine the correct dtype in the input signature. - inputs = self._maybe_cast_inputs(inputs) - - # On-the-fly setting of symbolic model inputs (either by using the tensor - # provided, or by creating a placeholder if Numpy data was provided). - model_inputs = training_utils.ModelInputs(inputs) - inputs = model_inputs.get_symbolic_inputs() - self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True) - self.input_names = model_inputs.get_input_names() - - self._feed_inputs = [] - self._feed_input_names = [] - self._feed_input_shapes = [] - - for k, v in model_inputs.as_dict(): - if K.is_placeholder(v): - self._feed_input_names.append(k) - self._feed_inputs.append(v) - self._feed_input_shapes.append(K.int_shape(v)) - - return inputs - - @trackable.no_automatic_dependency_tracking - def _set_output_attrs(self, outputs): - """Sets attributes related to the outputs of the Model.""" - # NOTE(taylorrobie): This convention cannot be changed without updating the - # data adapter since it assumes nest.flatten ordering. - outputs = nest.flatten(outputs) - self.outputs = outputs - self.output_names = training_utils.generic_output_names(outputs) - # TODO(scottzhu): Should we cleanup the self._training_endpoints here? - self.built = True - - @property - def _targets(self): - """The output target tensors for the model.""" - return [ - e.training_target.target - for e in self._training_endpoints - if e.has_training_target() - ] - - @property - def _feed_targets(self): - return [ - e.training_target.target - for e in self._training_endpoints - if e.has_feedable_training_target() - ] - - @property - def _feed_output_names(self): - return [ - e.output_name - for e in self._training_endpoints - if e.has_feedable_training_target() - ] - - @property - def _feed_output_shapes(self): - return [ - e.feed_output_shape - for e in self._training_endpoints - if e.has_feedable_training_target() - ] - - @property - def _feed_loss_fns(self): - return [ - e.loss_fn - for e in self._training_endpoints - if e.has_feedable_training_target() - ] - - @property - def _loss_weights_list(self): - return [e.loss_weight for e in self._training_endpoints] - - @property - def _output_loss_metrics(self): - if hasattr(self, '_training_endpoints'): - return [ - e.output_loss_metric - for e in self._training_endpoints - if e.output_loss_metric is not None - ] - return None - - @property - def sample_weights(self): - return [e.sample_weight for e in self._training_endpoints] - - @property - def _sample_weight_modes(self): - return [e.sample_weight_mode for e in self._training_endpoints] - - @property - def _feed_sample_weights(self): - return [e.sample_weight for e in self._training_endpoints - if e.sample_weight is not None] - - def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode): + def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch): """Maybe load initial epoch from ckpt considering possible worker recovery. Refer to tensorflow/python/keras/distribute/multi_worker_training_state.py @@ -2591,375 +1468,134 @@ class Model(network.Network, version_utils.ModelVersionSelector): Arguments: initial_epoch: The original initial_epoch user passes in in `fit()`. - mode: The mode for running `model.fit()`. Returns: If the training is recovering from previous failure under multi-worker training setting, return the epoch the training is supposed to continue at. Otherwise, return the `initial_epoch` the user passes in. """ - if hasattr(self, '_training_state'): + if self._training_state is not None: return self._training_state.maybe_load_initial_epoch_from_ckpt( - initial_epoch, mode) + initial_epoch, mode=ModeKeys.TRAIN) return initial_epoch - def _get_training_eval_metrics(self): - """Returns all the metrics that are to be reported. - - This includes the output loss metrics, compile metrics/weighted metrics, - add_metric metrics. - """ - metrics = [] - metrics.extend(getattr(self, '_output_loss_metrics', None) or []) - metrics.extend(getattr(self, 'metrics', None) or []) - return metrics - def _assert_compile_was_called(self): # Checks whether `compile` has been called. If it has been called, # then the optimizer is set. This is different from whether the # model is compiled # (i.e. whether the model is built and its inputs/outputs are set). - if not self.optimizer: + if not self._is_compiled: raise RuntimeError('You must compile your model before ' 'training/testing. ' 'Use `model.compile(optimizer, loss)`.') - def _in_multi_worker_mode(self): - """Method to infer if this `Model` is working in multi-worker settings. - - Multi-worker training refers to the setup where the training is - distributed across multiple workers, as opposed to the case where - only a local process performs the training. This function is - used to infer for example whether or not a distribute coordinator - should be run, and thus TensorFlow servers should be started for - communication with other servers in the cluster, or whether or not - saving/restoring checkpoints is relevant for preemption fault tolerance. - - Experimental. Signature and implementation are subject to change. - - Returns: - Whether this model indicates it's working in multi-worker settings. - """ - strategy = self._get_distribution_strategy() - return strategy and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access - - def _get_distribution_strategy(self): - # If the model was compiled under the scope of a `tf.distribute.Strategy', - # `self._distribution_strategy` would have been set and model should infer - # that as the used strategy (even if it's out of strategy scope already). - strategy = self._distribution_strategy - - # Otherwise, use the strategy whose scope this is in. - if not strategy and ds_context.has_strategy(): - strategy = ds_context.get_strategy() - - return strategy + def _set_inputs(self, inputs, outputs=None, training=None): + """This method is for compat with Modelv1. Only inputs are needed here.""" + self._set_save_spec(inputs) @property def _trackable_saved_model_saver(self): return model_serialization.ModelSavedModelSaver(self) + def _list_functions_for_serialization(self, serialization_cache): + # SavedModel needs to ignore the execution functions. + train_function = self.train_function + test_function = self.test_function + predict_function = self.predict_function + self.train_function = None + self.test_function = None + self.predict_function = None + functions = super( + Model, self)._list_functions_for_serialization(serialization_cache) + self.train_function = train_function + self.test_function = test_function + self.predict_function = predict_function + return functions -class _TrainingEndpoint(object): - """A container for the training output/target and related entities. - - In the case of model with multiple outputs, there is a one-to-one mapping - between model output (y_pred), model target (y_true), loss, metrics etc. - By unifying these entities into one class, different entity can access - information between each other, rather than currently access different list of - attributes of the model. - """ - - def __init__(self, - output, - output_name, - loss_fn, - loss_weight=None, - training_target=None, - output_loss_metric=None, - sample_weight=None, - sample_weight_mode=None): - """Initialize the _TrainingEndpoint. - - Note that the output and output_name should be stable as long as the model - structure doesn't change. The training_target suppose to be mutable since - the information is provided via `compile()` - - Args: - output: the output tensor of the model. - output_name: the unique name of the output tensor. - loss_fn: the loss function for the output tensor. - loss_weight: float, the weights for the loss. - training_target: the _TrainingTarget for the model. - output_loss_metric: the metric object for the loss function. - sample_weight: the weights for how a sample is weighted during metric and - loss calculation. Could be None. - sample_weight_mode: string, 'temporal', 'samplewise' or None. The mode for - how the sample_weight is populated. - """ - self._output = output - self._output_name = output_name - self._loss_fn = loss_fn - self._loss_weight = loss_weight - self._training_target = training_target - self._output_loss_metric = output_loss_metric - self._sample_weight = sample_weight - self._sample_weight_mode = sample_weight_mode - - @property - def output(self): - return self._output - - @property - def output_name(self): - return self._output_name - - @property - def shape(self): - return K.int_shape(self.output) - - @property - def loss_fn(self): - return self._loss_fn - - @property - def loss_weight(self): - return self._loss_weight - - @loss_weight.setter - def loss_weight(self, value): - self._loss_weight = value - - @property - def training_target(self): - return self._training_target - - @training_target.setter - def training_target(self, value): - self._training_target = value - - def create_training_target(self, target, run_eagerly=False): - """Create training_target instance and update the self.training_target. - - Note that the input target should just be a tensor or None, and - corresponding training target will be created based on the output and - loss_fn. - - Args: - target: the target tensor for the current output. Could be None. - run_eagerly: boolean, whether the model is in run_eagerly mode. - - Raises: - ValueError if the training_target field for the current instance has - already been populated. - """ - if self.has_training_target(): - raise ValueError('The training_target field for the _TrainingEndpoint ' - 'instance has already been populated') - if run_eagerly: - # When run_eagerly, the target tensor is ignored, and the None placeholder - # is created instead. - self.training_target = _TrainingTarget( - None, feedable=True, skip_target_weights=False) - return - - if self.should_skip_target(): - self.training_target = _TrainingTarget(None) + def _should_eval(self, epoch, validation_freq): + epoch = epoch + 1 # one-index the user-facing epoch. + if isinstance(validation_freq, int): + return epoch % validation_freq == 0 + elif isinstance(validation_freq, list): + return epoch in validation_freq else: - if target is not None and not K.is_placeholder(target): - feedable = False - skip_target_weights = True - else: - feedable = True - skip_target_weights = False + raise ValueError('Expected `validation_freq` to be a list or int.') - if target is None: - target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get( - self.loss_fn, K.dtype(self.output)) + ###################################################################### + # Functions below exist only as v1 / v2 compatibility shims. + ###################################################################### - target = K.placeholder( - ndim=len(self.shape), - name=self.output_name + '_target', - sparse=K.is_sparse(self.output), - dtype=target_dtype) + def _get_compile_args(self): + """Used for saving or cloning a Model.""" + self._assert_compile_was_called() + # pylint: disable=protected-access + compile_args = { + 'optimizer': self.optimizer, + 'loss': self.compiled_loss._user_losses, + 'metrics': self.compiled_metrics._user_metrics, + 'weighted_metrics': self.compiled_metrics._user_weighted_metrics, + 'loss_weights': self.compiled_loss._user_loss_weights, + 'sample_weight_mode': None, + } + # pylint: enable=protected-access + return compile_args - self.training_target = _TrainingTarget( - target, - feedable=feedable, - skip_target_weights=skip_target_weights) + def _get_callback_model(self): + return self + + def _in_multi_worker_mode(self): + return self.distribute_strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access + + def _get_distribution_strategy(self): + return self.distribute_strategy @property - def output_loss_metric(self): - return self._output_loss_metric - - @output_loss_metric.setter - def output_loss_metric(self, value): - self._output_loss_metric = value - - @property - def sample_weight(self): - return self._sample_weight - - @sample_weight.setter - def sample_weight(self, value): - self._sample_weight = value - - @property - def sample_weight_mode(self): - return self._sample_weight_mode - - @sample_weight_mode.setter - def sample_weight_mode(self, value): - self._sample_weight_mode = value - - def should_skip_target(self): - return self._loss_fn is None - - def should_skip_target_weights(self): - return (self.should_skip_target() or self.training_target is None or - self.training_target.skip_target_weights) - - def has_training_target(self): - return self.training_target is not None - - def has_feedable_training_target(self): - return (not self.should_skip_target() and - self.training_target is not None and self.training_target.feedable) - - def loss_name(self): - if self._loss_fn is not None: - return self._output_name + '_loss' - return None - - @property - def feed_output_shape(self): - """The output shape for the feedable target.""" - if not self.has_feedable_training_target(): - return None - - if ((isinstance(self.loss_fn, losses.LossFunctionWrapper) and - self.loss_fn.fn == losses.sparse_categorical_crossentropy)) or ( - isinstance(self.loss_fn, losses.SparseCategoricalCrossentropy)): - if K.image_data_format() == 'channels_first': - return (self.shape[0], 1) + self.shape[2:] - else: - return self.shape[:-1] + (1,) - elif (not isinstance(self.loss_fn, losses.Loss) or - (isinstance(self.loss_fn, losses.LossFunctionWrapper) and - (getattr(losses, self.loss_fn.fn.__name__, None) is None))): - # If the given loss is not an instance of the `Loss` class (custom - # class) or if the loss function that is wrapped is not in the - # `losses` module, then it is a user-defined loss and we make no - # assumptions about it. - return None - else: - return self.shape - - def sample_weights_mismatch(self): - """Check if the sample weight and the mode match or not.""" - # If there is a mismatch between sample weight mode and the placeholders - # created, then recompile the sub-graphs that depend on sample weights. - return ( - (self.sample_weight_mode is not None and self.sample_weight is None) or - (self.sample_weight_mode is None and self.sample_weight is not None)) - - def populate_sample_weight(self, sample_weight, sample_weight_mode): - """Populate the sample weight and based on the sample weight mode.""" - if (sample_weight is None and - (self.should_skip_target_weights() or sample_weight_mode is None or - context.executing_eagerly())): - self._sample_weight = None - return - - assert sample_weight_mode in ['temporal', 'samplewise'] - if sample_weight_mode == 'temporal': - default_value = [[1.]] - shape = [None, None] - else: - # sample_weight_mode == 'samplewise' - default_value = [1.] - shape = [None] - - if sample_weight is not None: - if not sample_weight.shape.is_compatible_with(shape): - raise ValueError('Received sample weight with shape {}. Expected shape ' - '{}.'.format(sample_weight.shape, shape)) - self._sample_weight = sample_weight - else: - self._sample_weight = array_ops.placeholder_with_default( - constant_op.constant(default_value, dtype=K.floatx()), - shape=shape, - name=self.output_name + '_sample_weights') + def _compile_was_called(self): + return self._is_compiled -class _TrainingTarget(object): - """Container for a target tensor (y_true) and its metadata (shape, loss...). +def reduce_per_replica(values, strategy, reduction='first'): + """Reduce PerReplica objects. Arguments: - target: A target tensor for the model. It may be `None` if the - output is excluded from loss computation. It is still kept as None - since each output of the model should have a corresponding target. If - the target is None, the rest of the attributes will be None as well. - feedable: Boolean, whether the target is feedable (requires data to be - passed in `fit` or `train_on_batch`), or not (model compiled with - `target_tensors` argument). - skip_target_weights: Boolean, whether the target should be skipped during - weights calculation. - """ - - def __init__(self, target, feedable=False, skip_target_weights=True): - self._target = target - self._feedable = feedable - self._skip_target_weights = skip_target_weights - - @property - def target(self): - return self._target - - @property - def feedable(self): - return self._feedable - - @property - def skip_target_weights(self): - return self._skip_target_weights - - -def _is_symbolic_tensor(x): - return tensor_util.is_tensor(x) and not isinstance(x, ops.EagerTensor) - - -def _convert_scipy_sparse_tensor(value, expected_input): - """Handle scipy sparse tensor conversions. - - This method takes a value 'value' and returns the proper conversion. If - value is a scipy sparse tensor and the expected input is a dense tensor, - we densify 'value'. If value is a scipy sparse tensor and the expected input - is a TF SparseTensor, we convert 'value' to a SparseTensor. If 'value' is - not a scipy sparse tensor, or scipy is not imported, we pass it through - unchanged. - - Arguments: - value: An object that may be a scipy sparse tensor - expected_input: The expected input placeholder. + values: Structure of `PerReplica` objects or `Tensor`s. `Tensor`s are + returned as-is. + strategy: `tf.distribute.Strategy` object. + reduction: One of 'first', 'concat'. Returns: - The possibly-converted 'value'. + Structure of `Tensor`s. """ - if issparse is not None and issparse(value): - if ops.is_dense_tensor_like(expected_input): - if ops.executing_eagerly_outside_functions(): - # In TF2 we do not silently densify sparse matrices. - raise ValueError('A SciPy sparse matrix was passed to a model ' - 'that expects dense inputs. Please densify your ' - 'inputs first, such as by calling `x.toarray().') - return value.toarray() + + def _reduce(v): + """Reduce a single `PerReplica` object.""" + if not isinstance(v, ds_values.PerReplica): + return v + elif reduction == 'first': + return strategy.unwrap(v)[0] # pylint: disable=protected-access + elif reduction == 'concat': + return concat(strategy.unwrap(v)) # pylint: disable=protected-access else: - sparse_coo = value.tocoo() - row, col = sparse_coo.row, sparse_coo.col - data, shape = sparse_coo.data, sparse_coo.shape - indices = np.concatenate((np.expand_dims(row, 1), np.expand_dims(col, 1)), - 1) - return sparse_tensor.SparseTensor(indices, data, shape) - else: - return value + raise ValueError('`reduction` must be "first" or "concat".') + + return nest.map_structure(_reduce, values) + + +def concat(tensors, axis=0): + """Concats `tensor`s along `axis`.""" + if isinstance(tensors[0], sparse_tensor.SparseTensor): + return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors) + if isinstance(tensors[0], ragged_tensor.RaggedTensor): + return ragged_concat_ops.concat(tensors, axis=axis) + return array_ops.concat(tensors, axis=axis) + + +def to_numpy(tensors): + """Converts a structure of `Tensor`s to `NumPy` arrays.""" + + def _to_single_numpy(t): + if isinstance(t, ops.Tensor): + return t.numpy() + return t # Don't turn ragged or sparse tensors to NumPy. + + return nest.map_structure(_to_single_numpy, tensors) diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index a9c746d6a52..531e576662b 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -226,13 +226,9 @@ def model_iteration(model, epochs=epochs, steps_per_epoch=steps_per_epoch, samples=num_samples_or_steps, - verbose=0, # Handle ProgBarLogger separately in this loop. + count_mode=count_mode, + verbose=verbose, mode=mode) - # TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready. - progbar = training_utils.get_progbar( - model, count_mode, mode != ModeKeys.PREDICT) - progbar.params = callbacks.params - progbar.params['verbose'] = verbose # Find beforehand arrays that need sparse-to-dense conversion. if issparse is not None and not use_steps: @@ -259,7 +255,6 @@ def model_iteration(model, callbacks.model.stop_training = False callbacks._call_begin_hook(mode) - progbar.on_train_begin() initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode) @@ -275,7 +270,6 @@ def model_iteration(model, model.reset_metrics() if mode == ModeKeys.TRAIN: callbacks.on_epoch_begin(epoch, epoch_logs) - progbar.on_epoch_begin(epoch, epoch_logs) if use_steps: # Step-wise loop. @@ -290,7 +284,6 @@ def model_iteration(model, while step < target_steps: batch_logs = {'batch': step, 'size': 1} callbacks._call_batch_hook(mode, 'begin', step, batch_logs) - progbar.on_batch_begin(step, batch_logs) # Get outputs. try: @@ -320,9 +313,6 @@ def model_iteration(model, elif step > 0: steps_per_epoch = step aggregator.steps = steps_per_epoch - if mode == ModeKeys.TRAIN: - progbar.params['steps'] = steps_per_epoch - progbar.progbar.target = steps_per_epoch else: # We ran out of batches while the user passed an iterator (legacy). callbacks.model.stop_training = True @@ -350,7 +340,6 @@ def model_iteration(model, # Callbacks batch end. batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) callbacks._call_batch_hook(mode, 'end', step, batch_logs) - progbar.on_batch_end(step, batch_logs) step += 1 if callbacks.model.stop_training: @@ -392,7 +381,6 @@ def model_iteration(model, # Callbacks batch_begin. batch_logs = {'batch': batch_index, 'size': len(batch_ids)} callbacks._call_batch_hook(mode, 'begin', batch_index, batch_logs) - progbar.on_batch_begin(batch_index, batch_logs) # Get outputs. batch_outs = f(ins_batch) @@ -407,7 +395,6 @@ def model_iteration(model, # Callbacks batch end. batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) callbacks._call_batch_hook(mode, 'end', batch_index, batch_logs) - progbar.on_batch_end(batch_index, batch_logs) if callbacks.model.stop_training: break @@ -452,7 +439,6 @@ def model_iteration(model, if mode == ModeKeys.TRAIN: # Epochs only apply to `fit`. callbacks.on_epoch_end(epoch, epoch_logs) - progbar.on_epoch_end(epoch, epoch_logs) # Reinitialize dataset iterator for the next epoch. if reset_dataset_after_each_epoch and epoch < epochs - 1: diff --git a/tensorflow/python/keras/engine/training_dataset_test.py b/tensorflow/python/keras/engine/training_dataset_test.py index 684c966cdd2..79719012c47 100644 --- a/tensorflow/python/keras/engine/training_dataset_test.py +++ b/tensorflow/python/keras/engine/training_dataset_test.py @@ -107,8 +107,7 @@ class TestTrainingWithDataset(keras_parameterized.TestCase): validation_data=dataset, validation_steps=2) # Test with validation split - with self.assertRaisesRegexp( - ValueError, '`validation_split` argument is not supported when '): + with self.assertRaises(ValueError): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, validation_split=0.5, validation_steps=2) @@ -124,19 +123,6 @@ class TestTrainingWithDataset(keras_parameterized.TestCase): verbose=0, sample_weight=sample_weight) - # Test invalid usage - with self.assertRaisesRegexp( - ValueError, 'The `batch_size` argument must not be specified'): - model.fit(dataset, batch_size=10, epochs=1, steps_per_epoch=2, - verbose=0) - - with self.assertRaisesRegexp( - ValueError, 'The `batch_size` argument must not be specified'): - model.predict(dataset, batch_size=10, steps=2, verbose=0) - with self.assertRaisesRegexp( - ValueError, 'The `batch_size` argument must not be specified'): - model.evaluate(dataset, batch_size=10, steps=2, verbose=0) - with self.assertRaisesRegexp( ValueError, '(you should not specify a target)|' '(`y` argument is not supported when using dataset as input.)'): @@ -144,14 +130,11 @@ class TestTrainingWithDataset(keras_parameterized.TestCase): epochs=1, steps_per_epoch=2, verbose=0) # With an infinite dataset, `steps_per_epoch`/`steps` argument is required. - with self.assertRaisesRegexp( - ValueError, 'the `steps_per_epoch` argument'): + with self.assertRaises(ValueError): model.fit(dataset, epochs=1, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'the `steps` argument'): + with self.assertRaises(ValueError): model.evaluate(dataset, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'the `steps` argument'): + with self.assertRaises(ValueError): model.predict(dataset, verbose=0) @keras_parameterized.run_with_all_model_types(exclude_models='sequential') @@ -185,14 +168,6 @@ class TestTrainingWithDataset(keras_parameterized.TestCase): model.fit(dataset_tuple, epochs=1, steps_per_epoch=2, verbose=1) model.evaluate(dataset_tuple, steps=2, verbose=1) - predict_dataset_tuple = dataset_ops.Dataset.from_tensor_slices( - (input_a_np, input_b_np)) - # TODO(b/123360757): Remove below assertion once predict() supports - # muti-input datasets. - with self.assertRaisesRegexp(ValueError, - 'Error when checking model input'): - model.predict(predict_dataset_tuple, steps=1) - # Test with dict input_dict = {'input_1': input_a_np, 'input_2': input_b_np} if testing_utils.get_model_type() == 'subclass': @@ -457,15 +432,7 @@ class TestTrainingWithDataset(keras_parameterized.TestCase): self.assertIn('10/10', lines[-1]) self.assertLen(history.history['loss'], 2) - # The first epoch will invoke batch begin 11 times, since it doesn't know - # the cardinality. The second epoch should just invoke 10 times. - if (testing_utils.should_run_eagerly() - or testing_utils.should_run_tf_function()): - expected_batch_begin_count = 21 - else: - expected_batch_begin_count = 20 - self.assertEqual(batch_counter.batch_begin_count, - expected_batch_begin_count) + self.assertEqual(batch_counter.batch_begin_count, 21) self.assertEqual(batch_counter.batch_end_count, 20) model.evaluate(dataset) out = model.predict(dataset) diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py index 8ac94f346c0..d6cd412d1ec 100644 --- a/tensorflow/python/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -194,12 +194,10 @@ class TrainingTest(keras_parameterized.TestCase): model.fit(dataset, epochs=1, verbose=0) # Step argument is required for infinite datasets. - with self.assertRaisesRegexp(ValueError, - 'specify the `validation_steps` argument.'): + with self.assertRaises(ValueError): model.fit(dataset, steps_per_epoch=2, epochs=1, verbose=0, validation_data=validation_dataset) - with self.assertRaisesRegexp(ValueError, - 'specify the `validation_steps` argument.'): + with self.assertRaises(ValueError): model.fit(dataset, steps_per_epoch=2, epochs=1, verbose=0, validation_data=validation_dataset) @@ -355,7 +353,8 @@ class CorrectnessTest(keras_parameterized.TestCase): x = np.ones((20, 4)).astype(np.float32) y = np.random.randint(0, 3, size=(20,)).astype(np.int64) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2) - evaluation_results = dict(zip(model.metrics_names, model.evaluate(dataset))) + results = model.evaluate(dataset) + evaluation_results = dict(zip(model.metrics_names, results)) # Rate of dropout depends on the learning phase. self.assertEqual(evaluation_results['regularization_loss'], expected_validation_loss) diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py index d19b2907aa4..1fcf3ef25e4 100644 --- a/tensorflow/python/keras/engine/training_generator.py +++ b/tensorflow/python/keras/engine/training_generator.py @@ -174,12 +174,9 @@ def model_iteration(model, steps_per_epoch=steps_per_epoch, batch_size=batch_size, samples=num_samples_or_steps, - verbose=0, # Handle ProgBar as part of Callbacks once hooks are ready. + count_mode=count_mode, + verbose=verbose, mode=mode) - # TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready. - progbar = training_utils.get_progbar(model, count_mode) - progbar.params = callbacks.params - progbar.params['verbose'] = verbose if mode == ModeKeys.PREDICT: aggregator = training_utils.OutputsAggregator(True, steps=steps_per_epoch) @@ -194,7 +191,6 @@ def model_iteration(model, callbacks.model.stop_training = False callbacks._call_begin_hook(mode) - progbar.on_train_begin() initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode) @@ -207,7 +203,6 @@ def model_iteration(model, epoch_logs = {} if mode == ModeKeys.TRAIN: callbacks.on_epoch_begin(epoch, epoch_logs) - progbar.on_epoch_begin(epoch, epoch_logs) if steps_per_epoch is None: # Loop over dataset until `OutOfRangeError` is raised. @@ -237,9 +232,6 @@ def model_iteration(model, elif step > 0: steps_per_epoch = step aggregator.steps = steps_per_epoch - if mode == ModeKeys.TRAIN: - progbar.params['steps'] = steps_per_epoch - progbar.progbar.target = steps_per_epoch else: # We ran out of batches while the user passed an iterator (legacy). callbacks.model.stop_training = True @@ -259,7 +251,6 @@ def model_iteration(model, # Callbacks batch begin. batch_logs = {'batch': step, 'size': batch_size} callbacks._call_batch_hook(mode, 'begin', step, batch_logs) - progbar.on_batch_begin(step, batch_logs) is_deferred = not model._is_compiled batch_outs = batch_function(*batch_data) @@ -283,16 +274,12 @@ def model_iteration(model, verbose=verbose, mode=mode) - progbar.params = callbacks.params - progbar.params['verbose'] = verbose - # Aggregate results. aggregator.aggregate(batch_outs) # Callbacks batch end. batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) callbacks._call_batch_hook(mode, 'end', step, batch_logs) - progbar.on_batch_end(step, batch_logs) step += 1 if callbacks.model.stop_training: @@ -330,7 +317,6 @@ def model_iteration(model, if mode == ModeKeys.TRAIN: # Epochs only apply to `fit`. callbacks.on_epoch_end(epoch, epoch_logs) - progbar.on_epoch_end(epoch, epoch_logs) # Recreate dataset iterator for the next epoch. if reset_dataset_after_each_epoch and epoch < epochs - 1: diff --git a/tensorflow/python/keras/engine/training_generator_test.py b/tensorflow/python/keras/engine/training_generator_test.py index 30e59114e75..c9642fd7c7f 100644 --- a/tensorflow/python/keras/engine/training_generator_test.py +++ b/tensorflow/python/keras/engine/training_generator_test.py @@ -245,15 +245,14 @@ class TestGeneratorMethods(keras_parameterized.TestCase): run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) - err_msg = 'Output of generator should be a tuple of 1 or 2 or 3 elements' - with self.assertRaisesRegex(ValueError, err_msg): + with self.assertRaises(ValueError): model.fit_generator(invalid_generator(), steps_per_epoch=5, epochs=1, verbose=1, max_queue_size=10, use_multiprocessing=False) - with self.assertRaisesRegex(ValueError, err_msg): + with self.assertRaises(ValueError): model.fit_generator(custom_generator(), steps_per_epoch=5, epochs=1, @@ -262,12 +261,12 @@ class TestGeneratorMethods(keras_parameterized.TestCase): use_multiprocessing=False, validation_data=invalid_generator(), validation_steps=10) - with self.assertRaisesRegex(ValueError, err_msg): + with self.assertRaises(ValueError): model.predict_generator(invalid_generator(), steps=5, max_queue_size=10, use_multiprocessing=False) - with self.assertRaisesRegex(ValueError, err_msg): + with self.assertRaises(ValueError): model.evaluate_generator(invalid_generator(), steps=5, max_queue_size=10, @@ -330,38 +329,11 @@ class TestGeneratorMethods(keras_parameterized.TestCase): model.evaluate(custom_generator_changing_batch_size(), steps=5) model.predict(custom_generator_changing_batch_size(), steps=5) - @keras_parameterized.run_with_all_model_types - @keras_parameterized.run_all_keras_modes - def test_invalid_batch_size_argument(self): - - def ones_generator(): - while True: - yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) - - model = testing_utils.get_small_mlp( - num_hidden=10, num_classes=1, input_dim=10) - - model.compile( - 'adam', - 'binary_crossentropy', - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - with self.assertRaisesRegexp( - ValueError, 'The `batch_size` argument must not be specified'): - model.fit(ones_generator(), batch_size=2, epochs=2) - with self.assertRaisesRegexp( - ValueError, 'The `batch_size` argument must not be specified'): - model.evaluate(ones_generator(), batch_size=2) - - with self.assertRaisesRegexp( - ValueError, 'The `batch_size` argument must not be specified'): - model.predict(ones_generator(), batch_size=2) - @keras_parameterized.run_with_all_model_types @keras_parameterized.run_all_keras_modes @data_utils.dont_use_multiprocessing_pool def test_generator_dynamic_shapes(self): + x = [ 'I think juice is great', 'unknown is the best language since slicedbread', diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index ac2f3972ad8..6ee8971d567 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -20,8 +20,6 @@ from __future__ import print_function import collections import io -import logging -import re import sys from absl.testing import parameterized @@ -29,16 +27,13 @@ import numpy as np import six from tensorflow.python import keras -from tensorflow.python import tf2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context -from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import keras_parameterized -from tensorflow.python.keras import losses from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import testing_utils from tensorflow.python.keras.callbacks import Callback @@ -53,7 +48,6 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.rmsprop import RMSPropOptimizer try: @@ -62,206 +56,6 @@ except ImportError: scipy_sparse = None -class CompileTest(keras_parameterized.TestCase): - - def _get_multi_output_model(self): - input_a = keras.layers.Input(shape=(3,), name='input_a') - output_a = keras.layers.Dense(1, name='dense_1')(input_a) - output_b = keras.layers.Dense(1, name='dense_2')(input_a) - return keras.models.Model(input_a, [output_a, output_b]) - - def _do_test_compile_with_model_and_single_loss(self, model, loss): - model.compile( - optimizer='adam', - loss=loss, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - self.assertEqual(model.loss, loss) - - loss = losses.get(loss) - if not isinstance(loss, list): - loss_list = [loss] * len(model.outputs) - - self.assertEqual(len(model.loss_functions), len(loss_list)) - for i in range(len(loss_list)): - self.assertIsInstance(model.loss_functions[i], losses.LossFunctionWrapper) - if not isinstance(loss_list[i], losses.LossFunctionWrapper): - self.assertEqual(model.loss_functions[i].fn, loss_list[i]) - self.assertAllEqual(model._loss_weights_list, [1.] * len(loss_list)) - - def test_respect_run_functions_eagerly(self): - with context.eager_mode(): - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=2, input_dim=3) - model.compile('sgd', 'mse') - def_function.run_functions_eagerly(True) - self.assertTrue(model.run_eagerly) - def_function.run_functions_eagerly(False) - self.assertFalse(model.run_eagerly) - - @keras_parameterized.run_all_keras_modes - @parameterized.named_parameters(('loss_string', 'mse'), - ('loss_function', losses.mean_squared_error), - ('loss_instance', losses.MeanSquaredError())) - def test_compile_with_single_output(self, loss): - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=2, input_dim=3) - self._do_test_compile_with_model_and_single_loss(model, loss) - - @keras_parameterized.run_all_keras_modes - @parameterized.named_parameters(('loss_string', 'mse'), - ('loss_function', losses.mean_squared_error), - ('loss_instance', losses.MeanSquaredError())) - def test_compile_with_multi_output(self, loss): - model = self._get_multi_output_model() - self._do_test_compile_with_model_and_single_loss(model, loss) - - @keras_parameterized.run_all_keras_modes - def test_compile_with_multi_output_and_multi_loss(self): - model = self._get_multi_output_model() - # Test loss is a list. - loss = ['mse', 'mae'] - model.compile( - optimizer='adam', - loss=loss, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - self.assertEqual(model.loss_functions[0].fn, losses.mean_squared_error) - self.assertEqual(model.loss_functions[1].fn, losses.mean_absolute_error) - self.assertAllEqual(model._loss_weights_list, [1., 1.]) - - # Test loss is a dict. - loss = {'dense_1': 'mae', 'dense_2': 'mse'} - model.compile( - optimizer='adam', - loss=loss, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - self.assertEqual(model.loss_functions[0].fn, losses.mean_absolute_error) - self.assertEqual(model.loss_functions[1].fn, losses.mean_squared_error) - self.assertAllEqual(model._loss_weights_list, [1., 1.]) - - @keras_parameterized.run_all_keras_modes - def test_compile_with_multi_output_and_loss_weights_list(self): - model = self._get_multi_output_model() - loss_weights = [1., 2.] - model.compile( - optimizer='adam', - loss='mse', - loss_weights=loss_weights, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - self.assertAllEqual(model._loss_weights_list, [1., 2.]) - - def test_compile_with_multi_output_and_loss_weights_dict(self): - with ops.get_default_graph().as_default(): - model = self._get_multi_output_model() - loss_weights = {'dense_1': 1., 'dense_2': 2.} - model.compile(optimizer='adam', loss='mse', loss_weights=loss_weights) - self.assertAllEqual(model._loss_weights_list, [1., 2.]) - - input_np = np.random.random((10, 3)) - output_a_np = np.random.random((10, 1)) - output_b_np = np.random.random((10, 1)) - - with self.cached_session() as sess: - sess.run(variables_lib.global_variables_initializer()) - total_loss, y_preds = sess.run( - [model.total_loss, model.outputs], - feed_dict={ - 'input_a:0': input_np, - 'dense_1_target:0': output_a_np, - 'dense_2_target:0': output_b_np - }) - self.assertAllClose( - total_loss, - np.mean( - np.add((output_a_np - y_preds[0])**2, - 2 * (output_b_np - y_preds[1])**2))) - - @keras_parameterized.run_all_keras_modes - def test_compile_with_incorrect_loss_size(self): - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=2, input_dim=3) - with self.assertRaisesRegexp(ValueError, 'The model has 1 outputs'): - model.compile( - optimizer='adam', - loss=['mse', 'mae'], - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - @keras_parameterized.run_all_keras_modes - def test_compile_with_incorrect_loss_key(self): - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=2, input_dim=3) - with self.assertRaisesRegexp( - ValueError, - r'Unknown entries in loss dictionary: \[\'unknown_output\'\]. ' - r'Only expected following keys: \[\'dense_1\'\]'): - model.compile( - optimizer='adam', - loss={'unknown_output': 'mse'}, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - @keras_parameterized.run_all_keras_modes - def test_compile_with_incorrect_loss_weights_size(self): - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=2, input_dim=3) - with self.assertRaisesRegexp(ValueError, - 'it should have one entry per model output'): - model.compile( - optimizer='adam', - loss='mse', - loss_weights=[1., 2.], - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - @keras_parameterized.run_all_keras_modes - def test_compile_with_incorrect_loss_weights_key(self): - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=2, input_dim=3) - with self.assertRaisesRegexp( - ValueError, - r'Unknown entries in loss_weights dictionary: \[\'unknown_output\'\]. ' - r'Only expected following keys: \[\'dense_1\'\]'): - model.compile( - optimizer='adam', - loss='mse', - loss_weights={'unknown_output': 1.}, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - @keras_parameterized.run_all_keras_modes - def test_compile_with_incorrect_sample_weight_mode(self): - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=2, input_dim=3) - with self.assertRaisesRegexp( - ValueError, - r'Unknown entries in sample_weight_mode dictionary: \[\'unknown\'\]. ' - r'Only expected following keys: \[\'dense_1\'\]'): - model.compile( - optimizer='adam', - loss='mse', - sample_weight_mode={'unknown': 'temporal'}, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - def test_compile_with_session_kwargs(self): - with ops.Graph().as_default(): - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=2, input_dim=3) - - # Test that unknown arguments are not accepted - with self.assertRaisesRegexp( - TypeError, - r'Invalid keyword argument'): - model.compile( - optimizer='adam', - loss='mse', - foo=True) - - class TrainingTest(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types @@ -356,7 +150,7 @@ class TrainingTest(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types def test_target_dtype_matches_output(self): - def _loss_fn(labels, preds): + def loss_fn(labels, preds): self.assertEqual(labels.dtype, preds.dtype) return labels - preds @@ -367,7 +161,7 @@ class TrainingTest(keras_parameterized.TestCase): targets = np.ones(10, dtype=np.float64) model.compile( 'sgd', - loss=_loss_fn, + loss=loss_fn, run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) model.train_on_batch(inputs, targets) @@ -584,31 +378,6 @@ class TrainingTest(keras_parameterized.TestCase): batch_size=5, verbose=0) - # Invalid use cases - with self.assertRaises(ValueError): - model.train_on_batch({'input_a': input_a_np}, - [output_d_np, output_e_np]) - with self.assertRaises(ValueError): - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - epochs=1, - validation_data=([input_a_np, input_b_np], 0, 0), - verbose=0) - with self.assertRaises(ValueError): - model.train_on_batch([input_a_np], [output_d_np, output_e_np]) - with self.assertRaises(ValueError): - model.train_on_batch(1, [output_d_np, output_e_np]) - with self.assertRaises(ValueError): - model.train_on_batch(input_a_np, [output_d_np, output_e_np]) - with self.assertRaises(ValueError): - bad_input = np.random.random((11, 3)) - model.train_on_batch([bad_input, input_b_np], - [output_d_np, output_e_np]) - with self.assertRaises(ValueError): - bad_target = np.random.random((11, 4)) - model.train_on_batch([input_a_np, input_b_np], - [bad_target, output_e_np]) - # Build single-input model x = keras.layers.Input(shape=(3,), name='input_a') y = keras.layers.Dense(4)(x) @@ -620,10 +389,6 @@ class TrainingTest(keras_parameterized.TestCase): experimental_run_tf_function=testing_utils.should_run_tf_function()) # This will work model.fit([input_a_np], output_d_np, epochs=1) - # TODO(gsundeep) Test only works in eager, file ticket - if testing_utils.should_run_eagerly() and context.executing_eagerly(): - with self.assertRaises(ValueError): - model.fit([input_a_np, input_a_np], output_d_np, epochs=1) # Test model on a list of floats input_a_np = np.random.random((10, 3)) @@ -841,22 +606,6 @@ class TrainingTest(keras_parameterized.TestCase): model.evaluate(xy_function(use_namedtuple=False), **evaluate_kwargs) model.predict(x_function(use_namedtuple=False), **predict_kwargs) - xy_pattern = re.escape( - "Received namedtuple () with fields " - "`('x', 'y')` as input.") - x_pattern = re.escape( - "Received namedtuple () with fields " - "`('x',)` as input.") - - with self.assertRaisesRegex(ValueError, xy_pattern): - model.fit(xy_function(use_namedtuple=True), **fit_kwargs) - - with self.assertRaisesRegex(ValueError, xy_pattern): - model.evaluate(xy_function(use_namedtuple=True), **evaluate_kwargs) - - with self.assertRaisesRegex(ValueError, x_pattern): - model.predict(x_function(use_namedtuple=True), **predict_kwargs) - @keras_parameterized.run_all_keras_modes def test_custom_mapping_in_config(self): @@ -872,41 +621,6 @@ class TrainingTest(keras_parameterized.TestCase): model = MyModel() self.assertIn('{"a": {}}', model.to_json()) - @keras_parameterized.run_all_keras_modes(always_skip_v1=True) - def test_training_on_sparse_data_with_dense_placeholders(self): - if scipy_sparse is None: - return - - test_inputs = [ - scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2) - ] - test_outputs = [ - scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5) - ] - in1 = keras.layers.Input(shape=(3,)) - in2 = keras.layers.Input(shape=(3,)) - out1 = keras.layers.Dropout(0.5, name='dropout')(in1) - out2 = keras.layers.Dense(4, name='dense_1')(in2) - model = keras.Model([in1, in2], [out1, out2]) - model.experimental_run_tf_function = testing_utils.should_run_tf_function() - - with self.assertRaisesRegexp(ValueError, 'Please densify'): - model.predict(test_inputs, batch_size=2) - optimizer = 'rmsprop' - model.compile( - optimizer, - 'mse', - metrics=['mae', metrics_module.CategoricalAccuracy()], - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - with self.assertRaisesRegexp(ValueError, 'Please densify'): - model.fit(test_inputs, test_outputs, - epochs=1, batch_size=2) - - with self.assertRaisesRegexp(ValueError, 'Please densify'): - model.evaluate(test_inputs, test_outputs, batch_size=2) - def test_training_on_sparse_data_with_dense_placeholders_v1(self): with ops.Graph().as_default(): if scipy_sparse is None: @@ -1087,66 +801,61 @@ class TrainingTest(keras_parameterized.TestCase): self.assertEqual(l.non_trainable_variables, [l.layer1.non_trainable_var]) self.assertLen(l.get_weights(), 2) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_logs_passed_to_callbacks(self): - with self.cached_session(): - input_dim = 5 - num_classes = 1 + input_dim = 5 + num_classes = 1 - class TestCallback(Callback): + class TestCallback(Callback): - def __init__(self): - super(TestCallback, self).__init__() - self.epoch_end_logs = None - self.batch_end_logs = None - self.epoch_end_call_count = 0 - self.batch_end_call_count = 0 + def __init__(self): + super(TestCallback, self).__init__() + self.epoch_end_logs = None + self.batch_end_logs = None + self.epoch_end_call_count = 0 + self.batch_end_call_count = 0 - def on_epoch_end(self, epoch, logs=None): - self.epoch_end_logs = logs - self.epoch_end_call_count += 1 + def on_epoch_end(self, epoch, logs=None): + self.epoch_end_logs = logs + self.epoch_end_call_count += 1 - def on_batch_end(self, batch, logs=None): - self.batch_end_logs = logs - self.batch_end_call_count += 1 + def on_batch_end(self, batch, logs=None): + self.batch_end_logs = logs + self.batch_end_call_count += 1 - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=num_classes, input_dim=input_dim) - model.compile( - loss='binary_crossentropy', - metrics=['acc'], - weighted_metrics=['mae'], - optimizer=RMSPropOptimizer(learning_rate=0.01)) + model = testing_utils.get_small_sequential_mlp( + num_hidden=10, num_classes=num_classes, input_dim=input_dim) + model.compile( + loss='binary_crossentropy', + metrics=['acc'], + weighted_metrics=['mae'], + optimizer=RMSPropOptimizer(learning_rate=0.01), + run_eagerly=testing_utils.should_run_eagerly()) - np.random.seed(1337) - (x_train, y_train), (_, _) = testing_utils.get_test_data( - train_samples=10, - test_samples=10, - input_shape=(input_dim,), - num_classes=num_classes) + np.random.seed(1337) + (x_train, y_train), (_, _) = testing_utils.get_test_data( + train_samples=10, + test_samples=10, + input_shape=(input_dim,), + num_classes=num_classes) - test_callback = TestCallback() - model.fit( - x_train, - y_train, - batch_size=2, - epochs=2, - verbose=0, - callbacks=[test_callback], - validation_data=(x_train, y_train)) - self.assertEqual(test_callback.batch_end_call_count, 10) - self.assertEqual(test_callback.epoch_end_call_count, 2) + test_callback = TestCallback() + model.fit( + x_train, + y_train, + batch_size=2, + epochs=2, + verbose=0, + callbacks=[test_callback], + validation_data=(x_train, y_train)) + self.assertEqual(test_callback.batch_end_call_count, 10) + self.assertEqual(test_callback.epoch_end_call_count, 2) - weighted_metric = ('mae' - if tf2.enabled() else 'weighted_mean_absolute_error') - self.assertSetEqual( - set(test_callback.batch_end_logs.keys()), - set(['batch', 'size', 'acc', 'loss', weighted_metric])) - self.assertSetEqual( - set(test_callback.epoch_end_logs.keys()), - set([ - 'acc', 'loss', weighted_metric, 'val_acc', 'val_loss', - 'val_' + weighted_metric - ])) + self.assertSetEqual( + set(test_callback.batch_end_logs.keys()), set(['acc', 'loss', 'mae'])) + self.assertSetEqual( + set(test_callback.epoch_end_logs.keys()), + set(['acc', 'loss', 'mae', 'val_acc', 'val_loss', 'val_mae'])) @keras_parameterized.run_all_keras_modes def test_mismatched_output_shape_and_target_shape(self): @@ -1160,8 +869,8 @@ class TrainingTest(keras_parameterized.TestCase): run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) # Test with Numpy data - x_train = np.random.random((10, 3, 4)) - y_train = np.random.randint(0, 5, size=(10, 3)) + x_train = np.random.random((10, 3, 4)).astype(np.float32) + y_train = np.random.randint(0, 5, size=(10, 3)).astype(np.float32) model.fit(x_train, y_train, batch_size=5, epochs=1) # Test with iterator @@ -1238,6 +947,8 @@ class TrainingTest(keras_parameterized.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_static_batch_in_input_layer(self): + if context.executing_eagerly(): + self.skipTest('Not inferred in eager.') class Counter(keras.callbacks.Callback): @@ -1268,6 +979,8 @@ class TrainingTest(keras_parameterized.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_static_batch_in_input_layer_consistency_checks(self): + if context.executing_eagerly(): + self.skipTest('Not inferred in eager.') x, y = np.ones((64, 10), 'float32'), np.ones((64, 1), 'float32') inputs = keras.Input(batch_size=2, shape=(10,)) @@ -1408,6 +1121,8 @@ class TrainingTest(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types @keras_parameterized.run_all_keras_modes def test_validation_steps_without_data(self): + if context.executing_eagerly(): + self.skipTest('Check removed in new `fit`') x, y = np.ones((10, 10)), np.ones((10, 1)) model = testing_utils.get_small_mlp(2, 1, 10) model.compile( @@ -1484,9 +1199,6 @@ class TrainingTest(keras_parameterized.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2) model.fit(dataset) self.assertEqual(model._compute_dtype, 'float32') - # Input dtype should match the model dtype, even if the inputs passed to the - # model have a different dtype. - self.assertEqual(model.inputs[0].dtype, 'float32') @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_subclassed_model_with_training_arg(self): @@ -1546,62 +1258,6 @@ class TrainingTest(keras_parameterized.TestCase): class TestExceptionsAndWarnings(keras_parameterized.TestCase): - @keras_parameterized.run_with_all_model_types - @keras_parameterized.run_all_keras_modes - def test_invalid_batch_dimension(self): - - def custom_reshape(inputs): - return keras.backend.reshape(inputs, (-1, 8, 8, 3)) - - layer_1 = keras.layers.Lambda(custom_reshape) - layer_2 = keras.layers.Conv2D(32, (3, 3)) - - model = testing_utils.get_model_from_layers([layer_1, layer_2], - input_shape=(8, 8, 6)) - model.compile('sgd', loss='mse') - - with self.assertRaisesRegex( - ValueError, - 'Mismatch between expected batch size and model output batch size. ' - r'Output shape = \(20, 6, 6, 32\), expected output shape = ' - r'shape \(10, 6, 6, 32\)'): - model.predict(np.ones((10, 8, 8, 6)), batch_size=10) - - @keras_parameterized.run_all_keras_modes - def test_invalid_loss(self): - num_classes = 5 - train_samples = 1000 - test_samples = 1000 - input_dim = 5 - - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=num_classes, input_dim=input_dim) - optimizer = RMSPropOptimizer(learning_rate=0.001) - model.compile(optimizer, loss='categorical_crossentropy') - np.random.seed(1337) - (x_train, y_train), (_, _) = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - - with self.assertRaisesRegexp( - ValueError, - 'Input arrays should have the same number of samples as target arrays'): - model.fit(x_train, np.concatenate([y_train, y_train], axis=-1)) - - with self.assertRaisesRegexp(ValueError, - 'expects targets to be binary matrices'): - model.fit(x_train, y_train) - - with self.assertRaisesRegexp(ValueError, 'no loss to optimize'): - model.compile( - optimizer, - loss=None, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - model.fit(x_train) - @keras_parameterized.run_all_keras_modes def test_compile_warning_for_loss_missing_output(self): with self.cached_session(): @@ -1611,98 +1267,17 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase): model = keras.models.Model(inputs=[inp], outputs=[out_1, out_2]) optimizer = RMSPropOptimizer(learning_rate=0.001) - with test.mock.patch.object(logging, 'warning') as mock_log: - model.compile( - optimizer, - loss={ - 'dense_2': 'categorical_crossentropy', - }, - metrics={ - 'dense_2': 'categorical_accuracy', - 'dense_1': metrics_module.CategoricalAccuracy(), - }, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - msg = ('Output dense_1 missing from loss dictionary. We assume this ' - 'was done on purpose. The fit and evaluate APIs will not be ' - 'expecting any data to be passed to dense_1.') - self.assertRegexpMatches(str(mock_log.call_args), msg) - - @keras_parameterized.run_all_keras_modes - def test_invalid_steps_per_epoch_usage(self): - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(1)(x) - - model = keras.Model(x, y) - model.compile( - 'sgd', - loss='mse', - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=False) - err_msg = 'When passing input data as arrays, do not specify' - - with test.mock.patch.object(logging, 'warning') as mock_log: - model._standardize_user_data( - np.zeros((100, 1)), np.ones((100, 1)), check_steps=True, steps=4) - self.assertRegexpMatches(str(mock_log.call_args), err_msg) - - @keras_parameterized.run_with_all_model_types - @keras_parameterized.run_all_keras_modes - def test_invalid_batch_size_argument_with_sequence_input(self): - - class DummySequence(data_utils.Sequence): - - def __getitem__(self, idx): - return np.zeros([10, 2]), np.ones([10, 4]) - - def __len__(self): - return 10 - - model = testing_utils.get_small_mlp( - num_hidden=10, num_classes=1, input_dim=10) - - model.compile( - 'adam', - 'binary_crossentropy', - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - with self.assertRaisesRegexp( - ValueError, 'The `batch_size` argument must not be specified'): - model.fit(DummySequence(), batch_size=2, epochs=2) - with self.assertRaisesRegexp( - ValueError, 'The `batch_size` argument must not be specified'): - model.evaluate(DummySequence(), batch_size=2) - - with self.assertRaisesRegexp( - ValueError, 'The `batch_size` argument must not be specified'): - model.predict(DummySequence(), batch_size=2) - - @keras_parameterized.run_with_all_model_types - @keras_parameterized.run_all_keras_modes(always_skip_v1=True) - def test_non_returning_sequence(self): - if not testing_utils.should_run_tf_function(): - self.skipTest('This case is only handled in the new execution path.') - - class DummySequence(data_utils.Sequence): - - def __getitem__(self, idx): - return - - def __len__(self): - return 10 - - model = testing_utils.get_small_mlp( - num_hidden=10, num_classes=1, input_dim=10) - - model.compile( - 'adam', - 'binary_crossentropy', - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - with self.assertRaisesRegexp(IndexError, 'Could not infer batch size'): - model.fit(DummySequence(), epochs=2) + model.compile( + optimizer, + loss={ + 'dense_2': 'categorical_crossentropy', + }, + metrics={ + 'dense_2': 'categorical_accuracy', + 'dense_1': metrics_module.CategoricalAccuracy(), + }, + run_eagerly=testing_utils.should_run_eagerly(), + experimental_run_tf_function=testing_utils.should_run_tf_function()) @keras_parameterized.run_with_all_model_types @keras_parameterized.run_all_keras_modes @@ -1972,100 +1547,11 @@ class LossWeightingTest(keras_parameterized.TestCase): x = np.random.random((10, 3)) y = np.random.random((10, 2)) - with self.assertRaisesRegexp( - ValueError, - r'Unknown entries in sample_weight dictionary: \[\'unknown\'\]. ' - r'Only expected following keys: \[\'output_1\', \'output_2\'\]'): - model.fit([x, x], [y, y], - epochs=1, - sample_weight={'unknown': 'something'}) + with self.assertRaises(ValueError): + model.fit([x, x], [y, y], epochs=1, sample_weight={'unknown': x}) - with self.assertRaisesRegexp( - ValueError, - r'Unknown entries in class_weight dictionary: \[\'unknown\'\]. ' - r'Only expected following keys: \[\'output_1\', \'output_2\'\]'): - model.fit([x, x], [y, y], epochs=1, class_weight={'unknown': 'something'}) - - @keras_parameterized.run_all_keras_modes - def test_class_weight_invalid_use_case(self): - num_classes = 5 - train_samples = 1000 - test_samples = 1000 - input_dim = 5 - timesteps = 3 - learning_rate = 0.001 - - with self.cached_session(): - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(num_classes), - input_shape=(timesteps, input_dim))) - model.add(keras.layers.Activation('softmax')) - optimizer = RMSPropOptimizer(learning_rate=learning_rate) - model.compile( - optimizer, - loss='binary_crossentropy', - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - (x_train, y_train), _ = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - # convert class vectors to binary class matrices - y_train = np_utils.to_categorical(y_train, num_classes) - class_weight = dict([(i, 1.) for i in range(num_classes)]) - - del class_weight[1] - with self.assertRaises(ValueError): - model.fit(x_train, y_train, - epochs=0, verbose=0, class_weight=class_weight) - - with self.assertRaises(ValueError): - model.compile( - optimizer, - loss='binary_crossentropy', - sample_weight_mode=[], - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - # Build multi-output model - x = keras.Input((3,)) - y1 = keras.layers.Dense(4, name='1')(x) - y2 = keras.layers.Dense(4, name='2')(x) - model = keras.models.Model(x, [y1, y2]) - model.compile( - optimizer, - loss='mse', - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - x_np = np.random.random((10, 3)) - y_np = np.random.random((10, 4)) - w_np = np.random.random((10,)) - # This will work - model.fit(x_np, [y_np, y_np], epochs=1, - sample_weight={'1': w_np}) - # These will not - with self.assertRaises(ValueError): - model.fit(x_np, [y_np, y_np], epochs=1, - sample_weight=[w_np]) - with self.assertRaises(TypeError): - model.fit(x_np, [y_np, y_np], epochs=1, - sample_weight=w_np) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((11,)) - model.fit(x_np, [y_np, y_np], epochs=1, - sample_weight={'1': bad_w_np}) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((10, 2)) - model.fit(x_np, [y_np, y_np], epochs=1, - sample_weight={'1': bad_w_np}) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((10, 2, 2)) - model.fit(x_np, [y_np, y_np], epochs=1, - sample_weight={'1': bad_w_np}) + with self.assertRaises(ValueError): + model.fit([x, x], [y, y], epochs=1, class_weight={'unknown': 1}) @keras_parameterized.run_all_keras_modes def test_default_sample_weight(self): @@ -2169,39 +1655,6 @@ class LossWeightingTest(keras_parameterized.TestCase): self.assertAllClose( (2+ .4 + .3 + 1) / 4, sess.run(model.total_loss, feed_dict=feeds)) - def test_prepare_sample_weights(self): - # pylint:disable=anomalous-backslash-in-string - input_layer = keras.layers.Input(shape=1, name='input_layer') - model = keras.Model(inputs=input_layer, outputs=[input_layer, input_layer]) - sample_weights = array_ops.constant([0, .4, 1, 1]) - temporal_weights = array_ops.constant([[1, 2], [3, 4], [5, 6]]) - - model.compile( - loss='mean_absolute_error', - optimizer='adam', - sample_weight_mode=None) - - with self.assertRaises(AssertionError): - model._prepare_sample_weights([sample_weights, sample_weights]) - - model.compile(loss='mean_absolute_error', optimizer='adam', - sample_weight_mode='temporal') - model._prepare_sample_weights([temporal_weights, temporal_weights]) - with self.assertRaisesRegexp(ValueError, 'Expected shape \[None, None\]'): - model._prepare_sample_weights([sample_weights, sample_weights]) - - with self.assertRaisesRegexp(ValueError, - 'sample weights must have same length as the ' - 'number of outputs'): - model._prepare_sample_weights([temporal_weights]) - - model.compile(loss='mean_absolute_error', optimizer='adam', - sample_weight_mode='samplewise') - model._prepare_sample_weights([sample_weights, sample_weights]) - with self.assertRaisesRegexp(ValueError, 'Expected shape \[None\]'): - model._prepare_sample_weights([temporal_weights, temporal_weights]) - # pylint:enable=anomalous-backslash-in-string - @keras_parameterized.run_all_keras_modes class MaskingTest(keras_parameterized.TestCase): @@ -2524,100 +1977,90 @@ class TestTrainingWithDataTensors(keras_parameterized.TestCase): validation_data=(inputs, targets), validation_steps=2) def test_training_and_eval_methods_on_symbolic_tensors_multi_io(self): - with ops.Graph().as_default(): - a = keras.layers.Input(shape=(3,), name='input_a') - b = keras.layers.Input(shape=(3,), name='input_b') + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') - dense = keras.layers.Dense(4, name='dense') - c = dense(a) - d = dense(b) - e = keras.layers.Dropout(0.5, name='dropout')(c) + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) - model = keras.models.Model([a, b], [d, e]) + model = keras.models.Model([a, b], [d, e]) - optimizer = 'rmsprop' - loss = 'mse' - loss_weights = [1., 0.5] - model.compile( - optimizer, - loss, - metrics=['mae', metrics_module.CategoricalAccuracy()], - loss_weights=loss_weights) + optimizer = 'rmsprop' + loss = 'mse' + loss_weights = [1., 0.5] + model.compile( + optimizer, + loss, + metrics=['mae', metrics_module.CategoricalAccuracy()], + loss_weights=loss_weights) - input_a_tf = keras.backend.zeros(shape=(10, 3)) - input_b_tf = keras.backend.zeros(shape=(10, 3)) + input_a_tf = array_ops.zeros(shape=(10, 3)) + input_b_tf = array_ops.zeros(shape=(10, 3)) - output_d_tf = keras.backend.zeros(shape=(10, 4)) - output_e_tf = keras.backend.zeros(shape=(10, 4)) + output_d_tf = array_ops.zeros(shape=(10, 4)) + output_e_tf = array_ops.zeros(shape=(10, 4)) - model.fit( - [input_a_tf, input_b_tf], [output_d_tf, output_e_tf], - epochs=1, - steps_per_epoch=2, - verbose=0) - with self.assertRaisesRegexp(ValueError, - 'should specify the `steps_per_epoch`'): - model.fit( - [input_a_tf, input_b_tf], [output_d_tf, output_e_tf], - epochs=1, - batch_size=5, - verbose=0) - model.train_on_batch([input_a_tf, input_b_tf], [output_d_tf, output_e_tf]) + model.fit([input_a_tf, input_b_tf], [output_d_tf, output_e_tf], + epochs=1, + steps_per_epoch=2, + verbose=0) + model.train_on_batch([input_a_tf, input_b_tf], [output_d_tf, output_e_tf]) - # Test with dictionary inputs - model.fit( - {'input_a': input_a_tf, - 'input_b': input_b_tf}, - {'dense': output_d_tf, - 'dropout': output_e_tf}, - epochs=1, - steps_per_epoch=2, - verbose=0) - model.fit( - {'input_a': input_a_tf, - 'input_b': input_b_tf}, - {'dense': output_d_tf, - 'dropout': output_e_tf}, - validation_data=({'input_a': input_a_tf, - 'input_b': input_b_tf}, - {'dense': output_d_tf, - 'dropout': output_e_tf}), - epochs=1, - steps_per_epoch=2, - validation_steps=2, - verbose=0) - model.train_on_batch( - {'input_a': input_a_tf, - 'input_b': input_b_tf}, - {'dense': output_d_tf, - 'dropout': output_e_tf}) + # Test with dictionary inputs + model.fit({ + 'input_a': input_a_tf, + 'input_b': input_b_tf + }, { + 'dense': output_d_tf, + 'dropout': output_e_tf + }, + epochs=1, + steps_per_epoch=2, + verbose=0) + model.fit({ + 'input_a': input_a_tf, + 'input_b': input_b_tf + }, { + 'dense': output_d_tf, + 'dropout': output_e_tf + }, + validation_data=({ + 'input_a': input_a_tf, + 'input_b': input_b_tf + }, { + 'dense': output_d_tf, + 'dropout': output_e_tf + }), + epochs=1, + steps_per_epoch=2, + validation_steps=2, + verbose=0) + model.train_on_batch({ + 'input_a': input_a_tf, + 'input_b': input_b_tf + }, { + 'dense': output_d_tf, + 'dropout': output_e_tf + }) - # Test with validation data - model.fit( - [input_a_tf, input_b_tf], [output_d_tf, output_e_tf], - validation_data=([input_a_tf, input_b_tf], - [output_d_tf, output_e_tf]), - epochs=1, - steps_per_epoch=2, - validation_steps=2, - verbose=0) - # Test with validation split - with self.assertRaisesRegexp(ValueError, - 'you cannot use `validation_split`'): - model.fit( - [input_a_tf, input_b_tf], [output_d_tf, output_e_tf], - epochs=2, - steps_per_epoch=2, - verbose=0, - validation_split=0.2, - validation_steps=2) - - # Test evaluation / prediction methods - model.evaluate([input_a_tf, input_b_tf], [output_d_tf, output_e_tf], - steps=2, verbose=0) - model.predict([input_a_tf, input_b_tf], steps=2) - model.test_on_batch([input_a_tf, input_b_tf], [output_d_tf, output_e_tf]) + # Test with validation data + model.fit([input_a_tf, input_b_tf], [output_d_tf, output_e_tf], + validation_data=([input_a_tf, + input_b_tf], [output_d_tf, output_e_tf]), + epochs=1, + steps_per_epoch=2, + validation_steps=2, + verbose=0) + # Test evaluation / prediction methods + model.evaluate([input_a_tf, input_b_tf], [output_d_tf, output_e_tf], + steps=2, + verbose=0) + model.predict([input_a_tf, input_b_tf], steps=2) + model.test_on_batch([input_a_tf, input_b_tf], [output_d_tf, output_e_tf]) + @tf_test_util.run_deprecated_v1 def test_model_with_input_feed_tensor(self): """We test building a model with a TF variable as input. @@ -2862,31 +2305,6 @@ class TestTrainingWithDataTensors(keras_parameterized.TestCase): out = model.test_on_batch(None, None) out = model.predict_on_batch(None) - # test fit - with self.assertRaises(ValueError): - out = model.fit(None, None, epochs=1, batch_size=10) - out = model.fit(None, None, epochs=1, steps_per_epoch=1) - - # test fit with validation data - with self.assertRaises(ValueError): - out = model.fit(None, None, epochs=1, - steps_per_epoch=None, - validation_steps=2) - out = model.fit(None, None, epochs=1, - steps_per_epoch=2, - validation_steps=2) - - # test evaluate - with self.assertRaises(ValueError): - out = model.evaluate(None, None, batch_size=10) - out = model.evaluate(None, None, steps=3) - - # test predict - with self.assertRaises(ValueError): - out = model.predict(None, batch_size=10) - out = model.predict(None, steps=3) - self.assertEqual(out.shape, (10 * 3, 4)) - # Test multi-output model with no external data at all. self.evaluate(variables_lib.variables_initializer([input_v])) a = keras.Input(tensor=input_v) @@ -2904,19 +2322,6 @@ class TestTrainingWithDataTensors(keras_parameterized.TestCase): out = model.test_on_batch(None, None) out = model.predict_on_batch(None) - # test fit - with self.assertRaises(ValueError): - out = model.fit(None, None, epochs=1, batch_size=10) - out = model.fit(None, None, epochs=1, steps_per_epoch=1) - - # test evaluate - with self.assertRaises(ValueError): - out = model.evaluate(None, None, batch_size=10) - out = model.evaluate(None, None, steps=3) - - # test predict - with self.assertRaises(ValueError): - out = model.predict(None, batch_size=10, verbose=1) out = model.predict(None, steps=3) self.assertEqual(len(out), 2) self.assertEqual(out[0].shape, (10 * 3, 4)) @@ -3074,15 +2479,13 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) - mse_metric = 'mse' if tf2.enabled() else 'mean_squared_error' + mse_metric = 'mse' if context.executing_eagerly() else 'mean_squared_error' reference_metric_names = [ 'loss', 'dense_loss', 'dropout_loss', 'dense_' + mse_metric, 'dense_binary_accuracy', 'dropout_' + mse_metric, 'dropout_binary_accuracy' ] - self.assertEqual(reference_metric_names, model.metrics_names) - # Verify that model metric names are not altered during training. input_a_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 3)) @@ -3181,63 +2584,6 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) - @keras_parameterized.run_all_keras_modes - def test_invalid_metrics(self): - num_classes = 5 - input_dim = 5 - - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=num_classes, input_dim=input_dim) - - with self.assertRaisesRegexp( - TypeError, 'Type of `metrics` argument not understood. ' - 'Expected a list or dictionary, found: '): - model.compile( - RMSPropOptimizer(learning_rate=0.001), - loss='categorical_crossentropy', - metrics=metrics_module.CategoricalAccuracy(), - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - inp = keras.layers.Input(shape=(1,)) - x = keras.layers.Dense(3, activation='relu')(inp) - out_1 = keras.layers.Dense(1, activation='sigmoid', name='output_1')(x) - out_2 = keras.layers.Dense(1, activation='sigmoid', name='output_2')(x) - model = keras.models.Model(inp, [out_1, out_2]) - with self.assertRaisesRegex( - ValueError, 'When passing a list of lists as `metrics`, ' - 'it should have one entry per model output. ' - 'The model has 2 outputs, but you passed metrics='): - model.compile('rmsprop', loss='mse', metrics=[['mse']]) - - with self.assertRaisesRegex( - ValueError, - r'Unknown entries in metrics dictionary: \[\'output_3\'\]. Only ' - r'expected following keys: \[\'output_1\', \'output_2\'\]'): - model.compile( - optimizer='rmsprop', - loss='mse', - metrics={ - 'output_1': 'mse', - 'output_3': 'mse', - }, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - - with self.assertRaisesRegex( - ValueError, - r'Unknown entries in metrics dictionary: \[\'output_3\'\]. Only ' - r'expected following keys: \[\'output_1\', \'output_2\'\]'): - model.compile( - optimizer='rmsprop', - loss='mse', - weighted_metrics={ - 'output_1': 'mse', - 'output_3': 'mse', - }, - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - @keras_parameterized.run_all_keras_modes def test_metrics_masking(self): np.random.seed(1337) @@ -3382,7 +2728,7 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): self.assertEqual(history.history['metric_1'][-1], 5) self.assertAlmostEqual(history.history['val_metric_1'][-1], 5, 0) - @keras_parameterized.run_all_keras_modes + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_model_metrics_list(self): class LayerWithAddMetric(keras.layers.Layer): @@ -3435,13 +2781,14 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) + model.fit(np.ones((10, 1)), np.ones((10, 1)), batch_size=10) + # Verify that the metrics added using `compile` and `add_metric` API are # included - self.assertEqual([m.name for m in model._compile_metrics], ['metric_4']) self.assertEqual([m.name for m in model.metrics], - ['metric_4', 'metric_2', 'metric_1', 'metric_3']) + ['loss', 'metric_4', 'metric_2', 'metric_1', 'metric_3']) - @keras_parameterized.run_all_keras_modes + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_model_metrics_list_in_call(self): class TestModel(keras.Model): @@ -3466,8 +2813,8 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): y = np.ones(shape=(10, 2)) model.fit(x, y, epochs=2, batch_size=5, validation_data=(x, y)) - self.assertEqual([m.name for m in model._compile_metrics], ['acc']) - self.assertEqual([m.name for m in model.metrics], ['acc', 'metric_1']) + self.assertEqual([m.name for m in model.metrics], + ['loss', 'acc', 'metric_1']) @keras_parameterized.run_all_keras_modes def test_multiple_add_metric_calls(self): @@ -3508,36 +2855,6 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): model.train_on_batch(x, y) model.test_on_batch(x, y) - @keras_parameterized.run_with_all_model_types - @keras_parameterized.run_all_keras_modes - def test_invalid_metric_tensor(self): - - class TestLayer(keras.layers.Layer): - - def build(self, input_shape): - self.built = True - - def call(self, inputs): - self.add_metric(math_ops.reduce_mean(inputs), name='metric_1') - return inputs + 1 - - layers = [TestLayer(input_shape=(1,))] - layers.append(keras.layers.Dense(2, kernel_initializer='ones')) - x = np.ones(shape=(10, 1)) - y = np.ones(shape=(10, 2)) - - with self.assertRaisesRegexp( - ValueError, - 'We do not support adding an aggregated metric result tensor that is ' - 'not the output of a `tf.keras.metrics.Metric` metric instance.'): - model = testing_utils.get_model_from_layers(layers, input_shape=(1,)) - model.compile( - loss='mse', - optimizer=RMSPropOptimizer(0.01), - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) - model.fit(x, y, epochs=2, batch_size=5, validation_data=(x, y)) - @keras_parameterized.run_all_keras_modes def test_duplicate_metric_name_in_add_metric(self): @@ -3677,7 +2994,7 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): 'one': [1.0, 1.0, 1.0] }) - @keras_parameterized.run_all_keras_modes + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_model_with_nested_compiled_model(self): class LayerWithAddMetric(keras.layers.Layer): @@ -3705,9 +3022,10 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): metrics=[metrics_module.Accuracy('acc')], run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) + inner_model.fit(np.ones((10, 1)), np.ones((10, 1)), batch_size=10) self.assertEqual([m.name for m in inner_model.metrics], - ['acc', 'mean', 'mean1']) + ['loss', 'acc', 'mean', 'mean1']) x = keras.layers.Input(shape=[1]) y = inner_model(x) @@ -3721,8 +3039,9 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): metrics=[metrics_module.Accuracy('acc2')], run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) + outer_model.fit(np.ones((10, 1)), np.ones((10, 1)), batch_size=10) self.assertEqual([m.name for m in outer_model.metrics], - ['acc2', 'mean', 'mean1', 'mean2']) + ['loss', 'acc2', 'mean', 'mean1', 'mean2']) class BareUpdateLayer(keras.layers.Layer): diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 67840a505e9..9261ab30889 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -49,8 +49,6 @@ from tensorflow.python.keras.engine import training_distributed from tensorflow.python.keras.engine import training_eager from tensorflow.python.keras.engine import training_generator from tensorflow.python.keras.engine import training_utils -from tensorflow.python.keras.engine import training_v2 -from tensorflow.python.keras.engine import training_v2_utils from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.saving.saved_model import model_serialization @@ -162,6 +160,8 @@ class Model(training_lib.Model): self._experimental_run_tf_function = ( ops.executing_eagerly_outside_functions()) + self._v1_compile_was_called = False + @trackable.no_automatic_dependency_tracking def _set_strategy(self, strategy): self._compile_time_distribution_strategy = strategy @@ -301,6 +301,7 @@ class Model(training_lib.Model): self._run_eagerly = kwargs.pop('run_eagerly', None) self._experimental_run_tf_function = kwargs.pop( 'experimental_run_tf_function', True) + self._v1_compile_was_called = True # Prepare Session arguments (legacy). kwargs.pop('cloning', None) # Legacy DistStrat argument, never used. @@ -561,14 +562,6 @@ class Model(training_lib.Model): 'original `Dataset` object instead of passing in ' '`iter(dataset)`.') - # Experiment training loop with default DS path. - if context.executing_eagerly() and self._experimental_run_tf_function: - if self._in_multi_worker_mode(): - return training_distributed.DistributionMultiWorkerTrainingLoop( - training_v2.Loop()) - else: - return training_v2.Loop() - # Case 1: distribution strategy. if self._distribution_strategy: if self._in_multi_worker_mode(): @@ -1031,18 +1024,6 @@ class Model(training_lib.Model): """ self._assert_compile_was_called() self._check_call_args('train_on_batch') - if self._experimental_run_tf_function: - outputs = training_v2_utils.train_on_batch( - self, x, y=y, sample_weight=sample_weight, - class_weight=class_weight, reset_metrics=reset_metrics, - standalone=True) - outputs = (outputs['total_loss'] + outputs['output_losses'] + - outputs['metrics']) - outputs = [ - training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access - if len(outputs) == 1: - outputs = outputs[0] - return outputs # If at this point we are in the replica context, then it is okay to execute # the Eager code path. The expected way to get here is to call `fit` that @@ -1069,8 +1050,7 @@ class Model(training_lib.Model): output_loss_metrics=self._output_loss_metrics) outputs = (output_dict['total_loss'] + output_dict['output_losses'] + output_dict['metrics']) - outputs = [ - training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access + outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access else: x = training_utils.ModelInputs(x).as_list() ins = x + list(y or []) + list(sample_weights or []) @@ -1129,17 +1109,6 @@ class Model(training_lib.Model): """ self._assert_compile_was_called() self._check_call_args('test_on_batch') - if self._experimental_run_tf_function: - outputs = training_v2_utils.test_on_batch( - self, x, y=y, sample_weight=sample_weight, - reset_metrics=reset_metrics, standalone=True) - outputs = (outputs['total_loss'] + outputs['output_losses'] + - outputs['metrics']) - outputs = [ - training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access - if len(outputs) == 1: - outputs = outputs[0] - return outputs if (self._distribution_strategy and distribution_strategy_context.in_cross_replica_context()): @@ -1160,8 +1129,7 @@ class Model(training_lib.Model): output_loss_metrics=self._output_loss_metrics) outputs = (output_dict['total_loss'] + output_dict['output_losses'] + output_dict['metrics']) - outputs = [ - training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access + outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access else: x = training_utils.ModelInputs(x).as_list() inputs = x + list(y or []) + list(sample_weights or []) @@ -1196,8 +1164,6 @@ class Model(training_lib.Model): expectations of the model. """ self._check_call_args('predict_on_batch') - if self._experimental_run_tf_function: - return training_v2_utils.predict_on_batch(self, x, standalone=True) if (self._distribution_strategy and distribution_strategy_context.in_cross_replica_context()): @@ -2601,6 +2567,7 @@ class Model(training_lib.Model): ValueError: If dict inputs are passed to a Sequential Model where the first layer isn't FeatureLayer. """ + self._set_save_spec(inputs) inputs = self._set_input_attrs(inputs) if outputs is None: @@ -2760,7 +2727,7 @@ class Model(training_lib.Model): training setting, return the epoch the training is supposed to continue at. Otherwise, return the `initial_epoch` the user passes in. """ - if hasattr(self, '_training_state'): + if self._training_state is not None: return self._training_state.maybe_load_initial_epoch_from_ckpt( initial_epoch, mode) return initial_epoch @@ -2781,7 +2748,7 @@ class Model(training_lib.Model): # then the optimizer is set. This is different from whether the # model is compiled # (i.e. whether the model is built and its inputs/outputs are set). - if not self.optimizer: + if not self._compile_was_called: raise RuntimeError('You must compile your model before ' 'training/testing. ' 'Use `model.compile(optimizer, loss)`.') @@ -2821,6 +2788,21 @@ class Model(training_lib.Model): def _trackable_saved_model_saver(self): return model_serialization.ModelSavedModelSaver(self) + def _get_compile_args(self): + self._assert_compile_was_called() + kwargs = { + 'loss': self.loss, + 'metrics': self._compile_metrics, + 'loss_weights': self.loss_weights, + 'sample_weight_mode': self.sample_weight_mode, + 'weighted_metrics': self._compile_weighted_metrics, + } + return kwargs + + @property + def _compile_was_called(self): + return self._v1_compile_was_called + class DistributedCallbackModel(Model): """Model that is used for callbacks with tf.distribute.Strategy.""" @@ -3189,3 +3171,8 @@ def _get_metrics_from_layers(layers): else: metrics.extend(layer.metrics) return metrics + + +def _non_none_constant_value(v): + constant_value = tensor_util.constant_value(v) + return constant_value if constant_value is not None else v diff --git a/tensorflow/python/keras/engine/training_v2.py b/tensorflow/python/keras/engine/training_v2.py deleted file mode 100644 index e994a8cd187..00000000000 --- a/tensorflow/python/keras/engine/training_v2.py +++ /dev/null @@ -1,778 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Training related logic for Keras model in TF 2.0 context. - -Note that all the code under this module is under active development, please DO -NOT use it unless you are really sure what you are doing. -""" - -# pylint: disable=protected-access -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools - -import numpy as np - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import distribution_strategy_context as ds_context -from tensorflow.python.framework import errors -from tensorflow.python.keras import callbacks as cbks -from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils -from tensorflow.python.keras.engine import data_adapter -from tensorflow.python.keras.engine import training_utils -from tensorflow.python.keras.engine import training_v2_utils -from tensorflow.python.keras.utils.mode_keys import ModeKeys -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.profiler import traceme -from tensorflow.python.util import nest -from tensorflow.python.util import tf_contextlib - - -# The list of DataAdapter that support validation_split, only numpy and data -# tensor support validation_split for now. -_ADAPTER_FOR_VALIDATION_SPLIT = [data_adapter.TensorLikeDataAdapter, - data_adapter.GenericArrayLikeDataAdapter] - -# The list of DataAdapter that support model._standardize_user_data. Currently -# keras.sequence/python generator will cause error when calling -# model._standardize_user_data, this should be updated in future cl, eg, the -# dataset/generate/sequence input will be peeked and processed by -# model._standardize_user_data() -_ADAPTER_FOR_STANDARDIZE_USER_DATA = [ - data_adapter.TensorLikeDataAdapter, - data_adapter.GenericArrayLikeDataAdapter, - data_adapter.CompositeTensorDataAdapter -] - - -def run_one_epoch(model, - iterator, - execution_function, - dataset_size=None, - batch_size=None, - strategy=None, - steps_per_epoch=None, - num_samples=None, - mode=ModeKeys.TRAIN, - training_context=None, - total_epochs=None): - """Run the execution function with the data from iterator. - - Given the dataset iterator and execution function, get the data from iterator - and call it with the execution function to get the result (metric/loss). - It will run for steps_per_epoch or until to the iterator is fully consumed. - - Args: - model: The keras model to run. - iterator: the dataset iterator to fetch the data. - execution_function: a tf.function that can be called with data. - dataset_size: the size of iterator, None when unknown. - batch_size: The size of the current batch. - strategy: the distribution strategy instance from the model. - steps_per_epoch: the number of steps to run for the epoch. - num_samples: the number of samples for the whole epoch if known. This can be - used to calculate the final partial batch, and scale the loss. - mode: the mode for the current epoch. - training_context: the context that contains callbacks and progress bar. - total_epochs: the total number of epochs that will be run. - Used when throw error when the iterator unexpectedly - reaches its end. - Returns: - The loss and metric value from the model. - """ - # Only use the sample to count if there is a partial batch at the end. - use_steps = num_samples is None - - if mode == ModeKeys.PREDICT: - aggregator = training_utils.OutputsAggregator( - use_steps=use_steps, - steps=steps_per_epoch, - num_samples=num_samples, - batch_size=batch_size) - else: - aggregator = training_utils.MetricsAggregator( - use_steps=use_steps, steps=steps_per_epoch, num_samples=num_samples) - callbacks = training_context.callbacks - progbar = training_context.progbar - - if callbacks.model.stop_training: - return - - target_steps = steps_per_epoch or np.inf - step = 0 - - while step < target_steps: - if use_steps: - current_batch_size = 1 - elif step < target_steps - 1: - current_batch_size = batch_size - else: - current_batch_size = num_samples - step * batch_size - with training_context.on_batch( - step=step, mode=mode, size=current_batch_size) as batch_logs: - try: - batch_outs = execution_function(iterator) - except (StopIteration, errors.OutOfRangeError): - # TODO(kaftan): File bug about tf function and errors.OutOfRangeError? - # Are there any other C++ errors tf function should recapture? - # The only acceptable case here is that the input has a unknown - # length, and configured to fully consume it. - if (dataset_size is None - and steps_per_epoch is None - and step > 0): - # The input passed by the user ran out of batches. - # Now we know the cardinality of the input(dataset or generator). - steps_per_epoch = step - aggregator.steps = steps_per_epoch - if mode == ModeKeys.TRAIN: - progbar.params['steps'] = steps_per_epoch - progbar.progbar.target = steps_per_epoch - else: - callbacks.model.stop_training = True - logging.warning( - 'Your input ran out of data; interrupting training. ' - 'Make sure that your dataset or generator can generate at ' - 'least `steps_per_epoch * epochs` batches (in this case, ' - '{} batches). You may need to use the repeat() function ' - 'when building your dataset.'.format( - total_epochs * steps_per_epoch)) - # In either case, break out the loop for training batch. - # Also note the training_context that data inputs are exhausted, so all - # the post batch hooks can be skipped. - batch_logs['data_exhausted'] = True - break - - if mode != ModeKeys.PREDICT: - data_batch_size = batch_outs['batch_size'] - batch_outs = (batch_outs['total_loss'] + batch_outs['output_losses'] - + batch_outs['metrics']) - if current_batch_size != data_batch_size: - batch_logs['size'] = data_batch_size - current_batch_size = data_batch_size - else: - batch_outs = training_v2_utils._aggregate_predict_results( - strategy, batch_outs, model) - - if step == 0: - aggregator.create(batch_outs) - - if use_steps: - aggregator.aggregate(batch_outs) - else: - aggregator.aggregate( - batch_outs, - batch_start=step * batch_size, - batch_end=step * batch_size + current_batch_size) - cbks.make_logs(model, batch_logs, batch_outs, mode) - step += 1 - - if callbacks.model.stop_training: - break - - # End of an epoch. - aggregator.finalize() - return aggregator.results - - -class Loop(training_utils.TrainingLoop): - """The training loop for the TF 2.0. - - This class has some existing assumption for runtime, eg eager by default, - have distribution strategy, etc. - """ - - def fit( - self, model, x=None, y=None, batch_size=None, epochs=1, verbose=1, - callbacks=None, validation_split=0., validation_data=None, shuffle=True, - class_weight=None, sample_weight=None, initial_epoch=0, - steps_per_epoch=None, validation_steps=None, validation_freq=1, - max_queue_size=10, workers=1, use_multiprocessing=False, **kwargs): - batch_size = model._validate_or_infer_batch_size( - batch_size, steps_per_epoch, x) - - strategy = model.distribute_strategy - batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size( - strategy, - x, - batch_size, - steps_per_epoch, - ModeKeys.TRAIN, - validation_split=validation_split) - dist_utils.validate_callbacks(input_callbacks=callbacks, - optimizer=model.optimizer) - # Enter tf.distribute.Strategy scope. - with strategy.scope(): - training_data_adapter, validation_adapter = _process_training_inputs( - model, - x, - y, - batch_size=batch_size, - epochs=epochs, - sample_weights=sample_weight, - class_weights=class_weight, - validation_split=validation_split, - steps_per_epoch=steps_per_epoch, - shuffle=shuffle, - validation_data=validation_data, - validation_steps=validation_steps, - distribution_strategy=strategy, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing) - - total_samples = _get_total_number_of_samples(training_data_adapter) - use_sample = total_samples is not None - do_validation = (validation_adapter is not None) - - recreate_training_iterator = ( - training_data_adapter.should_recreate_iterator()) - if not steps_per_epoch: - # TODO(b/139762795): Add step inference for when steps is None to - # prevent end of sequence warning message. - steps_per_epoch = training_data_adapter.get_size() - - # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch)) - training_context = TrainingContext() - - training_dataset = training_data_adapter.get_dataset() - # Raise an error if steps_per_epoch isn't specified but the dataset - # is infinite. - # TODO(scottzhu): This check should probably happen in the adapter - inferred_steps = training_utils.infer_steps_for_dataset( - model, - training_dataset, - steps_per_epoch, - steps_name='steps_per_epoch', - epochs=0) - - steps_per_epoch = ( - inferred_steps if steps_per_epoch is None else steps_per_epoch) - - training_dataset = strategy.experimental_distribute_dataset( - training_dataset) - - training_function = training_v2_utils._get_or_make_execution_function( - model, ModeKeys.TRAIN) - - training_data_iter = None - if do_validation: - validation_dataset = validation_adapter.get_dataset() - if not validation_steps: - # Raise an error if validation_steps isn't specified but the - # validation dataset is infinite. - validation_steps = ( - validation_adapter.get_size() or - training_utils.infer_steps_for_dataset( - model, - validation_dataset, - validation_steps, - steps_name='validation_steps')) - eval_function = training_v2_utils._get_or_make_execution_function( - model, ModeKeys.TEST) - eval_data_iter = None - validation_dataset = strategy.experimental_distribute_dataset( - validation_dataset) - val_total_samples = _get_total_number_of_samples(validation_adapter) - else: - val_total_samples = None - - if verbose and (total_samples or steps_per_epoch): - _print_train_info(total_samples, steps_per_epoch, val_total_samples, - validation_steps) - - training_callbacks = cbks.configure_callbacks( - callbacks, - model, - do_validation=do_validation, - batch_size=batch_size, - epochs=epochs, - steps_per_epoch=steps_per_epoch, - samples=total_samples or steps_per_epoch, - count_mode='samples' if use_sample else 'steps', - verbose=0, # Handle ProgBarLogger separately in this loop. - mode=ModeKeys.TRAIN) - - with training_context.on_start(model, training_callbacks, use_sample, - verbose, ModeKeys.TRAIN): - - initial_epoch = model._maybe_load_initial_epoch_from_ckpt( - initial_epoch, ModeKeys.TRAIN) - - for epoch in range(initial_epoch, epochs): - if training_context.callbacks.model.stop_training: - break - - # Training - with training_context.on_epoch(epoch, ModeKeys.TRAIN) as epoch_logs: - model.reset_metrics() - if training_data_iter is None or recreate_training_iterator: - if training_data_iter is not None and ds_context.has_strategy(): - # TODO(kaftan): remove this when MultiDeviceIterator is a - ## compositetensor (unless this is more efficient) - training_data_iter._initializer # pylint: disable=pointless-statement - else: - training_data_iter = iter(training_dataset) - - training_result = run_one_epoch( - model, - training_data_iter, - training_function, - dataset_size=training_data_adapter.get_size(), - batch_size=training_data_adapter.batch_size(), - strategy=strategy, - steps_per_epoch=steps_per_epoch, - num_samples=total_samples, - mode=ModeKeys.TRAIN, - training_context=training_context, - total_epochs=epochs) - cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN) - - # In the case of steps_per_epoch = None, the final cardinality will - # be determined when the inputs are fully consumed (eg dataset or - # generator). Update the steps_per_epoch to the new value. - if (steps_per_epoch is None - and training_context.progbar.progbar.target is not None): - steps_per_epoch = training_context.progbar.progbar.target - - # Evaluation - if (do_validation and - training_utils.should_run_validation(validation_freq, epoch) and - not training_callbacks.model.stop_training): - if eval_data_iter is not None and ds_context.has_strategy(): - # TODO(kaftan): remove this when MultiDeviceIterator is a - ## compositetensor (unless this is more efficient) - eval_data_iter._initializer # pylint: disable=pointless-statement - else: - eval_data_iter = iter(validation_dataset) - - validation_callbacks = cbks.configure_callbacks( - training_callbacks, - model, - batch_size=batch_size, - epochs=1, - steps_per_epoch=validation_steps, - samples=val_total_samples or validation_steps, - count_mode='samples' if use_sample else 'steps', - verbose=0, # Handle ProgBarLogger separately in this loop. - mode=ModeKeys.TEST) - - eval_context = TrainingContext() - with eval_context.on_start( - model, - validation_callbacks, - use_sample, - verbose=0, - mode=ModeKeys.TEST): - with eval_context.on_epoch(epoch, ModeKeys.TEST): - model.reset_metrics() - eval_result = run_one_epoch( - model, - eval_data_iter, - eval_function, - dataset_size=validation_adapter.get_size(), - batch_size=validation_adapter.batch_size(), - strategy=strategy, - steps_per_epoch=validation_steps, - num_samples=val_total_samples, - mode=ModeKeys.TEST, - training_context=eval_context, - total_epochs=1) - cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST, - prefix='val_') - - return model.history - - def _model_iteration( - self, model, mode, x=None, y=None, batch_size=None, verbose=1, - sample_weight=None, steps=None, callbacks=None, max_queue_size=10, - workers=1, use_multiprocessing=False, **kwargs): - - batch_size = model._validate_or_infer_batch_size( - batch_size, steps, x) - strategy = model.distribute_strategy - batch_size, steps = dist_utils.process_batch_and_step_size( - strategy, x, batch_size, steps, mode) - dist_utils.validate_callbacks(input_callbacks=callbacks, - optimizer=model.optimizer) - # Enter tf.distribute.Strategy scope. - with strategy.scope(): - adapter = _process_inputs( - model, - mode, - x, - y, - batch_size=batch_size, - sample_weights=sample_weight, - steps=steps, - distribution_strategy=strategy, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing) - total_samples = _get_total_number_of_samples(adapter) - use_sample = total_samples is not None - dataset = adapter.get_dataset() - - if not steps: - # Raise an error if `steps` isn't specified but the dataset - # is infinite. - steps = adapter.get_size() or training_utils.infer_steps_for_dataset( - model, dataset, steps, steps_name='steps') - - # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch)) - training_context = TrainingContext() - if training_v2_utils._should_add_batch_index_to_element(strategy, mode): - dataset = training_v2_utils._add_batch_index_to_element(dataset) - dataset = strategy.experimental_distribute_dataset(dataset) - - execution_function = training_v2_utils._get_or_make_execution_function( - model, mode) - - data_iterator = iter(dataset) - - callbacks = cbks.configure_callbacks( - callbacks, - model, - do_validation=False, - batch_size=batch_size, - epochs=1, - steps_per_epoch=steps, - samples=total_samples, - count_mode='samples' if use_sample else 'steps', - verbose=0, # Handle ProgBarLogger separately in this loop. - mode=mode) - - with training_context.on_start( - model, callbacks, use_sample, verbose, mode): - with training_context.on_epoch(0, mode) as epoch_logs: - model.reset_metrics() - result = run_one_epoch( - model, - data_iterator, - execution_function, - dataset_size=adapter.get_size(), - batch_size=adapter.batch_size(), - strategy=strategy, - steps_per_epoch=steps, - num_samples=total_samples, - mode=mode, - training_context=training_context, - total_epochs=1) - cbks.make_logs(model, epoch_logs, result, mode) - - if len(result) == 1: - result = result[0] - return result - - def evaluate( - self, model, x=None, y=None, batch_size=None, verbose=1, - sample_weight=None, steps=None, callbacks=None, max_queue_size=10, - workers=1, use_multiprocessing=False, **kwargs): - return self._model_iteration( - model, ModeKeys.TEST, x=x, y=y, batch_size=batch_size, verbose=verbose, - sample_weight=sample_weight, steps=steps, callbacks=callbacks, - max_queue_size=max_queue_size, workers=workers, - use_multiprocessing=use_multiprocessing, **kwargs) - - def predict(self, model, x, batch_size=None, verbose=0, steps=None, - callbacks=None, max_queue_size=10, workers=1, - use_multiprocessing=False, **kwargs): - return self._model_iteration( - model, ModeKeys.PREDICT, x=x, batch_size=batch_size, verbose=verbose, - steps=steps, callbacks=callbacks, max_queue_size=max_queue_size, - workers=workers, use_multiprocessing=use_multiprocessing, **kwargs) - - -def _process_training_inputs(model, - x, - y, - batch_size=None, - epochs=1, - sample_weights=None, - class_weights=None, - steps_per_epoch=None, - validation_split=0., - validation_data=None, - validation_steps=None, - shuffle=True, - distribution_strategy=None, - max_queue_size=10, - workers=1, - use_multiprocessing=False): - """Process the data input for fit() with respect to validation_split.""" - if validation_split and 0. < validation_split < 1. and validation_data: - raise ValueError('validation_data and validation_split cannot be used ' - 'at same time.') - - adapter_cls = data_adapter.select_data_adapter(x, y) - - # Handle validation_split, we want to split the data and get the training - # section before we give it to data adapter. - if validation_split and 0. < validation_split < 1.: - if adapter_cls not in _ADAPTER_FOR_VALIDATION_SPLIT: - raise ValueError( - '`validation_split` argument is not supported when ' - 'data adapter is {}. Received: x={}, validation_split={}'.format( - adapter_cls, x, validation_split)) - # Retrieve the training section from x and y, and then construct dataset - # from it. - x, y, sample_weights = model._standardize_user_data( - x, - y, - sample_weight=sample_weights, - class_weight=class_weights, - batch_size=batch_size, - check_steps=False, - steps=steps_per_epoch) - (x, y, sample_weights, - val_x, val_y, - val_sample_weights) = training_utils.split_training_and_validation_data( - x, y, sample_weights, validation_split) - - sample_weight_modes = [ - e.sample_weight_mode for e in model._training_endpoints - ] - train_adapter = adapter_cls( - x, - y, - batch_size=batch_size, - steps=steps_per_epoch, - epochs=epochs, - sample_weights=sample_weights, - sample_weight_modes=sample_weight_modes, - shuffle=shuffle, - distribution_strategy=distribution_strategy) - - val_adapter = adapter_cls( - val_x, - val_y, - steps=validation_steps, - sample_weights=val_sample_weights, - sample_weight_modes=sample_weight_modes, - batch_size=batch_size, - distribution_strategy=distribution_strategy) - else: - train_adapter = _process_inputs( - model, - ModeKeys.TRAIN, - x, - y, - sample_weights=sample_weights, - batch_size=batch_size, - steps=steps_per_epoch, - epochs=epochs, - class_weights=class_weights, - shuffle=shuffle, - distribution_strategy=distribution_strategy, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing) - val_adapter = None - if validation_data: - (val_x, val_y, - val_sample_weights) = training_utils.unpack_validation_data( - validation_data, raise_if_ambiguous=False) - # For eval data, we use a representative batch size of the - # training data if batch_size was unknown. - # This is useful for generator/sequence training data input with numpy - # validation data input. - if not batch_size: - batch_size = train_adapter.representative_batch_size() - val_adapter = _process_inputs( - model, - ModeKeys.TEST, - val_x, - val_y, - steps=validation_steps, - sample_weights=val_sample_weights, - batch_size=batch_size, - class_weights=class_weights, - distribution_strategy=distribution_strategy) - elif validation_steps: - raise ValueError('`validation_steps` should not be specified if ' - '`validation_data` is None.') - return train_adapter, val_adapter - - -def _process_inputs(model, - mode, - x, - y, - batch_size=None, - epochs=1, - sample_weights=None, - class_weights=None, - shuffle=False, - steps=None, - distribution_strategy=None, - max_queue_size=10, - workers=1, - use_multiprocessing=False): - """Process the inputs for fit/eval/predict().""" - adapter_cls = data_adapter.select_data_adapter(x, y) - standardize = functools.partial( - model._standardize_user_data, - class_weight=class_weights, - batch_size=batch_size, - check_steps=False, - steps=steps) - if adapter_cls in _ADAPTER_FOR_STANDARDIZE_USER_DATA: - standardize_function = None - x, y, sample_weights = standardize( - x, y, sample_weight=sample_weights) - elif adapter_cls is data_adapter.ListsOfScalarsDataAdapter: - standardize_function = standardize - else: - def standardize_function(dataset): - """Data adapters can standardize when appropriate.""" - # First we call _standardize_user_data with the dataset since that has - # enough structure to build the model. - if not model._is_compiled: - # We don't actually care about the values of these attributes, but they - # are only created in compile and are accessed in _standardize_user_data - model._training_endpoints = getattr(model, '_training_endpoints', []) - model.sample_weight_mode = getattr(model, 'sample_weight_mode', None) - - standardize(dataset, extract_tensors_from_dataset=False) - - # Then we map using only the tensor standardization portion. - def map_fn(x, y=None, sample_weights=None): - """Tensor manipulation portion of standardization for Dataset.map.""" - if (y is None and sample_weights is None): - # namedtuples are forbidden because it is ambiguous if they should be - # unpacked. If y or sample_weights is present then `x` was not the - # top level structure, and the correct behavior is unambiguous. - data_adapter.assert_not_namedtuple(x) - - standardized = model._standardize_tensors( - x, y, sample_weights, - run_eagerly=False, - dict_inputs=isinstance(x, dict), - is_dataset=False, - class_weight=class_weights, - batch_size=None) - x, y, sample_weights = nest._list_to_tuple(standardized) - if y is None: - return (x,) - if sample_weights is None: - return x, y - return x, y, sample_weights - return dataset.map(map_fn, num_parallel_calls=dataset_ops.AUTOTUNE) - - if mode == ModeKeys.PREDICT: - sample_weight_modes = None - else: - sample_weight_modes = [ - e.sample_weight_mode for e in model._training_endpoints - ] or model.sample_weight_mode - - adapter = adapter_cls( - x, - y, - standardize_function=standardize_function, - batch_size=batch_size, - epochs=epochs, - steps=steps, - sample_weights=sample_weights, - sample_weight_modes=sample_weight_modes, - shuffle=shuffle, - distribution_strategy=distribution_strategy, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing) - - return adapter - - -def _get_total_number_of_samples(adapter): - if not adapter.get_size() or not adapter.batch_size(): - return None - total_sample = adapter.get_size() * adapter.batch_size() - if adapter.has_partial_batch(): - total_sample -= (adapter.batch_size() - adapter.partial_batch_size()) - return total_sample - - -def _print_train_info(total_samples, steps, val_total_samples, val_steps): - increment = 'samples' if total_samples else 'steps' - conjunction = 'on' if total_samples else 'for' - msg = 'Train {} {} {}'.format(conjunction, total_samples or steps, increment) - if val_total_samples or val_steps: - increment = 'samples' if val_total_samples else 'steps' - conjunction = 'on' if val_total_samples else 'for' - msg += ', validate {} {} {}'.format(conjunction, val_total_samples or - val_steps, increment) - print(msg) - - -class TrainingContext(object): - """Utility object that wrap around callbacks and progress bars.""" - - @tf_contextlib.contextmanager - def on_start(self, model, callbacks=None, use_samples=False, verbose=0, - mode=ModeKeys.TRAIN): - """Provide a scope for the whole training process.""" - # TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready. - progbar = training_utils.get_progbar( - model, 'samples' if use_samples else 'steps') - progbar.params = callbacks.params - progbar.params['verbose'] = verbose - callbacks.model.stop_training = False - callbacks._call_begin_hook(mode) - progbar.on_train_begin() - - # Cache those two instance so that it can be used in other functions. - self.callbacks = callbacks - self.progbar = progbar - - try: - yield - model._successful_loop_finish = True - finally: - # End of all epochs - self.callbacks._call_end_hook(mode) - - @tf_contextlib.contextmanager - def on_epoch(self, epoch=0, mode=ModeKeys.TRAIN): - """Provide a scope for running one epoch.""" - epoch_logs = {} - if mode == ModeKeys.TRAIN: - self.callbacks.on_epoch_begin(epoch, epoch_logs) - self.progbar.on_epoch_begin(epoch, epoch_logs) - try: - yield epoch_logs - finally: - if mode == ModeKeys.TRAIN: - # Epochs only apply to `fit`. - self.callbacks.on_epoch_end(epoch, epoch_logs) - self.progbar.on_epoch_end(epoch, epoch_logs) - - @tf_contextlib.contextmanager - def on_batch(self, step=0, mode=ModeKeys.TRAIN, size=1): - """Provide a scope for running one batch.""" - with traceme.TraceMe( - 'TraceContext', graph_type=mode, step_num=step, batch_size=size): - batch_logs = {'batch': step, 'size': size} - self.callbacks._call_batch_hook( - mode, 'begin', step, batch_logs) - self.progbar.on_batch_begin(step, batch_logs) - try: - yield batch_logs - finally: - if not batch_logs.pop('data_exhausted', False): - self.callbacks._call_batch_hook( - mode, 'end', step, batch_logs) - self.progbar.on_batch_end(step, batch_logs) diff --git a/tensorflow/python/keras/engine/training_v2_utils.py b/tensorflow/python/keras/engine/training_v2_utils.py deleted file mode 100644 index b7eb1b123b6..00000000000 --- a/tensorflow/python/keras/engine/training_v2_utils.py +++ /dev/null @@ -1,556 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Training related logic for Keras model in TF 2.0 context. - -Note that all the code under this module is under active development, please DO -NOT use it unless you are really sure what you are doing. -""" - -# pylint: disable=protected-access -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import functools - -import numpy as np - -from tensorflow.python.distribute import distribution_strategy_context -from tensorflow.python.eager import def_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_util -from tensorflow.python.framework.ops import composite_tensor -from tensorflow.python.keras import backend -from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils -from tensorflow.python.keras.engine import training_eager -from tensorflow.python.keras.engine import training_utils -from tensorflow.python.keras.utils.mode_keys import ModeKeys -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.util import nest - - -def _get_or_make_function(model, mode, key_fn, make_fn): - """Helper function for managing cached execution functions.""" - model._init_distributed_function_cache_if_not_compiled() - key = key_fn(mode) - - function = dist_utils.get_distributed_function(model, key) - if function: - return function - - function = make_fn(model, mode) - dist_utils.set_distributed_function(model, key, function) - return function - - -def _get_or_make_execution_function(model, mode): - """Makes or reuses function to run one step of distributed model execution.""" - return _get_or_make_function( - model, mode, - # Use a key with 'v2' to distinguish from fall-back execution functions. - key_fn=lambda m: (m, 'v2'), - make_fn=_make_execution_function) - - -def _make_execution_function(model, mode): - """Creates a function to run one step of distributed model execution.""" - per_replica_function = _make_replica_execution_function(model, mode) - - def distributed_function(input_iterator): - """A single step of the distributed execution across replicas.""" - # Call `Model.{train,test,predict}_on_batch` on every replica passing - # PerReplicas as arguments. On every replica inside this call, each - # PerReplica object will return the value for that replica. The outputs - # are PerReplicas too. - strategy = distribution_strategy_context.get_strategy() - args = _prepare_feed_values(model, input_iterator, mode, strategy) - outputs = strategy.experimental_run_v2( - per_replica_function, args=args) - # Out of PerReplica outputs reduce or pick values to return. - all_outputs = dist_utils.unwrap_output_dict( - strategy, outputs, mode) - return all_outputs - - if not model.run_eagerly: - distributed_function = def_function.function( - distributed_function, autograph=False) - - def execution_function(input_fn): - # `numpy` translates Tensors to values in Eager mode. - return nest.map_structure(_non_none_constant_value, - distributed_function(input_fn)) - - return execution_function - - -def _get_or_make_on_batch_function(model, mode): - """Makes or reuses function to run one step of distributed model execution.""" - return _get_or_make_function( - model, mode, - # Use a key with 'v2' to distinguish from fall-back execution functions. - key_fn=lambda m: (m, 'v2_on_batch'), - make_fn=_make_on_batch_function) - - -def _make_on_batch_function(model, mode): - """Creates a function of Model.*_on_batch methods.""" - if mode == ModeKeys.TRAIN: - func = training_eager.train_on_batch - elif mode == ModeKeys.TEST: - func = training_eager.test_on_batch - else: - func = model - - if not model.run_eagerly: - # Pass `experimental_relax_shapes` to avoid retracing for dynamic batch - # size, variable length sequences, etc. - func = def_function.function(func, experimental_relax_shapes=True) - - return func - - -def _non_none_constant_value(v): - constant_value = tensor_util.constant_value(v) - return constant_value if constant_value is not None else v - - -def _prepare_feed_values(model, inputs, mode, strategy): - """Prepare feed values to the model execution function. - - Arguments: - model: Model to prepare feed values for. - inputs: An iterator of model inputs, targets, and sample_weights. - model inputs may be lists, single values, or dicts mapping input feed - names to values. - mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. - strategy: The current distribution strategy for the model. - - Returns: - Feed values for the model in the given mode. This is a tuple of - the structure (inputs, targets, sample_weights), where each of - (tuple, targets, sample_weights) may be a python list. Single values - for inputs will always be wrapped in lists. - """ - # For predict, we need to extract the manually added batch_index first. - with_batch_index = _should_add_batch_index_to_element(strategy, mode) - - inputs, targets, sample_weights, batch_index = _get_input_from_iterator( - inputs, with_batch_index) - - # When the inputs are dict, then we want to flatten it in the same order as - # the input layers, such that the data are fed into the input layers in the - # correct order. - if isinstance(inputs, dict): - inputs = [inputs[key] for key in model._feed_input_names] - else: - inputs = training_utils.ModelInputs(inputs).as_list() - - if mode == ModeKeys.PREDICT: - sample_weights = [] - targets = [] - - ins = [inputs, targets, sample_weights] - if batch_index is not None: - ins.append(batch_index) - return tuple(ins) - - -def _get_input_from_iterator(iterator, with_batch_index=False): - """Get elements from the iterator and verify the input shape and type.""" - next_element = next(iterator) - if with_batch_index: - batch_index, next_element = next_element - else: - batch_index = None - - if (tensor_util.is_tensor(next_element) or - isinstance(next_element, (dict, composite_tensor.CompositeTensor))): - next_element = [next_element] - if len(next_element) == 1: - x, = next_element - y = None - sample_weights = None - elif len(next_element) == 2: - x, y = next_element - sample_weights = None - else: - x, y, sample_weights = next_element - - # Validate that all the elements in x and y are of the same type and shape. - dist_utils.validate_distributed_dataset_inputs( - distribution_strategy_context.get_strategy(), x, y, sample_weights) - return x, y, sample_weights, batch_index - - -def _make_replica_execution_function(model, mode): - """A single step of the distributed execution on a replica.""" - if mode == ModeKeys.TRAIN: - func = functools.partial(train_on_batch, model) - elif mode == ModeKeys.TEST: - func = functools.partial(test_on_batch, model) - else: - def _predict_on_batch(x, y=None, sample_weights=None, batch_index=None): - del y, sample_weights - # Note that the x and batch_index is already per-replica value. - result = predict_on_batch(model, x) - if batch_index is None: - return result - else: - return batch_index, result - - func = _predict_on_batch - - if mode != ModeKeys.PREDICT: - # `reset_metrics` is set to False to maintain stateful metrics across - # batch-level calls. - func = functools.partial(func, reset_metrics=False) - - return func - - -def _aggregate_predict_results(strategy, batch_outs, model): - """Aggregate the prediction result from each replica.""" - num_replicas = strategy.num_replicas_in_sync - num_outputs = len(model.outputs) - - if not isinstance(batch_outs, list): - batch_outs = [batch_outs] - - with_batch_index = _should_add_batch_index_to_element( - strategy, ModeKeys.PREDICT) - - # batch_outs is in following structure: - # [ - # replica_1_batch_index, replica_2_batch_index, ...., replica_x_batch_index, - # replica_1_output_1, replica_2_output_1, ...., replica_x_output_1, - # ...... - # replica_1_output_y, replica_2_output_y, ...., replica_x_output_y, - # ] - # The replica_x_batch_index is optional and depended on teh strategy type. - if with_batch_index: - batch_index, batch_outs = (batch_outs[:num_replicas], - batch_outs[num_replicas:]) - batch_index = dist_utils.concat_along_batch_dimension(batch_index) - # Reorder the batch_index for it to do proper gather. Eg, if the original - # index is [0, 2, 4, 6, 1, 3, 5, 7], then the index for gather should be - # [0, 4, 1, 5, 2, 6, 3, 7]. - batch_index = np.argsort(batch_index) - # Only need to gather if the batch index is not sorted. - need_batch_index_gather = np.any(np.diff(batch_index) < 0) - else: - need_batch_index_gather = False - - total_batch_outs = [] - for i in range(num_outputs): - nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] - per_output_result = dist_utils.concat_along_batch_dimension( - nest.flatten(nested_outs)) - - if need_batch_index_gather: - if _get_batch_size(per_output_result).numpy() == len(batch_index): - # Skip the gather if the output has a different batch size than the - # batch_index. There will be some error handling in upper layer. - per_output_result = _gather_result_by_index(per_output_result, - batch_index) - total_batch_outs.append(per_output_result) - return total_batch_outs - - -def _gather_result_by_index(input_tensor, batch_index): - """Handle the data element gather for different type of tensor.""" - if isinstance(input_tensor, sparse_tensor.SparseTensor): - # For sparse tensor, both the index and value component should be gathered. - return sparse_tensor.SparseTensor( - indices=array_ops.gather_v2(input_tensor.indices, batch_index), - values=array_ops.gather_v2(input_tensor.values, batch_index), - dense_shape=input_tensor.dense_shape - ) - # For both ragged tensor or eager tensor or np array, tf.gather should do the - # correct thing. - elif isinstance(input_tensor, ragged_tensor.RaggedTensor): - return array_ops.gather_v2(input_tensor, batch_index) - elif isinstance(input_tensor, (ops.EagerTensor, np.ndarray)): - return array_ops.gather_v2(input_tensor, batch_index).numpy() - else: - raise ValueError('Unexpected type {} encountered when gathering ' - 'batch slices.'.format(input_tensor)) - - -def _get_batch_size(inputs): - first_inputs = nest.flatten(inputs)[0] - if isinstance(first_inputs, ragged_tensor.RaggedTensor): - return first_inputs.bounding_shape()[0] - else: - return array_ops.shape(first_inputs)[0] - - -def _add_batch_index_to_element(dataset): - """Adding a new batch index field to the every element in the batch. - - This is need in the model.predict() when running with multi-worker - distribution strategy. When sharding/distributing a dataset, the continuity of - the sharded dataset can't be easily ensured without performance sacrifice. It - is fine to train and eval with the reordered data, but not for prediction. To - solve this issue, Keras will add a batch index to each of the element in the - dataset, which will then pass to pre-replica execution function. The real - execution function will remove it before feeding the input to the model, and - pre-replica function will then zip the index with the result. Finally Keras - will sort the batch result based on the added batch-index field, remove it and - return the sorted result. - - Note that we didn't add single index to the per-replica batch, but to each of - the element in the batch, since we can't ensure the data in pre-replica is - continuous. Eg: model with 2 replica and predict with 4 elements per batch - like [1, 2, 3, 4], it is possible to shard as [1, 2], [3, 4], - or [1, 3], [2, 4]. - - Args: - dataset: a dataset that is created by any of the data_adapter, with the - element structure as (x, y, sample_weights). - - Returns: - a new dataset, with the element shape as - (batch_index, (x, y, sample_weights)). - """ - return dataset.map(lambda *inp: (math_ops.range(_get_batch_size(inp)), inp)) - - -def _should_add_batch_index_to_element(strategy, mode): - """Whether or not the batch index should be added to the input dataset. - - See docstring of _add_batch_index_to_element() for more details. So far the - batch index is only need when using TPUStrategy with a multi-worker setting. - We will try to avoid adding batch index for other cases since it has the - performance implication. - - Args: - strategy: the current distribution strategy for the model. - mode: the current mode (Training/Eval/Predict) for the model. - Returns: - Boolean, whether the batch index should be added for the input data to - preserve the ordering. - """ - # TODO(priyag, rxsang): Come up a better way to determine when the batch index - # should be added. - return (mode == ModeKeys.PREDICT - and dist_utils.is_tpu_strategy(strategy) - and strategy.extended.num_hosts > 1) - - -def train_on_batch( - model, - x, - y=None, - sample_weight=None, - class_weight=None, - reset_metrics=True, - standalone=False): - """Runs a single gradient update on a single batch of data. - - Arguments: - model: The model to train. - x: Input data. It could be: - - A Numpy array (or array-like), or a list of arrays - (in case the model has multiple inputs). - - A TensorFlow tensor, or a list of tensors - (in case the model has multiple inputs). - - A dict mapping input names to the corresponding array/tensors, - if the model has named inputs. - - A `tf.data` dataset. - y: Target data. Like the input data `x`, it could be either Numpy - array(s) or TensorFlow tensor(s). It should be consistent with `x` - (you cannot have Numpy inputs and tensor targets, or inversely). If - `x` is a dataset `y` should not be specified - (since targets will be obtained from the iterator). - sample_weight: Optional array of the same length as x, containing - weights to apply to the model's loss for each sample. In the case of - temporal data, you can pass a 2D array with shape (samples, - sequence_length), to apply a different weight to every timestep of - every sample. In this case you should make sure to specify - sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset. - class_weight: Optional dictionary mapping class indices (integers) to a - weight (float) to apply to the model's loss for the samples from this - class during training. This can be useful to tell the model to "pay - more attention" to samples from an under-represented class. - reset_metrics: If `True`, the metrics returned will be only for this - batch. If `False`, the metrics will be statefully accumulated across - batches. - standalone: If True, this method is not called as part of - Model.fit/evaluate/predict and can therefore be tf.function'd. - - Returns: - Scalar training loss - (if the model has a single output and no metrics) - or list of scalars (if the model has multiple outputs - and/or metrics). The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - - Raises: - ValueError: In case of invalid user-provided arguments. - """ - model._assert_compile_was_called() - - # TODO(scottzhu): Standardization should happen in the data handlers, - ## not on a per batch basis in the *_on_batch methods - # Validate and standardize user data. - x, y, sample_weights = model._standardize_user_data( - x, y, sample_weight=sample_weight, class_weight=class_weight, - extract_tensors_from_dataset=True) - batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0] - # If `model._distribution_strategy` is True, then we are in a replica context - # at this point because of the check above. `train_on_batch` is being run - # for each replica by `model._distribution_strategy` and the same code path - # as Eager is expected to be taken. - - if standalone: - train_on_batch_fn = _get_or_make_on_batch_function(model, ModeKeys.TRAIN) - else: - train_on_batch_fn = training_eager.train_on_batch - - outputs = train_on_batch_fn( - model, - x, - y, - sample_weights=sample_weights, - output_loss_metrics=model._output_loss_metrics) - - if reset_metrics: - model.reset_metrics() - - outputs['batch_size'] = math_ops.cast(batch_size, dtypes.int64) - return outputs - - -def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True, - standalone=False): - """Test the model on a single batch of samples. - - Arguments: - model: The model to test. - x: Input data. It could be: - - A Numpy array (or array-like), or a list of arrays - (in case the model has multiple inputs). - - A TensorFlow tensor, or a list of tensors - (in case the model has multiple inputs). - - A dict mapping input names to the corresponding array/tensors, - if the model has named inputs. - - A `tf.data` dataset. - y: Target data. Like the input data `x`, - it could be either Numpy array(s) or TensorFlow tensor(s). - It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset, - `y` should not be specified - (since targets will be obtained from the iterator). - sample_weight: Optional array of the same length as x, containing - weights to apply to the model's loss for each sample. - In the case of temporal data, you can pass a 2D array - with shape (samples, sequence_length), - to apply a different weight to every timestep of every sample. - In this case you should make sure to specify - sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset. - reset_metrics: If `True`, the metrics returned will be only for this - batch. If `False`, the metrics will be statefully accumulated across - batches. - standalone: If True, this method is not called as part of - Model.fit/evaluate/predict and can therefore be tf.function'd. - - Returns: - Scalar test loss (if the model has a single output and no metrics) - or list of scalars (if the model has multiple outputs - and/or metrics). The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - - Raises: - ValueError: In case of invalid user-provided arguments. - """ - model._assert_compile_was_called() - - # TODO(scottzhu): Standardization should happen in the data handlers, - ## not on a per batch basis in the *_on_batch methods - # Validate and standardize user data. - x, y, sample_weights = model._standardize_user_data( - x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True) - - batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0] - - if standalone: - test_on_batch_fn = _get_or_make_on_batch_function(model, ModeKeys.TEST) - else: - test_on_batch_fn = training_eager.test_on_batch - - outputs = test_on_batch_fn( - model, - x, - y, - sample_weights=sample_weights, - output_loss_metrics=model._output_loss_metrics) - - if reset_metrics: - model.reset_metrics() - - outputs['batch_size'] = math_ops.cast(batch_size, dtypes.int64) - return outputs - - -def predict_on_batch(model, x, standalone=False): - """Returns predictions for a single batch of samples. - - Arguments: - model: The model to predict with. - x: Input data. It could be: - - A Numpy array (or array-like), or a list of arrays - (in case the model has multiple inputs). - - A TensorFlow tensor, or a list of tensors - (in case the model has multiple inputs). - - A `tf.data` dataset. - standalone: If True, this method is not called as part of - Model.fit/evaluate/predict and can therefore be tf.function'd. - - Returns: - Numpy array(s) of predictions. - - Raises: - ValueError: In case of mismatch between given number of inputs and - expectations of the model. - """ - # TODO(scottzhu): Standardization should happen in the data handlers, - ## not on a per batch basis in the *_on_batch methods - # Validate and standardize user data. - inputs, _, _ = model._standardize_user_data( - x, extract_tensors_from_dataset=True) - - # If `model._distribution_strategy` is True, then we are in a replica context - # at this point. - inputs = training_utils.cast_to_model_input_dtypes(inputs, model) - if isinstance(inputs, collections.Sequence): - # Unwrap lists with only one input, as we do when training on batch - if len(inputs) == 1: - inputs = inputs[0] - - if standalone: - predict_on_batch_fn = _get_or_make_on_batch_function( - model, ModeKeys.PREDICT) - else: - predict_on_batch_fn = model - - with backend.eager_learning_phase_scope(0): - return predict_on_batch_fn(inputs) # pylint: disable=not-callable diff --git a/tensorflow/python/keras/engine/training_v2_utils_test.py b/tensorflow/python/keras/engine/training_v2_utils_test.py deleted file mode 100644 index 4499ad3c8c6..00000000000 --- a/tensorflow/python/keras/engine/training_v2_utils_test.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tensorflow.python.keras.engine.training_v2_utils.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -from absl.testing import parameterized -import mock -import numpy as np - - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import mirrored_strategy -from tensorflow.python.distribute import strategy_combinations -from tensorflow.python.eager import def_function -from tensorflow.python.framework import combinations -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils -from tensorflow.python.keras.engine import training_v2_utils -from tensorflow.python.keras.utils.mode_keys import ModeKeys -from tensorflow.python.ops import array_ops -from tensorflow.python.ops.ragged import ragged_factory_ops -from tensorflow.python.platform import test - - -class AggregatePredictResultsTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - super(AggregatePredictResultsTest, self).setUp() - strategy_combinations.set_virtual_cpus_to_at_least(3) - self.num_replica = 3 - self.batch_size = 16 - self.dense_shape = (2, 3) - self.total_sample = 2 * self.batch_size - - mock_model = collections.namedtuple('Model', ['outputs']) - self.mock_model = mock_model([1]) - - strategy = mirrored_strategy.MirroredStrategy( - ['/cpu:0', '/cpu:1', '/cpu:2']) - - execution_function = lambda *inp: inp - @def_function.function - def predict_loop(batch): - batch_result = strategy.experimental_run_v2(execution_function, batch) - batch_result = dist_utils.unwrap_output_dict( - strategy, batch_result, ModeKeys.PREDICT) - # swap the order of replica 1 and 2, to mimic random order. - batch_result[2], batch_result[1] = batch_result[1], batch_result[2] - batch_result[5], batch_result[4] = batch_result[4], batch_result[5] - return batch_result - - self.strategy = strategy - self.predict_loop = predict_loop - - @combinations.generate(combinations.combine(tf_api_version=[1, 2], - mode='eager')) - def test_aggregate_predict_results_dense(self): - dataset = dataset_ops.Dataset.range(self.total_sample) - def dense_map_fn(i): - # Mimic what we do for adding batch index - return i, array_ops.fill(self.dense_shape, i) - dense_dataset = dataset.map(dense_map_fn).batch(self.batch_size) - distributed_data = self.strategy.experimental_distribute_dataset( - dense_dataset) - - start = 0 - for batch in distributed_data: - with mock.patch.object(training_v2_utils, - '_should_add_batch_index_to_element', - fake_should_add_batch_index_to_element): - batch_result = self.predict_loop(batch) - final_result = training_v2_utils._aggregate_predict_results( - self.strategy, batch_result, self.mock_model) - - # Make sure the dense result is in a sorted order. - expected_result = np.arange( - start=start, stop=start+self.batch_size).reshape((-1, 1)) - expected_result = np.tile(expected_result, 6).reshape( - (-1,) + self.dense_shape) - self.assertAllClose(final_result[0], expected_result) - start += self.batch_size - - @combinations.generate(combinations.combine(tf_api_version=[1, 2], - mode='eager')) - def test_aggregate_predict_results_sparse(self): - dataset = dataset_ops.Dataset.range(self.total_sample) - def sparse_map_fn(i): - return i, sparse_tensor.SparseTensor( - indices=[(0, 0)], - values=[i], - dense_shape=self.dense_shape) - sparse_dataset = dataset.map(sparse_map_fn).batch(self.batch_size) - distributed_data = self.strategy.experimental_distribute_dataset( - sparse_dataset) - - start = 0 - for batch in distributed_data: - with mock.patch.object(training_v2_utils, - '_should_add_batch_index_to_element', - fake_should_add_batch_index_to_element): - batch_result = self.predict_loop(batch) - final_result = training_v2_utils._aggregate_predict_results( - self.strategy, batch_result, self.mock_model) - - # Make sure the dense result is in a sorted order. - expected_values = np.arange(start=start, stop=start+self.batch_size) - self.assertAllClose(final_result[0].values, expected_values) - start += self.batch_size - - @combinations.generate(combinations.combine(tf_api_version=[1, 2], - mode='eager')) - def test_aggregate_predict_results_ragged(self): - dataset = dataset_ops.Dataset.range(self.total_sample) - def ragged_map_fn(i): - return i, ragged_factory_ops.constant([[0], [], []], dtype=np.int64) + i - ragged_dataset = dataset.map(ragged_map_fn).batch(self.batch_size) - distributed_data = self.strategy.experimental_distribute_dataset( - ragged_dataset) - - start = 0 - for batch in distributed_data: - with mock.patch.object(training_v2_utils, - '_should_add_batch_index_to_element', - fake_should_add_batch_index_to_element): - batch_result = self.predict_loop(batch) - final_result = training_v2_utils._aggregate_predict_results( - self.strategy, batch_result, self.mock_model) - - # Make sure the dense result is in a sorted order. - expected_values = np.arange(start=start, stop=start+self.batch_size) - self.assertAllClose(final_result[0].flat_values, expected_values) - start += self.batch_size - - -def fake_should_add_batch_index_to_element(strategy, mode): - # Ignore the strategy instance check since we were using the MirroredStrategy - # for testing. - del strategy - return mode == ModeKeys.PREDICT - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 13134927409..65aadd7cd08 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -1122,12 +1122,17 @@ class Dense(Layer): raise TypeError('Unable to build `Dense` layer with non-floating point ' 'dtype %s' % (dtype,)) input_shape = tensor_shape.TensorShape(input_shape) - if tensor_shape.dimension_value(input_shape[-1]) is None: - raise ValueError('The last dimension of the inputs to `Dense` ' - 'should be defined. Found `None`.') - last_dim = tensor_shape.dimension_value(input_shape[-1]) - self.input_spec = InputSpec(min_ndim=2, - axes={-1: last_dim}) + # Handle 1-d inputs by reshaping to (-1, 1). + if input_shape.rank == 1: + input_shape = tensor_shape.TensorShape(input_shape.as_list() + [1]) + last_dim = tensor_shape.dimension_value(1) + self.input_spec = InputSpec(min_ndim=1, max_ndim=2) + else: + if tensor_shape.dimension_value(input_shape[-1]) is None: + raise ValueError('The last dimension of the inputs to `Dense` ' + 'should be defined. Found `None`.') + last_dim = tensor_shape.dimension_value(input_shape[-1]) + self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim}) self.kernel = self.add_weight( 'kernel', shape=[last_dim, self.units], @@ -1160,6 +1165,8 @@ class Dense(Layer): output_shape = shape[:-1] + [self.units] outputs.set_shape(output_shape) else: + if rank == 1: + inputs = array_ops.expand_dims_v2(inputs, axis=-1) inputs = math_ops.cast(inputs, self._compute_dtype) if K.is_sparse(inputs): outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, self.kernel) diff --git a/tensorflow/python/keras/layers/merge.py b/tensorflow/python/keras/layers/merge.py index bf39f30b71a..57a97952e4f 100644 --- a/tensorflow/python/keras/layers/merge.py +++ b/tensorflow/python/keras/layers/merge.py @@ -89,7 +89,7 @@ class _Merge(Layer): @tf_utils.shape_type_conversion def build(self, input_shape): # Used purely for shape validation. - if not isinstance(input_shape, list): + if not isinstance(input_shape[0], tuple): raise ValueError('A merge layer should be called on a list of inputs.') if len(input_shape) < 2: raise ValueError('A merge layer should be called ' @@ -118,7 +118,7 @@ class _Merge(Layer): self._reshape_required = True def call(self, inputs): - if not isinstance(inputs, list): + if not isinstance(inputs, (list, tuple)): raise ValueError('A merge layer should be called on a list of inputs.') if self._reshape_required: reshaped_inputs = [] @@ -204,9 +204,9 @@ class _Merge(Layer): def compute_mask(self, inputs, mask=None): if mask is None: return None - if not isinstance(mask, list): + if not isinstance(mask, (tuple, list)): raise ValueError('`mask` should be a list.') - if not isinstance(inputs, list): + if not isinstance(inputs, (tuple, list)): raise ValueError('`inputs` should be a list.') if len(mask) != len(inputs): raise ValueError('The lists `inputs` and `mask` ' @@ -489,7 +489,7 @@ class Concatenate(_Merge): @tf_utils.shape_type_conversion def build(self, input_shape): # Used purely for shape validation. - if not isinstance(input_shape, list) or len(input_shape) < 2: + if not isinstance(input_shape[0], tuple) or len(input_shape) < 2: raise ValueError('A `Concatenate` layer should be called ' 'on a list of at least 2 inputs') if all(shape is None for shape in input_shape): @@ -523,7 +523,7 @@ class Concatenate(_Merge): @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): - if not isinstance(input_shape, list): + if not isinstance(input_shape, (tuple, list)): raise ValueError('A `Concatenate` layer should be called ' 'on a list of inputs.') input_shapes = input_shape @@ -538,9 +538,9 @@ class Concatenate(_Merge): def compute_mask(self, inputs, mask=None): if mask is None: return None - if not isinstance(mask, list): + if not isinstance(mask, (tuple, list)): raise ValueError('`mask` should be a list.') - if not isinstance(inputs, list): + if not isinstance(inputs, (tuple, list)): raise ValueError('`inputs` should be a list.') if len(mask) != len(inputs): raise ValueError('The lists `inputs` and `mask` ' @@ -656,7 +656,7 @@ class Dot(_Merge): @tf_utils.shape_type_conversion def build(self, input_shape): # Used purely for shape validation. - if not isinstance(input_shape, list) or len(input_shape) != 2: + if not isinstance(input_shape[0], tuple) or len(input_shape) != 2: raise ValueError('A `Dot` layer should be called ' 'on a list of 2 inputs.') shape1 = input_shape[0] @@ -701,7 +701,7 @@ class Dot(_Merge): @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): - if not isinstance(input_shape, list) or len(input_shape) != 2: + if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2: raise ValueError('A `Dot` layer should be called ' 'on a list of 2 inputs.') shape1 = list(input_shape[0]) diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index 5222a32857d..687b76dbe98 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -37,6 +37,7 @@ from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2 from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker_v2 +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent @@ -498,7 +499,8 @@ class NormalizationLayersGraphModeOnlyTest( def _run_layernorm_correctness_test(layer, dtype='float32'): model = keras.models.Sequential() - norm = layer(input_shape=(2, 2, 2)) + model.add(keras.layers.Lambda(lambda x: math_ops.cast(x, dtype='float16'))) + norm = layer(input_shape=(2, 2, 2), dtype=dtype) model.add(norm) model.compile( loss='mse', diff --git a/tensorflow/python/keras/layers/preprocessing/normalization_test.py b/tensorflow/python/keras/layers/preprocessing/normalization_test.py index e1573df3387..227e961751e 100644 --- a/tensorflow/python/keras/layers/preprocessing/normalization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/normalization_test.py @@ -43,36 +43,40 @@ def get_layer_class(): def _get_layer_computation_test_cases(): test_cases = ({ - "adapt_data": np.array([[1.], [2.], [3.], [4.], [5.]]), + "adapt_data": np.array([[1.], [2.], [3.], [4.], [5.]], dtype=np.float32), "axis": -1, - "test_data": np.array([[1.], [2.], [3.]]), - "expected": np.array([[-1.414214], [-.707107], [0]]), + "test_data": np.array([[1.], [2.], [3.]], np.float32), + "expected": np.array([[-1.414214], [-.707107], [0]], np.float32), "testcase_name": "2d_single_element" }, { "adapt_data": - np.array([[[1., 2., 3.], [2., 3., 4.]], [[3., 4., 5.], [4., 5., - 6.]]]), + np.array([[[1., 2., 3.], [2., 3., 4.]], [[3., 4., 5.], [4., 5., 6.]]], + np.float32), "axis": 1, "test_data": - np.array([[[1., 2., 3.], [2., 3., 4.]], [[3., 4., 5.], [4., 5., - 6.]]]), + np.array([[[1., 2., 3.], [2., 3., 4.]], [[3., 4., 5.], [4., 5., 6.]]], + np.float32), "expected": np.array([[[-1.549193, -0.774597, 0.], [-1.549193, -0.774597, 0.]], - [[0., 0.774597, 1.549193], [0., 0.774597, 1.549193]]]), + [[0., 0.774597, 1.549193], [0., 0.774597, 1.549193]]], + np.float32), "testcase_name": "3d_internal_axis" }, { "adapt_data": - np.array([[[1., 0., 3.], [2., 3., 4.]], [[3., -1., 5.], [4., 5., - 8.]]]), + np.array( + [[[1., 0., 3.], [2., 3., 4.]], [[3., -1., 5.], [4., 5., 8.]]], + np.float32), "axis": (1, 2), "test_data": - np.array([[[3., 1., -1.], [2., 5., 4.]], [[3., 0., 5.], [2., 5., - 8.]]]), + np.array( + [[[3., 1., -1.], [2., 5., 4.]], [[3., 0., 5.], [2., 5., 8.]]], + np.float32), "expected": - np.array([[[1., 3., -5.], [-1., 1., -1.]], - [[1., 1., 1.], [-1., 1., 1.]]]), + np.array( + [[[1., 3., -5.], [-1., 1., -1.]], [[1., 1., 1.], [-1., 1., 1.]]], + np.float32), "testcase_name": "3d_multiple_axis" }) diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index d3da18e703e..1a7886cf369 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -253,29 +253,28 @@ class TimeDistributedTest(keras_parameterized.TestCase): self.assertAllEqual(mask_outputs_val[i], ref_mask_val[i]) self.assertIs(mask_outputs[-1], None) # final layer + @tf_test_util.run_in_graph_and_eager_modes def test_TimeDistributed_with_masking_layer(self): - with self.cached_session(): - # test with Masking layer - model = keras.models.Sequential() - model.add(keras.layers.TimeDistributed(keras.layers.Masking( - mask_value=0.,), input_shape=(None, 4))) - model.add(keras.layers.TimeDistributed(keras.layers.Dense(5))) - model.compile(optimizer='rmsprop', loss='mse') - model_input = np.random.randint(low=1, high=5, size=(10, 3, 4)) - for i in range(4): - model_input[i, i:, :] = 0. - model.compile(optimizer='rmsprop', loss='mse') - model.fit(model_input, - np.random.random((10, 3, 5)), epochs=1, batch_size=6) - mask_outputs = [model.layers[0].compute_mask(model.input)] - mask_outputs += [model.layers[1].compute_mask(model.layers[1].input, - mask_outputs[-1])] - func = keras.backend.function([model.input], mask_outputs) - mask_outputs_val = func([model_input]) - self.assertEqual((mask_outputs_val[0]).all(), - model_input.all()) - self.assertEqual((mask_outputs_val[1]).all(), - model_input.all()) + # test with Masking layer + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Masking(mask_value=0.,), input_shape=(None, 4))) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(5))) + model.compile(optimizer='rmsprop', loss='mse') + model_input = np.random.randint(low=1, high=5, size=(10, 3, 4)) + for i in range(4): + model_input[i, i:, :] = 0. + model.compile(optimizer='rmsprop', loss='mse') + model.fit(model_input, np.random.random((10, 3, 5)), epochs=1, batch_size=6) + mask_outputs = [model.layers[0].compute_mask(model.input)] + mask_outputs += [ + model.layers[1].compute_mask(model.layers[1].input, mask_outputs[-1]) + ] + func = keras.backend.function([model.input], mask_outputs) + mask_outputs_val = func([model_input]) + self.assertEqual((mask_outputs_val[0]).all(), model_input.all()) + self.assertEqual((mask_outputs_val[1]).all(), model_input.all()) def test_TimeDistributed_with_different_time_shapes(self): time_dist = keras.layers.TimeDistributed(keras.layers.Dense(5)) @@ -574,9 +573,9 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): output = bidi_rnn(inputs) model = keras.models.Model(inputs, output) - y_1 = model.predict(x) + y_1 = model.predict(x, batch_size=1) model.reset_states() - y_2 = model.predict(x) + y_2 = model.predict(x, batch_size=1) self.assertAllClose(y_1, y_2) diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 85731398ea7..061e31140b7 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -95,6 +95,17 @@ class Loss(object): # SUM_OVER_BATCH is only allowed in losses managed by `fit` or # CannedEstimators. self._allow_sum_over_batch_size = False + self._set_name_scope() + + def _set_name_scope(self): + """Creates a valid `name_scope` name.""" + if self.name is None: + self._name_scope = self.__class__.__name__ + elif self.name == '': + self._name_scope = 'lambda' + else: + # E.g. '_my_loss' => 'my_loss' + self._name_scope = self.name.strip('_') def __call__(self, y_true, y_pred, sample_weight=None): """Invokes the `Loss` instance. @@ -124,10 +135,9 @@ class Loss(object): """ # If we are wrapping a lambda function strip '<>' from the name as it is not # accepted in scope name. - scope_name = 'lambda' if self.name == '' else self.name graph_ctx = tf_utils.graph_context_for_symbolic_tensors( y_true, y_pred, sample_weight) - with K.name_scope(scope_name or self.__class__.__name__), graph_ctx: + with K.name_scope(self._name_scope), graph_ctx: losses = self.call(y_true, y_pred) return losses_utils.compute_weighted_loss( losses, sample_weight, reduction=self._get_reduction()) diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index bd0a8605135..1c851581a05 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -63,6 +63,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops.losses import util as tf_losses_utils +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls @@ -3220,11 +3221,7 @@ def clone_metric(metric): def clone_metrics(metrics): """Clones the given metric list/dict.""" - if metrics is None: - return None - if isinstance(metrics, dict): - return {key: clone_metric(value) for key, value in metrics.items()} - return [clone_metric(metric) for metric in metrics] + return nest.map_structure(clone_metric, metrics) @keras_export('keras.metrics.serialize') @@ -3243,6 +3240,7 @@ def deserialize(config, custom_objects=None): @keras_export('keras.metrics.get') def get(identifier): + """Return a metric given its identifer.""" if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, six.string_types): @@ -3250,5 +3248,6 @@ def get(identifier): elif callable(identifier): return identifier else: - raise ValueError('Could not interpret ' - 'metric function identifier: %s' % identifier) + error_msg = 'Could not interpret metric function identifier: {}'.format( + identifier) + raise ValueError(error_msg) diff --git a/tensorflow/python/keras/metrics_correctness_test.py b/tensorflow/python/keras/metrics_correctness_test.py index f372996141b..ea4222b6935 100644 --- a/tensorflow/python/keras/metrics_correctness_test.py +++ b/tensorflow/python/keras/metrics_correctness_test.py @@ -21,7 +21,6 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python import tf2 from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import layers from tensorflow.python.keras import losses @@ -29,6 +28,7 @@ from tensorflow.python.keras import metrics from tensorflow.python.keras import testing_utils from tensorflow.python.ops.losses import loss_reduction from tensorflow.python.platform import test +from tensorflow.python.util import nest def get_multi_io_model(): @@ -51,13 +51,6 @@ def custom_generator_multi_io(sample_weights=None): inputs = np.asarray([[1.], [2.], [3.], [4.]]) targets_1 = np.asarray([[2.], [4.], [6.], [8.]]) targets_2 = np.asarray([[1.], [2.], [3.], [4.]]) - if sample_weights: - assert len(sample_weights) == 2 - w1 = sample_weights[0] - w2 = sample_weights[1] - else: - w1 = None - w2 = None i = 0 while True: batch_index = i * batch_size % num_samples @@ -67,17 +60,14 @@ def custom_generator_multi_io(sample_weights=None): x = [inputs[start:end], inputs[start:end]] y = [targets_1[start:end], targets_2[start:end]] if sample_weights: - w = [ - None if w1 is None else w1[start:end], - None if w2 is None else w2[start:end] - ] + sw = nest.map_structure(lambda w: w[start:end], sample_weights) else: - w = None - yield x, y, w + sw = None + yield x, y, sw @keras_parameterized.run_with_all_model_types(exclude_models=['sequential']) -@keras_parameterized.run_all_keras_modes +@keras_parameterized.run_all_keras_modes(always_skip_v1=True) class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase): def _get_compiled_multi_io_model(self): @@ -100,8 +90,6 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase): self.y2 = np.asarray([[1.], [2.], [3.], [4.]]) self.sample_weight_1 = np.asarray([2., 3., 4., 5.]) self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5]) - self.class_weight_1 = {2: 2, 4: 3, 6: 4, 8: 5} - self.class_weight_2 = {1: 3.5, 2: 2.5, 3: 1.5, 4: 0.5} # y_true_1 = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]] # y_true_2 = [[1.], [2.], [3.], [4.]], y_pred = [[3.], [6.], [9.], [12.]] @@ -148,8 +136,6 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase): # Total loss without weights = 7.5 + 30 = 37.5 self.wmse = 'mean_squared_error_2' - if not tf2.enabled(): - self.wmse = 'weighted_' + self.wmse self.expected_fit_result_with_weights = { 'output_1_mean_squared_error': [7.5, 7.5], 'output_2_mean_squared_error': [30, 30], @@ -223,29 +209,6 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase): for key, value in self.expected_fit_result_with_weights_output_2.items(): self.assertAllClose(history.history[key], value, 1e-3) - def test_fit_with_class_weight(self): - model = self._get_compiled_multi_io_model() - history = model.fit([self.x, self.x], [self.y1, self.y2], - class_weight={ - 'output_1': self.class_weight_1, - 'output_2': self.class_weight_2, - }, - batch_size=2, - epochs=2, - shuffle=False) - for key, value in self.expected_fit_result_with_weights.items(): - self.assertAllClose(history.history[key], value, 1e-3) - - # Set weights for one output. - history = model.fit([self.x, self.x], [self.y1, self.y2], - class_weight={'output_2': self.class_weight_2}, - batch_size=2, - epochs=2, - shuffle=False) - - for key, value in self.expected_fit_result_with_weights_output_2.items(): - self.assertAllClose(history.history[key], value, 1e-3) - def test_eval(self): model = self._get_compiled_multi_io_model() eval_result = model.evaluate([self.x, self.x], [self.y1, self.y2], @@ -304,23 +267,6 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase): self.assertAllClose(result, self.expected_batch_result_with_weights_output_2, 1e-3) - def test_train_on_batch_with_class_weight(self): - model = self._get_compiled_multi_io_model() - result = model.train_on_batch([self.x, self.x], [self.y1, self.y2], - class_weight={ - 'output_1': self.class_weight_1, - 'output_2': self.class_weight_2, - }) - self.assertAllClose(result, self.expected_batch_result_with_weights, 1e-3) - - # Set weights for one output. - result = model.train_on_batch([self.x, self.x], [self.y1, self.y2], - class_weight={ - 'output_2': self.class_weight_2, - }) - self.assertAllClose(result, - self.expected_batch_result_with_weights_output_2, 1e-3) - def test_test_on_batch(self): model = self._get_compiled_multi_io_model() result = model.test_on_batch([self.x, self.x], [self.y1, self.y2]) @@ -362,29 +308,8 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase): # Set weights for one output. history = model.fit_generator( - custom_generator_multi_io(sample_weights=[None, self.sample_weight_2]), - steps_per_epoch=2, - epochs=2) - for key, value in self.expected_fit_result_with_weights_output_2.items(): - self.assertAllClose(history.history[key], value, 1e-3) - - def test_fit_generator_with_class_weight(self): - model = self._get_compiled_multi_io_model() - history = model.fit_generator( - custom_generator_multi_io(), - class_weight={ - 'output_1': self.class_weight_1, - 'output_2': self.class_weight_2, - }, - steps_per_epoch=2, - epochs=2) - for key, value in self.expected_fit_result_with_weights.items(): - self.assertAllClose(history.history[key], value, 1e-3) - - # Set weights for one output. - history = model.fit_generator( - custom_generator_multi_io(), - class_weight={'output_2': self.class_weight_2}, + custom_generator_multi_io( + sample_weights={'output_2': self.sample_weight_2}), steps_per_epoch=2, epochs=2) for key, value in self.expected_fit_result_with_weights_output_2.items(): @@ -406,14 +331,15 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase): # Set weights for one output. eval_result = model.evaluate_generator( - custom_generator_multi_io(sample_weights=[None, self.sample_weight_2]), + custom_generator_multi_io( + sample_weights={'output_2': self.sample_weight_2}), steps=2) self.assertAllClose(eval_result, self.expected_batch_result_with_weights_output_2, 1e-3) @keras_parameterized.run_with_all_model_types -@keras_parameterized.run_all_keras_modes +@keras_parameterized.run_all_keras_modes(always_skip_v1=True) class TestMetricsCorrectnessSingleIO(keras_parameterized.TestCase): def _get_model(self): @@ -452,7 +378,8 @@ class TestMetricsCorrectnessSingleIO(keras_parameterized.TestCase): self.x = np.asarray([[1.], [2.], [3.], [4.]]) self.y = np.asarray([[2.], [4.], [6.], [8.]]) self.sample_weight = np.asarray([2., 3., 4., 5.]) - self.class_weight = {2: 2, 4: 3, 6: 4, 8: 5} + self.class_weight = {i: 1 for i in range(10)} + self.class_weight.update({2: 2, 4: 3, 6: 4, 8: 5}) # y_true = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]] @@ -483,8 +410,6 @@ class TestMetricsCorrectnessSingleIO(keras_parameterized.TestCase): # Result = 7.5 wmse = 'mean_squared_error_2' - if not tf2.enabled(): - wmse = 'weighted_' + wmse self.expected_fit_result_with_weights = { 'mean_squared_error': [7.5, 7.5], diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 7620f2f072e..0b0121f521e 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -552,6 +552,8 @@ def _reset_build_compile_trackers(model): model.outputs = None # Reset compile state model._is_compiled = False # pylint:disable=protected-access + if not ops.executing_eagerly_outside_functions(): + model._v1_compile_was_called = False model.optimizer = None @@ -639,20 +641,23 @@ def clone_and_build_model( 'Error when cloning model: compile_clone was set to True, but the ' 'original model has not been compiled.') - with CustomObjectScope(custom_objects or {}): - if model._is_graph_network or isinstance(model, Sequential): - clone = clone_model(model, input_tensors=input_tensors) + if compile_clone: + compile_args = model._get_compile_args() # pylint: disable=protected-access + # Allows this method to be robust to switching graph and eager classes. + model._get_compile_args = lambda: compile_args - if all([ - isinstance(clone, Sequential), not clone._is_graph_network, - getattr(model, '_build_input_shape', None) is not None - ]): - # Set model inputs to build the model and add input/output properties. - # TODO(kathywu): Add multiple placeholders to handle edge case where - # sequential model has multiple inputs. - clone._set_inputs( - K.placeholder( - model._build_input_shape, dtype=model.inputs[0].dtype)) + with CustomObjectScope(custom_objects or {}): + if model._is_graph_network: + clone = clone_model(model, input_tensors=input_tensors) + elif isinstance(model, Sequential): + clone = clone_model(model, input_tensors=input_tensors) + if (not clone._is_graph_network and model._build_input_shape is not None): + if ops.executing_eagerly_outside_functions(): + clone.build(model._build_input_shape) + else: + clone._set_inputs( + K.placeholder( + model._build_input_shape, dtype=model.inputs[0].dtype)) else: try: # Prefer clonining the model if serial/deserial logic is implemented for @@ -704,14 +709,15 @@ def clone_and_build_model( if len(optimizer) == 1: optimizer = optimizer[0] - clone.compile( - optimizer, - model.loss, - metrics=metrics_module.clone_metrics(model._compile_metrics), - loss_weights=model.loss_weights, - sample_weight_mode=model.sample_weight_mode, - weighted_metrics=metrics_module.clone_metrics( - model._compile_weighted_metrics), - target_tensors=target_tensors) + + compile_args['optimizer'] = optimizer + if target_tensors is not None: + compile_args['target_tensors'] = target_tensors + # Ensure Metric objects in new model are separate from existing model. + compile_args['metrics'] = metrics_module.clone_metrics( + compile_args['metrics']) + compile_args['weighted_metrics'] = metrics_module.clone_metrics( + compile_args['weighted_metrics']) + clone.compile(**compile_args) return clone diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py index 3f9289b1021..8120afa0a55 100644 --- a/tensorflow/python/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -412,8 +412,6 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase): isinstance(model.optimizer, (keras.optimizers.RMSprop, keras.optimizer_v2.rmsprop.RMSprop))) - self.assertEqual(['acc', metrics.categorical_accuracy], - model._compile_metrics) def _clone_and_build_test_helper(self, model, model_type): inp = np.random.random((10, 4)) @@ -500,15 +498,13 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types @keras_parameterized.run_all_keras_modes def test_replace_tf_optimizer_iterations_variable(self): + if context.executing_eagerly(): + self.skipTest('v1 optimizers not supported with eager.') self.assert_optimizer_iterations_increases(adam.AdamOptimizer(0.01)) @keras_parameterized.run_with_all_model_types @keras_parameterized.run_all_keras_modes def test_replace_keras_optimizer_iterations_variable(self): - if testing_utils.should_run_eagerly(): - # This needs to be updated to run with v2 optimizers. - self.skipTest('b/120991591') - self.assert_optimizer_iterations_increases('adam') def test_clone_optimizer_in_different_graph(self): diff --git a/tensorflow/python/keras/premade/linear.py b/tensorflow/python/keras/premade/linear.py index dd3e1fdfaeb..32300421afa 100644 --- a/tensorflow/python/keras/premade/linear.py +++ b/tensorflow/python/keras/premade/linear.py @@ -97,7 +97,7 @@ class LinearModel(training.Model): def build(self, input_shape): self.dense_layers = [] - if isinstance(input_shape, list): + if isinstance(input_shape, (tuple, list)): for shape in input_shape: layer = core.Dense( units=self.units, diff --git a/tensorflow/python/keras/premade/wide_deep.py b/tensorflow/python/keras/premade/wide_deep.py index ba524367bc6..2f339786c67 100644 --- a/tensorflow/python/keras/premade/wide_deep.py +++ b/tensorflow/python/keras/premade/wide_deep.py @@ -18,10 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import backprop +from tensorflow.python.framework import ops from tensorflow.python.keras import activations from tensorflow.python.keras import backend as K from tensorflow.python.keras import layers as layer_module from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import data_adapter from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.utils import generic_utils from tensorflow.python.util import nest @@ -106,25 +109,38 @@ class WideDeepModel(keras_training.Model): return nest.map_structure(self.activation, output) return output - def _get_optimizers(self): - if isinstance(self.optimizer, (tuple, list)): - return (self.optimizer[0], self.optimizer[1]) - else: - return (self.optimizer, self.optimizer) - # This does not support gradient scaling and LossScaleOptimizer. - def _backwards(self, tape, loss): - linear_vars = self.linear_model.trainable_weights # pylint: disable=protected-access - dnn_vars = self.dnn_model.trainable_weights # pylint: disable=protected-access - linear_grads, dnn_grads = tape.gradient(loss, (linear_vars, dnn_vars)) - linear_optimizer, dnn_optimizer = self._get_optimizers() - linear_optimizer.apply_gradients(zip(linear_grads, linear_vars)) - dnn_optimizer.apply_gradients(zip(dnn_grads, dnn_vars)) - return + def _train_step(self, data): + x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) + x, y, sample_weight = data_adapter.expand_1d((x, y, sample_weight)) + + with backprop.GradientTape() as tape: + y_pred = self(x, training=True) + loss = self.compiled_loss( + y, y_pred, sample_weight, regularization_losses=self.losses) + self.compiled_metrics.update_state(y, y_pred, sample_weight) + + if isinstance(self.optimizer, (list, tuple)): + linear_vars = self.linear_model.trainable_variables + dnn_vars = self.dnn_model.trainable_variables + linear_grads, dnn_grads = tape.gradient(loss, (linear_vars, dnn_vars)) + + linear_optimizer = self.optimizer[0] + dnn_optimizer = self.optimizer[1] + linear_optimizer.apply_gradients(zip(linear_grads, linear_vars)) + dnn_optimizer.apply_gradients(zip(dnn_grads, dnn_vars)) + else: + trainable_variables = self.trainable_variables + grads = tape.gradient(loss, trainable_variables) + self.optimizer.apply_gradients(zip(grads, trainable_variables)) + + return {m.name: m.result() for m in self.metrics} def _make_train_function(self): - # TODO(tanzheny): This is a direct copy from super to make it work - # refactor it so that common logic can be shared. + if ops.executing_eagerly_outside_functions(): + return super(WideDeepModel, self)._make_train_function() + + # Only needed for graph mode and model_to_estimator. has_recompiled = self._recompile_weights_loss_and_weighted_metrics() self._check_trainable_weights_consistency() # If we have re-compiled the loss/weighted metric sub-graphs then create @@ -140,7 +156,13 @@ class WideDeepModel(keras_training.Model): if not isinstance(K.symbolic_learning_phase(), int): inputs += [K.symbolic_learning_phase()] - linear_optimizer, dnn_optimizer = self._get_optimizers() + if isinstance(self.optimizer, (list, tuple)): + linear_optimizer = self.optimizer[0] + dnn_optimizer = self.optimizer[1] + else: + linear_optimizer = self.optimizer + dnn_optimizer = self.optimizer + with K.get_graph().as_default(): with K.name_scope('training'): # Training updates diff --git a/tensorflow/python/keras/premade/wide_deep_test.py b/tensorflow/python/keras/premade/wide_deep_test.py index e2f471e3575..3b58984bd11 100644 --- a/tensorflow/python/keras/premade/wide_deep_test.py +++ b/tensorflow/python/keras/premade/wide_deep_test.py @@ -258,8 +258,6 @@ class WideDeepModelTest(keras_parameterized.TestCase): run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) wide_deep_model.fit(x={'symbol': data}, y=y, batch_size=32, epochs=10) - self.assertEqual(3, linear_model.inputs[0].shape[1]) - self.assertEqual(5, dnn_model.inputs[0].shape[1]) def test_config(self): linear_model = linear.LinearModel(units=1) diff --git a/tensorflow/python/keras/saving/hdf5_format_test.py b/tensorflow/python/keras/saving/hdf5_format_test.py index 66a712c4f2e..6c94ed50517 100644 --- a/tensorflow/python/keras/saving/hdf5_format_test.py +++ b/tensorflow/python/keras/saving/hdf5_format_test.py @@ -818,19 +818,23 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase): evaluation_results['sparse_categorical_crossentropy'] + evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6) + @test_util.run_in_graph_and_eager_modes def test_save_uncompiled_model_with_optimizer(self): - saved_model_dir = self._save_model_dir() - save_format = testing_utils.get_save_format() - model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(3,))]) - # Set the model's optimizer but don't compile. This can happen if the model - # is trained with a custom training loop. - model.optimizer = keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001) - model.save(saved_model_dir, save_format=save_format) + with self.cached_session() as session: + saved_model_dir = self._save_model_dir() + save_format = testing_utils.get_save_format() + model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(3,))]) + # Set the model's optimizer but don't compile. This can happen if the + # model is trained with a custom training loop. + model.optimizer = keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001) + if not context.executing_eagerly(): + session.run([v.initializer for v in model.variables]) + model.save(saved_model_dir, save_format=save_format) - if save_format in ['tf', 'tensorflow']: - loaded = keras.models.load_model(saved_model_dir) - self.assertIsInstance(loaded.optimizer, - keras.optimizer_v2.optimizer_v2.OptimizerV2) + if save_format in ['tf', 'tensorflow']: + loaded = keras.models.load_model(saved_model_dir) + self.assertIsInstance(loaded.optimizer, + keras.optimizer_v2.optimizer_v2.OptimizerV2) # Factory functions to create models that will be serialized inside a Network. diff --git a/tensorflow/python/keras/saving/losses_serialization_test.py b/tensorflow/python/keras/saving/losses_serialization_test.py index 60252b1dbf4..8bdcc2a794d 100644 --- a/tensorflow/python/keras/saving/losses_serialization_test.py +++ b/tensorflow/python/keras/saving/losses_serialization_test.py @@ -48,11 +48,11 @@ class MyMeanAbsoluteError(losses.LossFunctionWrapper): reduction=losses_utils.ReductionV2.AUTO, name='mean_absolute_error'): super(MyMeanAbsoluteError, self).__init__( - _my_mae, name=name, reduction=reduction) + my_mae, name=name, reduction=reduction) # Custom loss function -def _my_mae(y_true, y_pred): +def my_mae(y_true, y_pred): return keras.backend.mean(math_ops.abs(y_pred - y_true), axis=-1) @@ -70,7 +70,7 @@ def _get_multi_io_model(): dict(testcase_name='string', value='mae'), dict(testcase_name='built_in_fn', value=losses.mae), dict(testcase_name='built_in_class', value=losses.MeanAbsoluteError()), - dict(testcase_name='custom_fn', value=_my_mae), + dict(testcase_name='custom_fn', value=my_mae), dict(testcase_name='custom_class', value=MyMeanAbsoluteError()), dict(testcase_name='list_of_strings', value=['mae', 'mae']), dict(testcase_name='list_of_built_in_fns', value=[losses.mae, losses.mae]), @@ -78,7 +78,7 @@ def _get_multi_io_model(): testcase_name='list_of_built_in_classes', value=[losses.MeanAbsoluteError(), losses.MeanAbsoluteError()]), - dict(testcase_name='list_of_custom_fns', value=[_my_mae, _my_mae]), + dict(testcase_name='list_of_custom_fns', value=[my_mae, my_mae]), dict( testcase_name='list_of_custom_classes', value=[MyMeanAbsoluteError(), @@ -104,8 +104,8 @@ def _get_multi_io_model(): dict( testcase_name='dict_of_custom_fn', value={ - 'output': _my_mae, - 'output_1': _my_mae + 'output': my_mae, + 'output_1': my_mae }), dict( testcase_name='dict_of_custom_class', @@ -128,7 +128,7 @@ class LossesSerialization(keras_parameterized.TestCase): def test_serializing_model_with_loss_with_custom_object_scope(self, value): with generic_utils.custom_object_scope({ 'MyMeanAbsoluteError': MyMeanAbsoluteError, - '_my_mae': _my_mae, + 'my_mae': my_mae, 'Bias': testing_utils.Bias, }): model = _get_multi_io_model() @@ -182,7 +182,7 @@ class LossesSerialization(keras_parameterized.TestCase): self.model_filename, custom_objects={ 'MyMeanAbsoluteError': MyMeanAbsoluteError, - '_my_mae': _my_mae, + 'my_mae': my_mae, 'Bias': testing_utils.Bias, }) loaded_model.predict([self.x, self.x]) diff --git a/tensorflow/python/keras/saving/metrics_serialization_test.py b/tensorflow/python/keras/saving/metrics_serialization_test.py index 10eee4d4175..7ecc2e5b087 100644 --- a/tensorflow/python/keras/saving/metrics_serialization_test.py +++ b/tensorflow/python/keras/saving/metrics_serialization_test.py @@ -69,17 +69,6 @@ def _get_multi_io_model(): dict(testcase_name='built_in_class', value=[metrics.MeanAbsoluteError]), dict(testcase_name='custom_fn', value=[_my_mae]), dict(testcase_name='custom_class', value=[MyMeanAbsoluteError]), - dict(testcase_name='list_of_strings', value=['mae', 'mae']), - dict( - testcase_name='list_of_built_in_fns', value=[metrics.mae, metrics.mae]), - dict( - testcase_name='list_of_built_in_classes', - value=[metrics.MeanAbsoluteError, metrics.MeanAbsoluteError]), - dict(testcase_name='list_of_custom_fns', value=[_my_mae, _my_mae]), - dict( - testcase_name='list_of_custom_classes', - value=[MyMeanAbsoluteError, MyMeanAbsoluteError]), - dict(testcase_name='list_of_string_and_list', value=['mae', ['mae']]), dict( testcase_name='list_of_built_in_fn_and_list', value=[metrics.mae, [metrics.mae]]), diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index 0aac128eb43..d53530ec1d7 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -445,8 +445,11 @@ class KerasObjectLoader(tf_load.Loader): model.__init__(layers, name=config['name']) if not model.inputs: first_layer = self._get_child_layer_node_ids(model_id, model.name)[0] - input_shape = self._infer_inputs(first_layer) - model._set_inputs(input_shape) # pylint: disable=protected-access + input_specs = self._infer_inputs(first_layer) + input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True) + model._set_inputs(input_specs) # pylint: disable=protected-access + if not model.built and not isinstance(input_specs, dict): + model.build(input_shapes) else: (inputs, outputs, created_layers) = network_lib.reconstruct_from_config( config, created_layers={layer.name: layer for layer in layers}) diff --git a/tensorflow/python/keras/saving/saved_model/revive_test.py b/tensorflow/python/keras/saving/saved_model/revive_test.py index 36140e7fe20..3e267340caa 100644 --- a/tensorflow/python/keras/saving/saved_model/revive_test.py +++ b/tensorflow/python/keras/saving/saved_model/revive_test.py @@ -32,7 +32,6 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils @@ -121,12 +120,17 @@ class TestModelRevive(keras_parameterized.TestCase): def _assert_revived_correctness(self, model, revived): self.assertAllEqual(model.input_names, revived.input_names) self.assertAllEqual(model.output_names, revived.output_names) - self.assertTrue(all([ - i.shape.as_list() == r.shape.as_list() and i.dtype == r.dtype - for (i, r) in zip(model.inputs, revived.inputs)])) - self.assertTrue(all([ - i.shape.as_list() == r.shape.as_list() and i.dtype == r.dtype - for (i, r) in zip(model.outputs, revived.outputs)])) + if model.inputs is not None: + self.assertTrue( + all([ + i.shape.as_list() == r.shape.as_list() and i.dtype == r.dtype + for (i, r) in zip(model.inputs, revived.inputs) + ])) + self.assertTrue( + all([ + i.shape.as_list() == r.shape.as_list() and i.dtype == r.dtype + for (i, r) in zip(model.outputs, revived.outputs) + ])) self.assertAllClose(self.evaluate(model.weights), self.evaluate(revived.weights)) @@ -205,9 +209,8 @@ class TestModelRevive(keras_parameterized.TestCase): model = testing_utils.get_model_from_layers( layers, input_shape=input_shape) - # The inputs attribute must be defined in order to save the model. - if not model.inputs: - model._set_inputs(tensor_spec.TensorSpec((None, 2, 3))) + # Run data through the Model to create save spec and weights. + model.predict(np.ones((10, 2, 3)), batch_size=10) # Test that the correct checkpointed values are loaded, whether the layer is # created from the config or SavedModel. @@ -220,7 +223,8 @@ class TestModelRevive(keras_parameterized.TestCase): def test_revive_subclassed_with_nested_model(self): model = SubclassedModelNoConfig(1., 2.) - model._set_inputs(tensor_spec.TensorSpec((None, 2, 3))) + # Run data through the Model to create save spec and weights. + model.predict(np.ones((10, 2, 3)), batch_size=10) model.save(self.path, save_format='tf') revived = keras_load.load(self.path) self._assert_revived_correctness(model, revived) diff --git a/tensorflow/python/keras/saving/saved_model/save_impl.py b/tensorflow/python/keras/saving/saved_model/save_impl.py index 3fcc649cba5..7bd2b52fe84 100644 --- a/tensorflow/python/keras/saving/saved_model/save_impl.py +++ b/tensorflow/python/keras/saving/saved_model/save_impl.py @@ -67,28 +67,13 @@ sequential_lib = LazyLoader( def should_skip_serialization(layer): """Skip serializing extra objects and functions if layer inputs aren't set.""" - if isinstance(layer, training_lib.Model): - try: - # pylint:disable=pointless-statement - layer.inputs - layer.input_names - # pylint:enable=pointless-statement - except AttributeError: - # If the model does not have inputs set, because it was not called or its - # input shapes were not recorded, we won't have a signature so can't trace - # a function. But the user may still save an object with this Model - # attached; we won't fail the whole tf.saved_model.save. - logging.warning('Skipping full serialization of Keras model {}, because ' - 'its inputs are not defined.'.format(layer)) - return True - else: - return False - else: - if not layer.built: - logging.warning('Skipping full serialization of Keras layer {}, because ' - 'it is not built.'.format(layer)) - return True - return False + saved_model_input_spec_set = (isinstance(layer, training_lib.Model) and + layer._saved_model_inputs_spec is not None) # pylint: disable=protected-access + if not layer.built and not saved_model_input_spec_set: + logging.warning('Skipping full serialization of Keras layer {}, because ' + 'it is not built.'.format(layer)) + return True + return False def wrap_layer_objects(layer, serialization_cache): diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 018edc030e7..da86a7cdac1 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -85,17 +85,23 @@ class LayerWithLoss(keras.layers.Layer): def call(self, inputs): self.add_loss(math_ops.reduce_sum(inputs), inputs) - return inputs + return inputs * 2 class LayerWithUpdate(keras.layers.Layer): def build(self, _): - self.v = self.add_weight('v', shape=[], dtype=dtypes.int32) + self.v = self.add_weight( + 'v', + shape=[], + initializer=keras.initializers.zeros, + trainable=False, + dtype=dtypes.float32) - def call(self, inputs): - self.add_update(self.v.assign_add(math_ops.reduce_sum(inputs))) - return inputs + def call(self, inputs, training=True): + if training: + self.add_update(self.v.assign_add(1.)) + return inputs * 2. @keras_parameterized.run_all_keras_modes @@ -249,7 +255,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): model.add_loss(eager_loss) # Call predict to ensure that all layers are built and inputs are set. - model.predict(np.random.random((1, 3))) + model.predict(np.random.random((1, 3)).astype(np.float32)) saved_model_dir = self._save_model_dir() tf_save.save(model, saved_model_dir) @@ -608,13 +614,13 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): def _testAddUpdate(self, scope): with scope: - layer_with_update = LayerWithUpdate(dtype=dtypes.int32) + layer_with_update = LayerWithUpdate() model = testing_utils.get_model_from_layers([layer_with_update], - input_shape=(3,), - input_dtype=dtypes.int32) + input_shape=(3,)) + x = np.ones((10, 3)) if testing_utils.get_model_type() == 'subclass': - model._set_inputs(constant_op.constant([[1, 2, 3]], dtype=dtypes.int32)) + model.predict(x, batch_size=10) self.evaluate(variables.variables_initializer(model.variables)) saved_model_dir = self._save_model_dir() model.save(saved_model_dir, save_format='tf') @@ -622,11 +628,11 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): loaded = keras_load.load(saved_model_dir) loaded_layer = loaded.layers[-1] self.evaluate(variables.variables_initializer(loaded.variables)) - self.assertEqual(self.evaluate(loaded_layer.v), 0) + self.assertEqual(self.evaluate(loaded_layer.v), 0.) - loaded.predict(constant_op.constant([[1, 2, 3]], dtype=dtypes.int32), - steps=1) - self.assertEqual(self.evaluate(loaded_layer.v), 6) + loaded.compile('sgd', 'mse') + loaded.fit(x, x, batch_size=10) + self.assertEqual(self.evaluate(loaded_layer.v), 1.) @keras_parameterized.run_with_all_model_types def testSaveLayerWithUpdates(self): diff --git a/tensorflow/python/keras/saving/saved_model_experimental_test.py b/tensorflow/python/keras/saving/saved_model_experimental_test.py index 11a3ff5e1ab..2f3cf7cf9c9 100644 --- a/tensorflow/python/keras/saving/saved_model_experimental_test.py +++ b/tensorflow/python/keras/saving/saved_model_experimental_test.py @@ -32,8 +32,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util -from tensorflow.python.keras import keras_parameterized -from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import training as model_lib from tensorflow.python.keras.optimizer_v2 import adadelta from tensorflow.python.keras.optimizer_v2 import rmsprop @@ -47,7 +45,7 @@ from tensorflow.python.saved_model import model_utils from tensorflow.python.training import training as training_module -@keras_parameterized.run_all_keras_modes() +@test_util.run_deprecated_v1 # Removed in v2. class TestModelSavingandLoading(parameterized.TestCase, test.TestCase): def _save_model_dir(self, dirname='saved_model'): @@ -65,9 +63,7 @@ class TestModelSavingandLoading(parameterized.TestCase, test.TestCase): loss=keras.losses.MSE, optimizer=rmsprop.RMSprop(lr=0.0001), metrics=[keras.metrics.categorical_accuracy], - sample_weight_mode='temporal', - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) + sample_weight_mode='temporal') x = np.random.random((1, 3)) y = np.random.random((1, 3, 3)) model.train_on_batch(x, y) @@ -81,7 +77,6 @@ class TestModelSavingandLoading(parameterized.TestCase, test.TestCase): y = loaded_model.predict(x) self.assertAllClose(ref_y, y, atol=1e-05) - @test_util.run_in_graph_and_eager_modes def test_saving_sequential_model_without_compile(self): with self.cached_session(): model = keras.models.Sequential() @@ -109,9 +104,7 @@ class TestModelSavingandLoading(parameterized.TestCase, test.TestCase): model.compile( loss=keras.losses.MSE, optimizer=rmsprop.RMSprop(lr=0.0001), - metrics=[keras.metrics.categorical_accuracy], - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) + metrics=[keras.metrics.categorical_accuracy]) x = np.random.random((1, 3)) y = np.random.random((1, 3)) model.train_on_batch(x, y) @@ -125,7 +118,6 @@ class TestModelSavingandLoading(parameterized.TestCase, test.TestCase): y = loaded_model.predict(x) self.assertAllClose(ref_y, y, atol=1e-05) - @test_util.run_in_graph_and_eager_modes def test_saving_functional_model_without_compile(self): with self.cached_session(): inputs = keras.layers.Input(shape=(3,)) @@ -146,7 +138,6 @@ class TestModelSavingandLoading(parameterized.TestCase, test.TestCase): y = loaded_model.predict(x) self.assertAllClose(ref_y, y, atol=1e-05) - @test_util.run_in_graph_and_eager_modes def test_saving_with_tf_optimizer(self): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) @@ -167,9 +158,7 @@ class TestModelSavingandLoading(parameterized.TestCase, test.TestCase): loaded_model.compile( loss='mse', optimizer=training_module.RMSPropOptimizer(0.1), - metrics=['acc'], - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) + metrics=['acc']) y = loaded_model.predict(x) self.assertAllClose(ref_y, y, atol=1e-05) @@ -290,7 +279,7 @@ def load_model(sess, path, mode): return inputs, outputs, meta_graph_def -@test_util.run_all_in_graph_and_eager_modes +@test_util.run_deprecated_v1 # Removed in v2. class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): def _save_model_dir(self, dirname='saved_model'): diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py index fe8d26485b9..9a82f69a2fd 100644 --- a/tensorflow/python/keras/saving/saving_utils.py +++ b/tensorflow/python/keras/saving/saving_utils.py @@ -19,13 +19,14 @@ from __future__ import print_function import collections import os +import six from tensorflow.python.eager import def_function -from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K from tensorflow.python.keras import losses from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -43,13 +44,12 @@ def extract_model_metrics(model): Dictionary mapping metric names to metric instances. May return `None` if the model does not contain any metrics. """ - if not getattr(model, '_compile_metrics', None): - return None - - # TODO(psv/kathywu): use this implementation in model to estimator flow. - # We are not using model.metrics here because we want to exclude the metrics - # added using `add_metric` API. - return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access + if getattr(model, '_compile_metrics', None): + # TODO(psv/kathywu): use this implementation in model to estimator flow. + # We are not using model.metrics here because we want to exclude the metrics + # added using `add_metric` API. + return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access + return None def model_input_signature(model, keep_original_batch_size=False): @@ -73,29 +73,9 @@ def model_input_signature(model, keep_original_batch_size=False): A list containing either a single TensorSpec or an object with nested TensorSpecs. This list does not contain the `training` argument. """ - try: - inputs = model.inputs - input_names = model.input_names - except AttributeError: + input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access + if input_specs is None: return None - flat_inputs = nest.flatten(inputs) - flat_input_names = nest.flatten(input_names) - flat_input_specs = [] - for input_tensor, input_name in zip(flat_inputs, flat_input_names): - if keep_original_batch_size: - input_shape = input_tensor.shape.as_list() - else: - # If the user has not explicitly provided the input_signature, we - # create it from the inputs. We make sure to set the first dimension - # (batch) to None here, as in serving or retraining, batch should not - # be fixed. See b/132783590 for context. - input_shape = [None] + input_tensor.shape[1:].as_list() - flat_input_specs.append(tensor_spec.TensorSpec( - shape=input_shape, dtype=input_tensor.dtype, - name=input_name)) - input_specs = nest.pack_sequence_as(structure=inputs, - flat_sequence=flat_input_specs) - # Return a list with a single element as the model's input signature. if isinstance(input_specs, collections.Sequence) and len(input_specs) == 1: # Note that the isinstance check filters out single-element dictionaries, @@ -147,14 +127,15 @@ def trace_model_call(model, input_signature=None): with base_layer_utils.call_context().enter( model, inputs=inputs, build_graph=False, training=False, saving=True): - outputs_list = nest.flatten(model(inputs, training=False)) + outputs = model(inputs, training=False) - try: - output_names = model.output_names - except AttributeError: - from tensorflow.python.keras.engine import training_utils # pylint: disable=g-import-not-at-top - output_names = training_utils.generic_output_names(outputs_list) - return {name: output for name, output in zip(output_names, outputs_list)} + # Outputs always has to be a flat dict. + output_names = model.output_names # Functional Model. + if output_names is None: # Subclassed Model. + from tensorflow.python.keras.engine import compile_utils # pylint: disable=g-import-not-at-top + output_names = compile_utils.create_pseudo_output_names(outputs) + outputs = nest.flatten(outputs) + return {name: output for name, output in zip(output_names, outputs)} return _wrapped_model @@ -187,32 +168,22 @@ def model_metadata(model, include_optimizer=True, require_config=True): 'You will have to compile your model again after loading it. ' 'Prefer using a Keras optimizer instead ' '(see keras.io/optimizers).') - else: - try: - metadata['training_config'] = { - 'loss': model.loss, - # pylint: disable=protected-access - 'metrics': model._compile_metrics, - 'weighted_metrics': model._compile_weighted_metrics, - # pylint: enable=protected-access - 'sample_weight_mode': model.sample_weight_mode, - 'loss_weights': model.loss_weights, + elif model._compile_was_called: # pylint: disable=protected-access + training_config = model._get_compile_args() # pylint: disable=protected-access + training_config.pop('optimizer', None) # Handled separately. + metadata['training_config'] = _serialize_nested_config(training_config) + if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer): + raise NotImplementedError( + 'As of now, Optimizers loaded from SavedModel cannot be saved. ' + 'If you\'re calling `model.save` or `tf.keras.models.save_model`,' + ' please set the `include_optimizer` option to `False`. For ' + '`tf.saved_model.save`, delete the optimizer from the model.') + else: + optimizer_config = { + 'class_name': model.optimizer.__class__.__name__, + 'config': model.optimizer.get_config() } - if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer): - raise NotImplementedError( - 'As of now, Optimizers loaded from SavedModel cannot be saved. ' - 'If you\'re calling `model.save` or `tf.keras.models.save_model`,' - ' please set the `include_optimizer` option to `False`. For ' - '`tf.saved_model.save`, delete the optimizer from the model.') - else: - optimizer_config = { - 'class_name': model.optimizer.__class__.__name__, - 'config': model.optimizer.get_config()} - metadata['training_config']['optimizer_config'] = optimizer_config - except AttributeError: - pass # If the model has an optimizer, but not all of the attributes - # loss, _compile_metrics, etc., then it was not compiled using - # model.compile. In this case, do not save the training config. + metadata['training_config']['optimizer_config'] = optimizer_config return metadata @@ -224,73 +195,36 @@ def should_overwrite(filepath, overwrite): return True -def convert_output_metrics(metrics_config, custom_objects): - from tensorflow.python.keras import metrics as metrics_module # pylint:disable=g-import-not-at-top - if isinstance(metrics_config, list): - return [convert_output_metrics(mc, custom_objects) for mc in metrics_config] - elif (isinstance(metrics_config, dict) or - (metrics_config not in ['accuracy', 'acc', 'crossentropy', 'ce'])): - # Do not deserialize accuracy and cross-entropy strings as we have special - # case handling for these in compile, based on model output shape. - return metrics_module.deserialize(metrics_config, custom_objects) - return metrics_config - - def compile_args_from_training_config(training_config, custom_objects=None): """Return model.compile arguments from training config.""" if custom_objects is None: custom_objects = {} - optimizer_config = training_config['optimizer_config'] - optimizer = optimizers.deserialize( - optimizer_config, custom_objects=custom_objects) + with generic_utils.CustomObjectScope(custom_objects): + optimizer_config = training_config['optimizer_config'] + optimizer = optimizers.deserialize(optimizer_config) - # Recover losses. - loss_config = training_config['loss'] - if isinstance(loss_config, list): # Loss fed to compile as a list. - loss = [losses.deserialize(lc, custom_objects) for lc in loss_config] - elif isinstance(loss_config, dict) and 'class_name' not in loss_config: - # Loss fed to compile as a dict. - loss = { - k: losses.deserialize(v, custom_objects) - for (k, v) in loss_config.items() - } - else: # Loss fed to compile as a str/ function/ class instance. - loss = losses.deserialize(loss_config, custom_objects) + # Recover losses. + loss = None + loss_config = training_config.get('loss', None) + if loss_config is not None: + loss = _deserialize_nested_config(losses.deserialize, loss_config) - # Recover metrics. - metrics_config = training_config.get('metrics', None) - if isinstance(metrics_config, dict): # Metrics fed to compile as a dict. - metrics = { - k: convert_output_metrics(v, custom_objects) - for (k, v) in metrics_config.items() - } - elif isinstance(metrics_config, list): # Metrics fed to compile as a list. - metrics = [ - convert_output_metrics(m, custom_objects) for m in metrics_config - ] - else: # No metrics. + # Recover metrics. metrics = None + metrics_config = training_config.get('metrics', None) + if metrics_config is not None: + metrics = _deserialize_nested_config(_deserialize_metric, metrics_config) - # Recover weighted metrics. - weighted_metrics_config = training_config.get('weighted_metrics', None) - if isinstance(weighted_metrics_config, dict): - # Metrics fed to compile as a dict. - weighted_metrics = { - k: convert_output_metrics(v, custom_objects) - for (k, v) in weighted_metrics_config.items() - } - elif isinstance(weighted_metrics_config, list): - # Metrics fed to compile as a list. - weighted_metrics = [ - convert_output_metrics(m, custom_objects) - for m in weighted_metrics_config - ] - else: # No metrics. + # Recover weighted metrics. weighted_metrics = None + weighted_metrics_config = training_config.get('weighted_metrics', None) + if weighted_metrics_config is not None: + weighted_metrics = _deserialize_nested_config(_deserialize_metric, + weighted_metrics_config) - sample_weight_mode = training_config['sample_weight_mode'] - loss_weights = training_config['loss_weights'] + sample_weight_mode = training_config['sample_weight_mode'] + loss_weights = training_config['loss_weights'] return dict( optimizer=optimizer, @@ -299,3 +233,49 @@ def compile_args_from_training_config(training_config, custom_objects=None): weighted_metrics=weighted_metrics, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode) + + +def _deserialize_nested_config(deserialize_fn, config): + """Deserializes arbitrary Keras `config` using `deserialize_fn`.""" + + def _is_single_object(obj): + if isinstance(obj, dict) and 'class_name' in obj: + return True # Serialized Keras object. + if isinstance(obj, six.string_types): + return True # Serialized function or string. + return False + + if config is None: + return None + if _is_single_object(config): + return deserialize_fn(config) + elif isinstance(config, dict): + return { + k: _deserialize_nested_config(deserialize_fn, v) + for k, v in config.items() + } + elif isinstance(config, (tuple, list)): + return [_deserialize_nested_config(deserialize_fn, obj) for obj in config] + + raise ValueError('Saved configuration not understood.') + + +def _serialize_nested_config(config): + """Serialized a nested structure of Keras objects.""" + + def _serialize_fn(obj): + if callable(obj): + return generic_utils.serialize_keras_object(obj) + return obj + + return nest.map_structure(_serialize_fn, config) + + +def _deserialize_metric(metric_config): + """Deserialize metrics, leaving special strings untouched.""" + from tensorflow.python.keras import metrics as metrics_module # pylint:disable=g-import-not-at-top + if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']: + # Do not deserialize accuracy and cross-entropy strings as we have special + # case handling for these in compile, based on model output shape. + return metric_config + return metrics_module.deserialize(metric_config) diff --git a/tensorflow/python/keras/saving/saving_utils_test.py b/tensorflow/python/keras/saving/saving_utils_test.py index 92bee3df50a..4687e8a617a 100644 --- a/tensorflow/python/keras/saving/saving_utils_test.py +++ b/tensorflow/python/keras/saving/saving_utils_test.py @@ -76,7 +76,10 @@ class TraceModelCallTest(keras_parameterized.TestCase): fn = saving_utils.trace_model_call(model) signature_outputs = fn(inputs) - expected_outputs = {model.output_names[0]: model(inputs)} + if model.output_names: + expected_outputs = {model.output_names[0]: model(inputs)} + else: + expected_outputs = {'output_1': model(inputs)} self._assert_all_close(expected_outputs, signature_outputs) @@ -90,14 +93,19 @@ class TraceModelCallTest(keras_parameterized.TestCase): loss='mse', run_eagerly=testing_utils.should_run_eagerly(), experimental_run_tf_function=testing_utils.should_run_tf_function()) - model.fit(x=np.random.random((8, 5)), - y=np.random.random((8, 3)), epochs=2) + model.fit( + x=np.random.random((8, 5)).astype(np.float32), + y=np.random.random((8, 3)).astype(np.float32), + epochs=2) inputs = array_ops.ones((8, 5)) fn = saving_utils.trace_model_call(model) signature_outputs = fn(inputs) - expected_outputs = {model.output_names[0]: model(inputs)} + if model.output_names: + expected_outputs = {model.output_names[0]: model(inputs)} + else: + expected_outputs = {'output_1': model(inputs)} self._assert_all_close(expected_outputs, signature_outputs) @@ -140,9 +148,13 @@ class TraceModelCallTest(keras_parameterized.TestCase): fn = saving_utils.trace_model_call(model) signature_outputs = fn([input_a_np, input_b_np]) outputs = model([input_a_np, input_b_np]) - expected_outputs = {model.output_names[0]: outputs[0], - model.output_names[1]: outputs[1]} - + if model.output_names: + expected_outputs = { + model.output_names[0]: outputs[0], + model.output_names[1]: outputs[1] + } + else: + expected_outputs = {'output_1': outputs[0], 'output_2': outputs[1]} self._assert_all_close(expected_outputs, signature_outputs) @test_util.run_in_graph_and_eager_modes @@ -177,7 +189,10 @@ class TraceModelCallTest(keras_parameterized.TestCase): fn = saving_utils.trace_model_call( model, [tensor_spec.TensorSpec(shape=[None, 5], dtype=dtypes.float32)]) signature_outputs = fn(inputs) - expected_outputs = {model.output_names[0]: model(inputs)} + if model.output_names: + expected_outputs = {model.output_names[0]: model(inputs)} + else: + expected_outputs = {'output_1': model(inputs)} self._assert_all_close(expected_outputs, signature_outputs) @test_util.run_in_graph_and_eager_modes @@ -242,7 +257,9 @@ def _import_and_infer(save_dir, inputs): model = loader.load(session, [tag_constants.SERVING], save_dir) signature = model.signature_def[ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] - assert set(inputs.keys()) == set(signature.inputs.keys()) + assert set(inputs.keys()) == set( + signature.inputs.keys()), ('expected {}, found {}'.format( + signature.inputs.keys(), inputs.keys())) feed_dict = {} for arg_name in inputs.keys(): feed_dict[graph.get_tensor_by_name(signature.inputs[arg_name].name)] = ( @@ -254,10 +271,10 @@ def _import_and_infer(save_dir, inputs): return session.run(output_dict, feed_dict=feed_dict) +@keras_parameterized.run_with_all_model_types +@keras_parameterized.run_all_keras_modes(always_skip_v1=True) class ModelSaveTest(keras_parameterized.TestCase): - @keras_parameterized.run_with_all_model_types - @test_util.run_v2_only def test_model_save(self): input_dim = 5 model = testing_utils.get_small_mlp(10, 3, input_dim) @@ -269,14 +286,21 @@ class ModelSaveTest(keras_parameterized.TestCase): save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save_lib.save(model, save_dir) - self.assertAllClose( - {model.output_names[0]: model.predict_on_batch(inputs)}, - _import_and_infer(save_dir, {model.input_names[0]: np.ones((8, 5))})) + if model.output_names: + output_name = model.output_names[0] + input_name = model.input_names[0] + else: + output_name = 'output_1' + input_name = 'input_1' + + self.assertAllClose({output_name: model.predict_on_batch(inputs)}, + _import_and_infer(save_dir, + {input_name: np.ones((8, 5))})) +@test_util.run_deprecated_v1 # Not used in v2. class ExtractModelMetricsTest(keras_parameterized.TestCase): - @keras_parameterized.run_all_keras_modes def test_extract_model_metrics(self): a = keras.layers.Input(shape=(3,), name='input_a') b = keras.layers.Input(shape=(3,), name='input_b') @@ -308,9 +332,7 @@ class ExtractModelMetricsTest(keras_parameterized.TestCase): keras.metrics.BinaryAccuracy(), 'mae', keras.metrics.mean_squared_error ], - optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01), - run_eagerly=testing_utils.should_run_eagerly(), - experimental_run_tf_function=testing_utils.should_run_tf_function()) + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) extract_metrics = saving_utils.extract_model_metrics(model) self.assertEqual(set(model_metric_names), set(model.metrics_names)) self.assertEqual(set(extract_metric_names), set(extract_metrics.keys())) diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index a3867927e70..564c1d07fe2 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -632,6 +632,9 @@ class _MultiIOSubclassModel(keras.Model): inputs = layer(inputs) a = inputs b = inputs + elif isinstance(inputs, dict): + a = inputs['input_1'] + b = inputs['input_2'] else: a, b = inputs diff --git a/tensorflow/python/keras/tests/model_subclassing_compiled_test.py b/tensorflow/python/keras/tests/model_subclassing_compiled_test.py index 404c9f0c975..aa94f8400e0 100644 --- a/tensorflow/python/keras/tests/model_subclassing_compiled_test.py +++ b/tensorflow/python/keras/tests/model_subclassing_compiled_test.py @@ -134,8 +134,6 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): self.assertEqual(len(model.weights), 10) self.assertEqual(len(model.trainable_weights), 8) self.assertEqual(len(model.non_trainable_weights), 2) - self.assertEqual(len(model.inputs), 2) - self.assertEqual(len(model.outputs), 2) def test_updates(self): # test that updates get run during training diff --git a/tensorflow/python/keras/tests/model_subclassing_test.py b/tensorflow/python/keras/tests/model_subclassing_test.py index 56cdbb17d27..d3b601e75ed 100644 --- a/tensorflow/python/keras/tests/model_subclassing_test.py +++ b/tensorflow/python/keras/tests/model_subclassing_test.py @@ -340,7 +340,7 @@ class ModelSubclassingTest(keras_parameterized.TestCase): # Single-io model = testing_utils.SmallSubclassMLP( num_hidden=32, num_classes=4, use_bn=True, use_dp=True) - model._set_inputs(np.ones((3, 4))) # need to build model first + model(np.ones((3, 4))) # need to build model first print_fn = ToString() model.summary(print_fn=print_fn) self.assertTrue('Trainable params: 356' in print_fn.contents) @@ -348,8 +348,7 @@ class ModelSubclassingTest(keras_parameterized.TestCase): # Multi-io model = model_util.get_multi_io_subclass_model( num_classes=(5, 6), use_bn=True, use_dp=True) - model._set_inputs([np.ones((3, 4)), - np.ones((3, 4))]) # need to build model first + model([np.ones((3, 4)), np.ones((3, 4))]) # need to build model first print_fn = ToString() model.summary(print_fn=print_fn) self.assertTrue('Trainable params: 587' in print_fn.contents) @@ -677,6 +676,8 @@ class CustomCallSignatureTests(test.TestCase): @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created def test_training_no_default(self): + if not context.executing_eagerly(): + return model = model_util.TrainingNoDefaultModel() arg = array_ops.ones([1, 1]) model(arg, True) diff --git a/tensorflow/python/keras/tests/temporal_sample_weights_correctness_test.py b/tensorflow/python/keras/tests/temporal_sample_weights_correctness_test.py index 0d9f77cb000..8854783ea05 100644 --- a/tensorflow/python/keras/tests/temporal_sample_weights_correctness_test.py +++ b/tensorflow/python/keras/tests/temporal_sample_weights_correctness_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.python import tf2 from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import layers from tensorflow.python.keras import metrics from tensorflow.python.keras import optimizer_v2 from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test +from tensorflow.python.util import nest class Bias(layers.Layer): @@ -102,7 +102,7 @@ def run_with_different_sample_weight_mode_inputs(fn, partial_sw=True): @keras_parameterized.run_with_all_model_types(exclude_models=['sequential']) -@keras_parameterized.run_all_keras_modes +@keras_parameterized.run_all_keras_modes(always_skip_v1=True) class TestMetricsCorrectnessMultiIOTemporal(keras_parameterized.TestCase): def custom_generator_multi_io_temporal(self, sample_weights=None): @@ -116,13 +116,6 @@ class TestMetricsCorrectnessMultiIOTemporal(keras_parameterized.TestCase): """ batch_size = 3 num_samples = 3 - if sample_weights: - assert len(sample_weights) == 2 - w1 = sample_weights[0] - w2 = sample_weights[1] - else: - w1 = None - w2 = None iteration = 0 while True: batch_index = iteration * batch_size % num_samples @@ -132,13 +125,10 @@ class TestMetricsCorrectnessMultiIOTemporal(keras_parameterized.TestCase): x = [self.x[start:end], self.x[start:end]] y = [self.y1[start:end], self.y2[start:end]] if sample_weights: - w = [ - None if w1 is None else w1[start:end], - None if w2 is None else w2[start:end] - ] + sw = nest.map_structure(lambda w: w[start:end], sample_weights) else: - w = None - yield x, y, w + sw = None + yield x, y, sw def setUp(self): super(TestMetricsCorrectnessMultiIOTemporal, self).setUp() @@ -147,11 +137,6 @@ class TestMetricsCorrectnessMultiIOTemporal(keras_parameterized.TestCase): self.y1 = np.asarray([[[.5], [1.]], [[2.], [2.5]], [[3.5], [2.5]]]) self.y2 = np.asarray([[[.5], [1.5]], [[2.], [1.5]], [[3.5], [3.]]]) - if tf2.enabled(): - self.wmae = 'mae_2' - else: - self.wmae = 'weighted_mae_2' - # Without weights: # Epoch 1 - bias = 0 # y_pred_1 = [[[0.], [0.]], [[1.], [1.]], [[2.], [2.]]] @@ -172,8 +157,8 @@ class TestMetricsCorrectnessMultiIOTemporal(keras_parameterized.TestCase): self.expected_fit_result = { 'output_1_mae': [1, 0.9], 'output_2_mae': [1, 0.9], - 'output_1_' + self.wmae: [1, 0.9], - 'output_2_' + self.wmae: [1, 0.9], + 'output_1_mae_2': [1, 0.9], + 'output_2_mae_2': [1, 0.9], 'loss': [2., 1.8], 'output_1_loss': [1, 0.9], 'output_2_loss': [1, 0.9], @@ -229,8 +214,8 @@ class TestMetricsCorrectnessMultiIOTemporal(keras_parameterized.TestCase): self.expected_fit_result_with_weights = { 'output_1_mae': [1, 0.875], 'output_2_mae': [1, 0.875], - 'output_1_' + self.wmae: [1, 0.875], - 'output_2_' + self.wmae: [1, 0.875], + 'output_1_mae_2': [1, 0.875], + 'output_2_mae_2': [1, 0.875], 'loss': [2.5, 2.1875], 'output_1_loss': [1.25, 1.09375], 'output_2_loss': [1.25, 1.09375], @@ -239,8 +224,8 @@ class TestMetricsCorrectnessMultiIOTemporal(keras_parameterized.TestCase): self.expected_fit_result_with_weights_output_2 = { 'output_1_mae': [1., 0.9], 'output_2_mae': [1, 0.875], - 'output_1_' + self.wmae: [1., 0.9], - 'output_2_' + self.wmae: [1., 0.875], + 'output_1_mae_2': [1., 0.9], + 'output_2_mae_2': [1., 0.875], 'loss': [2.25, 1.99375], 'output_1_loss': [1., 0.9], 'output_2_loss': [1.25, 1.09375], @@ -461,7 +446,7 @@ class TestMetricsCorrectnessMultiIOTemporal(keras_parameterized.TestCase): def _train_and_assert(model): history = model.fit_generator( self.custom_generator_multi_io_temporal( - sample_weights=[None, self.sample_weight_2]), + sample_weights={'output_2': self.sample_weight_2}), steps_per_epoch=1, epochs=2) for key, value in self.expected_fit_result_with_weights_output_2.items(): @@ -506,7 +491,7 @@ class TestMetricsCorrectnessMultiIOTemporal(keras_parameterized.TestCase): }) eval_result = model.evaluate_generator( self.custom_generator_multi_io_temporal( - sample_weights=[None, self.sample_weight_2]), + sample_weights={'output_2': self.sample_weight_2}), steps=2) self.assertAllClose(eval_result, self.expected_batch_result_with_weights_output_2, @@ -517,9 +502,7 @@ class TestMetricsCorrectnessMultiIOTemporal(keras_parameterized.TestCase): def test_error_on_fit_with_class_weight(self): def _train_and_assert(model): - with self.assertRaisesRegex( - ValueError, - r'`class_weight` not supported for 3\+ dimensional targets.'): + with self.assertRaises(ValueError): model.fit([self.x, self.x], [self.y1, self.y2], class_weight={'output_1': { .5: .5, diff --git a/tensorflow/python/keras/utils/composite_tensor_support_test.py b/tensorflow/python/keras/utils/composite_tensor_support_test.py index 87e70a239ce..13af9590e80 100644 --- a/tensorflow/python/keras/utils/composite_tensor_support_test.py +++ b/tensorflow/python/keras/utils/composite_tensor_support_test.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test +from tensorflow.python.util import nest # Define test-only Layer classes to validate passing Sparse and Ragged tensors @@ -57,6 +58,10 @@ class ToDense(Layer): self._supports_ragged_inputs = True def call(self, inputs): + if isinstance(inputs, dict): # Dicts are no longer flattened. + # Always a single element in these tests. + inputs = nest.flatten(inputs)[0] + if isinstance(inputs, ragged_tensor.RaggedTensor): output = inputs.to_tensor(default_value=self._default_value) elif isinstance(inputs, sparse_tensor.SparseTensor): @@ -610,80 +615,6 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase, result = model.predict(input_data, **kwargs) self.assertAllEqual(expected_output, result) - def test_ragged_tensor_input_with_wrong_ragged_rank_fails( - self, use_dict, use_dataset): - # Define some input data that will NOT match the input shape spec. - data = [(ragged_factory_ops.constant([[[1, 0]], [[2, 3]]]), None)] - - # Prepare the model to test. - input_shape = (None, 2) # RaggedTensorInputTest uses (None, None). - input_name = get_input_name(use_dict) - model_input = input_layer.Input( - shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32) - layers = [ToDense(default_value=-1)] - model = get_model_from_layers_with_input(layers, model_input=model_input) - model.compile( - optimizer="sgd", - loss="mse", - metrics=["accuracy"], - **get_test_mode_kwargs()) - - # Define some input data with the wrong ragged rank - for data_element in data: - input_data, _ = prepare_inputs( - data_element, - use_dict, - use_dataset, - action="predict", - input_name=input_name) - with self.assertRaisesRegex(ValueError, ".*don't have the same nested.*"): - _ = model.predict(input_data) - - -# CompositeTensor shape validation only happens in non-eager modes and in non- -# subclassed models, so we run a separate parameterized test for them. -@keras_parameterized.run_with_all_model_types(exclude_models=["subclass"]) -@keras_parameterized.run_all_keras_modes(always_skip_eager=True) -class SparseTensorInputValidationTest(keras_parameterized.TestCase): - - def test_sparse_scipy_input_checks_shape(self): - model_input = input_layer.Input(shape=(3,), sparse=True, dtype=dtypes.int32) - layers = [ToDense(default_value=-1)] - model = get_model_from_layers_with_input(layers, model_input=model_input) - - input_data = scipy.sparse.coo_matrix(([1, 2, 3], ([0, 1, 1], [0, 0, 1])), - shape=[2, 4]) - with self.assertRaisesRegex(ValueError, ".*got array with shape.*"): - _ = model.predict(input_data) - - def test_sparse_tensor_input_checks_shapes(self): - # Create a model that accepts a sparse input and converts the sparse tensor - # back to a dense tensor. - model_input = input_layer.Input( - shape=(2, None), sparse=True, dtype=dtypes.int32) - layers = [ToDense(default_value=-1)] - model = get_model_from_layers_with_input(layers, model_input=model_input) - - # Define some input data. - input_data = sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]], - [1, 2, 3], [2, 1, 3]) - kwargs = get_kwargs(use_dataset=False) - with self.assertRaisesRegex(ValueError, ".*got array with shape.*"): - _ = model.predict(input_data, **kwargs) - - def test_ragged_tensor_input_with_wrong_value_shape(self): - # Create a model that accepts a ragged input and converts it to dense. - model_input = input_layer.Input( - shape=(None, 4), ragged=True, dtype=dtypes.int32) - layers = [ToDense(default_value=-1)] - model = get_model_from_layers_with_input(layers, model_input=model_input) - - # Define some input data with the wrong ragged rank - input_data = ragged_factory_ops.constant([[[1, 0]], [[2, 3]]], - ragged_rank=1) - with self.assertRaisesRegex(ValueError, ".*got array with shape.*"): - _ = model.predict(input_data) - @keras_parameterized.run_with_all_model_types() @keras_parameterized.run_all_keras_modes(always_skip_v1=True) @@ -707,7 +638,7 @@ class CompositeTensorModelPredictTest(keras_parameterized.TestCase): sparse_input = sparse_tensor.SparseTensor( # A two-row matrix indices=[(0, 0), (0, 1), (0, 2), (5, 0), (5, 1), (5, 2)], - values=[1, 1, 1, 1, 1, 1], + values=[1., 1., 1., 1., 1., 1.], dense_shape=(6, 3)) shape = model(sparse_input).shape @@ -736,37 +667,5 @@ class CompositeTensorModelPredictTest(keras_parameterized.TestCase): self.assertEqual((2, None, 5), self._normalize_shape(shape)) -@keras_parameterized.run_with_all_model_types( - exclude_models=["functional"]) -@keras_parameterized.run_all_keras_modes -class UndefinedCompositeTensorInputsTest(keras_parameterized.TestCase): - - def test_subclass_implicit_sparse_inputs_fails(self): - # Create a model that accepts a sparse input and converts the sparse tensor - # back to a dense tensor. - layers = [ToDense(default_value=-1)] - model = testing_utils.get_model_from_layers(layers) - - # Define some input data. - input_data = sparse_tensor.SparseTensor([[0, 0], [1, 0], [1, 1]], [1, 2, 3], - [2, 3]) - kwargs = get_kwargs(False) - with self.assertRaisesRegex( - ValueError, ".*All SparseTensor and RaggedTensor inputs .*"): - _ = model.predict(input_data, **kwargs) - - def test_subclass_implicit_sparse_scipy_inputs_fails(self): - # Create a model that accepts a sparse input and converts the sparse tensor - # back to a dense tensor. - layers = [ToDense(default_value=-1)] - model = testing_utils.get_model_from_layers(layers) - - # Define some input data. - input_data = scipy.sparse.coo_matrix(([1, 2, 3], ([0, 1, 1], [0, 0, 1])), - shape=[2, 3]) - with self.assertRaisesRegex(ValueError, ".*either a single array.*"): - _ = model.predict(input_data) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 801f5ad99bc..edbfed6d776 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -539,7 +539,7 @@ class Progbar(object): self._start = time.time() self._last_update = 0 - def update(self, current, values=None): + def update(self, current, values=None, finalize=None): """Updates the progress bar. Arguments: @@ -547,7 +547,15 @@ class Progbar(object): values: List of tuples: `(name, value_for_last_step)`. If `name` is in `stateful_metrics`, `value_for_last_step` will be displayed as-is. Else, an average of the metric over time will be displayed. + finalize: Whether this is the last update for the progress bar. If + `None`, defaults to `current >= self.target`. """ + if finalize is None: + if self.target is None: + finalize = False + else: + finalize = current >= self.target + values = values or [] for k, v in values: if k not in self._values_order: @@ -573,8 +581,7 @@ class Progbar(object): now = time.time() info = ' - %.0fs' % (now - self._start) if self.verbose == 1: - if (now - self._last_update < self.interval and - self.target is not None and current < self.target): + if now - self._last_update < self.interval and not finalize: return prev_total_width = self._total_width @@ -607,7 +614,15 @@ class Progbar(object): time_per_unit = (now - self._start) / current else: time_per_unit = 0 - if self.target is not None and current < self.target: + + if self.target is None or finalize: + if time_per_unit >= 1 or time_per_unit == 0: + info += ' %.0fs/%s' % (time_per_unit, self.unit_name) + elif time_per_unit >= 1e-3: + info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) + else: + info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) + else: eta = time_per_unit * (self.target - current) if eta > 3600: eta_format = '%d:%02d:%02d' % (eta // 3600, @@ -618,13 +633,6 @@ class Progbar(object): eta_format = '%ds' % eta info = ' - ETA: %s' % eta_format - else: - if time_per_unit >= 1 or time_per_unit == 0: - info += ' %.0fs/%s' % (time_per_unit, self.unit_name) - elif time_per_unit >= 1e-3: - info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) - else: - info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) for k in self._values_order: info += ' - %s:' % k @@ -641,14 +649,14 @@ class Progbar(object): if prev_total_width > self._total_width: info += (' ' * (prev_total_width - self._total_width)) - if self.target is not None and current >= self.target: + if finalize: info += '\n' sys.stdout.write(info) sys.stdout.flush() elif self.verbose == 2: - if self.target is not None and current >= self.target: + if finalize: numdigits = int(np.log10(self.target)) + 1 count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) info = count + info diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index dcb42abf687..1dfd2f517c6 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -258,7 +258,6 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): else: print_fn('_' * line_length) - model._check_trainable_weights_consistency() if hasattr(model, '_collected_trainable_weights'): trainable_count = count_params(model._collected_trainable_weights) else: diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 1a85b838be6..57b5c605db9 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import six from tensorflow.python.data.experimental.ops import cardinality @@ -464,3 +465,27 @@ def dataset_is_infinite(dataset): else: dataset_size = K.get_session().run(cardinality.cardinality(dataset)) return dataset_size == cardinality.INFINITE + + +def get_tensor_spec(t, dynamic_batch=False, name=None): + """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`.""" + if isinstance(t, type_spec.TypeSpec): + spec = t + elif isinstance(t, composite_tensor.CompositeTensor): + # TODO(b/148821952): Should these specs have a name attr? + spec = t._type_spec # pylint: disable=protected-access + elif hasattr(t, 'shape') and hasattr(t, 'dtype'): + spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) + else: + return None # Allow non-Tensors to pass through. + + if not dynamic_batch: + return spec + + dynamic_batch_spec = copy.deepcopy(spec) + # RaggedTensorSpec only has a private _shape. + shape = dynamic_batch_spec._shape.as_list() # pylint: disable=protected-access + if shape: + shape[0] = None + dynamic_batch_spec._shape = tensor_shape.TensorShape(shape) # pylint: disable=protected-access + return dynamic_batch_spec diff --git a/tensorflow/python/keras/utils/tf_utils_test.py b/tensorflow/python/keras/utils/tf_utils_test.py index 392ab7d59a5..2f87af2ef06 100644 --- a/tensorflow/python/keras/utils/tf_utils_test.py +++ b/tensorflow/python/keras/utils/tf_utils_test.py @@ -79,6 +79,8 @@ class TestIsSymbolicTensor(test.TestCase): self.assertTrue(tf_utils.is_symbolic_tensor(CustomClass())) def test_enables_nontensor_plumbing(self): + if context.executing_eagerly(): + self.skipTest('`compile` functionality changed.') # Setup. class Foo(object): diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 28741d82bbc..33abd5c664e 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -552,7 +552,7 @@ class Layer(base_layer.Layer): return outputs def __deepcopy__(self, memo): - no_copy = set(['_graph', '_thread_local']) + no_copy = set(['_graph', '_thread_local', '_metrics_lock']) shallow_copy = set(['_scope', '_always_reuse_variable_scope']) cls = self.__class__ result = cls.__new__(cls) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt index a823b172ace..440e6c8a5c4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt @@ -97,10 +97,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -195,7 +191,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -203,7 +199,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt index 77b0239181b..eee65bc6db4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt @@ -98,10 +98,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -200,7 +196,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -208,7 +204,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt index 4a6a96e3952..c64a1881f88 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt @@ -98,10 +98,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -196,7 +192,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -204,7 +200,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt index 4c44837ef5f..238701103f7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt @@ -98,10 +98,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -196,7 +192,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -204,7 +200,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt index c63d5ff76b3..788efce0063 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt @@ -97,10 +97,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -195,7 +191,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -203,7 +199,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt index 6ca4124190d..6166b16f964 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt @@ -98,10 +98,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -200,7 +196,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -208,7 +204,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-progbar.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-progbar.pbtxt index 8177cc71ed3..d7882583515 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-progbar.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-progbar.pbtxt @@ -12,6 +12,6 @@ tf_class { } member_method { name: "update" - argspec: "args=[\'self\', \'current\', \'values\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'current\', \'values\', \'finalize\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt index a823b172ace..440e6c8a5c4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt @@ -97,10 +97,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -195,7 +191,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -203,7 +199,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt index 77b0239181b..eee65bc6db4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt @@ -98,10 +98,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -200,7 +196,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -208,7 +204,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt index 4a6a96e3952..c64a1881f88 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt @@ -98,10 +98,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -196,7 +192,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -204,7 +200,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt index 4c44837ef5f..238701103f7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt @@ -98,10 +98,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -196,7 +192,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -204,7 +200,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt index c63d5ff76b3..788efce0063 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt @@ -97,10 +97,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -195,7 +191,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -203,7 +199,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt index 6ca4124190d..6166b16f964 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt @@ -98,10 +98,6 @@ tf_class { name: "run_eagerly" mtype: "" } - member { - name: "sample_weights" - mtype: "" - } member { name: "state_updates" mtype: "" @@ -200,7 +196,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'return_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'False\'], " } member_method { name: "evaluate_generator" @@ -208,7 +204,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_batch_size\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-progbar.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-progbar.pbtxt index 8177cc71ed3..d7882583515 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-progbar.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-progbar.pbtxt @@ -12,6 +12,6 @@ tf_class { } member_method { name: "update" - argspec: "args=[\'self\', \'current\', \'values\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'current\', \'values\', \'finalize\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } }