Add support for final partial batch in predict() for TPUStrategy
PiperOrigin-RevId: 231341652
This commit is contained in:
parent
01cf6233d6
commit
dd7de7bbac
@ -566,6 +566,7 @@ py_library(
|
||||
distribute_py_test(
|
||||
name = "keras_test",
|
||||
srcs = ["keras_test.py"],
|
||||
full_precision = True,
|
||||
main = "keras_test.py",
|
||||
shard_count = 16,
|
||||
tags = [
|
||||
|
@ -71,6 +71,18 @@ def simple_functional_model():
|
||||
return model
|
||||
|
||||
|
||||
def simple_multi_inputs_multi_outputs_model():
|
||||
input_a = keras.layers.Input(shape=(16,), name='input_a')
|
||||
input_b = keras.layers.Input(shape=(16,), name='input_b')
|
||||
|
||||
merged = keras.layers.concatenate([input_a, input_b], name='merge')
|
||||
output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
|
||||
output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
|
||||
model = keras.models.Model(
|
||||
inputs=[input_a, input_b], outputs=[output_c, output_d])
|
||||
return model
|
||||
|
||||
|
||||
def multi_inputs_multi_outputs_model():
|
||||
input_a = keras.layers.Input(shape=(16,), name='input_a')
|
||||
input_b = keras.layers.Input(shape=(16,), name='input_b')
|
||||
@ -671,6 +683,61 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
self.assertAllEqual([6, 7], outs[0].shape)
|
||||
self.assertAllEqual([6, 7], outs[1].shape)
|
||||
|
||||
@combinations.generate(tpu_strategy_combinations())
|
||||
def test_predict_with_partial_batch(self, distribution):
|
||||
with self.cached_session():
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
||||
loss = 'mse'
|
||||
|
||||
with distribution.scope():
|
||||
model_with_ds_strategy = get_model()
|
||||
model_with_ds_strategy.compile(optimizer, loss)
|
||||
|
||||
cpu_model = get_model()
|
||||
cpu_model.compile(optimizer, loss)
|
||||
|
||||
inputs = np.zeros((10, 3), dtype=np.float32)
|
||||
|
||||
# As sample size is 10, we batch by 4 so that the last batch is
|
||||
# a partial batch. Also `fit()` using numpy array as inputs without
|
||||
# distribution strategy uses entire sample as a single batch. As so,
|
||||
# we remove parameters `batch_size` and `steps`.
|
||||
cpu_model.set_weights(model_with_ds_strategy.get_weights())
|
||||
self.assertAllClose(
|
||||
model_with_ds_strategy.predict(inputs, batch_size=4, steps=3),
|
||||
cpu_model.predict(inputs),
|
||||
atol=1e-5, rtol=1e-5)
|
||||
|
||||
@combinations.generate(tpu_strategy_combinations())
|
||||
def test_predict_multi_output_model_with_partial_batch(
|
||||
self, distribution):
|
||||
with self.cached_session():
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
||||
loss = 'mse'
|
||||
|
||||
with distribution.scope():
|
||||
model_with_ds_strategy = simple_multi_inputs_multi_outputs_model()
|
||||
model_with_ds_strategy.compile(optimizer, loss)
|
||||
|
||||
cpu_model = simple_multi_inputs_multi_outputs_model()
|
||||
cpu_model.compile(optimizer, loss)
|
||||
|
||||
input_data, _ = get_multi_inputs_multi_outputs_data()
|
||||
input_dict = {
|
||||
'input_a': input_data['input_a'],
|
||||
'input_b': input_data['input_b'],
|
||||
}
|
||||
|
||||
# As sample size is 200, we batch by 18 so that the last batch is
|
||||
# a partial batch. Also `fit()` using numpy array as inputs without
|
||||
# distribution strategy uses entire sample as a single batch. As so,
|
||||
# we remove parameters `batch_size` and `steps`.
|
||||
cpu_model.set_weights(model_with_ds_strategy.get_weights())
|
||||
self.assertAllClose(
|
||||
model_with_ds_strategy.predict(input_dict, batch_size=18, steps=12),
|
||||
cpu_model.predict(input_dict),
|
||||
atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
@ -961,6 +1028,64 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
|
||||
self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr))
|
||||
|
||||
@combinations.generate(tpu_strategy_combinations())
|
||||
def test_predict_with_dataset_with_partial_batch(self, distribution):
|
||||
with self.cached_session():
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
||||
loss = 'mse'
|
||||
|
||||
with distribution.scope():
|
||||
model_with_ds_strategy = get_model()
|
||||
model_with_ds_strategy.compile(optimizer, loss)
|
||||
|
||||
cpu_model = get_model()
|
||||
cpu_model.compile(optimizer, loss)
|
||||
|
||||
inputs = np.zeros((10, 3), dtype=np.float32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs))
|
||||
|
||||
# As sample size is 10, we batch by 4 so that the last batch is
|
||||
# a partial batch.
|
||||
dataset_with_partial_batch = dataset.batch(4)
|
||||
cpu_model.set_weights(model_with_ds_strategy.get_weights())
|
||||
|
||||
self.assertAllClose(
|
||||
model_with_ds_strategy.predict(dataset_with_partial_batch, steps=3),
|
||||
cpu_model.predict(dataset_with_partial_batch, steps=3),
|
||||
atol=1e-5, rtol=1e-5)
|
||||
|
||||
@combinations.generate(tpu_strategy_combinations())
|
||||
def test_predict_multi_output_model_with_dataset_with_partial_batch(
|
||||
self, distribution):
|
||||
with self.cached_session():
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
||||
loss = 'mse'
|
||||
|
||||
with distribution.scope():
|
||||
model_with_ds_strategy = simple_multi_inputs_multi_outputs_model()
|
||||
model_with_ds_strategy.compile(optimizer, loss)
|
||||
|
||||
cpu_model = simple_multi_inputs_multi_outputs_model()
|
||||
cpu_model.compile(optimizer, loss)
|
||||
|
||||
input_data, _ = get_multi_inputs_multi_outputs_data()
|
||||
input_dict = {
|
||||
'input_a': input_data['input_a'],
|
||||
'input_b': input_data['input_b'],
|
||||
}
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(input_dict)
|
||||
|
||||
# As sample size is 200, we batch by 18 using 12 steps per epoch so
|
||||
# that the last batch is a partial batch.
|
||||
dataset_with_partial_batch = dataset.batch(18)
|
||||
cpu_model.set_weights(model_with_ds_strategy.get_weights())
|
||||
|
||||
self.assertAllClose(
|
||||
model_with_ds_strategy.predict(dataset_with_partial_batch, steps=12),
|
||||
cpu_model.predict(dataset_with_partial_batch, steps=12),
|
||||
atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -321,31 +321,35 @@ class _SingleWorkerDatasetIterator(object):
|
||||
return self._iterator.output_types
|
||||
|
||||
|
||||
def _split_dataset_batch(dataset, split_batch_by):
|
||||
"""Divide a batch-ed dataset's batches into smaller batches."""
|
||||
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
|
||||
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
|
||||
def _get_batched_dataset(d):
|
||||
"""Get the underlying batch dataset from the dataset object."""
|
||||
# pylint: disable=protected-access
|
||||
def _get_batch_dataset(d):
|
||||
"""Get the underlying batch dataset from the dataset object."""
|
||||
if isinstance(d, dataset_ops.DatasetV1Adapter):
|
||||
d = d._dataset
|
||||
if isinstance(d, dataset_ops.DatasetV1Adapter):
|
||||
d = d._dataset
|
||||
|
||||
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
|
||||
return d
|
||||
elif isinstance(d, dataset_ops.PrefetchDataset):
|
||||
return _get_batch_dataset(d._input_dataset)
|
||||
raise ValueError(
|
||||
"Unable to get batched dataset from the input dataset. `batch` "
|
||||
"`map_and_batch` need to be the last operations on the dataset. "
|
||||
"The batch operations can be followed by a prefetch.")
|
||||
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
|
||||
return d
|
||||
elif isinstance(d, dataset_ops.PrefetchDataset):
|
||||
return _get_batched_dataset(d._input_dataset)
|
||||
|
||||
batched_dataset = _get_batch_dataset(dataset)
|
||||
if isinstance(batched_dataset, dataset_ops.BatchDataset):
|
||||
batch_size = batched_dataset._batch_size
|
||||
drop_remainder = batched_dataset._drop_remainder
|
||||
elif isinstance(batched_dataset, batching._MapAndBatchDataset):
|
||||
batch_size = batched_dataset._batch_size_t
|
||||
drop_remainder = batched_dataset._drop_remainder_t
|
||||
raise ValueError(
|
||||
"Unable to get batched dataset from the input dataset. `batch` "
|
||||
"`map_and_batch` need to be the last operations on the dataset. "
|
||||
"The batch operations can be followed by a prefetch.")
|
||||
|
||||
|
||||
def _get_batched_dataset_attributes(dataset):
|
||||
"""Get `batch_size`, `drop_remainder`, and `prefetch_buffer` of dataset."""
|
||||
# pylint: disable=protected-access
|
||||
assert isinstance(dataset,
|
||||
(dataset_ops.BatchDataset, batching._MapAndBatchDataset))
|
||||
if isinstance(dataset, dataset_ops.BatchDataset):
|
||||
batch_size = dataset._batch_size
|
||||
drop_remainder = dataset._drop_remainder
|
||||
elif isinstance(dataset, batching._MapAndBatchDataset):
|
||||
batch_size = dataset._batch_size_t
|
||||
drop_remainder = dataset._drop_remainder_t
|
||||
|
||||
prefetch_buffer = None
|
||||
if isinstance(dataset, dataset_ops.PrefetchDataset):
|
||||
@ -361,6 +365,15 @@ def _split_dataset_batch(dataset, split_batch_by):
|
||||
if tensor_util.is_tensor(drop_remainder):
|
||||
drop_remainder = tensor_util.constant_value(drop_remainder)
|
||||
|
||||
return batch_size, drop_remainder, prefetch_buffer
|
||||
|
||||
|
||||
def _split_dataset_batch(dataset, split_batch_by):
|
||||
"""Divide a batch-ed dataset's batches into smaller batches."""
|
||||
batched_dataset = _get_batched_dataset(dataset)
|
||||
batch_size, drop_remainder, prefetch_buffer = (
|
||||
_get_batched_dataset_attributes(batched_dataset))
|
||||
|
||||
if batch_size % split_batch_by:
|
||||
raise ValueError(
|
||||
"Batch size %s cannot be sharded evenly across replicas %s" % (
|
||||
|
@ -134,6 +134,7 @@ py_library(
|
||||
"engine/input_layer.py",
|
||||
"engine/input_spec.py",
|
||||
"engine/network.py",
|
||||
"engine/partial_batch_padding_handler.py",
|
||||
"engine/saving.py",
|
||||
"engine/sequential.py",
|
||||
"engine/training.py",
|
||||
@ -159,6 +160,8 @@ py_library(
|
||||
":regularizers",
|
||||
":saving",
|
||||
"//tensorflow/python/data",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:input_lib",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/training/checkpointable:data_structures",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
|
@ -379,7 +379,7 @@ def configure_and_create_session(distribution_strategy):
|
||||
K.set_session(session)
|
||||
|
||||
|
||||
def validate_inputs(x, y, distribution_strategy):
|
||||
def validate_inputs(x, y, distribution_strategy, allow_partial_batch=False):
|
||||
"""Validate inputs when using DistributionStrategy.
|
||||
|
||||
Args:
|
||||
@ -387,6 +387,8 @@ def validate_inputs(x, y, distribution_strategy):
|
||||
y: Model Targets.
|
||||
distribution_strategy: The DistributionStrategy with which the model is
|
||||
compiled.
|
||||
allow_partial_batch: Boolean. If false, datasets must have fully
|
||||
defined shapes.
|
||||
|
||||
Raises:
|
||||
ValueError: if input is not a Dataset or a numpy array(when we use
|
||||
@ -400,18 +402,13 @@ def validate_inputs(x, y, distribution_strategy):
|
||||
|
||||
if is_tpu_strategy(distribution_strategy):
|
||||
for i in [x, y]:
|
||||
if isinstance(i, dataset_ops.DatasetV2):
|
||||
shapes = nest.flatten(i.output_shapes)
|
||||
try:
|
||||
s = next(s for s in shapes if not s.is_fully_defined())
|
||||
except StopIteration:
|
||||
continue
|
||||
else:
|
||||
if (isinstance(i, dataset_ops.DatasetV2) and not allow_partial_batch):
|
||||
if not is_dataset_shape_fully_defined(i):
|
||||
raise ValueError(
|
||||
'Using TPUs currently requires fully defined shapes. Either use '
|
||||
'set_shape() on the input tensors or use '
|
||||
'dataset.batch(..., drop_remainder=True).'
|
||||
'Found unknown shape {} in input {}.'.format(s, i))
|
||||
'Found unknown shape in input {}.'.format(i))
|
||||
|
||||
|
||||
# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
|
||||
@ -427,8 +424,15 @@ def is_tpu_strategy(strategy):
|
||||
return strategy is not None and strategy.__class__.__name__ == 'TPUStrategy'
|
||||
|
||||
|
||||
def is_dataset_shape_fully_defined(dataset):
|
||||
"""Returns whether a dataset contains a final partial batch."""
|
||||
shapes = nest.flatten(dataset.output_shapes)
|
||||
unknown_shapes = [s for s in shapes if not s.is_fully_defined()]
|
||||
return not unknown_shapes
|
||||
|
||||
|
||||
def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
|
||||
is_training=False):
|
||||
mode=None):
|
||||
"""Calculate the number of batches and steps/steps_per_epoch.
|
||||
|
||||
Args:
|
||||
@ -437,8 +441,10 @@ def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
|
||||
model input.
|
||||
steps: The specified number of steps.
|
||||
batch_size: The specified batch_size.
|
||||
is_training: Boolean to relax the constraints on consuming all the training
|
||||
samples to keep compatibility till we support partial batches.
|
||||
mode: ModeKey representing whether input will be used for training,
|
||||
evaluation, or prediction. This is used to relax the constraints on
|
||||
consuming all the training samples to keep compatibility till we
|
||||
support partial batches. If none, then partial batches are not allowed.
|
||||
|
||||
Returns:
|
||||
steps: The steps or steps_per_epoch argument depending on if a user is
|
||||
@ -456,6 +462,14 @@ def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
|
||||
use_per_replica_batch = not global_batch_size_supported(
|
||||
distribution_strategy)
|
||||
|
||||
# Partial batches are allowed for training as we repeat the
|
||||
# dataset when converting numpy arrays into a dataset.
|
||||
# For other modes uneven batch sizes are not allowed except
|
||||
# for `predict()` on TPUStrategy.
|
||||
allow_partial_batch = (mode == ModeKeys.TRAIN or
|
||||
(mode == ModeKeys.PREDICT
|
||||
and is_tpu_strategy(distribution_strategy)))
|
||||
|
||||
if steps is None:
|
||||
if batch_size is None:
|
||||
# If neither the batch size or number of steps are set. We choose the
|
||||
@ -468,7 +482,7 @@ def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
|
||||
global_batch_size = batch_size
|
||||
if use_per_replica_batch:
|
||||
global_batch_size *= distribution_strategy.num_replicas_in_sync
|
||||
if not is_training and num_samples % global_batch_size:
|
||||
if not allow_partial_batch and num_samples % global_batch_size:
|
||||
raise ValueError('The number of samples %s is not divisible by '
|
||||
'batch size %s.' % (num_samples, global_batch_size))
|
||||
steps = num_samples // global_batch_size
|
||||
@ -488,7 +502,11 @@ def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
|
||||
if use_per_replica_batch:
|
||||
global_batch_size *= distribution_strategy.num_replicas_in_sync
|
||||
|
||||
if num_samples < (global_batch_size * steps):
|
||||
min_num_samples = global_batch_size * steps
|
||||
if allow_partial_batch:
|
||||
min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0
|
||||
|
||||
if num_samples < min_num_samples:
|
||||
raise ValueError('Number of samples %s is less than samples required '
|
||||
'for specified batch_size %s and steps %s' % (
|
||||
num_samples, global_batch_size, steps))
|
||||
|
111
tensorflow/python/keras/engine/partial_batch_padding_handler.py
Normal file
111
tensorflow/python/keras/engine/partial_batch_padding_handler.py
Normal file
@ -0,0 +1,111 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility object to handler partial batches for TPUStrategy."""
|
||||
# pylint: disable=protected-access
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class PartialBatchPaddingHandler(object):
|
||||
"""A container that holds info about partial batches for `predict()`."""
|
||||
|
||||
def __init__(self, output_shape):
|
||||
self.padded_batch_size = 0
|
||||
self.padding_mask = array_ops.zeros(0)
|
||||
self.output_shape = output_shape
|
||||
|
||||
def get_real_batch_size(self, dataset_batch):
|
||||
"""Returns the number of elements in a potentially partial batch."""
|
||||
if isinstance(dataset_batch, (tuple, list)):
|
||||
dataset_batch = dataset_batch[0]
|
||||
|
||||
assert nest.flatten(dataset_batch)
|
||||
|
||||
def _find_any_tensor(batch_features):
|
||||
tensors = [
|
||||
x for x in nest.flatten(batch_features) if tensor_util.is_tensor(x)
|
||||
]
|
||||
if not tensors:
|
||||
raise ValueError('Cannot find any Tensor in features dict.')
|
||||
return tensors[0]
|
||||
|
||||
return K.cast(K.shape(_find_any_tensor(dataset_batch))[0],
|
||||
dtype='int64')
|
||||
|
||||
def update_mask(self, padding_mask, dataset_batch):
|
||||
"""Calculate and cache the amount of padding required for a batch."""
|
||||
original_batch_size = self.get_real_batch_size(dataset_batch)
|
||||
missing_count = self.padded_batch_size - original_batch_size
|
||||
mask = K.concatenate([array_ops.ones(original_batch_size),
|
||||
array_ops.zeros(missing_count)], axis=0)
|
||||
return K.concatenate([padding_mask, mask], axis=0)
|
||||
|
||||
def pad_batch(self, *dataset_batch_elements):
|
||||
"""Pads out the batch dimension of a tensor to the complete batch size."""
|
||||
def _pad(batch):
|
||||
"""Helper function to pad nested data within each batch elements."""
|
||||
padded_dict_batch = {}
|
||||
if isinstance(batch, dict):
|
||||
for key, value in six.iteritems(batch):
|
||||
padded_dict_batch[key] = _pad(value)
|
||||
return padded_dict_batch
|
||||
|
||||
rank = len(batch.shape)
|
||||
assert rank > 0
|
||||
missing_count = (self.padded_batch_size -
|
||||
self.get_real_batch_size(batch))
|
||||
padding = K.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
|
||||
return array_ops.pad(batch, padding, 'constant')
|
||||
|
||||
if len(dataset_batch_elements) == 1:
|
||||
return _pad(dataset_batch_elements[0])
|
||||
|
||||
batch_elements = []
|
||||
for batch_element in dataset_batch_elements:
|
||||
batch_elements.append(_pad(batch_element))
|
||||
return tuple(batch_elements)
|
||||
|
||||
def apply_mask(self, prediction_result):
|
||||
"""Removes prediction output that corresponds to padded input."""
|
||||
padding_mask = K.get_value(self.padding_mask)
|
||||
assert len(padding_mask.shape) == 1
|
||||
|
||||
if len(self.output_shape) == 1:
|
||||
prediction = np.take(prediction_result,
|
||||
np.nonzero(
|
||||
padding_mask[:len(prediction_result)]),
|
||||
axis=0)
|
||||
if prediction.shape[0] == 1:
|
||||
prediction = np.squeeze(prediction, axis=0)
|
||||
return prediction
|
||||
|
||||
else:
|
||||
predictions = []
|
||||
for i in range(len(self.output_shape)):
|
||||
prediction = prediction_result[i]
|
||||
prediction = np.take(prediction, np.nonzero(
|
||||
padding_mask[:len(prediction)]), axis=0)
|
||||
predictions.append(np.squeeze(prediction))
|
||||
|
||||
return predictions
|
@ -2136,7 +2136,9 @@ class Model(Network):
|
||||
steps_name='steps',
|
||||
steps=None,
|
||||
validation_split=0,
|
||||
shuffle=False):
|
||||
shuffle=False,
|
||||
repeat=True,
|
||||
allow_partial_batch=False):
|
||||
"""Runs validation checks on input and target data passed by the user.
|
||||
|
||||
This is called when using DistributionStrategy to train, evaluate or serve
|
||||
@ -2160,6 +2162,10 @@ class Model(Network):
|
||||
validation_split: Float between 0 and 1.
|
||||
Fraction of the training data to be used as validation data.
|
||||
shuffle: Boolean whether to shuffle the training data before each epoch.
|
||||
repeat: Boolean whether to repeat the numpy training data when converting
|
||||
to training dataset.
|
||||
allow_partial_batch: Boolean whether to enforce that all batches have the
|
||||
same size.
|
||||
|
||||
Returns:
|
||||
Dataset instance.
|
||||
@ -2239,10 +2245,13 @@ class Model(Network):
|
||||
session=session)
|
||||
if shuffle_buffer:
|
||||
ds = ds.shuffle(shuffle_buffer)
|
||||
ds = ds.repeat()
|
||||
if repeat:
|
||||
ds = ds.repeat()
|
||||
|
||||
# We need to use the drop_remainder argument to get a known static
|
||||
# input shape which is required for TPUs.
|
||||
drop_remainder = strategy.extended.experimental_require_static_shapes
|
||||
drop_remainder = (not allow_partial_batch and
|
||||
strategy.extended.experimental_require_static_shapes)
|
||||
x = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||
else:
|
||||
assert isinstance(x, dataset_ops.DatasetV2)
|
||||
|
@ -21,7 +21,9 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
@ -29,6 +31,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import callbacks as cbks
|
||||
from tensorflow.python.keras.engine import distributed_training_utils
|
||||
from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
|
||||
from tensorflow.python.keras.engine import training_arrays
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
@ -61,10 +64,13 @@ def fit_distributed(model,
|
||||
|
||||
first_x_value = nest.flatten(x)[0]
|
||||
if isinstance(first_x_value, np.ndarray):
|
||||
# Until support for partial batch is implemented across all
|
||||
# functions and distribution strategy, we pass `mode` to selectively
|
||||
# relax the costraint to consume all the training samples.
|
||||
steps_per_epoch, batch_size = (
|
||||
distributed_training_utils.get_input_params(
|
||||
model._distribution_strategy, first_x_value, steps_per_epoch,
|
||||
batch_size, is_training=True))
|
||||
batch_size, mode=ModeKeys.TRAIN))
|
||||
batch_size = model._validate_or_infer_batch_size(
|
||||
batch_size, steps_per_epoch, x)
|
||||
dataset = model._distribution_standardize_user_data(
|
||||
@ -176,18 +182,21 @@ def predict_distributed(model,
|
||||
callbacks=None):
|
||||
"""Predict loop for Distribution Strategies."""
|
||||
distributed_training_utils.validate_inputs(
|
||||
x, None, model._distribution_strategy)
|
||||
x, None, model._distribution_strategy, allow_partial_batch=True)
|
||||
first_x_value = nest.flatten(x)[0]
|
||||
if isinstance(first_x_value, np.ndarray):
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
model._distribution_strategy, first_x_value, steps, batch_size)
|
||||
model._distribution_strategy, first_x_value, steps,
|
||||
batch_size, mode=ModeKeys.PREDICT)
|
||||
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
|
||||
dataset = model._distribution_standardize_user_data(
|
||||
x,
|
||||
batch_size=batch_size,
|
||||
check_steps=True,
|
||||
steps_name='steps',
|
||||
steps=steps)
|
||||
steps=steps,
|
||||
repeat=False,
|
||||
allow_partial_batch=True)
|
||||
if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
|
||||
# TODO(fchollet): why aren't callbacks supported here?
|
||||
return experimental_tpu_predict_loop(
|
||||
@ -537,6 +546,32 @@ def experimental_tpu_predict_loop(model, dataset, verbose=0, steps=None):
|
||||
or list of arrays of predictions
|
||||
(if the model has multiple outputs).
|
||||
"""
|
||||
dataset_fully_shaped = (distributed_training_utils.
|
||||
is_dataset_shape_fully_defined(dataset))
|
||||
padding_handler = None
|
||||
if not dataset_fully_shaped:
|
||||
# TODO(hongjunchoi): Investigate whether operations from
|
||||
# PartialBatchPaddingHandler are unnecessarily pruned out
|
||||
# during graph optimization.
|
||||
padding_handler = padding_util.PartialBatchPaddingHandler(
|
||||
model._feed_output_shapes)
|
||||
batched_dataset = input_lib._get_batched_dataset(dataset)
|
||||
batch_size, _, prefetch_buffer = input_lib._get_batched_dataset_attributes(
|
||||
batched_dataset)
|
||||
padding_handler.padded_batch_size = batch_size
|
||||
padding_handler.padding_mask = dataset.reduce(padding_handler.padding_mask,
|
||||
padding_handler.update_mask)
|
||||
|
||||
dataset = dataset.map(padding_handler.pad_batch)
|
||||
dataset = dataset.apply(batching.unbatch())
|
||||
# Upon this point, it is guaranteed that the dataset does not
|
||||
# have partial batches. Thus, we set `drop_remainder=True` to
|
||||
# get static shape information about the elements in the dataset.
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
|
||||
if prefetch_buffer is not None:
|
||||
dataset = dataset.prefetch(prefetch_buffer)
|
||||
|
||||
current_strategy = model._distribution_strategy
|
||||
iterator = distributed_training_utils.get_iterator(dataset, current_strategy)
|
||||
|
||||
@ -623,8 +658,14 @@ def experimental_tpu_predict_loop(model, dataset, verbose=0, steps=None):
|
||||
scope.__exit__(None, None, None)
|
||||
|
||||
if len(unconcatenated_outs) == 1:
|
||||
return np.concatenate(unconcatenated_outs[0], axis=0)
|
||||
return [
|
||||
np.concatenate(unconcatenated_outs[i], axis=0)
|
||||
for i in range(len(unconcatenated_outs))
|
||||
]
|
||||
prediction_result = np.concatenate(unconcatenated_outs[0], axis=0)
|
||||
else:
|
||||
prediction_result = [
|
||||
np.concatenate(unconcatenated_outs[i], axis=0)
|
||||
for i in range(len(unconcatenated_outs))
|
||||
]
|
||||
|
||||
if padding_handler:
|
||||
prediction_result = padding_handler.apply_mask(prediction_result)
|
||||
|
||||
return prediction_result
|
||||
|
Loading…
Reference in New Issue
Block a user