Rolling forward changing distribution strategies to use MultiDeviceIterator.
The underlying issue with NaN's has now been resolved. PiperOrigin-RevId: 217014692
This commit is contained in:
parent
f667188dae
commit
47b04fdb3e
@ -22,7 +22,6 @@ py_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":input_ops",
|
||||
":prefetching_ops_v2",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:device_util",
|
||||
@ -31,6 +30,7 @@ py_library(
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:multi_device_iterator_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"@six_archive//:six",
|
||||
@ -666,32 +666,6 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "prefetching_ops_v2",
|
||||
srcs = ["prefetching_ops_v2.py"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:prefetching_ops",
|
||||
"//tensorflow/python:experimental_dataset_ops_gen",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "prefetching_ops_v2_test",
|
||||
srcs = ["prefetching_ops_v2_test.py"],
|
||||
additional_deps = [
|
||||
":prefetching_ops_v2",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "input_ops",
|
||||
srcs = ["input_ops.py"],
|
||||
|
@ -96,7 +96,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
||||
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
|
||||
with ops.Graph().as_default(), distribution.scope():
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
dataset_fn).make_initializable_iterator()
|
||||
if isinstance(distribution, tpu_strategy.TPUStrategy):
|
||||
def step_fn(ctx, inputs):
|
||||
value, update = distribution.call_for_each_tower(
|
||||
@ -120,6 +120,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
||||
# replace "distribution.num_towers" with "1".
|
||||
batches_per_update = distribution.num_towers
|
||||
|
||||
self.evaluate(iterator.initializer)
|
||||
self.evaluate(distribution.initialize())
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
|
||||
|
@ -41,6 +41,14 @@ from tensorflow.python.ops.losses import losses_impl
|
||||
|
||||
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _get_iterator(self, ds):
|
||||
if context.executing_eagerly():
|
||||
iterator = ds.make_one_shot_iterator()
|
||||
else:
|
||||
iterator = ds.make_initializable_iterator()
|
||||
self.evaluate(iterator.initializer)
|
||||
return iterator
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.distributions_and_v1_optimizers(),
|
||||
@ -62,8 +70,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
distribution.call_for_each_tower(
|
||||
model_fn, *inputs, run_concurrently=layer.built))
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
|
||||
def run_step():
|
||||
return distribution.run_steps_on_dataset(
|
||||
@ -99,8 +106,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
|
||||
def run_step():
|
||||
return distribution.group(
|
||||
@ -159,8 +165,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
distribution.call_for_each_tower(
|
||||
model_fn, *inputs, run_concurrently=layer.built))
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
|
||||
def run_step():
|
||||
return distribution.run_steps_on_dataset(
|
||||
@ -239,8 +244,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
|
||||
return control_flow_ops.group(fetches)
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
|
||||
def run_step():
|
||||
return distribution.run_steps_on_dataset(
|
||||
@ -333,8 +337,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
distribution.call_for_each_tower(
|
||||
model_fn, x, y, run_concurrently=False))
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
|
||||
def run_step():
|
||||
return distribution.run_steps_on_dataset(
|
||||
@ -427,8 +430,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
output=loss)
|
||||
return distribution.group(train_op)
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
|
||||
def run_step():
|
||||
initial_loss = lambda: constant_op.constant(1e7)
|
||||
|
@ -307,9 +307,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:CPU:0"])
|
||||
features = dist.distribute_dataset(
|
||||
lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
|
||||
).make_one_shot_iterator().get_next()
|
||||
ds = dist.distribute_dataset(
|
||||
lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
||||
if context.executing_eagerly():
|
||||
iterator = ds.make_one_shot_iterator()
|
||||
else:
|
||||
iterator = ds.make_initializable_iterator()
|
||||
self.evaluate([iterator.initializer])
|
||||
|
||||
features = iterator.get_next()
|
||||
|
||||
with dist.scope():
|
||||
result = dist.call_for_each_tower(
|
||||
|
@ -51,6 +51,7 @@ class Monitor(object):
|
||||
else:
|
||||
if session is None:
|
||||
raise ValueError("Should provide a `session` in Graph mode.")
|
||||
session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
|
||||
self._run_step = session.make_callable(step_callable())
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
|
@ -42,8 +42,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
|
||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
ds = distribution.distribute_dataset(dataset_fn)
|
||||
if context.executing_eagerly():
|
||||
iterator = ds.make_one_shot_iterator()
|
||||
else:
|
||||
iterator = ds.make_initializable_iterator()
|
||||
|
||||
def run_step():
|
||||
return control_flow_ops.group(distribution.unwrap(
|
||||
@ -52,6 +55,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
if not context.executing_eagerly():
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
run_step = sess.make_callable(run_step())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
|
@ -1,232 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Extension of prefetching_ops to support more than one device."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import warnings
|
||||
|
||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import nest as data_nest
|
||||
from tensorflow.python.data.util import sparse
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class _PrefetchToDeviceIterator(object):
|
||||
"""A replacement for `tf.data.Iterator` that prefetches to another device.
|
||||
|
||||
Args:
|
||||
input_dataset: The input dataset.
|
||||
one_shot: If true, we make a one shot iterator that's already initialized.
|
||||
devices: Devices on which to prefetch.
|
||||
buffer_size: Size of the prefetching buffer.
|
||||
shared_name: (Optional.) If non-empty, the returned iterator will be shared
|
||||
under the given name across multiple sessions that share the same devices
|
||||
(e.g. when using a remote server). Only used if one_shot is False.
|
||||
|
||||
Returns:
|
||||
An Iterator type object.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_dataset,
|
||||
one_shot,
|
||||
devices,
|
||||
buffer_size,
|
||||
shared_name=None):
|
||||
self._input_dataset = input_dataset
|
||||
self._get_next_call_count = 0
|
||||
self._one_shot = one_shot
|
||||
if shared_name is None:
|
||||
shared_name = ""
|
||||
self._devices = devices
|
||||
|
||||
if self._one_shot:
|
||||
self._input_iterator = input_dataset.make_one_shot_iterator()
|
||||
else:
|
||||
self._input_iterator = iterator_ops.Iterator.from_structure(
|
||||
self._input_dataset.output_types, self._input_dataset.output_shapes,
|
||||
shared_name, self._input_dataset.output_classes)
|
||||
input_iterator_handle = self._input_iterator.string_handle()
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _prefetch_fn(handle):
|
||||
"""Prefetches one element from `input_iterator`."""
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
handle, self._input_iterator.output_types,
|
||||
self._input_iterator.output_shapes,
|
||||
self._input_iterator.output_classes)
|
||||
ret = remote_iterator.get_next()
|
||||
return nest.flatten(sparse.serialize_sparse_tensors(ret))
|
||||
|
||||
target_device = ged_ops.experimental_iterator_get_device(
|
||||
self._input_iterator._iterator_resource)
|
||||
self._buffering_resources = []
|
||||
for device in nest.flatten(self._devices):
|
||||
with ops.device(device):
|
||||
buffer_resource_handle = prefetching_ops.function_buffering_resource(
|
||||
f=_prefetch_fn,
|
||||
output_types=data_nest.flatten(
|
||||
sparse.as_dense_types(self._input_dataset.output_types,
|
||||
self._input_dataset.output_classes)),
|
||||
target_device=target_device,
|
||||
string_arg=input_iterator_handle,
|
||||
buffer_size=buffer_size,
|
||||
shared_name=shared_name)
|
||||
self._buffering_resources.append(buffer_resource_handle)
|
||||
|
||||
if not self._one_shot:
|
||||
reset_ops = []
|
||||
for buffer_resource in self._buffering_resources:
|
||||
reset_ops.append(
|
||||
ged_ops.experimental_function_buffering_resource_reset(
|
||||
buffer_resource))
|
||||
with ops.control_dependencies(reset_ops):
|
||||
self._initializer = self._input_iterator.make_initializer(
|
||||
self._input_dataset)
|
||||
|
||||
def get_next(self, name=None):
|
||||
"""See `tf.data.Iterator.get_next`."""
|
||||
self._get_next_call_count += 1
|
||||
if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
|
||||
warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
|
||||
|
||||
flat_result = []
|
||||
# TODO(priyag): This will fail if the input size (typically number of
|
||||
# batches) is not divisible by number of devices.
|
||||
# How do we handle that more gracefully / let the user know?
|
||||
for buffer_resource in self._buffering_resources:
|
||||
flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
|
||||
buffer_resource,
|
||||
output_types=data_nest.flatten(
|
||||
sparse.as_dense_types(self.output_types, self.output_classes)),
|
||||
name=name)
|
||||
|
||||
ret = sparse.deserialize_sparse_tensors(
|
||||
data_nest.pack_sequence_as(self.output_types, flat_ret),
|
||||
self.output_types, self.output_shapes, self.output_classes)
|
||||
|
||||
for tensor, shape in zip(
|
||||
data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
|
||||
if isinstance(tensor, ops.Tensor):
|
||||
tensor.set_shape(shape)
|
||||
flat_result.append(ret)
|
||||
|
||||
return nest.pack_sequence_as(self._devices, flat_result)
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
if self._one_shot:
|
||||
raise NotImplementedError("Can't initialize a one_shot_iterator")
|
||||
return self._initializer
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._input_dataset.output_classes
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._input_dataset.output_shapes
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._input_dataset.output_types
|
||||
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` whose iterator prefetches elements to other device(s)."""
|
||||
|
||||
def __init__(self, input_dataset, devices, buffer_size):
|
||||
super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._devices = devices
|
||||
self._buffer_size = buffer_size if buffer_size is not None else 1
|
||||
|
||||
def make_one_shot_iterator(self):
|
||||
return _PrefetchToDeviceIterator(
|
||||
self._input_dataset,
|
||||
one_shot=True,
|
||||
devices=self._devices,
|
||||
buffer_size=self._buffer_size)
|
||||
|
||||
def make_initializable_iterator(self, shared_name=None):
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"make_initializable_iterator is not supported when eager "
|
||||
"execution is enabled.")
|
||||
|
||||
return _PrefetchToDeviceIterator(
|
||||
self._input_dataset,
|
||||
one_shot=False,
|
||||
devices=self._devices,
|
||||
buffer_size=self._buffer_size,
|
||||
shared_name=shared_name)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
|
||||
# transformation methods is called.
|
||||
# TODO(mrry): Investigate support for chaining further transformations after
|
||||
# the prefetch, including GPU support.
|
||||
raise NotImplementedError("`prefetch_to_devices()` must be the last "
|
||||
"transformation in a dataset pipeline.")
|
||||
|
||||
# TODO(priyag): Fix the output types, shapes and classes to match the result
|
||||
# of get_next (which has the additional nesting layer of devices now).
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._input_dataset.output_types
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._input_dataset.output_shapes
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._input_dataset.output_classes
|
||||
|
||||
|
||||
def prefetch_to_devices(devices, buffer_size=None):
|
||||
"""A transformation that prefetches dataset values to the given `devices`.
|
||||
|
||||
NOTE: Although the transformation creates a `tf.data.Dataset`, the
|
||||
transformation must be the final `Dataset` in the input pipeline.
|
||||
|
||||
Args:
|
||||
devices: A nested structure of devices on which to prefetch the data. It can
|
||||
be a single device name, or a tuple or list of device names.
|
||||
buffer_size: (Optional.) The number of elements to buffer on each device.
|
||||
Defaults to an automatically chosen value.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
`tf.data.Dataset.apply`.
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
|
||||
|
||||
return _apply_fn
|
@ -1,90 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for prefetching_ops_v2."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distribute.python import prefetching_ops_v2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class PrefetchingOpsV2Test(test.TestCase):
|
||||
|
||||
def testPrefetchToOneDevice(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
|
||||
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchToTwoDevicesInAList(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
|
||||
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
output = []
|
||||
# TODO(rohanj): Modify test to go till the end of the dataset when we
|
||||
# switch to MultiDeviceIterator.
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(4):
|
||||
result = sess.run(next_element)
|
||||
self.assertEqual(2, len(result))
|
||||
output.extend(result)
|
||||
self.assertEquals(set(range(8)), set(output))
|
||||
|
||||
def testPrefetchToTwoDevicesWithReinit(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
|
||||
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
# TODO(rohanj): Modify test to go till the end of the dataset when we
|
||||
# switch to MultiDeviceIterator.
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for _ in range(4):
|
||||
sess.run(next_element)
|
||||
sess.run(iterator.initializer)
|
||||
for _ in range(4):
|
||||
sess.run(next_element)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.training import optimizer as optimizer_lib
|
||||
|
||||
|
||||
@ -50,7 +51,11 @@ class StandardInputStep(Step):
|
||||
def __init__(self, dataset_fn, distribution):
|
||||
super(StandardInputStep, self).__init__(distribution)
|
||||
self._distributed_input = distribution.distribute_dataset(dataset_fn)
|
||||
if context.executing_eagerly():
|
||||
self._iterator = self._distributed_input.make_one_shot_iterator()
|
||||
else:
|
||||
# TODO(priyag): Expose initializer via some initializer property.
|
||||
self._iterator = self._distributed_input.make_initializable_iterator()
|
||||
|
||||
|
||||
class StandardSingleLossStep(StandardInputStep):
|
||||
|
@ -50,6 +50,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
run_step = single_loss_step
|
||||
else:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(single_loss_step._iterator.initializer)
|
||||
run_step = sess.make_callable(single_loss_step())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
|
@ -27,7 +27,7 @@ import weakref
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.distribute.python import input_ops
|
||||
from tensorflow.contrib.distribute.python import prefetching_ops_v2
|
||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
@ -1089,7 +1089,7 @@ class PerDeviceDataIterator(object):
|
||||
def get_next(self, name=None):
|
||||
"""Scatter the input across devices."""
|
||||
if self._prefetch_on_device:
|
||||
data_list = self._iterator.get_next(name=name)
|
||||
data_list = self._iterator.get_next()
|
||||
index = dict(zip(self._devices, data_list))
|
||||
else:
|
||||
batch = self._iterator.get_next(name=name)
|
||||
@ -1113,17 +1113,15 @@ class PerDeviceDataset(object):
|
||||
self._devices = devices
|
||||
|
||||
# Default to using prefetching in graph mode, unless specified.
|
||||
# TODO(priyag): Enable prefetching in eager mode.
|
||||
# TODO(rohanj): Enable prefetching in eager mode.
|
||||
self._prefetch_on_device = prefetch_on_device
|
||||
if self._prefetch_on_device is None:
|
||||
self._prefetch_on_device = not context.executing_eagerly()
|
||||
assert not (self._prefetch_on_device and context.executing_eagerly()), (
|
||||
"Prefetching is only supported in graph mode currently")
|
||||
|
||||
if self._prefetch_on_device:
|
||||
self._dataset = dataset.apply(
|
||||
prefetching_ops_v2.prefetch_to_devices(self._devices))
|
||||
else:
|
||||
self._dataset = dataset
|
||||
if not self._prefetch_on_device:
|
||||
# TODO(priyag): If dropping remainder is not appropriate, find another
|
||||
# approach to distributing the dataset when not possible to divide evenly.
|
||||
# Possibly not an issue when we start using PartitionedDataset.
|
||||
@ -1131,15 +1129,33 @@ class PerDeviceDataset(object):
|
||||
|
||||
def make_one_shot_iterator(self):
|
||||
"""Get a one time use iterator for the distributed PerDeviceDataset."""
|
||||
# Graph mode with one shot iterator is disabled.
|
||||
if not context.executing_eagerly():
|
||||
raise ValueError("Cannot create a one shot iterator. Please use "
|
||||
"`make_initializable_iterator()` instead.")
|
||||
# Eager mode prefetching would error out in constructor. Only remaining
|
||||
# case is non-prefetching in eager mode. We delegate to
|
||||
# PerDeviceDataIterator to handle that case.
|
||||
dataset_iterator = self._dataset.make_one_shot_iterator()
|
||||
return PerDeviceDataIterator(dataset_iterator, self._devices,
|
||||
self._prefetch_on_device)
|
||||
return PerDeviceDataIterator(
|
||||
dataset_iterator, self._devices, prefetch_on_device=False)
|
||||
|
||||
def make_initializable_iterator(self):
|
||||
"""Get an initializable iterator for the distributed PerDeviceDataset."""
|
||||
# Eager mode generates already initialized iterators. Hence we cannot create
|
||||
# an initializable iterator.
|
||||
if context.executing_eagerly():
|
||||
raise ValueError("Cannot create initializable iterator in Eager mode. "
|
||||
"Please use `make_one_shot_iterator` instead.")
|
||||
if self._prefetch_on_device:
|
||||
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
self._dataset, self._devices)
|
||||
else:
|
||||
dataset_iterator = self._dataset.make_initializable_iterator()
|
||||
return PerDeviceDataIterator(dataset_iterator, self._devices,
|
||||
self._prefetch_on_device)
|
||||
return PerDeviceDataIterator(
|
||||
dataset_iterator,
|
||||
self._devices,
|
||||
prefetch_on_device=self._prefetch_on_device)
|
||||
|
||||
|
||||
class MultiWorkerDataIterator(object):
|
||||
|
@ -349,7 +349,11 @@ class PerDeviceDatasetTest(test.TestCase):
|
||||
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
|
||||
per_device_dataset = values.PerDeviceDataset(
|
||||
dataset, devices, prefetch_on_device=False)
|
||||
if context.executing_eagerly():
|
||||
iterator = per_device_dataset.make_one_shot_iterator()
|
||||
else:
|
||||
iterator = per_device_dataset.make_initializable_iterator()
|
||||
self.evaluate([iterator.initializer])
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
@ -366,21 +370,14 @@ class PerDeviceDatasetTest(test.TestCase):
|
||||
if not context.executing_eagerly():
|
||||
per_device_dataset = values.PerDeviceDataset(
|
||||
dataset, devices, prefetch_on_device=True)
|
||||
iterator = per_device_dataset.make_one_shot_iterator()
|
||||
iterator = per_device_dataset.make_initializable_iterator()
|
||||
self.evaluate([iterator.initializer])
|
||||
|
||||
# With prefetching, we cannot guarantee which input ends up on which
|
||||
# device, so we verify that the complete set seen on all devices is
|
||||
# correct, and equal numbers are distributed to each device.
|
||||
combined_actual = []
|
||||
combined_expected = []
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
combined_actual.extend(
|
||||
self.evaluate(
|
||||
[values.select_device(d, next_element) for d in devices]))
|
||||
combined_expected.extend(expected_value)
|
||||
|
||||
self.assertEqual(set(combined_expected), set(combined_actual))
|
||||
computed_value = self.evaluate(
|
||||
[values.select_device(d, next_element) for d in devices])
|
||||
self.assertEqual(expected_value, computed_value)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
|
@ -229,3 +229,15 @@ class MultiDeviceIterator(object):
|
||||
@property
|
||||
def initializer(self):
|
||||
return self._initializer
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._dataset.output_types
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._dataset.output_shapes
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._dataset.output_classes
|
||||
|
Loading…
Reference in New Issue
Block a user