Rolling forward changing distribution strategies to use MultiDeviceIterator.
The underlying issue with NaN's has now been resolved. PiperOrigin-RevId: 217014692
This commit is contained in:
parent
f667188dae
commit
47b04fdb3e
@ -22,7 +22,6 @@ py_library(
|
|||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
":input_ops",
|
":input_ops",
|
||||||
":prefetching_ops_v2",
|
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:device_util",
|
"//tensorflow/python:device_util",
|
||||||
@ -31,6 +30,7 @@ py_library(
|
|||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/data/ops:multi_device_iterator_ops",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/training/checkpointable:base",
|
"//tensorflow/python/training/checkpointable:base",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
@ -666,32 +666,6 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "prefetching_ops_v2",
|
|
||||||
srcs = ["prefetching_ops_v2.py"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/contrib/data/python/ops:prefetching_ops",
|
|
||||||
"//tensorflow/python:experimental_dataset_ops_gen",
|
|
||||||
"//tensorflow/python:framework_ops",
|
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
|
||||||
"//tensorflow/python/data/util:nest",
|
|
||||||
"//tensorflow/python/data/util:sparse",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cuda_py_test(
|
|
||||||
name = "prefetching_ops_v2_test",
|
|
||||||
srcs = ["prefetching_ops_v2_test.py"],
|
|
||||||
additional_deps = [
|
|
||||||
":prefetching_ops_v2",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:framework_ops",
|
|
||||||
"//tensorflow/python:framework_test_lib",
|
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
|
||||||
"//tensorflow/python/data/ops:iterator_ops",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "input_ops",
|
name = "input_ops",
|
||||||
srcs = ["input_ops.py"],
|
srcs = ["input_ops.py"],
|
||||||
|
@ -96,7 +96,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
|||||||
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
|
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
|
||||||
with ops.Graph().as_default(), distribution.scope():
|
with ops.Graph().as_default(), distribution.scope():
|
||||||
iterator = distribution.distribute_dataset(
|
iterator = distribution.distribute_dataset(
|
||||||
dataset_fn).make_one_shot_iterator()
|
dataset_fn).make_initializable_iterator()
|
||||||
if isinstance(distribution, tpu_strategy.TPUStrategy):
|
if isinstance(distribution, tpu_strategy.TPUStrategy):
|
||||||
def step_fn(ctx, inputs):
|
def step_fn(ctx, inputs):
|
||||||
value, update = distribution.call_for_each_tower(
|
value, update = distribution.call_for_each_tower(
|
||||||
@ -120,6 +120,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
|||||||
# replace "distribution.num_towers" with "1".
|
# replace "distribution.num_towers" with "1".
|
||||||
batches_per_update = distribution.num_towers
|
batches_per_update = distribution.num_towers
|
||||||
|
|
||||||
|
self.evaluate(iterator.initializer)
|
||||||
self.evaluate(distribution.initialize())
|
self.evaluate(distribution.initialize())
|
||||||
self.evaluate(variables.local_variables_initializer())
|
self.evaluate(variables.local_variables_initializer())
|
||||||
|
|
||||||
|
@ -41,6 +41,14 @@ from tensorflow.python.ops.losses import losses_impl
|
|||||||
|
|
||||||
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def _get_iterator(self, ds):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
iterator = ds.make_one_shot_iterator()
|
||||||
|
else:
|
||||||
|
iterator = ds.make_initializable_iterator()
|
||||||
|
self.evaluate(iterator.initializer)
|
||||||
|
return iterator
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(
|
combinations.times(
|
||||||
combinations.distributions_and_v1_optimizers(),
|
combinations.distributions_and_v1_optimizers(),
|
||||||
@ -62,8 +70,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
distribution.call_for_each_tower(
|
distribution.call_for_each_tower(
|
||||||
model_fn, *inputs, run_concurrently=layer.built))
|
model_fn, *inputs, run_concurrently=layer.built))
|
||||||
|
|
||||||
iterator = distribution.distribute_dataset(
|
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||||
dataset_fn).make_one_shot_iterator()
|
|
||||||
|
|
||||||
def run_step():
|
def run_step():
|
||||||
return distribution.run_steps_on_dataset(
|
return distribution.run_steps_on_dataset(
|
||||||
@ -99,8 +106,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||||
|
|
||||||
iterator = distribution.distribute_dataset(
|
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||||
dataset_fn).make_one_shot_iterator()
|
|
||||||
|
|
||||||
def run_step():
|
def run_step():
|
||||||
return distribution.group(
|
return distribution.group(
|
||||||
@ -159,8 +165,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
distribution.call_for_each_tower(
|
distribution.call_for_each_tower(
|
||||||
model_fn, *inputs, run_concurrently=layer.built))
|
model_fn, *inputs, run_concurrently=layer.built))
|
||||||
|
|
||||||
iterator = distribution.distribute_dataset(
|
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||||
dataset_fn).make_one_shot_iterator()
|
|
||||||
|
|
||||||
def run_step():
|
def run_step():
|
||||||
return distribution.run_steps_on_dataset(
|
return distribution.run_steps_on_dataset(
|
||||||
@ -239,8 +244,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
|
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
|
||||||
return control_flow_ops.group(fetches)
|
return control_flow_ops.group(fetches)
|
||||||
|
|
||||||
iterator = distribution.distribute_dataset(
|
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||||
dataset_fn).make_one_shot_iterator()
|
|
||||||
|
|
||||||
def run_step():
|
def run_step():
|
||||||
return distribution.run_steps_on_dataset(
|
return distribution.run_steps_on_dataset(
|
||||||
@ -333,8 +337,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
distribution.call_for_each_tower(
|
distribution.call_for_each_tower(
|
||||||
model_fn, x, y, run_concurrently=False))
|
model_fn, x, y, run_concurrently=False))
|
||||||
|
|
||||||
iterator = distribution.distribute_dataset(
|
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||||
dataset_fn).make_one_shot_iterator()
|
|
||||||
|
|
||||||
def run_step():
|
def run_step():
|
||||||
return distribution.run_steps_on_dataset(
|
return distribution.run_steps_on_dataset(
|
||||||
@ -427,8 +430,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
output=loss)
|
output=loss)
|
||||||
return distribution.group(train_op)
|
return distribution.group(train_op)
|
||||||
|
|
||||||
iterator = distribution.distribute_dataset(
|
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||||
dataset_fn).make_one_shot_iterator()
|
|
||||||
|
|
||||||
def run_step():
|
def run_step():
|
||||||
initial_loss = lambda: constant_op.constant(1e7)
|
initial_loss = lambda: constant_op.constant(1e7)
|
||||||
|
@ -307,9 +307,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
|||||||
|
|
||||||
dist = mirrored_strategy.MirroredStrategy(
|
dist = mirrored_strategy.MirroredStrategy(
|
||||||
["/device:GPU:0", "/device:CPU:0"])
|
["/device:GPU:0", "/device:CPU:0"])
|
||||||
features = dist.distribute_dataset(
|
ds = dist.distribute_dataset(
|
||||||
lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
|
lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
||||||
).make_one_shot_iterator().get_next()
|
if context.executing_eagerly():
|
||||||
|
iterator = ds.make_one_shot_iterator()
|
||||||
|
else:
|
||||||
|
iterator = ds.make_initializable_iterator()
|
||||||
|
self.evaluate([iterator.initializer])
|
||||||
|
|
||||||
|
features = iterator.get_next()
|
||||||
|
|
||||||
with dist.scope():
|
with dist.scope():
|
||||||
result = dist.call_for_each_tower(
|
result = dist.call_for_each_tower(
|
||||||
|
@ -51,6 +51,7 @@ class Monitor(object):
|
|||||||
else:
|
else:
|
||||||
if session is None:
|
if session is None:
|
||||||
raise ValueError("Should provide a `session` in Graph mode.")
|
raise ValueError("Should provide a `session` in Graph mode.")
|
||||||
|
session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
|
||||||
self._run_step = session.make_callable(step_callable())
|
self._run_step = session.make_callable(step_callable())
|
||||||
session.run(variables.global_variables_initializer())
|
session.run(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
@ -42,8 +42,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||||
|
|
||||||
iterator = distribution.distribute_dataset(
|
ds = distribution.distribute_dataset(dataset_fn)
|
||||||
dataset_fn).make_one_shot_iterator()
|
if context.executing_eagerly():
|
||||||
|
iterator = ds.make_one_shot_iterator()
|
||||||
|
else:
|
||||||
|
iterator = ds.make_initializable_iterator()
|
||||||
|
|
||||||
def run_step():
|
def run_step():
|
||||||
return control_flow_ops.group(distribution.unwrap(
|
return control_flow_ops.group(distribution.unwrap(
|
||||||
@ -52,6 +55,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
|
sess.run(iterator.initializer)
|
||||||
run_step = sess.make_callable(run_step())
|
run_step = sess.make_callable(run_step())
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
@ -1,232 +0,0 @@
|
|||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Extension of prefetching_ops to support more than one device."""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
|
||||||
from tensorflow.python.data.util import nest as data_nest
|
|
||||||
from tensorflow.python.data.util import sparse
|
|
||||||
from tensorflow.python.eager import context
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import function
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
|
||||||
from tensorflow.python.util import nest
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
class _PrefetchToDeviceIterator(object):
|
|
||||||
"""A replacement for `tf.data.Iterator` that prefetches to another device.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dataset: The input dataset.
|
|
||||||
one_shot: If true, we make a one shot iterator that's already initialized.
|
|
||||||
devices: Devices on which to prefetch.
|
|
||||||
buffer_size: Size of the prefetching buffer.
|
|
||||||
shared_name: (Optional.) If non-empty, the returned iterator will be shared
|
|
||||||
under the given name across multiple sessions that share the same devices
|
|
||||||
(e.g. when using a remote server). Only used if one_shot is False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An Iterator type object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
input_dataset,
|
|
||||||
one_shot,
|
|
||||||
devices,
|
|
||||||
buffer_size,
|
|
||||||
shared_name=None):
|
|
||||||
self._input_dataset = input_dataset
|
|
||||||
self._get_next_call_count = 0
|
|
||||||
self._one_shot = one_shot
|
|
||||||
if shared_name is None:
|
|
||||||
shared_name = ""
|
|
||||||
self._devices = devices
|
|
||||||
|
|
||||||
if self._one_shot:
|
|
||||||
self._input_iterator = input_dataset.make_one_shot_iterator()
|
|
||||||
else:
|
|
||||||
self._input_iterator = iterator_ops.Iterator.from_structure(
|
|
||||||
self._input_dataset.output_types, self._input_dataset.output_shapes,
|
|
||||||
shared_name, self._input_dataset.output_classes)
|
|
||||||
input_iterator_handle = self._input_iterator.string_handle()
|
|
||||||
|
|
||||||
@function.Defun(dtypes.string)
|
|
||||||
def _prefetch_fn(handle):
|
|
||||||
"""Prefetches one element from `input_iterator`."""
|
|
||||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
|
||||||
handle, self._input_iterator.output_types,
|
|
||||||
self._input_iterator.output_shapes,
|
|
||||||
self._input_iterator.output_classes)
|
|
||||||
ret = remote_iterator.get_next()
|
|
||||||
return nest.flatten(sparse.serialize_sparse_tensors(ret))
|
|
||||||
|
|
||||||
target_device = ged_ops.experimental_iterator_get_device(
|
|
||||||
self._input_iterator._iterator_resource)
|
|
||||||
self._buffering_resources = []
|
|
||||||
for device in nest.flatten(self._devices):
|
|
||||||
with ops.device(device):
|
|
||||||
buffer_resource_handle = prefetching_ops.function_buffering_resource(
|
|
||||||
f=_prefetch_fn,
|
|
||||||
output_types=data_nest.flatten(
|
|
||||||
sparse.as_dense_types(self._input_dataset.output_types,
|
|
||||||
self._input_dataset.output_classes)),
|
|
||||||
target_device=target_device,
|
|
||||||
string_arg=input_iterator_handle,
|
|
||||||
buffer_size=buffer_size,
|
|
||||||
shared_name=shared_name)
|
|
||||||
self._buffering_resources.append(buffer_resource_handle)
|
|
||||||
|
|
||||||
if not self._one_shot:
|
|
||||||
reset_ops = []
|
|
||||||
for buffer_resource in self._buffering_resources:
|
|
||||||
reset_ops.append(
|
|
||||||
ged_ops.experimental_function_buffering_resource_reset(
|
|
||||||
buffer_resource))
|
|
||||||
with ops.control_dependencies(reset_ops):
|
|
||||||
self._initializer = self._input_iterator.make_initializer(
|
|
||||||
self._input_dataset)
|
|
||||||
|
|
||||||
def get_next(self, name=None):
|
|
||||||
"""See `tf.data.Iterator.get_next`."""
|
|
||||||
self._get_next_call_count += 1
|
|
||||||
if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
|
|
||||||
warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
|
|
||||||
|
|
||||||
flat_result = []
|
|
||||||
# TODO(priyag): This will fail if the input size (typically number of
|
|
||||||
# batches) is not divisible by number of devices.
|
|
||||||
# How do we handle that more gracefully / let the user know?
|
|
||||||
for buffer_resource in self._buffering_resources:
|
|
||||||
flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
|
|
||||||
buffer_resource,
|
|
||||||
output_types=data_nest.flatten(
|
|
||||||
sparse.as_dense_types(self.output_types, self.output_classes)),
|
|
||||||
name=name)
|
|
||||||
|
|
||||||
ret = sparse.deserialize_sparse_tensors(
|
|
||||||
data_nest.pack_sequence_as(self.output_types, flat_ret),
|
|
||||||
self.output_types, self.output_shapes, self.output_classes)
|
|
||||||
|
|
||||||
for tensor, shape in zip(
|
|
||||||
data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
|
|
||||||
if isinstance(tensor, ops.Tensor):
|
|
||||||
tensor.set_shape(shape)
|
|
||||||
flat_result.append(ret)
|
|
||||||
|
|
||||||
return nest.pack_sequence_as(self._devices, flat_result)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def initializer(self):
|
|
||||||
if self._one_shot:
|
|
||||||
raise NotImplementedError("Can't initialize a one_shot_iterator")
|
|
||||||
return self._initializer
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_classes(self):
|
|
||||||
return self._input_dataset.output_classes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_shapes(self):
|
|
||||||
return self._input_dataset.output_shapes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_types(self):
|
|
||||||
return self._input_dataset.output_types
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
|
|
||||||
class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
|
|
||||||
"""A `Dataset` whose iterator prefetches elements to other device(s)."""
|
|
||||||
|
|
||||||
def __init__(self, input_dataset, devices, buffer_size):
|
|
||||||
super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
|
|
||||||
self._input_dataset = input_dataset
|
|
||||||
self._devices = devices
|
|
||||||
self._buffer_size = buffer_size if buffer_size is not None else 1
|
|
||||||
|
|
||||||
def make_one_shot_iterator(self):
|
|
||||||
return _PrefetchToDeviceIterator(
|
|
||||||
self._input_dataset,
|
|
||||||
one_shot=True,
|
|
||||||
devices=self._devices,
|
|
||||||
buffer_size=self._buffer_size)
|
|
||||||
|
|
||||||
def make_initializable_iterator(self, shared_name=None):
|
|
||||||
if context.executing_eagerly():
|
|
||||||
raise RuntimeError(
|
|
||||||
"make_initializable_iterator is not supported when eager "
|
|
||||||
"execution is enabled.")
|
|
||||||
|
|
||||||
return _PrefetchToDeviceIterator(
|
|
||||||
self._input_dataset,
|
|
||||||
one_shot=False,
|
|
||||||
devices=self._devices,
|
|
||||||
buffer_size=self._buffer_size,
|
|
||||||
shared_name=shared_name)
|
|
||||||
|
|
||||||
def _as_variant_tensor(self):
|
|
||||||
# TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
|
|
||||||
# transformation methods is called.
|
|
||||||
# TODO(mrry): Investigate support for chaining further transformations after
|
|
||||||
# the prefetch, including GPU support.
|
|
||||||
raise NotImplementedError("`prefetch_to_devices()` must be the last "
|
|
||||||
"transformation in a dataset pipeline.")
|
|
||||||
|
|
||||||
# TODO(priyag): Fix the output types, shapes and classes to match the result
|
|
||||||
# of get_next (which has the additional nesting layer of devices now).
|
|
||||||
@property
|
|
||||||
def output_types(self):
|
|
||||||
return self._input_dataset.output_types
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_shapes(self):
|
|
||||||
return self._input_dataset.output_shapes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_classes(self):
|
|
||||||
return self._input_dataset.output_classes
|
|
||||||
|
|
||||||
|
|
||||||
def prefetch_to_devices(devices, buffer_size=None):
|
|
||||||
"""A transformation that prefetches dataset values to the given `devices`.
|
|
||||||
|
|
||||||
NOTE: Although the transformation creates a `tf.data.Dataset`, the
|
|
||||||
transformation must be the final `Dataset` in the input pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
devices: A nested structure of devices on which to prefetch the data. It can
|
|
||||||
be a single device name, or a tuple or list of device names.
|
|
||||||
buffer_size: (Optional.) The number of elements to buffer on each device.
|
|
||||||
Defaults to an automatically chosen value.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `Dataset` transformation function, which can be passed to
|
|
||||||
`tf.data.Dataset.apply`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _apply_fn(dataset):
|
|
||||||
return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
|
|
||||||
|
|
||||||
return _apply_fn
|
|
@ -1,90 +0,0 @@
|
|||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Tests for prefetching_ops_v2."""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.contrib.distribute.python import prefetching_ops_v2
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
from tensorflow.python.platform import test
|
|
||||||
|
|
||||||
|
|
||||||
class PrefetchingOpsV2Test(test.TestCase):
|
|
||||||
|
|
||||||
def testPrefetchToOneDevice(self):
|
|
||||||
if not test_util.is_gpu_available():
|
|
||||||
self.skipTest("No GPU available")
|
|
||||||
|
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
|
||||||
device_dataset = host_dataset.apply(
|
|
||||||
prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
|
|
||||||
|
|
||||||
iterator = device_dataset.make_one_shot_iterator()
|
|
||||||
next_element = iterator.get_next()
|
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
for i in range(10):
|
|
||||||
self.assertEqual(i, sess.run(next_element))
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
sess.run(next_element)
|
|
||||||
|
|
||||||
def testPrefetchToTwoDevicesInAList(self):
|
|
||||||
if not test_util.is_gpu_available():
|
|
||||||
self.skipTest("No GPU available")
|
|
||||||
|
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
|
||||||
device_dataset = host_dataset.apply(
|
|
||||||
prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
|
|
||||||
|
|
||||||
iterator = device_dataset.make_one_shot_iterator()
|
|
||||||
next_element = iterator.get_next()
|
|
||||||
|
|
||||||
output = []
|
|
||||||
# TODO(rohanj): Modify test to go till the end of the dataset when we
|
|
||||||
# switch to MultiDeviceIterator.
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
for _ in range(4):
|
|
||||||
result = sess.run(next_element)
|
|
||||||
self.assertEqual(2, len(result))
|
|
||||||
output.extend(result)
|
|
||||||
self.assertEquals(set(range(8)), set(output))
|
|
||||||
|
|
||||||
def testPrefetchToTwoDevicesWithReinit(self):
|
|
||||||
if not test_util.is_gpu_available():
|
|
||||||
self.skipTest("No GPU available")
|
|
||||||
|
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
|
||||||
device_dataset = host_dataset.apply(
|
|
||||||
prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
|
|
||||||
|
|
||||||
iterator = device_dataset.make_initializable_iterator()
|
|
||||||
next_element = iterator.get_next()
|
|
||||||
|
|
||||||
# TODO(rohanj): Modify test to go till the end of the dataset when we
|
|
||||||
# switch to MultiDeviceIterator.
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
sess.run(iterator.initializer)
|
|
||||||
for _ in range(4):
|
|
||||||
sess.run(next_element)
|
|
||||||
sess.run(iterator.initializer)
|
|
||||||
for _ in range(4):
|
|
||||||
sess.run(next_element)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test.main()
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.training import optimizer as optimizer_lib
|
from tensorflow.python.training import optimizer as optimizer_lib
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +51,11 @@ class StandardInputStep(Step):
|
|||||||
def __init__(self, dataset_fn, distribution):
|
def __init__(self, dataset_fn, distribution):
|
||||||
super(StandardInputStep, self).__init__(distribution)
|
super(StandardInputStep, self).__init__(distribution)
|
||||||
self._distributed_input = distribution.distribute_dataset(dataset_fn)
|
self._distributed_input = distribution.distribute_dataset(dataset_fn)
|
||||||
self._iterator = self._distributed_input.make_one_shot_iterator()
|
if context.executing_eagerly():
|
||||||
|
self._iterator = self._distributed_input.make_one_shot_iterator()
|
||||||
|
else:
|
||||||
|
# TODO(priyag): Expose initializer via some initializer property.
|
||||||
|
self._iterator = self._distributed_input.make_initializable_iterator()
|
||||||
|
|
||||||
|
|
||||||
class StandardSingleLossStep(StandardInputStep):
|
class StandardSingleLossStep(StandardInputStep):
|
||||||
|
@ -50,6 +50,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
run_step = single_loss_step
|
run_step = single_loss_step
|
||||||
else:
|
else:
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
|
sess.run(single_loss_step._iterator.initializer)
|
||||||
run_step = sess.make_callable(single_loss_step())
|
run_step = sess.make_callable(single_loss_step())
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ import weakref
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.contrib.distribute.python import input_ops
|
from tensorflow.contrib.distribute.python import input_ops
|
||||||
from tensorflow.contrib.distribute.python import prefetching_ops_v2
|
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import device as tf_device
|
from tensorflow.python.framework import device as tf_device
|
||||||
@ -1089,7 +1089,7 @@ class PerDeviceDataIterator(object):
|
|||||||
def get_next(self, name=None):
|
def get_next(self, name=None):
|
||||||
"""Scatter the input across devices."""
|
"""Scatter the input across devices."""
|
||||||
if self._prefetch_on_device:
|
if self._prefetch_on_device:
|
||||||
data_list = self._iterator.get_next(name=name)
|
data_list = self._iterator.get_next()
|
||||||
index = dict(zip(self._devices, data_list))
|
index = dict(zip(self._devices, data_list))
|
||||||
else:
|
else:
|
||||||
batch = self._iterator.get_next(name=name)
|
batch = self._iterator.get_next(name=name)
|
||||||
@ -1113,17 +1113,15 @@ class PerDeviceDataset(object):
|
|||||||
self._devices = devices
|
self._devices = devices
|
||||||
|
|
||||||
# Default to using prefetching in graph mode, unless specified.
|
# Default to using prefetching in graph mode, unless specified.
|
||||||
# TODO(priyag): Enable prefetching in eager mode.
|
# TODO(rohanj): Enable prefetching in eager mode.
|
||||||
self._prefetch_on_device = prefetch_on_device
|
self._prefetch_on_device = prefetch_on_device
|
||||||
if self._prefetch_on_device is None:
|
if self._prefetch_on_device is None:
|
||||||
self._prefetch_on_device = not context.executing_eagerly()
|
self._prefetch_on_device = not context.executing_eagerly()
|
||||||
assert not (self._prefetch_on_device and context.executing_eagerly()), (
|
assert not (self._prefetch_on_device and context.executing_eagerly()), (
|
||||||
"Prefetching is only supported in graph mode currently")
|
"Prefetching is only supported in graph mode currently")
|
||||||
|
|
||||||
if self._prefetch_on_device:
|
self._dataset = dataset
|
||||||
self._dataset = dataset.apply(
|
if not self._prefetch_on_device:
|
||||||
prefetching_ops_v2.prefetch_to_devices(self._devices))
|
|
||||||
else:
|
|
||||||
# TODO(priyag): If dropping remainder is not appropriate, find another
|
# TODO(priyag): If dropping remainder is not appropriate, find another
|
||||||
# approach to distributing the dataset when not possible to divide evenly.
|
# approach to distributing the dataset when not possible to divide evenly.
|
||||||
# Possibly not an issue when we start using PartitionedDataset.
|
# Possibly not an issue when we start using PartitionedDataset.
|
||||||
@ -1131,15 +1129,33 @@ class PerDeviceDataset(object):
|
|||||||
|
|
||||||
def make_one_shot_iterator(self):
|
def make_one_shot_iterator(self):
|
||||||
"""Get a one time use iterator for the distributed PerDeviceDataset."""
|
"""Get a one time use iterator for the distributed PerDeviceDataset."""
|
||||||
|
# Graph mode with one shot iterator is disabled.
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
raise ValueError("Cannot create a one shot iterator. Please use "
|
||||||
|
"`make_initializable_iterator()` instead.")
|
||||||
|
# Eager mode prefetching would error out in constructor. Only remaining
|
||||||
|
# case is non-prefetching in eager mode. We delegate to
|
||||||
|
# PerDeviceDataIterator to handle that case.
|
||||||
dataset_iterator = self._dataset.make_one_shot_iterator()
|
dataset_iterator = self._dataset.make_one_shot_iterator()
|
||||||
return PerDeviceDataIterator(dataset_iterator, self._devices,
|
return PerDeviceDataIterator(
|
||||||
self._prefetch_on_device)
|
dataset_iterator, self._devices, prefetch_on_device=False)
|
||||||
|
|
||||||
def make_initializable_iterator(self):
|
def make_initializable_iterator(self):
|
||||||
"""Get an initializable iterator for the distributed PerDeviceDataset."""
|
"""Get an initializable iterator for the distributed PerDeviceDataset."""
|
||||||
dataset_iterator = self._dataset.make_initializable_iterator()
|
# Eager mode generates already initialized iterators. Hence we cannot create
|
||||||
return PerDeviceDataIterator(dataset_iterator, self._devices,
|
# an initializable iterator.
|
||||||
self._prefetch_on_device)
|
if context.executing_eagerly():
|
||||||
|
raise ValueError("Cannot create initializable iterator in Eager mode. "
|
||||||
|
"Please use `make_one_shot_iterator` instead.")
|
||||||
|
if self._prefetch_on_device:
|
||||||
|
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
|
self._dataset, self._devices)
|
||||||
|
else:
|
||||||
|
dataset_iterator = self._dataset.make_initializable_iterator()
|
||||||
|
return PerDeviceDataIterator(
|
||||||
|
dataset_iterator,
|
||||||
|
self._devices,
|
||||||
|
prefetch_on_device=self._prefetch_on_device)
|
||||||
|
|
||||||
|
|
||||||
class MultiWorkerDataIterator(object):
|
class MultiWorkerDataIterator(object):
|
||||||
|
@ -349,7 +349,11 @@ class PerDeviceDatasetTest(test.TestCase):
|
|||||||
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
|
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
|
||||||
per_device_dataset = values.PerDeviceDataset(
|
per_device_dataset = values.PerDeviceDataset(
|
||||||
dataset, devices, prefetch_on_device=False)
|
dataset, devices, prefetch_on_device=False)
|
||||||
iterator = per_device_dataset.make_one_shot_iterator()
|
if context.executing_eagerly():
|
||||||
|
iterator = per_device_dataset.make_one_shot_iterator()
|
||||||
|
else:
|
||||||
|
iterator = per_device_dataset.make_initializable_iterator()
|
||||||
|
self.evaluate([iterator.initializer])
|
||||||
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
@ -366,21 +370,14 @@ class PerDeviceDatasetTest(test.TestCase):
|
|||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
per_device_dataset = values.PerDeviceDataset(
|
per_device_dataset = values.PerDeviceDataset(
|
||||||
dataset, devices, prefetch_on_device=True)
|
dataset, devices, prefetch_on_device=True)
|
||||||
iterator = per_device_dataset.make_one_shot_iterator()
|
iterator = per_device_dataset.make_initializable_iterator()
|
||||||
|
self.evaluate([iterator.initializer])
|
||||||
|
|
||||||
# With prefetching, we cannot guarantee which input ends up on which
|
|
||||||
# device, so we verify that the complete set seen on all devices is
|
|
||||||
# correct, and equal numbers are distributed to each device.
|
|
||||||
combined_actual = []
|
|
||||||
combined_expected = []
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
combined_actual.extend(
|
computed_value = self.evaluate(
|
||||||
self.evaluate(
|
[values.select_device(d, next_element) for d in devices])
|
||||||
[values.select_device(d, next_element) for d in devices]))
|
self.assertEqual(expected_value, computed_value)
|
||||||
combined_expected.extend(expected_value)
|
|
||||||
|
|
||||||
self.assertEqual(set(combined_expected), set(combined_actual))
|
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
@ -229,3 +229,15 @@ class MultiDeviceIterator(object):
|
|||||||
@property
|
@property
|
||||||
def initializer(self):
|
def initializer(self):
|
||||||
return self._initializer
|
return self._initializer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return self._dataset.output_types
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shapes(self):
|
||||||
|
return self._dataset.output_shapes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_classes(self):
|
||||||
|
return self._dataset.output_classes
|
||||||
|
Loading…
Reference in New Issue
Block a user