Rolling forward changing distribution strategies to use MultiDeviceIterator.

The underlying issue with NaN's has now been resolved.

PiperOrigin-RevId: 217014692
This commit is contained in:
Rohan Jain 2018-10-13 17:42:53 -07:00 committed by TensorFlower Gardener
parent f667188dae
commit 47b04fdb3e
13 changed files with 90 additions and 393 deletions

View File

@ -22,7 +22,6 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":input_ops",
":prefetching_ops_v2",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
@ -31,6 +30,7 @@ py_library(
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
@ -666,32 +666,6 @@ cuda_py_test(
],
)
py_library(
name = "prefetching_ops_v2",
srcs = ["prefetching_ops_v2.py"],
deps = [
"//tensorflow/contrib/data/python/ops:prefetching_ops",
"//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
],
)
cuda_py_test(
name = "prefetching_ops_v2_test",
srcs = ["prefetching_ops_v2_test.py"],
additional_deps = [
":prefetching_ops_v2",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
)
py_library(
name = "input_ops",
srcs = ["input_ops.py"],

View File

@ -96,7 +96,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
dataset_fn).make_initializable_iterator()
if isinstance(distribution, tpu_strategy.TPUStrategy):
def step_fn(ctx, inputs):
value, update = distribution.call_for_each_tower(
@ -120,6 +120,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
# replace "distribution.num_towers" with "1".
batches_per_update = distribution.num_towers
self.evaluate(iterator.initializer)
self.evaluate(distribution.initialize())
self.evaluate(variables.local_variables_initializer())

View File

@ -41,6 +41,14 @@ from tensorflow.python.ops.losses import losses_impl
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def _get_iterator(self, ds):
if context.executing_eagerly():
iterator = ds.make_one_shot_iterator()
else:
iterator = ds.make_initializable_iterator()
self.evaluate(iterator.initializer)
return iterator
@combinations.generate(
combinations.times(
combinations.distributions_and_v1_optimizers(),
@ -62,8 +70,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@ -99,8 +106,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.group(
@ -159,8 +165,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@ -239,8 +244,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
return control_flow_ops.group(fetches)
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@ -333,8 +337,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, x, y, run_concurrently=False))
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@ -427,8 +430,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output=loss)
return distribution.group(train_op)
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
initial_loss = lambda: constant_op.constant(1e7)

View File

@ -307,9 +307,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
features = dist.distribute_dataset(
lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
).make_one_shot_iterator().get_next()
ds = dist.distribute_dataset(
lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
if context.executing_eagerly():
iterator = ds.make_one_shot_iterator()
else:
iterator = ds.make_initializable_iterator()
self.evaluate([iterator.initializer])
features = iterator.get_next()
with dist.scope():
result = dist.call_for_each_tower(

View File

@ -51,6 +51,7 @@ class Monitor(object):
else:
if session is None:
raise ValueError("Should provide a `session` in Graph mode.")
session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
self._run_step = session.make_callable(step_callable())
session.run(variables.global_variables_initializer())

View File

@ -42,8 +42,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
ds = distribution.distribute_dataset(dataset_fn)
if context.executing_eagerly():
iterator = ds.make_one_shot_iterator()
else:
iterator = ds.make_initializable_iterator()
def run_step():
return control_flow_ops.group(distribution.unwrap(
@ -52,6 +55,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.cached_session() as sess:
sess.run(iterator.initializer)
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())

View File

@ -1,232 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Extension of prefetching_ops to support more than one device."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest as data_nest
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util import nest
# pylint: disable=protected-access
class _PrefetchToDeviceIterator(object):
"""A replacement for `tf.data.Iterator` that prefetches to another device.
Args:
input_dataset: The input dataset.
one_shot: If true, we make a one shot iterator that's already initialized.
devices: Devices on which to prefetch.
buffer_size: Size of the prefetching buffer.
shared_name: (Optional.) If non-empty, the returned iterator will be shared
under the given name across multiple sessions that share the same devices
(e.g. when using a remote server). Only used if one_shot is False.
Returns:
An Iterator type object.
"""
def __init__(self,
input_dataset,
one_shot,
devices,
buffer_size,
shared_name=None):
self._input_dataset = input_dataset
self._get_next_call_count = 0
self._one_shot = one_shot
if shared_name is None:
shared_name = ""
self._devices = devices
if self._one_shot:
self._input_iterator = input_dataset.make_one_shot_iterator()
else:
self._input_iterator = iterator_ops.Iterator.from_structure(
self._input_dataset.output_types, self._input_dataset.output_shapes,
shared_name, self._input_dataset.output_classes)
input_iterator_handle = self._input_iterator.string_handle()
@function.Defun(dtypes.string)
def _prefetch_fn(handle):
"""Prefetches one element from `input_iterator`."""
remote_iterator = iterator_ops.Iterator.from_string_handle(
handle, self._input_iterator.output_types,
self._input_iterator.output_shapes,
self._input_iterator.output_classes)
ret = remote_iterator.get_next()
return nest.flatten(sparse.serialize_sparse_tensors(ret))
target_device = ged_ops.experimental_iterator_get_device(
self._input_iterator._iterator_resource)
self._buffering_resources = []
for device in nest.flatten(self._devices):
with ops.device(device):
buffer_resource_handle = prefetching_ops.function_buffering_resource(
f=_prefetch_fn,
output_types=data_nest.flatten(
sparse.as_dense_types(self._input_dataset.output_types,
self._input_dataset.output_classes)),
target_device=target_device,
string_arg=input_iterator_handle,
buffer_size=buffer_size,
shared_name=shared_name)
self._buffering_resources.append(buffer_resource_handle)
if not self._one_shot:
reset_ops = []
for buffer_resource in self._buffering_resources:
reset_ops.append(
ged_ops.experimental_function_buffering_resource_reset(
buffer_resource))
with ops.control_dependencies(reset_ops):
self._initializer = self._input_iterator.make_initializer(
self._input_dataset)
def get_next(self, name=None):
"""See `tf.data.Iterator.get_next`."""
self._get_next_call_count += 1
if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
flat_result = []
# TODO(priyag): This will fail if the input size (typically number of
# batches) is not divisible by number of devices.
# How do we handle that more gracefully / let the user know?
for buffer_resource in self._buffering_resources:
flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
buffer_resource,
output_types=data_nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)),
name=name)
ret = sparse.deserialize_sparse_tensors(
data_nest.pack_sequence_as(self.output_types, flat_ret),
self.output_types, self.output_shapes, self.output_classes)
for tensor, shape in zip(
data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
if isinstance(tensor, ops.Tensor):
tensor.set_shape(shape)
flat_result.append(ret)
return nest.pack_sequence_as(self._devices, flat_result)
@property
def initializer(self):
if self._one_shot:
raise NotImplementedError("Can't initialize a one_shot_iterator")
return self._initializer
@property
def output_classes(self):
return self._input_dataset.output_classes
@property
def output_shapes(self):
return self._input_dataset.output_shapes
@property
def output_types(self):
return self._input_dataset.output_types
# pylint: enable=protected-access
class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` whose iterator prefetches elements to other device(s)."""
def __init__(self, input_dataset, devices, buffer_size):
super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._devices = devices
self._buffer_size = buffer_size if buffer_size is not None else 1
def make_one_shot_iterator(self):
return _PrefetchToDeviceIterator(
self._input_dataset,
one_shot=True,
devices=self._devices,
buffer_size=self._buffer_size)
def make_initializable_iterator(self, shared_name=None):
if context.executing_eagerly():
raise RuntimeError(
"make_initializable_iterator is not supported when eager "
"execution is enabled.")
return _PrefetchToDeviceIterator(
self._input_dataset,
one_shot=False,
devices=self._devices,
buffer_size=self._buffer_size,
shared_name=shared_name)
def _as_variant_tensor(self):
# TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
# transformation methods is called.
# TODO(mrry): Investigate support for chaining further transformations after
# the prefetch, including GPU support.
raise NotImplementedError("`prefetch_to_devices()` must be the last "
"transformation in a dataset pipeline.")
# TODO(priyag): Fix the output types, shapes and classes to match the result
# of get_next (which has the additional nesting layer of devices now).
@property
def output_types(self):
return self._input_dataset.output_types
@property
def output_shapes(self):
return self._input_dataset.output_shapes
@property
def output_classes(self):
return self._input_dataset.output_classes
def prefetch_to_devices(devices, buffer_size=None):
"""A transformation that prefetches dataset values to the given `devices`.
NOTE: Although the transformation creates a `tf.data.Dataset`, the
transformation must be the final `Dataset` in the input pipeline.
Args:
devices: A nested structure of devices on which to prefetch the data. It can
be a single device name, or a tuple or list of device names.
buffer_size: (Optional.) The number of elements to buffer on each device.
Defaults to an automatically chosen value.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
return _apply_fn

View File

@ -1,90 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for prefetching_ops_v2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
class PrefetchingOpsV2Test(test.TestCase):
def testPrefetchToOneDevice(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchToTwoDevicesInAList(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
output = []
# TODO(rohanj): Modify test to go till the end of the dataset when we
# switch to MultiDeviceIterator.
with self.cached_session() as sess:
for _ in range(4):
result = sess.run(next_element)
self.assertEqual(2, len(result))
output.extend(result)
self.assertEquals(set(range(8)), set(output))
def testPrefetchToTwoDevicesWithReinit(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
# TODO(rohanj): Modify test to go till the end of the dataset when we
# switch to MultiDeviceIterator.
with self.cached_session() as sess:
sess.run(iterator.initializer)
for _ in range(4):
sess.run(next_element)
sess.run(iterator.initializer)
for _ in range(4):
sess.run(next_element)
if __name__ == "__main__":
test.main()

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.training import optimizer as optimizer_lib
@ -50,7 +51,11 @@ class StandardInputStep(Step):
def __init__(self, dataset_fn, distribution):
super(StandardInputStep, self).__init__(distribution)
self._distributed_input = distribution.distribute_dataset(dataset_fn)
self._iterator = self._distributed_input.make_one_shot_iterator()
if context.executing_eagerly():
self._iterator = self._distributed_input.make_one_shot_iterator()
else:
# TODO(priyag): Expose initializer via some initializer property.
self._iterator = self._distributed_input.make_initializable_iterator()
class StandardSingleLossStep(StandardInputStep):

View File

@ -50,6 +50,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
run_step = single_loss_step
else:
with self.cached_session() as sess:
sess.run(single_loss_step._iterator.initializer)
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())

View File

@ -27,7 +27,7 @@ import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import device as tf_device
@ -1089,7 +1089,7 @@ class PerDeviceDataIterator(object):
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
data_list = self._iterator.get_next(name=name)
data_list = self._iterator.get_next()
index = dict(zip(self._devices, data_list))
else:
batch = self._iterator.get_next(name=name)
@ -1113,17 +1113,15 @@ class PerDeviceDataset(object):
self._devices = devices
# Default to using prefetching in graph mode, unless specified.
# TODO(priyag): Enable prefetching in eager mode.
# TODO(rohanj): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
if self._prefetch_on_device:
self._dataset = dataset.apply(
prefetching_ops_v2.prefetch_to_devices(self._devices))
else:
self._dataset = dataset
if not self._prefetch_on_device:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
@ -1131,15 +1129,33 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
# Graph mode with one shot iterator is disabled.
if not context.executing_eagerly():
raise ValueError("Cannot create a one shot iterator. Please use "
"`make_initializable_iterator()` instead.")
# Eager mode prefetching would error out in constructor. Only remaining
# case is non-prefetching in eager mode. We delegate to
# PerDeviceDataIterator to handle that case.
dataset_iterator = self._dataset.make_one_shot_iterator()
return PerDeviceDataIterator(dataset_iterator, self._devices,
self._prefetch_on_device)
return PerDeviceDataIterator(
dataset_iterator, self._devices, prefetch_on_device=False)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
dataset_iterator = self._dataset.make_initializable_iterator()
return PerDeviceDataIterator(dataset_iterator, self._devices,
self._prefetch_on_device)
# Eager mode generates already initialized iterators. Hence we cannot create
# an initializable iterator.
if context.executing_eagerly():
raise ValueError("Cannot create initializable iterator in Eager mode. "
"Please use `make_one_shot_iterator` instead.")
if self._prefetch_on_device:
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._devices)
else:
dataset_iterator = self._dataset.make_initializable_iterator()
return PerDeviceDataIterator(
dataset_iterator,
self._devices,
prefetch_on_device=self._prefetch_on_device)
class MultiWorkerDataIterator(object):

View File

@ -349,7 +349,11 @@ class PerDeviceDatasetTest(test.TestCase):
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=False)
iterator = per_device_dataset.make_one_shot_iterator()
if context.executing_eagerly():
iterator = per_device_dataset.make_one_shot_iterator()
else:
iterator = per_device_dataset.make_initializable_iterator()
self.evaluate([iterator.initializer])
for expected_value in expected_values:
next_element = iterator.get_next()
@ -366,21 +370,14 @@ class PerDeviceDatasetTest(test.TestCase):
if not context.executing_eagerly():
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=True)
iterator = per_device_dataset.make_one_shot_iterator()
iterator = per_device_dataset.make_initializable_iterator()
self.evaluate([iterator.initializer])
# With prefetching, we cannot guarantee which input ends up on which
# device, so we verify that the complete set seen on all devices is
# correct, and equal numbers are distributed to each device.
combined_actual = []
combined_expected = []
for expected_value in expected_values:
next_element = iterator.get_next()
combined_actual.extend(
self.evaluate(
[values.select_device(d, next_element) for d in devices]))
combined_expected.extend(expected_value)
self.assertEqual(set(combined_expected), set(combined_actual))
computed_value = self.evaluate(
[values.select_device(d, next_element) for d in devices])
self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()

View File

@ -229,3 +229,15 @@ class MultiDeviceIterator(object):
@property
def initializer(self):
return self._initializer
@property
def output_types(self):
return self._dataset.output_types
@property
def output_shapes(self):
return self._dataset.output_shapes
@property
def output_classes(self):
return self._dataset.output_classes