Change distribution.distribute_dataset to accept an input_fn instead of a dataset.

PiperOrigin-RevId: 193437651
This commit is contained in:
Yuefeng Zhou 2018-04-18 16:35:44 -07:00 committed by TensorFlower Gardener
parent e9d47fbff0
commit fddfa9f8dc
10 changed files with 79 additions and 59 deletions

View File

@ -54,21 +54,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss,
is_tpu):
with distribution.scope():
model_fn, dataset, layer = minimize_loss_example(
optimizer_fn,
use_bias=True,
use_callable_loss=use_callable_loss)
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
def tpu_dataset_fn():
return dataset_fn().batch(2)
# TODO(isaprykin): Eliminate `is_tpu`. Probably add a
# `DistributionStrategy.create_monitor` so that each DistributionStrategy
# could influence its training loop. That method would return an instance
# of Monitor. TPUMonitor would execute tpu.initialize_system() and
# tpu.shutdown_system().
if is_tpu:
dataset = dataset.batch(2)
iterator = distribution.distribute_dataset(
dataset).make_one_shot_iterator()
tpu_dataset_fn if is_tpu else dataset_fn).make_one_shot_iterator()
def run_step():
# TODO(isaprykin): Make iterator get_next() return a list of sub-
@ -122,14 +119,14 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
# `distribution.scope`.
with variable_scope.variable_creator_scope(
appending_creator), distribution.scope():
model_fn, dataset, layer = minimize_loss_example(
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn,
use_bias=True,
use_callable_loss=True,
create_optimizer_inside_model_fn=True)
iterator = distribution.distribute_dataset(
dataset).make_one_shot_iterator()
dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.group(
@ -176,7 +173,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
"""Verifies that moving mean updates are reduced across towers."""
with distribution.scope():
num_towers = len(distribution.worker_devices)
model_fn, dataset, batchnorm = batchnorm_example(
model_fn, dataset_fn, batchnorm = batchnorm_example(
optimizer_fn,
batch_per_epoch=num_towers,
momentum=momentum,
@ -188,7 +185,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
if isinstance(distribution, mirrored_strategy.MirroredStrategy):
distribution._prefetch_on_device = False
iterator = distribution.distribute_dataset(
dataset).make_one_shot_iterator()
dataset_fn).make_one_shot_iterator()
def run_step():
return control_flow_ops.group(
@ -260,11 +257,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
else:
return optimizer.minimize(loss_fn())
features = dataset_ops.Dataset.from_tensors([[2.], [7.]])
labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
dataset = dataset_ops.Dataset.zip((features, labels)).repeat()
def dataset_fn():
features = dataset_ops.Dataset.from_tensors([[2.], [7.]])
labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
return dataset_ops.Dataset.zip((features, labels)).repeat()
iterator = distribution.distribute_dataset(
dataset).make_one_shot_iterator()
dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.group(

View File

@ -140,9 +140,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
g.add_to_collections(collections, result)
return result
def distribute_dataset(self, dataset):
def distribute_dataset(self, dataset_fn):
return values.PerDeviceDataset(
dataset, self._devices, self._prefetch_on_device)
self._call_dataset_fn(dataset_fn), self._devices,
self._prefetch_on_device)
def _broadcast(self, tensor, destinations):
# TODO(josh11b): In eager mode, use one thread per device, or async mode.

View File

@ -247,9 +247,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
features = dist.distribute_dataset(
features).make_one_shot_iterator().get_next()
lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
).make_one_shot_iterator().get_next()
with dist.scope():
result = dist.call_for_each_tower(

View File

@ -60,8 +60,8 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.colocate_with(colocate_with):
return next_creator(*args, **kwargs)
def distribute_dataset(self, dataset):
return dataset
def distribute_dataset(self, dataset_fn):
return self._call_dataset_fn(dataset_fn)
def _broadcast(self, tensor, destinations):
return tensor

View File

@ -39,11 +39,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
def testTrainNetwork(self, distribution, optimizer_fn,
use_callable_loss=True):
with distribution.scope():
model_fn, dataset, layer = minimize_loss_example(
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
iterator = distribution.distribute_dataset(
dataset).make_one_shot_iterator()
dataset_fn).make_one_shot_iterator()
def run_step():
return control_flow_ops.group(distribution.unwrap(

View File

@ -29,7 +29,10 @@ from tensorflow.python.ops import math_ops
def single_loss_example(optimizer_fn, distribution, use_bias=False):
"""Build a very simple network to use in tests and examples."""
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
def dataset_fn():
return dataset_ops.Dataset.from_tensors([[1.]]).repeat()
optimizer = optimizer_fn()
layer = core.Dense(1, use_bias=use_bias)
@ -37,8 +40,8 @@ def single_loss_example(optimizer_fn, distribution, use_bias=False):
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
return y * y
single_loss_step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer,
distribution)
single_loss_step = step_fn.StandardSingleLossStep(dataset_fn, loss_fn,
optimizer, distribution)
# Layer is returned for inspecting the kernels in tests.
return single_loss_step, layer
@ -49,7 +52,10 @@ def minimize_loss_example(optimizer_fn,
use_callable_loss=True,
create_optimizer_inside_model_fn=False):
"""Example of non-distribution-aware legacy code."""
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
def dataset_fn():
return dataset_ops.Dataset.from_tensors([[1.]]).repeat()
# An Optimizer instance is created either outside or inside model_fn.
outer_optimizer = None
if not create_optimizer_inside_model_fn:
@ -71,7 +77,7 @@ def minimize_loss_example(optimizer_fn,
else:
return optimizer.minimize(loss_fn())
return model_fn, dataset, layer
return model_fn, dataset_fn, layer
def batchnorm_example(optimizer_fn,
@ -79,12 +85,15 @@ def batchnorm_example(optimizer_fn,
momentum=0.9,
renorm=False):
"""Example of non-distribution-aware legacy code with batch normalization."""
# input shape is [16, 8], input values are increasing in both dimensions.
dataset = dataset_ops.Dataset.from_tensor_slices(
[[[float(x * 8 + y + z * 100)
for y in range(8)]
for x in range(16)]
for z in range(batch_per_epoch)]).repeat()
def dataset_fn():
# input shape is [16, 8], input values are increasing in both dimensions.
return dataset_ops.Dataset.from_tensor_slices(
[[[float(x * 8 + y + z * 100)
for y in range(8)]
for x in range(16)]
for z in range(batch_per_epoch)]).repeat()
optimizer = optimizer_fn()
batchnorm = normalization.BatchNormalization(
renorm=renorm, momentum=momentum, fused=False)
@ -99,4 +108,4 @@ def batchnorm_example(optimizer_fn,
# Callable loss.
return optimizer.minimize(loss_fn)
return model_fn, dataset, batchnorm
return model_fn, dataset_fn, batchnorm

View File

@ -49,13 +49,14 @@ class StandardInputStep(Step):
"""Step with a standard implementation of input handling.
Args:
input_dataset: a tf.data Dataset that provides input.
dataset_fn: a function that returns a tf.data Dataset that produces the
input for the model.
"""
def __init__(self, input_dataset, distribution):
def __init__(self, dataset_fn, distribution):
Step.__init__(self, distribution)
self._distributed_input = distribution.distribute_dataset(
input_dataset).make_one_shot_iterator()
dataset_fn).make_one_shot_iterator()
def inputs(self):
return self._distributed_input.get_next()
@ -77,14 +78,15 @@ class StandardSingleLossStep(StandardInputStep):
```
Args:
input_dataset: a tf.data Dataset that provides input.
dataset_fn: a function that returns a tf.data Dataset that produces the
input for the model.
loss_fn: a function that returns loss.
optimizer: an optimizer that implements an update rule.
distribution: a `DistributionStrategy` object.
"""
def __init__(self, input_dataset, loss_fn, optimizer, distribution):
StandardInputStep.__init__(self, input_dataset, distribution)
def __init__(self, dataset_fn, loss_fn, optimizer, distribution):
StandardInputStep.__init__(self, dataset_fn, distribution)
self._loss_fn = loss_fn
self._optimizer = optimizer
self._is_run_concurrently = False

View File

@ -3048,6 +3048,7 @@ py_library(
":state_ops",
":util",
":variable_scope",
"//tensorflow/python/data",
"//tensorflow/python/ops/losses",
],
)

View File

@ -688,22 +688,19 @@ class Estimator(object):
def _get_features_and_labels_from_input_fn(self, input_fn, mode):
"""Extracts the `features` and labels from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)
# TODO(anjalisridhar): What about the default DistributionStrategy? Perhaps
# using any input is alright in that case. There is also a
# has_dataset_or_queue_runner function that we may want to extend and use.
if (self._distribution is not None and
not isinstance(result, dataset_ops.Dataset) and
mode == model_fn_lib.ModeKeys.TRAIN):
raise ValueError('input_fn() must return a tf.data.Dataset when using a '
'DistributionStrategy.')
input_hooks = []
if isinstance(result, dataset_ops.Dataset):
if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
result = self._distribution.distribute_dataset(result)
if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
result = self._distribution.distribute_dataset(
lambda: self._call_input_fn(input_fn, mode))
iterator = result.make_initializable_iterator()
input_hooks.append(_DatasetInitializerHook(iterator))
result = iterator.get_next()
else:
result = self._call_input_fn(input_fn, mode)
if isinstance(result, dataset_ops.Dataset):
iterator = result.make_initializable_iterator()
input_hooks.append(_DatasetInitializerHook(iterator))
result = iterator.get_next()
if isinstance(result, (list, tuple)):
if len(result) != 2:
raise ValueError(

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import threading
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -672,25 +673,35 @@ class DistributionStrategy(object):
_require_distribution_strategy_scope(self)
return variable_scope.variable_creator_scope(create_colocated_variable)
def _call_dataset_fn(self, dataset_fn):
result = dataset_fn()
if not isinstance(result, dataset_ops.Dataset):
raise ValueError(
"dataset_fn() must return a tf.data.Dataset when using a "
"DistributionStrategy.")
return result
# TODO(josh11b): `PerDeviceDataset` currently only implements a few methods of
# Dataset API such as make_one_shot_iterator and make_initializable_iterator.
# Extend to implement more functionality of datasets.
def distribute_dataset(self, dataset):
def distribute_dataset(self, dataset_fn):
"""Return a `dataset` split across all towers.
Suitable for providing input to for `call_for_each_tower()` by creating an
iterator:
```
def dataset_fn():
return tf.data.Dataset.from_tensors([[1.]]).repeat()
with distribution_strategy.scope():
distributed_dataset = distribution_strategy.distribute_dataset(dataset)
distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn)
iterator = distributed_dataset.make_one_shot_iterator()
tower_results = distribution_strategy.call_for_each_tower(
tower_fn, iterator.get_next())
```
Args:
dataset: A `tf.data.Dataset`.
dataset_fn: A function that returns a `tf.data.Dataset`.
Returns:
A `PerDeviceDataset` that will produce data for each tower.
@ -1135,8 +1146,8 @@ class _DefaultDistributionStrategy(DistributionStrategy):
_require_distribution_strategy_scope(self)
return ops.colocate_with(colocate_with_variable)
def distribute_dataset(self, dataset):
return dataset
def distribute_dataset(self, dataset_fn):
return self._call_dataset_fn(dataset_fn)
def _broadcast(self, tensor, destinations):
if destinations is None: