Sort the keras predict result based on the batch index in v2 code path.

PiperOrigin-RevId: 279158153
Change-Id: I76638fa6613cb0de1f8f6494c960debbee3055ce
This commit is contained in:
Scott Zhu 2019-11-07 13:43:01 -08:00 committed by TensorFlower Gardener
parent 37cdb806bc
commit 8e0f8c4d8c
4 changed files with 287 additions and 20 deletions

View File

@ -1612,6 +1612,23 @@ tf_py_test(
],
)
tf_py_test(
name = "training_v2_utils_test",
size = "medium",
srcs = ["engine/training_v2_utils_test.py"],
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python/distribute:strategy_combinations",
],
tags = [
"no_oss", # TODO(b/135021748) reenable
"notsan",
],
)
py_library(
name = "model_subclassing_test_util",
srcs = ["model_subclassing_test_util.py"],

View File

@ -164,7 +164,8 @@ def run_one_epoch(model,
batch_logs['size'] = data_batch_size
current_batch_size = data_batch_size
else:
batch_outs = _aggregate_predict_results(strategy, batch_outs, model)
batch_outs = training_v2_utils._aggregate_predict_results(
strategy, batch_outs, model)
if step == 0:
aggregator.create(batch_outs)
@ -435,6 +436,8 @@ class Loop(training_utils.TrainingLoop):
# tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
training_context = TrainingContext()
if mode == ModeKeys.PREDICT:
dataset = training_v2_utils._add_batch_index_to_element(dataset)
dataset = strategy.experimental_distribute_dataset(dataset)
execution_function = training_v2_utils._get_or_make_execution_function(
@ -708,18 +711,6 @@ def _get_total_number_of_samples(adapter):
return total_sample
def _aggregate_predict_results(strategy, batch_outs, model):
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
total_batch_outs = []
for i in range(len(model.outputs)):
num_replicas = strategy.num_replicas_in_sync
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
total_batch_outs.append(
dist_utils.concat_along_batch_dimension(nest.flatten(nested_outs)))
return total_batch_outs
def _print_train_info(total_samples, steps, val_total_samples, val_steps):
increment = 'samples' if total_samples else 'steps'
conjunction = 'on' if total_samples else 'for'

View File

@ -26,9 +26,13 @@ from __future__ import print_function
import collections
import functools
import numpy as np
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework.ops import composite_tensor
from tensorflow.python.keras import backend
@ -38,6 +42,7 @@ from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import nest
@ -62,14 +67,14 @@ def _make_execution_function(model, mode):
def distributed_function(input_iterator):
"""A single step of the distributed execution across replicas."""
x, y, sample_weights = _prepare_feed_values(model, input_iterator, mode)
args = _prepare_feed_values(model, input_iterator, mode)
# Call `Model.{train,test,predict}_on_batch` on every replica passing
# PerReplicas as arguments. On every replica inside this call, each
# PerReplica object will return the value for that replica. The outputs
# are PerReplicas too.
strategy = distribution_strategy_context.get_strategy()
outputs = strategy.experimental_run_v2(
per_replica_function, args=(x, y, sample_weights))
per_replica_function, args=args)
# Out of PerReplica outputs reduce or pick values to return.
all_outputs = dist_utils.unwrap_output_dict(
strategy, outputs, mode)
@ -108,7 +113,11 @@ def _prepare_feed_values(model, inputs, mode):
(tuple, targets, sample_weights) may be a python list. Single values
for inputs will always be wrapped in lists.
"""
inputs, targets, sample_weights = _get_input_from_iterator(inputs)
# For predict, we need to extract the manually added batch_index first.
with_batch_index = mode == ModeKeys.PREDICT
inputs, targets, sample_weights, batch_index = _get_input_from_iterator(
inputs, with_batch_index)
# When the inputs are dict, then we want to flatten it in the same order as
# the input layers, such that the data are fed into the input layers in the
@ -123,12 +132,18 @@ def _prepare_feed_values(model, inputs, mode):
targets = []
ins = [inputs, targets, sample_weights]
if batch_index is not None:
ins.append(batch_index)
return tuple(ins)
def _get_input_from_iterator(iterator):
def _get_input_from_iterator(iterator, with_batch_index=False):
"""Get elements from the iterator and verify the input shape and type."""
next_element = next(iterator)
if with_batch_index:
batch_index, next_element = next_element
else:
batch_index = None
if (tensor_util.is_tensor(next_element) or
isinstance(next_element, (dict, composite_tensor.CompositeTensor))):
@ -146,7 +161,7 @@ def _get_input_from_iterator(iterator):
# Validate that all the elements in x and y are of the same type and shape.
dist_utils.validate_distributed_dataset_inputs(
distribution_strategy_context.get_strategy(), x, y, sample_weights)
return x, y, sample_weights
return x, y, sample_weights, batch_index
def _make_replica_execution_function(model, mode):
@ -156,9 +171,11 @@ def _make_replica_execution_function(model, mode):
elif mode == ModeKeys.TEST:
func = functools.partial(test_on_batch, model)
else:
def _predict_on_batch(x, y=None, sample_weights=None):
def _predict_on_batch(x, y=None, sample_weights=None, batch_index=None):
del y, sample_weights
return predict_on_batch(model, x)
# Note that the x and batch_index is already per-replica value.
result = predict_on_batch(model, x)
return (batch_index, result)
func = _predict_on_batch
@ -170,6 +187,105 @@ def _make_replica_execution_function(model, mode):
return func
def _aggregate_predict_results(strategy, batch_outs, model):
"""Aggregate the prediction result from each replica."""
num_replicas = strategy.num_replicas_in_sync
num_outputs = len(model.outputs)
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
# batch_outs is in following structure:
# [
# replica_1_batch_index, replica_2_batch_index, ...., replica_x_batch_index,
# replica_1_output_1, replica_2_output_1, ...., replica_x_output_1,
# ......
# replica_1_output_y, replica_2_output_y, ...., replica_x_output_y,
# ]
batch_index, batch_outs = batch_outs[:num_replicas], batch_outs[num_replicas:]
batch_index = dist_utils.concat_along_batch_dimension(batch_index)
# Reorder the batch_index for it to do proper gather. Eg, if the original
# index is [0, 2, 4, 6, 1, 3, 5, 7], then the index for gather should be
# [0, 4, 1, 5, 2, 6, 3, 7].
batch_index = np.argsort(batch_index)
# Only need to gather if the batch index is not sorted.
need_batch_index_gather = np.any(np.diff(batch_index) < 0)
total_batch_outs = []
for i in range(num_outputs):
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
per_output_result = dist_utils.concat_along_batch_dimension(
nest.flatten(nested_outs))
if need_batch_index_gather:
if _get_batch_size(per_output_result).numpy() == len(batch_index):
# Skip the gather if the output has a different batch size than the
# batch_index. There will be some error handling in upper layer.
per_output_result = _gather_result_by_index(per_output_result,
batch_index)
total_batch_outs.append(per_output_result)
return total_batch_outs
def _gather_result_by_index(input_tensor, batch_index):
"""Handle the data element gather for different type of tensor."""
if isinstance(input_tensor, sparse_tensor.SparseTensor):
# For sparse tensor, both the index and value component should be gathered.
return sparse_tensor.SparseTensor(
indices=array_ops.gather_v2(input_tensor.indices, batch_index),
values=array_ops.gather_v2(input_tensor.values, batch_index),
dense_shape=input_tensor.dense_shape
)
# For both ragged tensor or eager tensor or np array, tf.gather should do the
# correct thing.
elif isinstance(input_tensor, ragged_tensor.RaggedTensor):
return array_ops.gather_v2(input_tensor, batch_index)
elif isinstance(input_tensor, (ops.EagerTensor, np.ndarray)):
return array_ops.gather_v2(input_tensor, batch_index).numpy()
else:
raise ValueError('Unexpected type {} encountered when gathering '
'batch slices.'.format(input_tensor))
def _get_batch_size(inputs):
first_inputs = nest.flatten(inputs)[0]
if isinstance(first_inputs, ragged_tensor.RaggedTensor):
return first_inputs.bounding_shape()[0]
else:
return array_ops.shape(first_inputs)[0]
def _add_batch_index_to_element(dataset):
"""Adding a new batch index field to the every element in the batch.
This is need in the model.predict() when running with multi-worker
distribution strategy. When sharding/distributing a dataset, the continuity of
the sharded dataset can't be easily ensured without performance sacrifice. It
is fine to train and eval with the reordered data, but not for prediction. To
solve this issue, Keras will add a batch index to each of the element in the
dataset, which will then pass to pre-replica execution function. The real
execution function will remove it before feeding the input to the model, and
pre-replica function will then zip the index with the result. Finally Keras
will sort the batch result based on the added batch-index field, remove it and
return the sorted result.
Note that we didn't add single index to the per-replica batch, but to each of
the element in the batch, since we can't ensure the data in pre-replica is
continuous. Eg: model with 2 replica and predict with 4 elements per batch
like [1, 2, 3, 4], it is possible to shard as [1, 2], [3, 4],
or [1, 3], [2, 4].
Args:
dataset: a dataset that is created by any of the data_adapter, with the
element structure as (x, y, sample_weights).
Returns:
a new dataset, with the element shape as
(batch_index, (x, y, sample_weights)).
"""
return dataset.map(lambda *inp: (math_ops.range(_get_batch_size(inp)), inp))
def train_on_batch(
model,
x,

View File

@ -0,0 +1,143 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.keras.engine.training_v2_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import def_function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
from tensorflow.python.keras.engine import training_v2_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
class AggregatePredictResultsTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(AggregatePredictResultsTest, self).setUp()
strategy_combinations.set_virtual_cpus_to_at_least(3)
self.num_replica = 3
self.batch_size = 16
self.dense_shape = (2, 3)
self.total_sample = 2 * self.batch_size
mock_model = collections.namedtuple('Model', ['outputs'])
self.mock_model = mock_model([1])
strategy = mirrored_strategy.MirroredStrategy(
['/cpu:0', '/cpu:1', '/cpu:2'])
execution_function = lambda *inp: inp
@def_function.function
def predict_loop(batch):
batch_result = strategy.experimental_run_v2(execution_function, batch)
batch_result = dist_utils.unwrap_output_dict(
strategy, batch_result, ModeKeys.PREDICT)
# swap the order of replica 1 and 2, to mimic random order.
batch_result[2], batch_result[1] = batch_result[1], batch_result[2]
batch_result[5], batch_result[4] = batch_result[4], batch_result[5]
return batch_result
self.strategy = strategy
self.predict_loop = predict_loop
@combinations.generate(combinations.combine(tf_api_version=[1, 2],
mode='eager'))
def test_aggregate_predict_results_dense(self):
dataset = dataset_ops.Dataset.range(self.total_sample)
def dense_map_fn(i):
# Mimic what we do for adding batch index
return i, array_ops.fill(self.dense_shape, i)
dense_dataset = dataset.map(dense_map_fn).batch(self.batch_size)
distributed_data = self.strategy.experimental_distribute_dataset(
dense_dataset)
start = 0
for batch in distributed_data:
batch_result = self.predict_loop(batch)
final_result = training_v2_utils._aggregate_predict_results(
self.strategy, batch_result, self.mock_model)
# Make sure the dense result is in a sorted order.
expected_result = np.arange(
start=start, stop=start+self.batch_size).reshape((-1, 1))
expected_result = np.tile(expected_result, 6).reshape(
(-1,) + self.dense_shape)
self.assertAllClose(final_result[0], expected_result)
start += self.batch_size
@combinations.generate(combinations.combine(tf_api_version=[1, 2],
mode='eager'))
def test_aggregate_predict_results_sparse(self):
dataset = dataset_ops.Dataset.range(self.total_sample)
def sparse_map_fn(i):
return i, sparse_tensor.SparseTensor(
indices=[(0, 0)],
values=[i],
dense_shape=self.dense_shape)
sparse_dataset = dataset.map(sparse_map_fn).batch(self.batch_size)
distributed_data = self.strategy.experimental_distribute_dataset(
sparse_dataset)
start = 0
for batch in distributed_data:
batch_result = self.predict_loop(batch)
final_result = training_v2_utils._aggregate_predict_results(
self.strategy, batch_result, self.mock_model)
# Make sure the dense result is in a sorted order.
expected_values = np.arange(start=start, stop=start+self.batch_size)
self.assertAllClose(final_result[0].values, expected_values)
start += self.batch_size
@combinations.generate(combinations.combine(tf_api_version=[1, 2],
mode='eager'))
def test_aggregate_predict_results_ragged(self):
dataset = dataset_ops.Dataset.range(self.total_sample)
def ragged_map_fn(i):
return i, ragged_factory_ops.constant([[0], [], []], dtype=np.int64) + i
ragged_dataset = dataset.map(ragged_map_fn).batch(self.batch_size)
distributed_data = self.strategy.experimental_distribute_dataset(
ragged_dataset)
start = 0
for batch in distributed_data:
batch_result = self.predict_loop(batch)
final_result = training_v2_utils._aggregate_predict_results(
self.strategy, batch_result, self.mock_model)
# Make sure the dense result is in a sorted order.
expected_values = np.arange(start=start, stop=start+self.batch_size)
self.assertAllClose(final_result[0].flat_values, expected_values)
start += self.batch_size
if __name__ == '__main__':
test.main()