Add support for final partial batch in predict() for TPUStrategy

PiperOrigin-RevId: 231341652
This commit is contained in:
A. Unique TensorFlower 2019-01-28 21:18:14 -08:00 committed by TensorFlower Gardener
parent 01cf6233d6
commit dd7de7bbac
8 changed files with 369 additions and 48 deletions

View File

@ -566,6 +566,7 @@ py_library(
distribute_py_test(
name = "keras_test",
srcs = ["keras_test.py"],
full_precision = True,
main = "keras_test.py",
shard_count = 16,
tags = [

View File

@ -71,6 +71,18 @@ def simple_functional_model():
return model
def simple_multi_inputs_multi_outputs_model():
input_a = keras.layers.Input(shape=(16,), name='input_a')
input_b = keras.layers.Input(shape=(16,), name='input_b')
merged = keras.layers.concatenate([input_a, input_b], name='merge')
output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
model = keras.models.Model(
inputs=[input_a, input_b], outputs=[output_c, output_d])
return model
def multi_inputs_multi_outputs_model():
input_a = keras.layers.Input(shape=(16,), name='input_a')
input_b = keras.layers.Input(shape=(16,), name='input_b')
@ -671,6 +683,61 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
self.assertAllEqual([6, 7], outs[0].shape)
self.assertAllEqual([6, 7], outs[1].shape)
@combinations.generate(tpu_strategy_combinations())
def test_predict_with_partial_batch(self, distribution):
with self.cached_session():
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
with distribution.scope():
model_with_ds_strategy = get_model()
model_with_ds_strategy.compile(optimizer, loss)
cpu_model = get_model()
cpu_model.compile(optimizer, loss)
inputs = np.zeros((10, 3), dtype=np.float32)
# As sample size is 10, we batch by 4 so that the last batch is
# a partial batch. Also `fit()` using numpy array as inputs without
# distribution strategy uses entire sample as a single batch. As so,
# we remove parameters `batch_size` and `steps`.
cpu_model.set_weights(model_with_ds_strategy.get_weights())
self.assertAllClose(
model_with_ds_strategy.predict(inputs, batch_size=4, steps=3),
cpu_model.predict(inputs),
atol=1e-5, rtol=1e-5)
@combinations.generate(tpu_strategy_combinations())
def test_predict_multi_output_model_with_partial_batch(
self, distribution):
with self.cached_session():
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
with distribution.scope():
model_with_ds_strategy = simple_multi_inputs_multi_outputs_model()
model_with_ds_strategy.compile(optimizer, loss)
cpu_model = simple_multi_inputs_multi_outputs_model()
cpu_model.compile(optimizer, loss)
input_data, _ = get_multi_inputs_multi_outputs_data()
input_dict = {
'input_a': input_data['input_a'],
'input_b': input_data['input_b'],
}
# As sample size is 200, we batch by 18 so that the last batch is
# a partial batch. Also `fit()` using numpy array as inputs without
# distribution strategy uses entire sample as a single batch. As so,
# we remove parameters `batch_size` and `steps`.
cpu_model.set_weights(model_with_ds_strategy.get_weights())
self.assertAllClose(
model_with_ds_strategy.predict(input_dict, batch_size=18, steps=12),
cpu_model.predict(input_dict),
atol=1e-4, rtol=1e-4)
class TestDistributionStrategyWithDatasets(test.TestCase,
parameterized.TestCase):
@ -961,6 +1028,64 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr))
@combinations.generate(tpu_strategy_combinations())
def test_predict_with_dataset_with_partial_batch(self, distribution):
with self.cached_session():
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
with distribution.scope():
model_with_ds_strategy = get_model()
model_with_ds_strategy.compile(optimizer, loss)
cpu_model = get_model()
cpu_model.compile(optimizer, loss)
inputs = np.zeros((10, 3), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs))
# As sample size is 10, we batch by 4 so that the last batch is
# a partial batch.
dataset_with_partial_batch = dataset.batch(4)
cpu_model.set_weights(model_with_ds_strategy.get_weights())
self.assertAllClose(
model_with_ds_strategy.predict(dataset_with_partial_batch, steps=3),
cpu_model.predict(dataset_with_partial_batch, steps=3),
atol=1e-5, rtol=1e-5)
@combinations.generate(tpu_strategy_combinations())
def test_predict_multi_output_model_with_dataset_with_partial_batch(
self, distribution):
with self.cached_session():
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
with distribution.scope():
model_with_ds_strategy = simple_multi_inputs_multi_outputs_model()
model_with_ds_strategy.compile(optimizer, loss)
cpu_model = simple_multi_inputs_multi_outputs_model()
cpu_model.compile(optimizer, loss)
input_data, _ = get_multi_inputs_multi_outputs_data()
input_dict = {
'input_a': input_data['input_a'],
'input_b': input_data['input_b'],
}
dataset = dataset_ops.Dataset.from_tensor_slices(input_dict)
# As sample size is 200, we batch by 18 using 12 steps per epoch so
# that the last batch is a partial batch.
dataset_with_partial_batch = dataset.batch(18)
cpu_model.set_weights(model_with_ds_strategy.get_weights())
self.assertAllClose(
model_with_ds_strategy.predict(dataset_with_partial_batch, steps=12),
cpu_model.predict(dataset_with_partial_batch, steps=12),
atol=1e-4, rtol=1e-4)
class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):

