diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index a15a1c03143..05809accba3 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -1612,6 +1612,23 @@ tf_py_test( ], ) +tf_py_test( + name = "training_v2_utils_test", + size = "medium", + srcs = ["engine/training_v2_utils_test.py"], + additional_deps = [ + ":keras", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python/distribute:strategy_combinations", + ], + tags = [ + "no_oss", # TODO(b/135021748) reenable + "notsan", + ], +) + py_library( name = "model_subclassing_test_util", srcs = ["model_subclassing_test_util.py"], diff --git a/tensorflow/python/keras/engine/training_v2.py b/tensorflow/python/keras/engine/training_v2.py index c20e119752a..3025d186668 100644 --- a/tensorflow/python/keras/engine/training_v2.py +++ b/tensorflow/python/keras/engine/training_v2.py @@ -164,7 +164,8 @@ def run_one_epoch(model, batch_logs['size'] = data_batch_size current_batch_size = data_batch_size else: - batch_outs = _aggregate_predict_results(strategy, batch_outs, model) + batch_outs = training_v2_utils._aggregate_predict_results( + strategy, batch_outs, model) if step == 0: aggregator.create(batch_outs) @@ -435,6 +436,8 @@ class Loop(training_utils.TrainingLoop): # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch)) training_context = TrainingContext() + if mode == ModeKeys.PREDICT: + dataset = training_v2_utils._add_batch_index_to_element(dataset) dataset = strategy.experimental_distribute_dataset(dataset) execution_function = training_v2_utils._get_or_make_execution_function( @@ -708,18 +711,6 @@ def _get_total_number_of_samples(adapter): return total_sample -def _aggregate_predict_results(strategy, batch_outs, model): - if not isinstance(batch_outs, list): - batch_outs = [batch_outs] - total_batch_outs = [] - for i in range(len(model.outputs)): - num_replicas = strategy.num_replicas_in_sync - nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] - total_batch_outs.append( - dist_utils.concat_along_batch_dimension(nest.flatten(nested_outs))) - return total_batch_outs - - def _print_train_info(total_samples, steps, val_total_samples, val_steps): increment = 'samples' if total_samples else 'steps' conjunction = 'on' if total_samples else 'for' diff --git a/tensorflow/python/keras/engine/training_v2_utils.py b/tensorflow/python/keras/engine/training_v2_utils.py index 4cb4283954f..665a4a26391 100644 --- a/tensorflow/python/keras/engine/training_v2_utils.py +++ b/tensorflow/python/keras/engine/training_v2_utils.py @@ -26,9 +26,13 @@ from __future__ import print_function import collections import functools +import numpy as np + from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util from tensorflow.python.framework.ops import composite_tensor from tensorflow.python.keras import backend @@ -38,6 +42,7 @@ from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import nest @@ -62,14 +67,14 @@ def _make_execution_function(model, mode): def distributed_function(input_iterator): """A single step of the distributed execution across replicas.""" - x, y, sample_weights = _prepare_feed_values(model, input_iterator, mode) + args = _prepare_feed_values(model, input_iterator, mode) # Call `Model.{train,test,predict}_on_batch` on every replica passing # PerReplicas as arguments. On every replica inside this call, each # PerReplica object will return the value for that replica. The outputs # are PerReplicas too. strategy = distribution_strategy_context.get_strategy() outputs = strategy.experimental_run_v2( - per_replica_function, args=(x, y, sample_weights)) + per_replica_function, args=args) # Out of PerReplica outputs reduce or pick values to return. all_outputs = dist_utils.unwrap_output_dict( strategy, outputs, mode) @@ -108,7 +113,11 @@ def _prepare_feed_values(model, inputs, mode): (tuple, targets, sample_weights) may be a python list. Single values for inputs will always be wrapped in lists. """ - inputs, targets, sample_weights = _get_input_from_iterator(inputs) + # For predict, we need to extract the manually added batch_index first. + with_batch_index = mode == ModeKeys.PREDICT + + inputs, targets, sample_weights, batch_index = _get_input_from_iterator( + inputs, with_batch_index) # When the inputs are dict, then we want to flatten it in the same order as # the input layers, such that the data are fed into the input layers in the @@ -123,12 +132,18 @@ def _prepare_feed_values(model, inputs, mode): targets = [] ins = [inputs, targets, sample_weights] + if batch_index is not None: + ins.append(batch_index) return tuple(ins) -def _get_input_from_iterator(iterator): +def _get_input_from_iterator(iterator, with_batch_index=False): """Get elements from the iterator and verify the input shape and type.""" next_element = next(iterator) + if with_batch_index: + batch_index, next_element = next_element + else: + batch_index = None if (tensor_util.is_tensor(next_element) or isinstance(next_element, (dict, composite_tensor.CompositeTensor))): @@ -146,7 +161,7 @@ def _get_input_from_iterator(iterator): # Validate that all the elements in x and y are of the same type and shape. dist_utils.validate_distributed_dataset_inputs( distribution_strategy_context.get_strategy(), x, y, sample_weights) - return x, y, sample_weights + return x, y, sample_weights, batch_index def _make_replica_execution_function(model, mode): @@ -156,9 +171,11 @@ def _make_replica_execution_function(model, mode): elif mode == ModeKeys.TEST: func = functools.partial(test_on_batch, model) else: - def _predict_on_batch(x, y=None, sample_weights=None): + def _predict_on_batch(x, y=None, sample_weights=None, batch_index=None): del y, sample_weights - return predict_on_batch(model, x) + # Note that the x and batch_index is already per-replica value. + result = predict_on_batch(model, x) + return (batch_index, result) func = _predict_on_batch @@ -170,6 +187,105 @@ def _make_replica_execution_function(model, mode): return func +def _aggregate_predict_results(strategy, batch_outs, model): + """Aggregate the prediction result from each replica.""" + num_replicas = strategy.num_replicas_in_sync + num_outputs = len(model.outputs) + + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + + # batch_outs is in following structure: + # [ + # replica_1_batch_index, replica_2_batch_index, ...., replica_x_batch_index, + # replica_1_output_1, replica_2_output_1, ...., replica_x_output_1, + # ...... + # replica_1_output_y, replica_2_output_y, ...., replica_x_output_y, + # ] + batch_index, batch_outs = batch_outs[:num_replicas], batch_outs[num_replicas:] + batch_index = dist_utils.concat_along_batch_dimension(batch_index) + # Reorder the batch_index for it to do proper gather. Eg, if the original + # index is [0, 2, 4, 6, 1, 3, 5, 7], then the index for gather should be + # [0, 4, 1, 5, 2, 6, 3, 7]. + batch_index = np.argsort(batch_index) + # Only need to gather if the batch index is not sorted. + need_batch_index_gather = np.any(np.diff(batch_index) < 0) + + total_batch_outs = [] + for i in range(num_outputs): + nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] + per_output_result = dist_utils.concat_along_batch_dimension( + nest.flatten(nested_outs)) + + if need_batch_index_gather: + if _get_batch_size(per_output_result).numpy() == len(batch_index): + # Skip the gather if the output has a different batch size than the + # batch_index. There will be some error handling in upper layer. + per_output_result = _gather_result_by_index(per_output_result, + batch_index) + total_batch_outs.append(per_output_result) + return total_batch_outs + + +def _gather_result_by_index(input_tensor, batch_index): + """Handle the data element gather for different type of tensor.""" + if isinstance(input_tensor, sparse_tensor.SparseTensor): + # For sparse tensor, both the index and value component should be gathered. + return sparse_tensor.SparseTensor( + indices=array_ops.gather_v2(input_tensor.indices, batch_index), + values=array_ops.gather_v2(input_tensor.values, batch_index), + dense_shape=input_tensor.dense_shape + ) + # For both ragged tensor or eager tensor or np array, tf.gather should do the + # correct thing. + elif isinstance(input_tensor, ragged_tensor.RaggedTensor): + return array_ops.gather_v2(input_tensor, batch_index) + elif isinstance(input_tensor, (ops.EagerTensor, np.ndarray)): + return array_ops.gather_v2(input_tensor, batch_index).numpy() + else: + raise ValueError('Unexpected type {} encountered when gathering ' + 'batch slices.'.format(input_tensor)) + + +def _get_batch_size(inputs): + first_inputs = nest.flatten(inputs)[0] + if isinstance(first_inputs, ragged_tensor.RaggedTensor): + return first_inputs.bounding_shape()[0] + else: + return array_ops.shape(first_inputs)[0] + + +def _add_batch_index_to_element(dataset): + """Adding a new batch index field to the every element in the batch. + + This is need in the model.predict() when running with multi-worker + distribution strategy. When sharding/distributing a dataset, the continuity of + the sharded dataset can't be easily ensured without performance sacrifice. It + is fine to train and eval with the reordered data, but not for prediction. To + solve this issue, Keras will add a batch index to each of the element in the + dataset, which will then pass to pre-replica execution function. The real + execution function will remove it before feeding the input to the model, and + pre-replica function will then zip the index with the result. Finally Keras + will sort the batch result based on the added batch-index field, remove it and + return the sorted result. + + Note that we didn't add single index to the per-replica batch, but to each of + the element in the batch, since we can't ensure the data in pre-replica is + continuous. Eg: model with 2 replica and predict with 4 elements per batch + like [1, 2, 3, 4], it is possible to shard as [1, 2], [3, 4], + or [1, 3], [2, 4]. + + Args: + dataset: a dataset that is created by any of the data_adapter, with the + element structure as (x, y, sample_weights). + + Returns: + a new dataset, with the element shape as + (batch_index, (x, y, sample_weights)). + """ + return dataset.map(lambda *inp: (math_ops.range(_get_batch_size(inp)), inp)) + + def train_on_batch( model, x, diff --git a/tensorflow/python/keras/engine/training_v2_utils_test.py b/tensorflow/python/keras/engine/training_v2_utils_test.py new file mode 100644 index 00000000000..84f90fe9a82 --- /dev/null +++ b/tensorflow/python/keras/engine/training_v2_utils_test.py @@ -0,0 +1,143 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.keras.engine.training_v2_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from absl.testing import parameterized +import numpy as np + + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.eager import def_function +from tensorflow.python.framework import combinations +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils +from tensorflow.python.keras.engine import training_v2_utils +from tensorflow.python.keras.utils.mode_keys import ModeKeys +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.platform import test + + +class AggregatePredictResultsTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super(AggregatePredictResultsTest, self).setUp() + strategy_combinations.set_virtual_cpus_to_at_least(3) + self.num_replica = 3 + self.batch_size = 16 + self.dense_shape = (2, 3) + self.total_sample = 2 * self.batch_size + + mock_model = collections.namedtuple('Model', ['outputs']) + self.mock_model = mock_model([1]) + + strategy = mirrored_strategy.MirroredStrategy( + ['/cpu:0', '/cpu:1', '/cpu:2']) + + execution_function = lambda *inp: inp + @def_function.function + def predict_loop(batch): + batch_result = strategy.experimental_run_v2(execution_function, batch) + batch_result = dist_utils.unwrap_output_dict( + strategy, batch_result, ModeKeys.PREDICT) + # swap the order of replica 1 and 2, to mimic random order. + batch_result[2], batch_result[1] = batch_result[1], batch_result[2] + batch_result[5], batch_result[4] = batch_result[4], batch_result[5] + return batch_result + + self.strategy = strategy + self.predict_loop = predict_loop + + @combinations.generate(combinations.combine(tf_api_version=[1, 2], + mode='eager')) + def test_aggregate_predict_results_dense(self): + dataset = dataset_ops.Dataset.range(self.total_sample) + def dense_map_fn(i): + # Mimic what we do for adding batch index + return i, array_ops.fill(self.dense_shape, i) + dense_dataset = dataset.map(dense_map_fn).batch(self.batch_size) + distributed_data = self.strategy.experimental_distribute_dataset( + dense_dataset) + + start = 0 + for batch in distributed_data: + batch_result = self.predict_loop(batch) + final_result = training_v2_utils._aggregate_predict_results( + self.strategy, batch_result, self.mock_model) + + # Make sure the dense result is in a sorted order. + expected_result = np.arange( + start=start, stop=start+self.batch_size).reshape((-1, 1)) + expected_result = np.tile(expected_result, 6).reshape( + (-1,) + self.dense_shape) + self.assertAllClose(final_result[0], expected_result) + start += self.batch_size + + @combinations.generate(combinations.combine(tf_api_version=[1, 2], + mode='eager')) + def test_aggregate_predict_results_sparse(self): + dataset = dataset_ops.Dataset.range(self.total_sample) + def sparse_map_fn(i): + return i, sparse_tensor.SparseTensor( + indices=[(0, 0)], + values=[i], + dense_shape=self.dense_shape) + sparse_dataset = dataset.map(sparse_map_fn).batch(self.batch_size) + distributed_data = self.strategy.experimental_distribute_dataset( + sparse_dataset) + + start = 0 + for batch in distributed_data: + batch_result = self.predict_loop(batch) + final_result = training_v2_utils._aggregate_predict_results( + self.strategy, batch_result, self.mock_model) + + # Make sure the dense result is in a sorted order. + expected_values = np.arange(start=start, stop=start+self.batch_size) + self.assertAllClose(final_result[0].values, expected_values) + start += self.batch_size + + @combinations.generate(combinations.combine(tf_api_version=[1, 2], + mode='eager')) + def test_aggregate_predict_results_ragged(self): + dataset = dataset_ops.Dataset.range(self.total_sample) + def ragged_map_fn(i): + return i, ragged_factory_ops.constant([[0], [], []], dtype=np.int64) + i + ragged_dataset = dataset.map(ragged_map_fn).batch(self.batch_size) + distributed_data = self.strategy.experimental_distribute_dataset( + ragged_dataset) + + start = 0 + for batch in distributed_data: + batch_result = self.predict_loop(batch) + final_result = training_v2_utils._aggregate_predict_results( + self.strategy, batch_result, self.mock_model) + + # Make sure the dense result is in a sorted order. + expected_values = np.arange(start=start, stop=start+self.batch_size) + self.assertAllClose(final_result[0].flat_values, expected_values) + start += self.batch_size + + +if __name__ == '__main__': + test.main()