Sort the keras predict result based on the batch index in v2 code path.
PiperOrigin-RevId: 279158153 Change-Id: I76638fa6613cb0de1f8f6494c960debbee3055ce
This commit is contained in:
parent
37cdb806bc
commit
8e0f8c4d8c
@ -1612,6 +1612,23 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "training_v2_utils_test",
|
||||
size = "medium",
|
||||
srcs = ["engine/training_v2_utils_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
],
|
||||
tags = [
|
||||
"no_oss", # TODO(b/135021748) reenable
|
||||
"notsan",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_subclassing_test_util",
|
||||
srcs = ["model_subclassing_test_util.py"],
|
||||
|
@ -164,7 +164,8 @@ def run_one_epoch(model,
|
||||
batch_logs['size'] = data_batch_size
|
||||
current_batch_size = data_batch_size
|
||||
else:
|
||||
batch_outs = _aggregate_predict_results(strategy, batch_outs, model)
|
||||
batch_outs = training_v2_utils._aggregate_predict_results(
|
||||
strategy, batch_outs, model)
|
||||
|
||||
if step == 0:
|
||||
aggregator.create(batch_outs)
|
||||
@ -435,6 +436,8 @@ class Loop(training_utils.TrainingLoop):
|
||||
|
||||
# tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
|
||||
training_context = TrainingContext()
|
||||
if mode == ModeKeys.PREDICT:
|
||||
dataset = training_v2_utils._add_batch_index_to_element(dataset)
|
||||
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
|
||||
execution_function = training_v2_utils._get_or_make_execution_function(
|
||||
@ -708,18 +711,6 @@ def _get_total_number_of_samples(adapter):
|
||||
return total_sample
|
||||
|
||||
|
||||
def _aggregate_predict_results(strategy, batch_outs, model):
|
||||
if not isinstance(batch_outs, list):
|
||||
batch_outs = [batch_outs]
|
||||
total_batch_outs = []
|
||||
for i in range(len(model.outputs)):
|
||||
num_replicas = strategy.num_replicas_in_sync
|
||||
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
|
||||
total_batch_outs.append(
|
||||
dist_utils.concat_along_batch_dimension(nest.flatten(nested_outs)))
|
||||
return total_batch_outs
|
||||
|
||||
|
||||
def _print_train_info(total_samples, steps, val_total_samples, val_steps):
|
||||
increment = 'samples' if total_samples else 'steps'
|
||||
conjunction = 'on' if total_samples else 'for'
|
||||
|
@ -26,9 +26,13 @@ from __future__ import print_function
|
||||
import collections
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework.ops import composite_tensor
|
||||
from tensorflow.python.keras import backend
|
||||
@ -38,6 +42,7 @@ from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -62,14 +67,14 @@ def _make_execution_function(model, mode):
|
||||
|
||||
def distributed_function(input_iterator):
|
||||
"""A single step of the distributed execution across replicas."""
|
||||
x, y, sample_weights = _prepare_feed_values(model, input_iterator, mode)
|
||||
args = _prepare_feed_values(model, input_iterator, mode)
|
||||
# Call `Model.{train,test,predict}_on_batch` on every replica passing
|
||||
# PerReplicas as arguments. On every replica inside this call, each
|
||||
# PerReplica object will return the value for that replica. The outputs
|
||||
# are PerReplicas too.
|
||||
strategy = distribution_strategy_context.get_strategy()
|
||||
outputs = strategy.experimental_run_v2(
|
||||
per_replica_function, args=(x, y, sample_weights))
|
||||
per_replica_function, args=args)
|
||||
# Out of PerReplica outputs reduce or pick values to return.
|
||||
all_outputs = dist_utils.unwrap_output_dict(
|
||||
strategy, outputs, mode)
|
||||
@ -108,7 +113,11 @@ def _prepare_feed_values(model, inputs, mode):
|
||||
(tuple, targets, sample_weights) may be a python list. Single values
|
||||
for inputs will always be wrapped in lists.
|
||||
"""
|
||||
inputs, targets, sample_weights = _get_input_from_iterator(inputs)
|
||||
# For predict, we need to extract the manually added batch_index first.
|
||||
with_batch_index = mode == ModeKeys.PREDICT
|
||||
|
||||
inputs, targets, sample_weights, batch_index = _get_input_from_iterator(
|
||||
inputs, with_batch_index)
|
||||
|
||||
# When the inputs are dict, then we want to flatten it in the same order as
|
||||
# the input layers, such that the data are fed into the input layers in the
|
||||
@ -123,12 +132,18 @@ def _prepare_feed_values(model, inputs, mode):
|
||||
targets = []
|
||||
|
||||
ins = [inputs, targets, sample_weights]
|
||||
if batch_index is not None:
|
||||
ins.append(batch_index)
|
||||
return tuple(ins)
|
||||
|
||||
|
||||
def _get_input_from_iterator(iterator):
|
||||
def _get_input_from_iterator(iterator, with_batch_index=False):
|
||||
"""Get elements from the iterator and verify the input shape and type."""
|
||||
next_element = next(iterator)
|
||||
if with_batch_index:
|
||||
batch_index, next_element = next_element
|
||||
else:
|
||||
batch_index = None
|
||||
|
||||
if (tensor_util.is_tensor(next_element) or
|
||||
isinstance(next_element, (dict, composite_tensor.CompositeTensor))):
|
||||
@ -146,7 +161,7 @@ def _get_input_from_iterator(iterator):
|
||||
# Validate that all the elements in x and y are of the same type and shape.
|
||||
dist_utils.validate_distributed_dataset_inputs(
|
||||
distribution_strategy_context.get_strategy(), x, y, sample_weights)
|
||||
return x, y, sample_weights
|
||||
return x, y, sample_weights, batch_index
|
||||
|
||||
|
||||
def _make_replica_execution_function(model, mode):
|
||||
@ -156,9 +171,11 @@ def _make_replica_execution_function(model, mode):
|
||||
elif mode == ModeKeys.TEST:
|
||||
func = functools.partial(test_on_batch, model)
|
||||
else:
|
||||
def _predict_on_batch(x, y=None, sample_weights=None):
|
||||
def _predict_on_batch(x, y=None, sample_weights=None, batch_index=None):
|
||||
del y, sample_weights
|
||||
return predict_on_batch(model, x)
|
||||
# Note that the x and batch_index is already per-replica value.
|
||||
result = predict_on_batch(model, x)
|
||||
return (batch_index, result)
|
||||
|
||||
func = _predict_on_batch
|
||||
|
||||
@ -170,6 +187,105 @@ def _make_replica_execution_function(model, mode):
|
||||
return func
|
||||
|
||||
|
||||
def _aggregate_predict_results(strategy, batch_outs, model):
|
||||
"""Aggregate the prediction result from each replica."""
|
||||
num_replicas = strategy.num_replicas_in_sync
|
||||
num_outputs = len(model.outputs)
|
||||
|
||||
if not isinstance(batch_outs, list):
|
||||
batch_outs = [batch_outs]
|
||||
|
||||
# batch_outs is in following structure:
|
||||
# [
|
||||
# replica_1_batch_index, replica_2_batch_index, ...., replica_x_batch_index,
|
||||
# replica_1_output_1, replica_2_output_1, ...., replica_x_output_1,
|
||||
# ......
|
||||
# replica_1_output_y, replica_2_output_y, ...., replica_x_output_y,
|
||||
# ]
|
||||
batch_index, batch_outs = batch_outs[:num_replicas], batch_outs[num_replicas:]
|
||||
batch_index = dist_utils.concat_along_batch_dimension(batch_index)
|
||||
# Reorder the batch_index for it to do proper gather. Eg, if the original
|
||||
# index is [0, 2, 4, 6, 1, 3, 5, 7], then the index for gather should be
|
||||
# [0, 4, 1, 5, 2, 6, 3, 7].
|
||||
batch_index = np.argsort(batch_index)
|
||||
# Only need to gather if the batch index is not sorted.
|
||||
need_batch_index_gather = np.any(np.diff(batch_index) < 0)
|
||||
|
||||
total_batch_outs = []
|
||||
for i in range(num_outputs):
|
||||
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
|
||||
per_output_result = dist_utils.concat_along_batch_dimension(
|
||||
nest.flatten(nested_outs))
|
||||
|
||||
if need_batch_index_gather:
|
||||
if _get_batch_size(per_output_result).numpy() == len(batch_index):
|
||||
# Skip the gather if the output has a different batch size than the
|
||||
# batch_index. There will be some error handling in upper layer.
|
||||
per_output_result = _gather_result_by_index(per_output_result,
|
||||
batch_index)
|
||||
total_batch_outs.append(per_output_result)
|
||||
return total_batch_outs
|
||||
|
||||
|
||||
def _gather_result_by_index(input_tensor, batch_index):
|
||||
"""Handle the data element gather for different type of tensor."""
|
||||
if isinstance(input_tensor, sparse_tensor.SparseTensor):
|
||||
# For sparse tensor, both the index and value component should be gathered.
|
||||
return sparse_tensor.SparseTensor(
|
||||
indices=array_ops.gather_v2(input_tensor.indices, batch_index),
|
||||
values=array_ops.gather_v2(input_tensor.values, batch_index),
|
||||
dense_shape=input_tensor.dense_shape
|
||||
)
|
||||
# For both ragged tensor or eager tensor or np array, tf.gather should do the
|
||||
# correct thing.
|
||||
elif isinstance(input_tensor, ragged_tensor.RaggedTensor):
|
||||
return array_ops.gather_v2(input_tensor, batch_index)
|
||||
elif isinstance(input_tensor, (ops.EagerTensor, np.ndarray)):
|
||||
return array_ops.gather_v2(input_tensor, batch_index).numpy()
|
||||
else:
|
||||
raise ValueError('Unexpected type {} encountered when gathering '
|
||||
'batch slices.'.format(input_tensor))
|
||||
|
||||
|
||||
def _get_batch_size(inputs):
|
||||
first_inputs = nest.flatten(inputs)[0]
|
||||
if isinstance(first_inputs, ragged_tensor.RaggedTensor):
|
||||
return first_inputs.bounding_shape()[0]
|
||||
else:
|
||||
return array_ops.shape(first_inputs)[0]
|
||||
|
||||
|
||||
def _add_batch_index_to_element(dataset):
|
||||
"""Adding a new batch index field to the every element in the batch.
|
||||
|
||||
This is need in the model.predict() when running with multi-worker
|
||||
distribution strategy. When sharding/distributing a dataset, the continuity of
|
||||
the sharded dataset can't be easily ensured without performance sacrifice. It
|
||||
is fine to train and eval with the reordered data, but not for prediction. To
|
||||
solve this issue, Keras will add a batch index to each of the element in the
|
||||
dataset, which will then pass to pre-replica execution function. The real
|
||||
execution function will remove it before feeding the input to the model, and
|
||||
pre-replica function will then zip the index with the result. Finally Keras
|
||||
will sort the batch result based on the added batch-index field, remove it and
|
||||
return the sorted result.
|
||||
|
||||
Note that we didn't add single index to the per-replica batch, but to each of
|
||||
the element in the batch, since we can't ensure the data in pre-replica is
|
||||
continuous. Eg: model with 2 replica and predict with 4 elements per batch
|
||||
like [1, 2, 3, 4], it is possible to shard as [1, 2], [3, 4],
|
||||
or [1, 3], [2, 4].
|
||||
|
||||
Args:
|
||||
dataset: a dataset that is created by any of the data_adapter, with the
|
||||
element structure as (x, y, sample_weights).
|
||||
|
||||
Returns:
|
||||
a new dataset, with the element shape as
|
||||
(batch_index, (x, y, sample_weights)).
|
||||
"""
|
||||
return dataset.map(lambda *inp: (math_ops.range(_get_batch_size(inp)), inp))
|
||||
|
||||
|
||||
def train_on_batch(
|
||||
model,
|
||||
x,
|
||||
|
143
tensorflow/python/keras/engine/training_v2_utils_test.py
Normal file
143
tensorflow/python/keras/engine/training_v2_utils_test.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.python.keras.engine.training_v2_utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
|
||||
from tensorflow.python.keras.engine import training_v2_utils
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class AggregatePredictResultsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(AggregatePredictResultsTest, self).setUp()
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(3)
|
||||
self.num_replica = 3
|
||||
self.batch_size = 16
|
||||
self.dense_shape = (2, 3)
|
||||
self.total_sample = 2 * self.batch_size
|
||||
|
||||
mock_model = collections.namedtuple('Model', ['outputs'])
|
||||
self.mock_model = mock_model([1])
|
||||
|
||||
strategy = mirrored_strategy.MirroredStrategy(
|
||||
['/cpu:0', '/cpu:1', '/cpu:2'])
|
||||
|
||||
execution_function = lambda *inp: inp
|
||||
@def_function.function
|
||||
def predict_loop(batch):
|
||||
batch_result = strategy.experimental_run_v2(execution_function, batch)
|
||||
batch_result = dist_utils.unwrap_output_dict(
|
||||
strategy, batch_result, ModeKeys.PREDICT)
|
||||
# swap the order of replica 1 and 2, to mimic random order.
|
||||
batch_result[2], batch_result[1] = batch_result[1], batch_result[2]
|
||||
batch_result[5], batch_result[4] = batch_result[4], batch_result[5]
|
||||
return batch_result
|
||||
|
||||
self.strategy = strategy
|
||||
self.predict_loop = predict_loop
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=[1, 2],
|
||||
mode='eager'))
|
||||
def test_aggregate_predict_results_dense(self):
|
||||
dataset = dataset_ops.Dataset.range(self.total_sample)
|
||||
def dense_map_fn(i):
|
||||
# Mimic what we do for adding batch index
|
||||
return i, array_ops.fill(self.dense_shape, i)
|
||||
dense_dataset = dataset.map(dense_map_fn).batch(self.batch_size)
|
||||
distributed_data = self.strategy.experimental_distribute_dataset(
|
||||
dense_dataset)
|
||||
|
||||
start = 0
|
||||
for batch in distributed_data:
|
||||
batch_result = self.predict_loop(batch)
|
||||
final_result = training_v2_utils._aggregate_predict_results(
|
||||
self.strategy, batch_result, self.mock_model)
|
||||
|
||||
# Make sure the dense result is in a sorted order.
|
||||
expected_result = np.arange(
|
||||
start=start, stop=start+self.batch_size).reshape((-1, 1))
|
||||
expected_result = np.tile(expected_result, 6).reshape(
|
||||
(-1,) + self.dense_shape)
|
||||
self.assertAllClose(final_result[0], expected_result)
|
||||
start += self.batch_size
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=[1, 2],
|
||||
mode='eager'))
|
||||
def test_aggregate_predict_results_sparse(self):
|
||||
dataset = dataset_ops.Dataset.range(self.total_sample)
|
||||
def sparse_map_fn(i):
|
||||
return i, sparse_tensor.SparseTensor(
|
||||
indices=[(0, 0)],
|
||||
values=[i],
|
||||
dense_shape=self.dense_shape)
|
||||
sparse_dataset = dataset.map(sparse_map_fn).batch(self.batch_size)
|
||||
distributed_data = self.strategy.experimental_distribute_dataset(
|
||||
sparse_dataset)
|
||||
|
||||
start = 0
|
||||
for batch in distributed_data:
|
||||
batch_result = self.predict_loop(batch)
|
||||
final_result = training_v2_utils._aggregate_predict_results(
|
||||
self.strategy, batch_result, self.mock_model)
|
||||
|
||||
# Make sure the dense result is in a sorted order.
|
||||
expected_values = np.arange(start=start, stop=start+self.batch_size)
|
||||
self.assertAllClose(final_result[0].values, expected_values)
|
||||
start += self.batch_size
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=[1, 2],
|
||||
mode='eager'))
|
||||
def test_aggregate_predict_results_ragged(self):
|
||||
dataset = dataset_ops.Dataset.range(self.total_sample)
|
||||
def ragged_map_fn(i):
|
||||
return i, ragged_factory_ops.constant([[0], [], []], dtype=np.int64) + i
|
||||
ragged_dataset = dataset.map(ragged_map_fn).batch(self.batch_size)
|
||||
distributed_data = self.strategy.experimental_distribute_dataset(
|
||||
ragged_dataset)
|
||||
|
||||
start = 0
|
||||
for batch in distributed_data:
|
||||
batch_result = self.predict_loop(batch)
|
||||
final_result = training_v2_utils._aggregate_predict_results(
|
||||
self.strategy, batch_result, self.mock_model)
|
||||
|
||||
# Make sure the dense result is in a sorted order.
|
||||
expected_values = np.arange(start=start, stop=start+self.batch_size)
|
||||
self.assertAllClose(final_result[0].flat_values, expected_values)
|
||||
start += self.batch_size
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
x
Reference in New Issue
Block a user