View File

@ -321,31 +321,35 @@ class _SingleWorkerDatasetIterator(object):
return self._iterator.output_types
def _split_dataset_batch(dataset, split_batch_by):
"""Divide a batch-ed dataset's batches into smaller batches."""
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_batched_dataset(d):
"""Get the underlying batch dataset from the dataset object."""
# pylint: disable=protected-access
def _get_batch_dataset(d):
"""Get the underlying batch dataset from the dataset object."""
if isinstance(d, dataset_ops.DatasetV1Adapter):
d = d._dataset
if isinstance(d, dataset_ops.DatasetV1Adapter):
d = d._dataset
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
return d
elif isinstance(d, dataset_ops.PrefetchDataset):
return _get_batch_dataset(d._input_dataset)
raise ValueError(
"Unable to get batched dataset from the input dataset. `batch` "
"`map_and_batch` need to be the last operations on the dataset. "
"The batch operations can be followed by a prefetch.")
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
return d
elif isinstance(d, dataset_ops.PrefetchDataset):
return _get_batched_dataset(d._input_dataset)
batched_dataset = _get_batch_dataset(dataset)
if isinstance(batched_dataset, dataset_ops.BatchDataset):
batch_size = batched_dataset._batch_size
drop_remainder = batched_dataset._drop_remainder
elif isinstance(batched_dataset, batching._MapAndBatchDataset):
batch_size = batched_dataset._batch_size_t
drop_remainder = batched_dataset._drop_remainder_t
raise ValueError(
"Unable to get batched dataset from the input dataset. `batch` "
"`map_and_batch` need to be the last operations on the dataset. "
"The batch operations can be followed by a prefetch.")
def _get_batched_dataset_attributes(dataset):
"""Get `batch_size`, `drop_remainder`, and `prefetch_buffer` of dataset."""
# pylint: disable=protected-access
assert isinstance(dataset,
(dataset_ops.BatchDataset, batching._MapAndBatchDataset))
if isinstance(dataset, dataset_ops.BatchDataset):
batch_size = dataset._batch_size
drop_remainder = dataset._drop_remainder
elif isinstance(dataset, batching._MapAndBatchDataset):
batch_size = dataset._batch_size_t
drop_remainder = dataset._drop_remainder_t
prefetch_buffer = None
if isinstance(dataset, dataset_ops.PrefetchDataset):
@ -361,6 +365,15 @@ def _split_dataset_batch(dataset, split_batch_by):
if tensor_util.is_tensor(drop_remainder):
drop_remainder = tensor_util.constant_value(drop_remainder)
return batch_size, drop_remainder, prefetch_buffer
def _split_dataset_batch(dataset, split_batch_by):
"""Divide a batch-ed dataset's batches into smaller batches."""
batched_dataset = _get_batched_dataset(dataset)
batch_size, drop_remainder, prefetch_buffer = (
_get_batched_dataset_attributes(batched_dataset))
if batch_size % split_batch_by:
raise ValueError(
"Batch size %s cannot be sharded evenly across replicas %s" % (

View File

@ -134,6 +134,7 @@ py_library(
"engine/input_layer.py",
"engine/input_spec.py",
"engine/network.py",
"engine/partial_batch_padding_handler.py",
"engine/saving.py",
"engine/sequential.py",
"engine/training.py",
@ -159,6 +160,8 @@ py_library(
":regularizers",
":saving",
"//tensorflow/python/data",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/training/checkpointable:data_structures",
"//tensorflow/tools/docs:doc_controls",

View File

@ -379,7 +379,7 @@ def configure_and_create_session(distribution_strategy):
K.set_session(session)
def validate_inputs(x, y, distribution_strategy):
def validate_inputs(x, y, distribution_strategy, allow_partial_batch=False):
"""Validate inputs when using DistributionStrategy.
Args:
@ -387,6 +387,8 @@ def validate_inputs(x, y, distribution_strategy):
y: Model Targets.
distribution_strategy: The DistributionStrategy with which the model is
compiled.
allow_partial_batch: Boolean. If false, datasets must have fully
defined shapes.
Raises:
ValueError: if input is not a Dataset or a numpy array(when we use
@ -400,18 +402,13 @@ def validate_inputs(x, y, distribution_strategy):
if is_tpu_strategy(distribution_strategy):
for i in [x, y]:
if isinstance(i, dataset_ops.DatasetV2):
shapes = nest.flatten(i.output_shapes)
try:
s = next(s for s in shapes if not s.is_fully_defined())
except StopIteration:
continue
else:
if (isinstance(i, dataset_ops.DatasetV2) and not allow_partial_batch):
if not is_dataset_shape_fully_defined(i):
raise ValueError(
'Using TPUs currently requires fully defined shapes. Either use '
'set_shape() on the input tensors or use '
'dataset.batch(..., drop_remainder=True).'
'Found unknown shape {} in input {}.'.format(s, i))
'Found unknown shape in input {}.'.format(i))
# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
@ -427,8 +424,15 @@ def is_tpu_strategy(strategy):
return strategy is not None and strategy.__class__.__name__ == 'TPUStrategy'
def is_dataset_shape_fully_defined(dataset):
"""Returns whether a dataset contains a final partial batch."""
shapes = nest.flatten(dataset.output_shapes)
unknown_shapes = [s for s in shapes if not s.is_fully_defined()]
return not unknown_shapes
def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
is_training=False):
mode=None):
"""Calculate the number of batches and steps/steps_per_epoch.
Args:
@ -437,8 +441,10 @@ def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
model input.
steps: The specified number of steps.
batch_size: The specified batch_size.
is_training: Boolean to relax the constraints on consuming all the training
samples to keep compatibility till we support partial batches.
mode: ModeKey representing whether input will be used for training,
evaluation, or prediction. This is used to relax the constraints on
consuming all the training samples to keep compatibility till we
support partial batches. If none, then partial batches are not allowed.
Returns:
steps: The steps or steps_per_epoch argument depending on if a user is
@ -456,6 +462,14 @@ def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
use_per_replica_batch = not global_batch_size_supported(
distribution_strategy)
# Partial batches are allowed for training as we repeat the
# dataset when converting numpy arrays into a dataset.
# For other modes uneven batch sizes are not allowed except
# for `predict()` on TPUStrategy.
allow_partial_batch = (mode == ModeKeys.TRAIN or
(mode == ModeKeys.PREDICT
and is_tpu_strategy(distribution_strategy)))
if steps is None:
if batch_size is None:
# If neither the batch size or number of steps are set. We choose the
@ -468,7 +482,7 @@ def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
global_batch_size = batch_size
if use_per_replica_batch:
global_batch_size *= distribution_strategy.num_replicas_in_sync
if not is_training and num_samples % global_batch_size:
if not allow_partial_batch and num_samples % global_batch_size:
raise ValueError('The number of samples %s is not divisible by '
'batch size %s.' % (num_samples, global_batch_size))
steps = num_samples // global_batch_size
@ -488,7 +502,11 @@ def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
if use_per_replica_batch:
global_batch_size *= distribution_strategy.num_replicas_in_sync
if num_samples < (global_batch_size * steps):
min_num_samples = global_batch_size * steps
if allow_partial_batch:
min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0
if num_samples < min_num_samples:
raise ValueError('Number of samples %s is less than samples required '
'for specified batch_size %s and steps %s' % (
num_samples, global_batch_size, steps))

View File

@ -0,0 +1,111 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility object to handler partial batches for TPUStrategy."""
# pylint: disable=protected-access
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import six
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.ops import array_ops
from tensorflow.python.util import nest
class PartialBatchPaddingHandler(object):
"""A container that holds info about partial batches for `predict()`."""
def __init__(self, output_shape):
self.padded_batch_size = 0
self.padding_mask = array_ops.zeros(0)
self.output_shape = output_shape
def get_real_batch_size(self, dataset_batch):
"""Returns the number of elements in a potentially partial batch."""
if isinstance(dataset_batch, (tuple, list)):
dataset_batch = dataset_batch[0]
assert nest.flatten(dataset_batch)
def _find_any_tensor(batch_features):
tensors = [
x for x in nest.flatten(batch_features) if tensor_util.is_tensor(x)
]
if not tensors:
raise ValueError('Cannot find any Tensor in features dict.')
return tensors[0]
return K.cast(K.shape(_find_any_tensor(dataset_batch))[0],
dtype='int64')
def update_mask(self, padding_mask, dataset_batch):
"""Calculate and cache the amount of padding required for a batch."""
original_batch_size = self.get_real_batch_size(dataset_batch)
missing_count = self.padded_batch_size - original_batch_size
mask = K.concatenate([array_ops.ones(original_batch_size),
array_ops.zeros(missing_count)], axis=0)
return K.concatenate([padding_mask, mask], axis=0)
def pad_batch(self, *dataset_batch_elements):
"""Pads out the batch dimension of a tensor to the complete batch size."""
def _pad(batch):
"""Helper function to pad nested data within each batch elements."""
padded_dict_batch = {}
if isinstance(batch, dict):
for key, value in six.iteritems(batch):
padded_dict_batch[key] = _pad(value)
return padded_dict_batch
rank = len(batch.shape)
assert rank > 0
missing_count = (self.padded_batch_size -
self.get_real_batch_size(batch))
padding = K.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
return array_ops.pad(batch, padding, 'constant')
if len(dataset_batch_elements) == 1:
return _pad(dataset_batch_elements[0])
batch_elements = []
for batch_element in dataset_batch_elements:
batch_elements.append(_pad(batch_element))
return tuple(batch_elements)
def apply_mask(self, prediction_result):
"""Removes prediction output that corresponds to padded input."""
padding_mask = K.get_value(self.padding_mask)
assert len(padding_mask.shape) == 1
if len(self.output_shape) == 1:
prediction = np.take(prediction_result,
np.nonzero(
padding_mask[:len(prediction_result)]),
axis=0)
if prediction.shape[0] == 1:
prediction = np.squeeze(prediction, axis=0)
return prediction
else:
predictions = []
for i in range(len(self.output_shape)):
prediction = prediction_result[i]
prediction = np.take(prediction, np.nonzero(
padding_mask[:len(prediction)]), axis=0)
predictions.append(np.squeeze(prediction))
return predictions

View File

@ -2136,7 +2136,9 @@ class Model(Network):
steps_name='steps',
steps=None,
validation_split=0,
shuffle=False):
shuffle=False,
repeat=True,
allow_partial_batch=False):
"""Runs validation checks on input and target data passed by the user.
This is called when using DistributionStrategy to train, evaluate or serve
@ -2160,6 +2162,10 @@ class Model(Network):
validation_split: Float between 0 and 1.
Fraction of the training data to be used as validation data.
shuffle: Boolean whether to shuffle the training data before each epoch.
repeat: Boolean whether to repeat the numpy training data when converting
to training dataset.
allow_partial_batch: Boolean whether to enforce that all batches have the
same size.
Returns:
Dataset instance.
@ -2239,10 +2245,13 @@ class Model(Network):
session=session)
if shuffle_buffer:
ds = ds.shuffle(shuffle_buffer)
ds = ds.repeat()
if repeat:
ds = ds.repeat()
# We need to use the drop_remainder argument to get a known static
# input shape which is required for TPUs.
drop_remainder = strategy.extended.experimental_require_static_shapes
drop_remainder = (not allow_partial_batch and
strategy.extended.experimental_require_static_shapes)
x = ds.batch(batch_size, drop_remainder=drop_remainder)
else:
assert isinstance(x, dataset_ops.DatasetV2)

View File

@ -21,7 +21,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
@ -29,6 +31,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
from tensorflow.python.keras.engine import training_arrays
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils.generic_utils import Progbar
@ -61,10 +64,13 @@ def fit_distributed(model,
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
# Until support for partial batch is implemented across all
# functions and distribution strategy, we pass `mode` to selectively
# relax the costraint to consume all the training samples.
steps_per_epoch, batch_size = (
distributed_training_utils.get_input_params(
model._distribution_strategy, first_x_value, steps_per_epoch,
batch_size, is_training=True))
batch_size, mode=ModeKeys.TRAIN))
batch_size = model._validate_or_infer_batch_size(
batch_size, steps_per_epoch, x)
dataset = model._distribution_standardize_user_data(
@ -176,18 +182,21 @@ def predict_distributed(model,
callbacks=None):
"""Predict loop for Distribution Strategies."""
distributed_training_utils.validate_inputs(
x, None, model._distribution_strategy)
x, None, model._distribution_strategy, allow_partial_batch=True)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
steps, batch_size = distributed_training_utils.get_input_params(
model._distribution_strategy, first_x_value, steps, batch_size)
model._distribution_strategy, first_x_value, steps,
batch_size, mode=ModeKeys.PREDICT)
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
dataset = model._distribution_standardize_user_data(
x,
batch_size=batch_size,
check_steps=True,
steps_name='steps',
steps=steps)
steps=steps,
repeat=False,
allow_partial_batch=True)
if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
# TODO(fchollet): why aren't callbacks supported here?
return experimental_tpu_predict_loop(
@ -537,6 +546,32 @@ def experimental_tpu_predict_loop(model, dataset, verbose=0, steps=None):
or list of arrays of predictions
(if the model has multiple outputs).
"""
dataset_fully_shaped = (distributed_training_utils.
is_dataset_shape_fully_defined(dataset))
padding_handler = None
if not dataset_fully_shaped:
# TODO(hongjunchoi): Investigate whether operations from
# PartialBatchPaddingHandler are unnecessarily pruned out
# during graph optimization.
padding_handler = padding_util.PartialBatchPaddingHandler(
model._feed_output_shapes)
batched_dataset = input_lib._get_batched_dataset(dataset)
batch_size, _, prefetch_buffer = input_lib._get_batched_dataset_attributes(
batched_dataset)
padding_handler.padded_batch_size = batch_size
padding_handler.padding_mask = dataset.reduce(padding_handler.padding_mask,
padding_handler.update_mask)
dataset = dataset.map(padding_handler.pad_batch)
dataset = dataset.apply(batching.unbatch())
# Upon this point, it is guaranteed that the dataset does not
# have partial batches. Thus, we set `drop_remainder=True` to
# get static shape information about the elements in the dataset.
dataset = dataset.batch(batch_size, drop_remainder=True)
if prefetch_buffer is not None:
dataset = dataset.prefetch(prefetch_buffer)
current_strategy = model._distribution_strategy
iterator = distributed_training_utils.get_iterator(dataset, current_strategy)
@ -623,8 +658,14 @@ def experimental_tpu_predict_loop(model, dataset, verbose=0, steps=None):
scope.__exit__(None, None, None)
if len(unconcatenated_outs) == 1:
return np.concatenate(unconcatenated_outs[0], axis=0)
return [
np.concatenate(unconcatenated_outs[i], axis=0)
for i in range(len(unconcatenated_outs))
]
prediction_result = np.concatenate(unconcatenated_outs[0], axis=0)
else:
prediction_result = [
np.concatenate(unconcatenated_outs[i], axis=0)
for i in range(len(unconcatenated_outs))
]
if padding_handler:
prediction_result = padding_handler.apply_mask(prediction_result)
return prediction_result