Change distribution.distribute_dataset to accept an input_fn instead of a dataset.
PiperOrigin-RevId: 193437651
This commit is contained in:
parent
e9d47fbff0
commit
fddfa9f8dc
@ -54,21 +54,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss,
|
||||
is_tpu):
|
||||
with distribution.scope():
|
||||
model_fn, dataset, layer = minimize_loss_example(
|
||||
optimizer_fn,
|
||||
use_bias=True,
|
||||
use_callable_loss=use_callable_loss)
|
||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
|
||||
def tpu_dataset_fn():
|
||||
return dataset_fn().batch(2)
|
||||
# TODO(isaprykin): Eliminate `is_tpu`. Probably add a
|
||||
# `DistributionStrategy.create_monitor` so that each DistributionStrategy
|
||||
# could influence its training loop. That method would return an instance
|
||||
# of Monitor. TPUMonitor would execute tpu.initialize_system() and
|
||||
# tpu.shutdown_system().
|
||||
if is_tpu:
|
||||
dataset = dataset.batch(2)
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset).make_one_shot_iterator()
|
||||
tpu_dataset_fn if is_tpu else dataset_fn).make_one_shot_iterator()
|
||||
|
||||
def run_step():
|
||||
# TODO(isaprykin): Make iterator get_next() return a list of sub-
|
||||
@ -122,14 +119,14 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
# `distribution.scope`.
|
||||
with variable_scope.variable_creator_scope(
|
||||
appending_creator), distribution.scope():
|
||||
model_fn, dataset, layer = minimize_loss_example(
|
||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||
optimizer_fn,
|
||||
use_bias=True,
|
||||
use_callable_loss=True,
|
||||
create_optimizer_inside_model_fn=True)
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset).make_one_shot_iterator()
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
|
||||
def run_step():
|
||||
return distribution.group(
|
||||
@ -176,7 +173,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
"""Verifies that moving mean updates are reduced across towers."""
|
||||
with distribution.scope():
|
||||
num_towers = len(distribution.worker_devices)
|
||||
model_fn, dataset, batchnorm = batchnorm_example(
|
||||
model_fn, dataset_fn, batchnorm = batchnorm_example(
|
||||
optimizer_fn,
|
||||
batch_per_epoch=num_towers,
|
||||
momentum=momentum,
|
||||
@ -188,7 +185,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
if isinstance(distribution, mirrored_strategy.MirroredStrategy):
|
||||
distribution._prefetch_on_device = False
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset).make_one_shot_iterator()
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
|
||||
def run_step():
|
||||
return control_flow_ops.group(
|
||||
@ -260,11 +257,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
else:
|
||||
return optimizer.minimize(loss_fn())
|
||||
|
||||
features = dataset_ops.Dataset.from_tensors([[2.], [7.]])
|
||||
labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
|
||||
dataset = dataset_ops.Dataset.zip((features, labels)).repeat()
|
||||
def dataset_fn():
|
||||
features = dataset_ops.Dataset.from_tensors([[2.], [7.]])
|
||||
labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
|
||||
return dataset_ops.Dataset.zip((features, labels)).repeat()
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset).make_one_shot_iterator()
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
|
||||
def run_step():
|
||||
return distribution.group(
|
||||
|
@ -140,9 +140,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
|
||||
g.add_to_collections(collections, result)
|
||||
return result
|
||||
|
||||
def distribute_dataset(self, dataset):
|
||||
def distribute_dataset(self, dataset_fn):
|
||||
return values.PerDeviceDataset(
|
||||
dataset, self._devices, self._prefetch_on_device)
|
||||
self._call_dataset_fn(dataset_fn), self._devices,
|
||||
self._prefetch_on_device)
|
||||
|
||||
def _broadcast(self, tensor, destinations):
|
||||
# TODO(josh11b): In eager mode, use one thread per device, or async mode.
|
||||
|
@ -247,9 +247,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:CPU:0"])
|
||||
features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
|
||||
features = dist.distribute_dataset(
|
||||
features).make_one_shot_iterator().get_next()
|
||||
lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
|
||||
).make_one_shot_iterator().get_next()
|
||||
|
||||
with dist.scope():
|
||||
result = dist.call_for_each_tower(
|
||||
|
@ -60,8 +60,8 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
|
||||
with ops.colocate_with(colocate_with):
|
||||
return next_creator(*args, **kwargs)
|
||||
|
||||
def distribute_dataset(self, dataset):
|
||||
return dataset
|
||||
def distribute_dataset(self, dataset_fn):
|
||||
return self._call_dataset_fn(dataset_fn)
|
||||
|
||||
def _broadcast(self, tensor, destinations):
|
||||
return tensor
|
||||
|
@ -39,11 +39,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
|
||||
def testTrainNetwork(self, distribution, optimizer_fn,
|
||||
use_callable_loss=True):
|
||||
with distribution.scope():
|
||||
model_fn, dataset, layer = minimize_loss_example(
|
||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset).make_one_shot_iterator()
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
|
||||
def run_step():
|
||||
return control_flow_ops.group(distribution.unwrap(
|
||||
|
@ -29,7 +29,10 @@ from tensorflow.python.ops import math_ops
|
||||
|
||||
def single_loss_example(optimizer_fn, distribution, use_bias=False):
|
||||
"""Build a very simple network to use in tests and examples."""
|
||||
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
|
||||
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.from_tensors([[1.]]).repeat()
|
||||
|
||||
optimizer = optimizer_fn()
|
||||
layer = core.Dense(1, use_bias=use_bias)
|
||||
|
||||
@ -37,8 +40,8 @@ def single_loss_example(optimizer_fn, distribution, use_bias=False):
|
||||
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
|
||||
return y * y
|
||||
|
||||
single_loss_step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer,
|
||||
distribution)
|
||||
single_loss_step = step_fn.StandardSingleLossStep(dataset_fn, loss_fn,
|
||||
optimizer, distribution)
|
||||
|
||||
# Layer is returned for inspecting the kernels in tests.
|
||||
return single_loss_step, layer
|
||||
@ -49,7 +52,10 @@ def minimize_loss_example(optimizer_fn,
|
||||
use_callable_loss=True,
|
||||
create_optimizer_inside_model_fn=False):
|
||||
"""Example of non-distribution-aware legacy code."""
|
||||
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
|
||||
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.from_tensors([[1.]]).repeat()
|
||||
|
||||
# An Optimizer instance is created either outside or inside model_fn.
|
||||
outer_optimizer = None
|
||||
if not create_optimizer_inside_model_fn:
|
||||
@ -71,7 +77,7 @@ def minimize_loss_example(optimizer_fn,
|
||||
else:
|
||||
return optimizer.minimize(loss_fn())
|
||||
|
||||
return model_fn, dataset, layer
|
||||
return model_fn, dataset_fn, layer
|
||||
|
||||
|
||||
def batchnorm_example(optimizer_fn,
|
||||
@ -79,12 +85,15 @@ def batchnorm_example(optimizer_fn,
|
||||
momentum=0.9,
|
||||
renorm=False):
|
||||
"""Example of non-distribution-aware legacy code with batch normalization."""
|
||||
# input shape is [16, 8], input values are increasing in both dimensions.
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
[[[float(x * 8 + y + z * 100)
|
||||
for y in range(8)]
|
||||
for x in range(16)]
|
||||
for z in range(batch_per_epoch)]).repeat()
|
||||
|
||||
def dataset_fn():
|
||||
# input shape is [16, 8], input values are increasing in both dimensions.
|
||||
return dataset_ops.Dataset.from_tensor_slices(
|
||||
[[[float(x * 8 + y + z * 100)
|
||||
for y in range(8)]
|
||||
for x in range(16)]
|
||||
for z in range(batch_per_epoch)]).repeat()
|
||||
|
||||
optimizer = optimizer_fn()
|
||||
batchnorm = normalization.BatchNormalization(
|
||||
renorm=renorm, momentum=momentum, fused=False)
|
||||
@ -99,4 +108,4 @@ def batchnorm_example(optimizer_fn,
|
||||
# Callable loss.
|
||||
return optimizer.minimize(loss_fn)
|
||||
|
||||
return model_fn, dataset, batchnorm
|
||||
return model_fn, dataset_fn, batchnorm
|
||||
|
@ -49,13 +49,14 @@ class StandardInputStep(Step):
|
||||
"""Step with a standard implementation of input handling.
|
||||
|
||||
Args:
|
||||
input_dataset: a tf.data Dataset that provides input.
|
||||
dataset_fn: a function that returns a tf.data Dataset that produces the
|
||||
input for the model.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, distribution):
|
||||
def __init__(self, dataset_fn, distribution):
|
||||
Step.__init__(self, distribution)
|
||||
self._distributed_input = distribution.distribute_dataset(
|
||||
input_dataset).make_one_shot_iterator()
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
|
||||
def inputs(self):
|
||||
return self._distributed_input.get_next()
|
||||
@ -77,14 +78,15 @@ class StandardSingleLossStep(StandardInputStep):
|
||||
```
|
||||
|
||||
Args:
|
||||
input_dataset: a tf.data Dataset that provides input.
|
||||
dataset_fn: a function that returns a tf.data Dataset that produces the
|
||||
input for the model.
|
||||
loss_fn: a function that returns loss.
|
||||
optimizer: an optimizer that implements an update rule.
|
||||
distribution: a `DistributionStrategy` object.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, loss_fn, optimizer, distribution):
|
||||
StandardInputStep.__init__(self, input_dataset, distribution)
|
||||
def __init__(self, dataset_fn, loss_fn, optimizer, distribution):
|
||||
StandardInputStep.__init__(self, dataset_fn, distribution)
|
||||
self._loss_fn = loss_fn
|
||||
self._optimizer = optimizer
|
||||
self._is_run_concurrently = False
|
||||
|
@ -3048,6 +3048,7 @@ py_library(
|
||||
":state_ops",
|
||||
":util",
|
||||
":variable_scope",
|
||||
"//tensorflow/python/data",
|
||||
"//tensorflow/python/ops/losses",
|
||||
],
|
||||
)
|
||||
|
@ -688,22 +688,19 @@ class Estimator(object):
|
||||
|
||||
def _get_features_and_labels_from_input_fn(self, input_fn, mode):
|
||||
"""Extracts the `features` and labels from return values of `input_fn`."""
|
||||
result = self._call_input_fn(input_fn, mode)
|
||||
# TODO(anjalisridhar): What about the default DistributionStrategy? Perhaps
|
||||
# using any input is alright in that case. There is also a
|
||||
# has_dataset_or_queue_runner function that we may want to extend and use.
|
||||
if (self._distribution is not None and
|
||||
not isinstance(result, dataset_ops.Dataset) and
|
||||
mode == model_fn_lib.ModeKeys.TRAIN):
|
||||
raise ValueError('input_fn() must return a tf.data.Dataset when using a '
|
||||
'DistributionStrategy.')
|
||||
input_hooks = []
|
||||
if isinstance(result, dataset_ops.Dataset):
|
||||
if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
result = self._distribution.distribute_dataset(result)
|
||||
if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
result = self._distribution.distribute_dataset(
|
||||
lambda: self._call_input_fn(input_fn, mode))
|
||||
iterator = result.make_initializable_iterator()
|
||||
input_hooks.append(_DatasetInitializerHook(iterator))
|
||||
result = iterator.get_next()
|
||||
else:
|
||||
result = self._call_input_fn(input_fn, mode)
|
||||
if isinstance(result, dataset_ops.Dataset):
|
||||
iterator = result.make_initializable_iterator()
|
||||
input_hooks.append(_DatasetInitializerHook(iterator))
|
||||
result = iterator.get_next()
|
||||
if isinstance(result, (list, tuple)):
|
||||
if len(result) != 2:
|
||||
raise ValueError(
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -672,25 +673,35 @@ class DistributionStrategy(object):
|
||||
_require_distribution_strategy_scope(self)
|
||||
return variable_scope.variable_creator_scope(create_colocated_variable)
|
||||
|
||||
def _call_dataset_fn(self, dataset_fn):
|
||||
result = dataset_fn()
|
||||
if not isinstance(result, dataset_ops.Dataset):
|
||||
raise ValueError(
|
||||
"dataset_fn() must return a tf.data.Dataset when using a "
|
||||
"DistributionStrategy.")
|
||||
return result
|
||||
|
||||
# TODO(josh11b): `PerDeviceDataset` currently only implements a few methods of
|
||||
# Dataset API such as make_one_shot_iterator and make_initializable_iterator.
|
||||
# Extend to implement more functionality of datasets.
|
||||
def distribute_dataset(self, dataset):
|
||||
def distribute_dataset(self, dataset_fn):
|
||||
"""Return a `dataset` split across all towers.
|
||||
|
||||
Suitable for providing input to for `call_for_each_tower()` by creating an
|
||||
iterator:
|
||||
|
||||
```
|
||||
def dataset_fn():
|
||||
return tf.data.Dataset.from_tensors([[1.]]).repeat()
|
||||
with distribution_strategy.scope():
|
||||
distributed_dataset = distribution_strategy.distribute_dataset(dataset)
|
||||
distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn)
|
||||
iterator = distributed_dataset.make_one_shot_iterator()
|
||||
tower_results = distribution_strategy.call_for_each_tower(
|
||||
tower_fn, iterator.get_next())
|
||||
```
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset`.
|
||||
dataset_fn: A function that returns a `tf.data.Dataset`.
|
||||
|
||||
Returns:
|
||||
A `PerDeviceDataset` that will produce data for each tower.
|
||||
@ -1135,8 +1146,8 @@ class _DefaultDistributionStrategy(DistributionStrategy):
|
||||
_require_distribution_strategy_scope(self)
|
||||
return ops.colocate_with(colocate_with_variable)
|
||||
|
||||
def distribute_dataset(self, dataset):
|
||||
return dataset
|
||||
def distribute_dataset(self, dataset_fn):
|
||||
return self._call_dataset_fn(dataset_fn)
|
||||
|
||||
def _broadcast(self, tensor, destinations):
|
||||
if destinations is None:
|
||||
|
Loading…
Reference in New Issue
Block a user