Split input-related classes out of python/distribute/values.py into new
file .../input_lib.py. PiperOrigin-RevId: 227227637
This commit is contained in:
parent
e0aa938725
commit
d2dd369f9d
@ -23,17 +23,14 @@ cuda_py_test(
|
|||||||
additional_deps = [
|
additional_deps = [
|
||||||
":combinations",
|
":combinations",
|
||||||
":mirrored_strategy",
|
":mirrored_strategy",
|
||||||
":multi_worker_test_base",
|
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
"//tensorflow/python:errors",
|
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
|
||||||
"//tensorflow/python/distribute:device_util",
|
"//tensorflow/python/distribute:device_util",
|
||||||
"//tensorflow/python/distribute:values",
|
"//tensorflow/python/distribute:values",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
@ -45,14 +42,36 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "input_lib_test",
|
||||||
|
srcs = ["input_lib_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":combinations",
|
||||||
|
":mirrored_strategy",
|
||||||
|
":multi_worker_test_base",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/distribute:input_lib",
|
||||||
|
"//tensorflow/python/distribute:values",
|
||||||
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/eager:test",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"no_pip",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "mirrored_strategy",
|
name = "mirrored_strategy",
|
||||||
srcs = ["mirrored_strategy.py"],
|
srcs = ["mirrored_strategy.py"],
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python/distribute:distribute_lib",
|
"//tensorflow/python/distribute:distribute_lib",
|
||||||
|
"//tensorflow/python/distribute:input_lib",
|
||||||
"//tensorflow/python/distribute:mirrored_strategy",
|
"//tensorflow/python/distribute:mirrored_strategy",
|
||||||
"//tensorflow/python/distribute:values",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -69,6 +88,7 @@ py_library(
|
|||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/distribute:cross_device_ops",
|
"//tensorflow/python/distribute:cross_device_ops",
|
||||||
|
"//tensorflow/python/distribute:input_lib",
|
||||||
"//tensorflow/python/distribute:multi_worker_util",
|
"//tensorflow/python/distribute:multi_worker_util",
|
||||||
"//tensorflow/python/distribute:reduce_util",
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/distribute:values",
|
"//tensorflow/python/distribute:values",
|
||||||
@ -119,6 +139,7 @@ py_library(
|
|||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python/distribute:distribute_lib",
|
"//tensorflow/python/distribute:distribute_lib",
|
||||||
|
"//tensorflow/python/distribute:input_lib",
|
||||||
"//tensorflow/python/distribute:reduce_util",
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/distribute:values",
|
"//tensorflow/python/distribute:values",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
@ -139,6 +160,7 @@ py_library(
|
|||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python/distribute:cross_device_ops",
|
"//tensorflow/python/distribute:cross_device_ops",
|
||||||
"//tensorflow/python/distribute:cross_device_utils",
|
"//tensorflow/python/distribute:cross_device_utils",
|
||||||
|
"//tensorflow/python/distribute:input_lib",
|
||||||
"//tensorflow/python/distribute:multi_worker_util",
|
"//tensorflow/python/distribute:multi_worker_util",
|
||||||
"//tensorflow/python/distribute:values",
|
"//tensorflow/python/distribute:values",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
@ -289,6 +311,7 @@ py_library(
|
|||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:tensor_util",
|
"//tensorflow/python:tensor_util",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/distribute:input_lib",
|
||||||
"//tensorflow/python/distribute:reduce_util",
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/distribute:values",
|
"//tensorflow/python/distribute:values",
|
||||||
],
|
],
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li
|
|||||||
from tensorflow.python.distribute import cross_device_utils
|
from tensorflow.python.distribute import cross_device_utils
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -130,7 +131,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
|
|
||||||
self._collective_keys = cross_device_utils.CollectiveKeys()
|
self._collective_keys = cross_device_utils.CollectiveKeys()
|
||||||
self._initialize_local(local_devices)
|
self._initialize_local(local_devices)
|
||||||
self._input_workers = values.InputWorkers(
|
self._input_workers = input_lib.InputWorkers(
|
||||||
self._device_map, [(self._worker_device, self.worker_devices)])
|
self._device_map, [(self._worker_device, self.worker_devices)])
|
||||||
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||||
num_workers=self._num_workers,
|
num_workers=self._num_workers,
|
||||||
@ -229,13 +230,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
"""Distributes the dataset to each local GPU."""
|
"""Distributes the dataset to each local GPU."""
|
||||||
# TODO(yuefengz): shard the dataset.
|
# TODO(yuefengz): shard the dataset.
|
||||||
worker_index = 0
|
worker_index = 0
|
||||||
return values.PerReplicaDataset(
|
return input_lib.PerReplicaDataset(
|
||||||
self._call_dataset_fn(dataset_fn), self._input_workers, worker_index,
|
self._call_dataset_fn(dataset_fn), self._input_workers, worker_index,
|
||||||
prefetch_on_device=True)
|
prefetch_on_device=True)
|
||||||
|
|
||||||
def _make_dataset_iterator(self, dataset):
|
def _make_dataset_iterator(self, dataset):
|
||||||
return values.DatasetIterator(dataset, self._input_workers,
|
return input_lib.DatasetIterator(dataset, self._input_workers,
|
||||||
self._num_replicas_in_sync)
|
self._num_replicas_in_sync)
|
||||||
|
|
||||||
def _make_input_fn_iterator(
|
def _make_input_fn_iterator(
|
||||||
self,
|
self,
|
||||||
@ -252,7 +253,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
input_pipeline_id=input_pipeline_id,
|
input_pipeline_id=input_pipeline_id,
|
||||||
num_replicas_in_sync=self._num_replicas_in_sync)
|
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||||
|
|
||||||
return values.InputFunctionIterator(
|
return input_lib.InputFunctionIterator(
|
||||||
input_fn, self._input_workers, [input_context])
|
input_fn, self._input_workers, [input_context])
|
||||||
|
|
||||||
def _configure(self,
|
def _configure(self,
|
||||||
|
480
tensorflow/contrib/distribute/python/input_lib_test.py
Normal file
480
tensorflow/contrib/distribute/python/input_lib_test.py
Normal file
@ -0,0 +1,480 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for the input_lib library."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.contrib.distribute.python import combinations
|
||||||
|
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.data.experimental.ops import batching
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.distribute import distribute_lib
|
||||||
|
from tensorflow.python.distribute import input_lib
|
||||||
|
from tensorflow.python.distribute import values
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
class PerReplicaDatasetTest(test.TestCase):
|
||||||
|
|
||||||
|
config = config_pb2.ConfigProto()
|
||||||
|
config.allow_soft_placement = True
|
||||||
|
|
||||||
|
def _test_iterator(self, devices, dataset, expected_values):
|
||||||
|
device_map = values.ReplicaDeviceMap(devices)
|
||||||
|
input_workers = input_lib.InputWorkers(device_map)
|
||||||
|
per_replica_dataset = input_lib.PerReplicaDataset(dataset, input_workers, 0)
|
||||||
|
if context.executing_eagerly():
|
||||||
|
iterator = per_replica_dataset.make_one_shot_iterator()
|
||||||
|
else:
|
||||||
|
iterator = per_replica_dataset.make_initializable_iterator()
|
||||||
|
self.evaluate([iterator.initializer])
|
||||||
|
|
||||||
|
for expected_value in expected_values:
|
||||||
|
next_element = iterator.get_next_as_list()
|
||||||
|
computed_value = self.evaluate(next_element)
|
||||||
|
self.assertEqual(expected_value, computed_value)
|
||||||
|
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
next_element = iterator.get_next_as_list()
|
||||||
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testOneDevice(self):
|
||||||
|
devices = ["/device:CPU:0"]
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
|
expected_values = [[i] for i in range(10)]
|
||||||
|
|
||||||
|
self._test_iterator(devices, dataset, expected_values)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||||
|
def testMultipleDevices(self):
|
||||||
|
if context.num_gpus() < 1 and context.executing_eagerly():
|
||||||
|
self.skipTest("A GPU is not available for this test in eager mode.")
|
||||||
|
|
||||||
|
devices = ["/device:CPU:0", "/device:GPU:0"]
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
|
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||||
|
|
||||||
|
self._test_iterator(devices, dataset, expected_values)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||||
|
def testTupleDataset(self):
|
||||||
|
if context.num_gpus() < 1 and context.executing_eagerly():
|
||||||
|
self.skipTest("A GPU is not available for this test in eager mode.")
|
||||||
|
|
||||||
|
devices = ["/device:CPU:0", "/device:GPU:0"]
|
||||||
|
dataset1 = dataset_ops.Dataset.range(10)
|
||||||
|
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
||||||
|
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
|
|
||||||
|
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
|
||||||
|
|
||||||
|
self._test_iterator(devices, dataset, expected_values)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||||
|
def testUnevenDatasetBatches(self):
|
||||||
|
if context.num_gpus() < 1 and context.executing_eagerly():
|
||||||
|
self.skipTest("A GPU is not available for this test in eager mode.")
|
||||||
|
|
||||||
|
devices = ["/device:CPU:0", "/device:GPU:0"]
|
||||||
|
dataset = dataset_ops.Dataset.range(11)
|
||||||
|
|
||||||
|
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||||
|
self._test_iterator(devices, dataset, expected_values)
|
||||||
|
|
||||||
|
def testInitializableIterator(self):
|
||||||
|
with context.graph_mode():
|
||||||
|
devices = ["/device:CPU:0"]
|
||||||
|
# Using random input since that is only allowed with initializable
|
||||||
|
# iterator.
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||||
|
random_ops.random_uniform((10,)))
|
||||||
|
|
||||||
|
device_map = values.ReplicaDeviceMap(devices)
|
||||||
|
input_workers = input_lib.InputWorkers(device_map)
|
||||||
|
per_replica_dataset = input_lib.PerReplicaDataset(
|
||||||
|
dataset, input_workers, 0)
|
||||||
|
iterator = per_replica_dataset.make_initializable_iterator()
|
||||||
|
|
||||||
|
self.evaluate(iterator.initializer)
|
||||||
|
next_element = iterator.get_next_as_list()
|
||||||
|
for _ in range(10):
|
||||||
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
# Should fail after the input is finished.
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
# After re-initializing the iterator, should be able to iterate again.
|
||||||
|
self.evaluate(iterator.initializer)
|
||||||
|
for _ in range(10):
|
||||||
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
|
||||||
|
|
||||||
|
def _test_iterator(self, sess, iterator, devices, expected_values):
|
||||||
|
next_element = iterator.get_next()
|
||||||
|
for r, device in enumerate(devices):
|
||||||
|
v = values.select_replica(r, next_element)
|
||||||
|
# The `v` here can be a tuple.
|
||||||
|
for element in nest.flatten(v):
|
||||||
|
self.assertTrue(element.device in device)
|
||||||
|
|
||||||
|
for expected_value in expected_values:
|
||||||
|
t = [values.select_replica(r, next_element) for r in range(len(devices))]
|
||||||
|
actual = sess.run(t)
|
||||||
|
self.assertEqual(expected_value, actual)
|
||||||
|
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run([values.select_replica(r, next_element)
|
||||||
|
for r in range(len(devices))])
|
||||||
|
|
||||||
|
def _test_dataset(self, dataset_fn, worker_devices, devices,
|
||||||
|
expected_values):
|
||||||
|
device_map = values.ReplicaDeviceMap(devices)
|
||||||
|
input_workers = input_lib.InputWorkers(device_map, worker_devices)
|
||||||
|
multi_worker_dataset = input_lib.MultiWorkerDataset(
|
||||||
|
dataset_fn, input_workers)
|
||||||
|
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
sess.run(multi_worker_iterator.initializer)
|
||||||
|
self._test_iterator(sess, multi_worker_iterator, devices, expected_values)
|
||||||
|
|
||||||
|
def _cpu_devices(self):
|
||||||
|
worker_devices = (
|
||||||
|
("/job:worker/replica:0/task:0",
|
||||||
|
["/job:worker/replica:0/task:0/device:CPU:0"]),
|
||||||
|
("/job:worker/replica:0/task:1",
|
||||||
|
["/job:worker/replica:0/task:1/device:CPU:0"])
|
||||||
|
)
|
||||||
|
devices = [
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0",
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0"
|
||||||
|
]
|
||||||
|
return worker_devices, devices
|
||||||
|
|
||||||
|
def _cpu_and_one_gpu_devices(self):
|
||||||
|
worker_devices = (
|
||||||
|
("/job:worker/replica:0/task:0", (
|
||||||
|
"/job:worker/replica:0/task:0/device:GPU:0",
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0"
|
||||||
|
)),
|
||||||
|
("/job:worker/replica:0/task:1", (
|
||||||
|
"/job:worker/replica:0/task:1/device:GPU:0",
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0"
|
||||||
|
))
|
||||||
|
)
|
||||||
|
devices = [
|
||||||
|
"/job:worker/replica:0/task:0/device:GPU:0",
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0",
|
||||||
|
"/job:worker/replica:0/task:1/device:GPU:0",
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0"
|
||||||
|
]
|
||||||
|
return worker_devices, devices
|
||||||
|
|
||||||
|
def testDataDistributionOneDevicePerWorker(self):
|
||||||
|
worker_devices, devices = self._cpu_devices()
|
||||||
|
with context.graph_mode():
|
||||||
|
dataset_fn = lambda: dataset_ops.Dataset.range(8)
|
||||||
|
self._test_dataset(
|
||||||
|
dataset_fn, worker_devices, devices,
|
||||||
|
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
|
||||||
|
|
||||||
|
def testDataDistributionTwoDevicePerWorker(self):
|
||||||
|
if context.num_gpus() < 1:
|
||||||
|
self.skipTest("A GPU is not available for this test.")
|
||||||
|
worker_devices, devices = self._cpu_and_one_gpu_devices()
|
||||||
|
with context.graph_mode():
|
||||||
|
dataset_fn = lambda: dataset_ops.Dataset.range(8)
|
||||||
|
self._test_dataset(
|
||||||
|
dataset_fn, worker_devices, devices,
|
||||||
|
[[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]])
|
||||||
|
|
||||||
|
def testTupleDataset(self):
|
||||||
|
worker_devices, devices = self._cpu_devices()
|
||||||
|
|
||||||
|
with context.graph_mode():
|
||||||
|
|
||||||
|
def dataset_fn():
|
||||||
|
dataset1 = dataset_ops.Dataset.range(8)
|
||||||
|
dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2)
|
||||||
|
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
|
|
||||||
|
expected_values = [[(i, i**2), (i, i**2)] for i in range(8)]
|
||||||
|
self._test_dataset(dataset_fn, worker_devices, devices,
|
||||||
|
expected_values)
|
||||||
|
|
||||||
|
def testInitializableIterator(self):
|
||||||
|
worker_devices, devices = self._cpu_devices()
|
||||||
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
|
dataset_fn = lambda: dataset_ops.Dataset.range(8)
|
||||||
|
device_map = values.ReplicaDeviceMap(devices)
|
||||||
|
input_workers = input_lib.InputWorkers(device_map, worker_devices)
|
||||||
|
multi_worker_dataset = input_lib.MultiWorkerDataset(
|
||||||
|
dataset_fn, input_workers)
|
||||||
|
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
|
||||||
|
|
||||||
|
sess.run(multi_worker_iterator.initializer)
|
||||||
|
self._test_iterator(
|
||||||
|
sess, multi_worker_iterator, devices,
|
||||||
|
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
|
||||||
|
|
||||||
|
# After re-initializing the iterator, should be able to iterate again.
|
||||||
|
sess.run(multi_worker_iterator.initializer)
|
||||||
|
self._test_iterator(
|
||||||
|
sess, multi_worker_iterator, devices,
|
||||||
|
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
|
||||||
|
|
||||||
|
def testValueErrorForIterator(self):
|
||||||
|
# Incompatiable arguments.
|
||||||
|
d1 = "/device:GPU:0"
|
||||||
|
d2 = "/device:GPU:1"
|
||||||
|
device_map = values.ReplicaDeviceMap([d1, d2])
|
||||||
|
input_workers = input_lib.InputWorkers(
|
||||||
|
device_map, (("w1", (d1,)), ("w2", (d2,))))
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
input_lib.MultiWorkerDataIterator([("w1", None)], input_workers)
|
||||||
|
|
||||||
|
def testDuplicateDevices(self):
|
||||||
|
_, devices = self._cpu_devices()
|
||||||
|
devices.append("/job:worker/replica:0/task:0/device:CPU:0")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = values.ReplicaDeviceMap(devices)
|
||||||
|
|
||||||
|
|
||||||
|
class InputIteratorTestBase(test.TestCase):
|
||||||
|
|
||||||
|
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
|
||||||
|
expected_values, sess=None, split_batch_by=None):
|
||||||
|
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
||||||
|
device_map = values.ReplicaDeviceMap(devices)
|
||||||
|
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
|
||||||
|
|
||||||
|
if input_type == "input_fn":
|
||||||
|
input_contexts = [
|
||||||
|
distribute_lib.InputContext() for _ in worker_device_pairs]
|
||||||
|
input_fn = lambda _: dataset_fn()
|
||||||
|
iterator = input_lib.InputFunctionIterator(
|
||||||
|
input_fn, input_workers, input_contexts)
|
||||||
|
else:
|
||||||
|
iterator = input_lib.DatasetIterator(
|
||||||
|
dataset_fn(), input_workers, split_batch_by)
|
||||||
|
|
||||||
|
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||||
|
|
||||||
|
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||||
|
|
||||||
|
for expected_value in expected_values:
|
||||||
|
next_element = iterator.get_next()
|
||||||
|
computed_value = evaluate(
|
||||||
|
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||||
|
self.assertAllEqual(expected_value, computed_value)
|
||||||
|
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
next_element = iterator.get_next()
|
||||||
|
evaluate([values.select_replica(r, next_element)
|
||||||
|
for r in range(len(devices))])
|
||||||
|
|
||||||
|
# After re-initializing the iterator, should be able to iterate again.
|
||||||
|
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||||
|
|
||||||
|
for expected_value in expected_values:
|
||||||
|
next_element = iterator.get_next()
|
||||||
|
computed_value = evaluate(
|
||||||
|
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||||
|
self.assertAllEqual(expected_value, computed_value)
|
||||||
|
|
||||||
|
|
||||||
|
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(
|
||||||
|
mode=["graph", "eager"],
|
||||||
|
input_type=["input_fn", "dataset"]))
|
||||||
|
def testOneDeviceCPU(self, input_type):
|
||||||
|
worker_device_pairs = [("", ["/device:CPU:0"])]
|
||||||
|
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
|
expected_values = [[i] for i in range(10)]
|
||||||
|
|
||||||
|
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||||
|
expected_values)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(
|
||||||
|
mode=["graph", "eager"],
|
||||||
|
input_type=["input_fn", "dataset"],
|
||||||
|
required_gpus=1))
|
||||||
|
def testTwoDevicesOneGPUOneCPU(self, input_type):
|
||||||
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
|
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
|
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||||
|
|
||||||
|
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||||
|
expected_values)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(
|
||||||
|
mode=["graph", "eager"],
|
||||||
|
input_type=["input_fn", "dataset"],
|
||||||
|
required_gpus=1))
|
||||||
|
def testTupleDataset(self, input_type):
|
||||||
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
|
def dataset_fn():
|
||||||
|
dataset1 = dataset_ops.Dataset.range(10)
|
||||||
|
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
||||||
|
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
|
|
||||||
|
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
|
||||||
|
|
||||||
|
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||||
|
expected_values)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(
|
||||||
|
mode=["graph", "eager"],
|
||||||
|
input_type=["input_fn", "dataset"],
|
||||||
|
required_gpus=1))
|
||||||
|
def testUnevenDatasetBatches(self, input_type):
|
||||||
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
|
dataset_fn = lambda: dataset_ops.Dataset.range(11)
|
||||||
|
|
||||||
|
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||||
|
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||||
|
expected_values)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(
|
||||||
|
mode=["graph", "eager"],
|
||||||
|
input_type=["dataset"],
|
||||||
|
split_batch_by=[None, 2],
|
||||||
|
required_gpus=1))
|
||||||
|
def testBatchSplitting(self, input_type, split_batch_by):
|
||||||
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
|
batch_size = 10
|
||||||
|
dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size)
|
||||||
|
|
||||||
|
updated_batch_size = (
|
||||||
|
batch_size // split_batch_by if split_batch_by else batch_size)
|
||||||
|
expected_values = [[range(i, i+updated_batch_size),
|
||||||
|
range(i+updated_batch_size, i+2*updated_batch_size)]
|
||||||
|
for i in range(0, 100, updated_batch_size*2)]
|
||||||
|
|
||||||
|
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||||
|
expected_values, sess=None,
|
||||||
|
split_batch_by=split_batch_by)
|
||||||
|
|
||||||
|
|
||||||
|
class InputIteratorMultiWorkerTest(
|
||||||
|
multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
def _cpu_devices(self):
|
||||||
|
return [
|
||||||
|
("/job:worker/replica:0/task:0",
|
||||||
|
["/job:worker/replica:0/task:0/device:CPU:0"]),
|
||||||
|
("/job:worker/replica:0/task:1",
|
||||||
|
["/job:worker/replica:0/task:1/device:CPU:0"])]
|
||||||
|
|
||||||
|
def _cpu_and_one_gpu_devices(self):
|
||||||
|
return [
|
||||||
|
("/job:worker/replica:0/task:0", [
|
||||||
|
"/job:worker/replica:0/task:0/device:GPU:0",
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0"
|
||||||
|
]),
|
||||||
|
("/job:worker/replica:0/task:1", [
|
||||||
|
"/job:worker/replica:0/task:1/device:GPU:0",
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0"
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(
|
||||||
|
mode=["graph"],
|
||||||
|
input_type=["input_fn", "dataset"]))
|
||||||
|
def testOneDevicePerWorker(self, input_type):
|
||||||
|
worker_devices = self._cpu_devices()
|
||||||
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
|
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
||||||
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
|
[[0, 0], [1, 1], [2, 2], [3, 3]], sess)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(
|
||||||
|
mode=["graph"],
|
||||||
|
input_type=["input_fn", "dataset"],
|
||||||
|
required_gpus=1))
|
||||||
|
def testTwoDevicesPerWorker(self, input_type):
|
||||||
|
worker_devices = self._cpu_and_one_gpu_devices()
|
||||||
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
|
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
||||||
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
|
[[0, 1, 0, 1], [2, 3, 2, 3]], sess)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(
|
||||||
|
mode=["graph"],
|
||||||
|
input_type=["input_fn", "dataset"]))
|
||||||
|
def testTupleDataset(self, input_type):
|
||||||
|
worker_devices = self._cpu_devices()
|
||||||
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
|
def dataset_fn():
|
||||||
|
dataset1 = dataset_ops.Dataset.range(4)
|
||||||
|
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
|
||||||
|
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
|
|
||||||
|
expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)]
|
||||||
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
|
expected_values, sess)
|
||||||
|
|
||||||
|
|
||||||
|
class SplitDatasetBatchTest(test.TestCase):
|
||||||
|
|
||||||
|
def testBatchDataset(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(100).batch(20)
|
||||||
|
split_batch_by = 2
|
||||||
|
result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by)
|
||||||
|
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
|
||||||
|
result = [self.evaluate(el) for el in result_dataset]
|
||||||
|
self.assertAllEqual(expected_values, result)
|
||||||
|
|
||||||
|
def testMapAndBatchDataset(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(100)
|
||||||
|
dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20))
|
||||||
|
split_batch_by = 2
|
||||||
|
result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by)
|
||||||
|
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
|
||||||
|
result = [self.evaluate(el) for el in result_dataset]
|
||||||
|
self.assertAllEqual(expected_values, result)
|
||||||
|
|
||||||
|
def testPrefetchDataset(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1)
|
||||||
|
split_batch_by = 2
|
||||||
|
result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by)
|
||||||
|
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
|
||||||
|
result = [self.evaluate(el) for el in result_dataset]
|
||||||
|
self.assertAllEqual(expected_values, result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -21,8 +21,8 @@ from __future__ import print_function
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import mirrored_strategy
|
from tensorflow.python.distribute import mirrored_strategy
|
||||||
from tensorflow.python.distribute import values
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access,invalid-name
|
# pylint: disable=protected-access,invalid-name
|
||||||
@ -135,14 +135,14 @@ class MirroredExtended(CoreMirroredExtended):
|
|||||||
Returns:
|
Returns:
|
||||||
An `InputIterator` which returns inputs for each step of the computation.
|
An `InputIterator` which returns inputs for each step of the computation.
|
||||||
"""
|
"""
|
||||||
return values.DatasetIterator(dataset, self._input_workers)
|
return input_lib.DatasetIterator(dataset, self._input_workers)
|
||||||
|
|
||||||
def _distribute_dataset(self, dataset_fn):
|
def _distribute_dataset(self, dataset_fn):
|
||||||
if self._local_mode:
|
if self._local_mode:
|
||||||
return values.PerReplicaDataset(
|
return input_lib.PerReplicaDataset(
|
||||||
self._call_dataset_fn(dataset_fn), self._input_workers, 0)
|
self._call_dataset_fn(dataset_fn), self._input_workers, 0)
|
||||||
else:
|
else:
|
||||||
return values.MultiWorkerDataset(
|
return input_lib.MultiWorkerDataset(
|
||||||
functools.partial(self._call_dataset_fn, dataset_fn),
|
functools.partial(self._call_dataset_fn, dataset_fn),
|
||||||
self._input_workers,
|
self._input_workers,
|
||||||
auto_shard=self._auto_shard_dataset)
|
auto_shard=self._auto_shard_dataset)
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -52,7 +53,8 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
worker = device_util.canonicalize("/device:CPU:0")
|
worker = device_util.canonicalize("/device:CPU:0")
|
||||||
worker_device_pairs = [(worker, [self._device])]
|
worker_device_pairs = [(worker, [self._device])]
|
||||||
device_map = values.SingleDeviceMap(device)
|
device_map = values.SingleDeviceMap(device)
|
||||||
self._input_workers = values.InputWorkers(device_map, worker_device_pairs)
|
self._input_workers = input_lib.InputWorkers(
|
||||||
|
device_map, worker_device_pairs)
|
||||||
|
|
||||||
def _create_variable(self, next_creator, *args, **kwargs):
|
def _create_variable(self, next_creator, *args, **kwargs):
|
||||||
colocate_with = kwargs.pop("colocate_with", None)
|
colocate_with = kwargs.pop("colocate_with", None)
|
||||||
@ -67,17 +69,17 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
|
|
||||||
def _make_dataset_iterator(self, dataset):
|
def _make_dataset_iterator(self, dataset):
|
||||||
"""Make iterator from dataset without splitting the batch."""
|
"""Make iterator from dataset without splitting the batch."""
|
||||||
return values.DatasetIterator(dataset, self._input_workers)
|
return input_lib.DatasetIterator(dataset, self._input_workers)
|
||||||
|
|
||||||
def _distribute_dataset(self, dataset_fn):
|
def _distribute_dataset(self, dataset_fn):
|
||||||
return values.PerReplicaDataset(
|
return input_lib.PerReplicaDataset(
|
||||||
self._call_dataset_fn(dataset_fn), self._input_workers, 0)
|
self._call_dataset_fn(dataset_fn), self._input_workers, 0)
|
||||||
|
|
||||||
def _make_input_fn_iterator(
|
def _make_input_fn_iterator(
|
||||||
self,
|
self,
|
||||||
input_fn,
|
input_fn,
|
||||||
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
|
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
|
||||||
return values.InputFunctionIterator(
|
return input_lib.InputFunctionIterator(
|
||||||
input_fn, self._input_workers, [distribute_lib.InputContext()])
|
input_fn, self._input_workers, [distribute_lib.InputContext()])
|
||||||
|
|
||||||
def _broadcast_to(self, tensor, destinations):
|
def _broadcast_to(self, tensor, destinations):
|
||||||
@ -91,7 +93,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
initial_loop_values = {}
|
initial_loop_values = {}
|
||||||
initial_loop_values = nest.flatten(initial_loop_values)
|
initial_loop_values = nest.flatten(initial_loop_values)
|
||||||
|
|
||||||
ctx = values.MultiStepContext()
|
ctx = input_lib.MultiStepContext()
|
||||||
def body(i, *args):
|
def body(i, *args):
|
||||||
"""A wrapper around `fn` to create the while loop body."""
|
"""A wrapper around `fn` to create the while loop body."""
|
||||||
del args
|
del args
|
||||||
|
@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python import mirrored_strategy
|
|||||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -153,7 +154,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
compute_devices = (worker_device,)
|
compute_devices = (worker_device,)
|
||||||
|
|
||||||
self._device_map = values.ReplicaDeviceMap(compute_devices)
|
self._device_map = values.ReplicaDeviceMap(compute_devices)
|
||||||
self._input_workers = values.InputWorkers(
|
self._input_workers = input_lib.InputWorkers(
|
||||||
self._device_map, [(worker_device, compute_devices)])
|
self._device_map, [(worker_device, compute_devices)])
|
||||||
|
|
||||||
# In distributed mode, place variables on ps jobs in a round-robin fashion.
|
# In distributed mode, place variables on ps jobs in a round-robin fashion.
|
||||||
@ -210,7 +211,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
compute_devices = (_LOCAL_CPU,)
|
compute_devices = (_LOCAL_CPU,)
|
||||||
|
|
||||||
self._device_map = values.ReplicaDeviceMap(compute_devices)
|
self._device_map = values.ReplicaDeviceMap(compute_devices)
|
||||||
self._input_workers = values.InputWorkers(
|
self._input_workers = input_lib.InputWorkers(
|
||||||
self._device_map, [(worker_device, compute_devices)])
|
self._device_map, [(worker_device, compute_devices)])
|
||||||
|
|
||||||
# If there is only one GPU, put everything on that GPU. Otherwise, place
|
# If there is only one GPU, put everything on that GPU. Otherwise, place
|
||||||
@ -237,13 +238,13 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
|
|
||||||
def _distribute_dataset(self, dataset_fn):
|
def _distribute_dataset(self, dataset_fn):
|
||||||
"""Distributes the dataset to each local GPU."""
|
"""Distributes the dataset to each local GPU."""
|
||||||
return values.PerReplicaDataset(
|
return input_lib.PerReplicaDataset(
|
||||||
self._call_dataset_fn(dataset_fn), self._input_workers, 0,
|
self._call_dataset_fn(dataset_fn), self._input_workers, 0,
|
||||||
prefetch_on_device=True)
|
prefetch_on_device=True)
|
||||||
|
|
||||||
def _make_dataset_iterator(self, dataset):
|
def _make_dataset_iterator(self, dataset):
|
||||||
return values.DatasetIterator(dataset, self._input_workers,
|
return input_lib.DatasetIterator(dataset, self._input_workers,
|
||||||
self._num_replicas_in_sync)
|
self._num_replicas_in_sync)
|
||||||
|
|
||||||
def _make_input_fn_iterator(
|
def _make_input_fn_iterator(
|
||||||
self,
|
self,
|
||||||
@ -262,7 +263,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
num_input_pipelines=num_input_pipelines,
|
num_input_pipelines=num_input_pipelines,
|
||||||
input_pipeline_id=input_pipeline_id,
|
input_pipeline_id=input_pipeline_id,
|
||||||
num_replicas_in_sync=self._num_replicas_in_sync)
|
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||||
return values.InputFunctionIterator(
|
return input_lib.InputFunctionIterator(
|
||||||
input_fn, self._input_workers, [input_context])
|
input_fn, self._input_workers, [input_context])
|
||||||
|
|
||||||
def _broadcast_to(self, tensor, destinations):
|
def _broadcast_to(self, tensor, destinations):
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.client import session as session_lib
|
|||||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver_lib
|
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver_lib
|
||||||
@ -204,7 +205,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
(self.get_host(hid), [self.get_host_cpu_device(hid)])
|
(self.get_host(hid), [self.get_host_cpu_device(hid)])
|
||||||
for hid in range(self.num_hosts)
|
for hid in range(self.num_hosts)
|
||||||
]
|
]
|
||||||
self._input_workers = values.InputWorkers(input_device_map, worker_devices)
|
self._input_workers = input_lib.InputWorkers(
|
||||||
|
input_device_map, worker_devices)
|
||||||
|
|
||||||
# TODO(sourabhbajaj): Remove this once performance of running one step
|
# TODO(sourabhbajaj): Remove this once performance of running one step
|
||||||
# at a time is comparable to multiple steps.
|
# at a time is comparable to multiple steps.
|
||||||
@ -304,11 +306,11 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
def _make_dataset_iterator(self, dataset):
|
def _make_dataset_iterator(self, dataset):
|
||||||
"""Make iterators for each of the TPU hosts."""
|
"""Make iterators for each of the TPU hosts."""
|
||||||
|
|
||||||
return values.DatasetIterator(dataset, self._input_workers,
|
return input_lib.DatasetIterator(dataset, self._input_workers,
|
||||||
self._num_replicas_in_sync)
|
self._num_replicas_in_sync)
|
||||||
|
|
||||||
def _distribute_dataset(self, dataset_fn):
|
def _distribute_dataset(self, dataset_fn):
|
||||||
return values.MultiWorkerDataset(
|
return input_lib.MultiWorkerDataset(
|
||||||
functools.partial(self._call_dataset_fn, dataset_fn),
|
functools.partial(self._call_dataset_fn, dataset_fn),
|
||||||
self._input_workers)
|
self._input_workers)
|
||||||
|
|
||||||
@ -339,7 +341,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
if initial_loop_values is None:
|
if initial_loop_values is None:
|
||||||
initial_loop_values = {}
|
initial_loop_values = {}
|
||||||
initial_loop_values = nest.flatten(initial_loop_values)
|
initial_loop_values = nest.flatten(initial_loop_values)
|
||||||
ctx = values.MultiStepContext()
|
ctx = input_lib.MultiStepContext()
|
||||||
|
|
||||||
def run_fn(*args, **kwargs):
|
def run_fn(*args, **kwargs):
|
||||||
"""Single step on the TPU device."""
|
"""Single step on the TPU device."""
|
||||||
|
@ -22,28 +22,20 @@ import os
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.contrib.distribute.python import combinations
|
from tensorflow.contrib.distribute.python import combinations
|
||||||
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
|
||||||
from tensorflow.python.ops import random_ops
|
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
from tensorflow.python.util import nest
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedValuesTest(test.TestCase):
|
class DistributedValuesTest(test.TestCase):
|
||||||
@ -354,444 +346,6 @@ class RegroupAndSelectDeviceTest(test.TestCase):
|
|||||||
merged_estimator_spec))
|
merged_estimator_spec))
|
||||||
|
|
||||||
|
|
||||||
class PerReplicaDatasetTest(test.TestCase):
|
|
||||||
|
|
||||||
config = config_pb2.ConfigProto()
|
|
||||||
config.allow_soft_placement = True
|
|
||||||
|
|
||||||
def _test_iterator(self, devices, dataset, expected_values):
|
|
||||||
device_map = values.ReplicaDeviceMap(devices)
|
|
||||||
input_workers = values.InputWorkers(device_map)
|
|
||||||
per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0)
|
|
||||||
if context.executing_eagerly():
|
|
||||||
iterator = per_replica_dataset.make_one_shot_iterator()
|
|
||||||
else:
|
|
||||||
iterator = per_replica_dataset.make_initializable_iterator()
|
|
||||||
self.evaluate([iterator.initializer])
|
|
||||||
|
|
||||||
for expected_value in expected_values:
|
|
||||||
next_element = iterator.get_next_as_list()
|
|
||||||
computed_value = self.evaluate(next_element)
|
|
||||||
self.assertEqual(expected_value, computed_value)
|
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
next_element = iterator.get_next_as_list()
|
|
||||||
self.evaluate(next_element)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testOneDevice(self):
|
|
||||||
devices = ["/device:CPU:0"]
|
|
||||||
dataset = dataset_ops.Dataset.range(10)
|
|
||||||
|
|
||||||
expected_values = [[i] for i in range(10)]
|
|
||||||
|
|
||||||
self._test_iterator(devices, dataset, expected_values)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
|
||||||
def testMultipleDevices(self):
|
|
||||||
if context.num_gpus() < 1 and context.executing_eagerly():
|
|
||||||
self.skipTest("A GPU is not available for this test in eager mode.")
|
|
||||||
|
|
||||||
devices = ["/device:CPU:0", "/device:GPU:0"]
|
|
||||||
dataset = dataset_ops.Dataset.range(10)
|
|
||||||
|
|
||||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
|
||||||
|
|
||||||
self._test_iterator(devices, dataset, expected_values)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
|
||||||
def testTupleDataset(self):
|
|
||||||
if context.num_gpus() < 1 and context.executing_eagerly():
|
|
||||||
self.skipTest("A GPU is not available for this test in eager mode.")
|
|
||||||
|
|
||||||
devices = ["/device:CPU:0", "/device:GPU:0"]
|
|
||||||
dataset1 = dataset_ops.Dataset.range(10)
|
|
||||||
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
|
||||||
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
|
||||||
|
|
||||||
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
|
|
||||||
|
|
||||||
self._test_iterator(devices, dataset, expected_values)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
|
||||||
def testUnevenDatasetBatches(self):
|
|
||||||
if context.num_gpus() < 1 and context.executing_eagerly():
|
|
||||||
self.skipTest("A GPU is not available for this test in eager mode.")
|
|
||||||
|
|
||||||
devices = ["/device:CPU:0", "/device:GPU:0"]
|
|
||||||
dataset = dataset_ops.Dataset.range(11)
|
|
||||||
|
|
||||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
|
||||||
self._test_iterator(devices, dataset, expected_values)
|
|
||||||
|
|
||||||
def testInitializableIterator(self):
|
|
||||||
with context.graph_mode():
|
|
||||||
devices = ["/device:CPU:0"]
|
|
||||||
# Using random input since that is only allowed with initializable
|
|
||||||
# iterator.
|
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
|
||||||
random_ops.random_uniform((10,)))
|
|
||||||
|
|
||||||
device_map = values.ReplicaDeviceMap(devices)
|
|
||||||
input_workers = values.InputWorkers(device_map)
|
|
||||||
per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0)
|
|
||||||
iterator = per_replica_dataset.make_initializable_iterator()
|
|
||||||
|
|
||||||
self.evaluate(iterator.initializer)
|
|
||||||
next_element = iterator.get_next_as_list()
|
|
||||||
for _ in range(10):
|
|
||||||
self.evaluate(next_element)
|
|
||||||
|
|
||||||
# Should fail after the input is finished.
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
self.evaluate(next_element)
|
|
||||||
|
|
||||||
# After re-initializing the iterator, should be able to iterate again.
|
|
||||||
self.evaluate(iterator.initializer)
|
|
||||||
for _ in range(10):
|
|
||||||
self.evaluate(next_element)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
|
|
||||||
|
|
||||||
def _test_iterator(self, sess, iterator, devices, expected_values):
|
|
||||||
next_element = iterator.get_next()
|
|
||||||
for r, device in enumerate(devices):
|
|
||||||
v = values.select_replica(r, next_element)
|
|
||||||
# The `v` here can be a tuple.
|
|
||||||
for element in nest.flatten(v):
|
|
||||||
self.assertTrue(element.device in device)
|
|
||||||
|
|
||||||
for expected_value in expected_values:
|
|
||||||
t = [values.select_replica(r, next_element) for r in range(len(devices))]
|
|
||||||
actual = sess.run(t)
|
|
||||||
self.assertEqual(expected_value, actual)
|
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
sess.run([values.select_replica(r, next_element)
|
|
||||||
for r in range(len(devices))])
|
|
||||||
|
|
||||||
def _test_dataset(self, dataset_fn, worker_devices, devices,
|
|
||||||
expected_values):
|
|
||||||
device_map = values.ReplicaDeviceMap(devices)
|
|
||||||
input_workers = values.InputWorkers(device_map, worker_devices)
|
|
||||||
multi_worker_dataset = values.MultiWorkerDataset(
|
|
||||||
dataset_fn, input_workers)
|
|
||||||
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
sess.run(multi_worker_iterator.initializer)
|
|
||||||
self._test_iterator(sess, multi_worker_iterator, devices, expected_values)
|
|
||||||
|
|
||||||
def _cpu_devices(self):
|
|
||||||
worker_devices = (
|
|
||||||
("/job:worker/replica:0/task:0",
|
|
||||||
["/job:worker/replica:0/task:0/device:CPU:0"]),
|
|
||||||
("/job:worker/replica:0/task:1",
|
|
||||||
["/job:worker/replica:0/task:1/device:CPU:0"])
|
|
||||||
)
|
|
||||||
devices = [
|
|
||||||
"/job:worker/replica:0/task:0/device:CPU:0",
|
|
||||||
"/job:worker/replica:0/task:1/device:CPU:0"
|
|
||||||
]
|
|
||||||
return worker_devices, devices
|
|
||||||
|
|
||||||
def _cpu_and_one_gpu_devices(self):
|
|
||||||
worker_devices = (
|
|
||||||
("/job:worker/replica:0/task:0", (
|
|
||||||
"/job:worker/replica:0/task:0/device:GPU:0",
|
|
||||||
"/job:worker/replica:0/task:0/device:CPU:0"
|
|
||||||
)),
|
|
||||||
("/job:worker/replica:0/task:1", (
|
|
||||||
"/job:worker/replica:0/task:1/device:GPU:0",
|
|
||||||
"/job:worker/replica:0/task:1/device:CPU:0"
|
|
||||||
))
|
|
||||||
)
|
|
||||||
devices = [
|
|
||||||
"/job:worker/replica:0/task:0/device:GPU:0",
|
|
||||||
"/job:worker/replica:0/task:0/device:CPU:0",
|
|
||||||
"/job:worker/replica:0/task:1/device:GPU:0",
|
|
||||||
"/job:worker/replica:0/task:1/device:CPU:0"
|
|
||||||
]
|
|
||||||
return worker_devices, devices
|
|
||||||
|
|
||||||
def testDataDistributionOneDevicePerWorker(self):
|
|
||||||
worker_devices, devices = self._cpu_devices()
|
|
||||||
with context.graph_mode():
|
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(8)
|
|
||||||
self._test_dataset(
|
|
||||||
dataset_fn, worker_devices, devices,
|
|
||||||
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
|
|
||||||
|
|
||||||
def testDataDistributionTwoDevicePerWorker(self):
|
|
||||||
if context.num_gpus() < 1:
|
|
||||||
self.skipTest("A GPU is not available for this test.")
|
|
||||||
worker_devices, devices = self._cpu_and_one_gpu_devices()
|
|
||||||
with context.graph_mode():
|
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(8)
|
|
||||||
self._test_dataset(
|
|
||||||
dataset_fn, worker_devices, devices,
|
|
||||||
[[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]])
|
|
||||||
|
|
||||||
def testTupleDataset(self):
|
|
||||||
worker_devices, devices = self._cpu_devices()
|
|
||||||
|
|
||||||
with context.graph_mode():
|
|
||||||
|
|
||||||
def dataset_fn():
|
|
||||||
dataset1 = dataset_ops.Dataset.range(8)
|
|
||||||
dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2)
|
|
||||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
|
||||||
|
|
||||||
expected_values = [[(i, i**2), (i, i**2)] for i in range(8)]
|
|
||||||
self._test_dataset(dataset_fn, worker_devices, devices,
|
|
||||||
expected_values)
|
|
||||||
|
|
||||||
def testInitializableIterator(self):
|
|
||||||
worker_devices, devices = self._cpu_devices()
|
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(8)
|
|
||||||
device_map = values.ReplicaDeviceMap(devices)
|
|
||||||
input_workers = values.InputWorkers(device_map, worker_devices)
|
|
||||||
multi_worker_dataset = values.MultiWorkerDataset(
|
|
||||||
dataset_fn, input_workers)
|
|
||||||
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
|
|
||||||
|
|
||||||
sess.run(multi_worker_iterator.initializer)
|
|
||||||
self._test_iterator(
|
|
||||||
sess, multi_worker_iterator, devices,
|
|
||||||
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
|
|
||||||
|
|
||||||
# After re-initializing the iterator, should be able to iterate again.
|
|
||||||
sess.run(multi_worker_iterator.initializer)
|
|
||||||
self._test_iterator(
|
|
||||||
sess, multi_worker_iterator, devices,
|
|
||||||
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
|
|
||||||
|
|
||||||
def testValueErrorForIterator(self):
|
|
||||||
# Incompatiable arguments.
|
|
||||||
d1 = "/device:GPU:0"
|
|
||||||
d2 = "/device:GPU:1"
|
|
||||||
device_map = values.ReplicaDeviceMap([d1, d2])
|
|
||||||
input_workers = values.InputWorkers(
|
|
||||||
device_map, (("w1", (d1,)), ("w2", (d2,))))
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
values.MultiWorkerDataIterator([("w1", None)], input_workers)
|
|
||||||
|
|
||||||
def testDuplicateDevices(self):
|
|
||||||
_, devices = self._cpu_devices()
|
|
||||||
devices.append("/job:worker/replica:0/task:0/device:CPU:0")
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = values.ReplicaDeviceMap(devices)
|
|
||||||
|
|
||||||
|
|
||||||
class InputIteratorTestBase(test.TestCase):
|
|
||||||
|
|
||||||
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
|
|
||||||
expected_values, sess=None, split_batch_by=None):
|
|
||||||
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
|
||||||
device_map = values.ReplicaDeviceMap(devices)
|
|
||||||
input_workers = values.InputWorkers(device_map, worker_device_pairs)
|
|
||||||
|
|
||||||
if input_type == "input_fn":
|
|
||||||
input_contexts = [
|
|
||||||
distribute_lib.InputContext() for _ in worker_device_pairs]
|
|
||||||
input_fn = lambda _: dataset_fn()
|
|
||||||
iterator = values.InputFunctionIterator(
|
|
||||||
input_fn, input_workers, input_contexts)
|
|
||||||
else:
|
|
||||||
iterator = values.DatasetIterator(
|
|
||||||
dataset_fn(), input_workers, split_batch_by)
|
|
||||||
|
|
||||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
|
||||||
|
|
||||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
|
||||||
|
|
||||||
for expected_value in expected_values:
|
|
||||||
next_element = iterator.get_next()
|
|
||||||
computed_value = evaluate(
|
|
||||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
|
||||||
self.assertAllEqual(expected_value, computed_value)
|
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
next_element = iterator.get_next()
|
|
||||||
evaluate([values.select_replica(r, next_element)
|
|
||||||
for r in range(len(devices))])
|
|
||||||
|
|
||||||
# After re-initializing the iterator, should be able to iterate again.
|
|
||||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
|
||||||
|
|
||||||
for expected_value in expected_values:
|
|
||||||
next_element = iterator.get_next()
|
|
||||||
computed_value = evaluate(
|
|
||||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
|
||||||
self.assertAllEqual(expected_value, computed_value)
|
|
||||||
|
|
||||||
|
|
||||||
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|
||||||
parameterized.TestCase):
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(
|
|
||||||
mode=["graph", "eager"],
|
|
||||||
input_type=["input_fn", "dataset"]))
|
|
||||||
def testOneDeviceCPU(self, input_type):
|
|
||||||
worker_device_pairs = [("", ["/device:CPU:0"])]
|
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
|
||||||
|
|
||||||
expected_values = [[i] for i in range(10)]
|
|
||||||
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
|
||||||
expected_values)
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(
|
|
||||||
mode=["graph", "eager"],
|
|
||||||
input_type=["input_fn", "dataset"],
|
|
||||||
required_gpus=1))
|
|
||||||
def testTwoDevicesOneGPUOneCPU(self, input_type):
|
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
|
||||||
|
|
||||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
|
||||||
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
|
||||||
expected_values)
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(
|
|
||||||
mode=["graph", "eager"],
|
|
||||||
input_type=["input_fn", "dataset"],
|
|
||||||
required_gpus=1))
|
|
||||||
def testTupleDataset(self, input_type):
|
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
|
||||||
def dataset_fn():
|
|
||||||
dataset1 = dataset_ops.Dataset.range(10)
|
|
||||||
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
|
||||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
|
||||||
|
|
||||||
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
|
|
||||||
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
|
||||||
expected_values)
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(
|
|
||||||
mode=["graph", "eager"],
|
|
||||||
input_type=["input_fn", "dataset"],
|
|
||||||
required_gpus=1))
|
|
||||||
def testUnevenDatasetBatches(self, input_type):
|
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(11)
|
|
||||||
|
|
||||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
|
||||||
expected_values)
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(
|
|
||||||
mode=["graph", "eager"],
|
|
||||||
input_type=["dataset"],
|
|
||||||
split_batch_by=[None, 2],
|
|
||||||
required_gpus=1))
|
|
||||||
def testBatchSplitting(self, input_type, split_batch_by):
|
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
|
||||||
batch_size = 10
|
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size)
|
|
||||||
|
|
||||||
updated_batch_size = (
|
|
||||||
batch_size // split_batch_by if split_batch_by else batch_size)
|
|
||||||
expected_values = [[range(i, i+updated_batch_size),
|
|
||||||
range(i+updated_batch_size, i+2*updated_batch_size)]
|
|
||||||
for i in range(0, 100, updated_batch_size*2)]
|
|
||||||
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
|
||||||
expected_values, sess=None,
|
|
||||||
split_batch_by=split_batch_by)
|
|
||||||
|
|
||||||
|
|
||||||
class InputIteratorMultiWorkerTest(
|
|
||||||
multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase,
|
|
||||||
parameterized.TestCase):
|
|
||||||
|
|
||||||
def _cpu_devices(self):
|
|
||||||
return [
|
|
||||||
("/job:worker/replica:0/task:0",
|
|
||||||
["/job:worker/replica:0/task:0/device:CPU:0"]),
|
|
||||||
("/job:worker/replica:0/task:1",
|
|
||||||
["/job:worker/replica:0/task:1/device:CPU:0"])]
|
|
||||||
|
|
||||||
def _cpu_and_one_gpu_devices(self):
|
|
||||||
return [
|
|
||||||
("/job:worker/replica:0/task:0", [
|
|
||||||
"/job:worker/replica:0/task:0/device:GPU:0",
|
|
||||||
"/job:worker/replica:0/task:0/device:CPU:0"
|
|
||||||
]),
|
|
||||||
("/job:worker/replica:0/task:1", [
|
|
||||||
"/job:worker/replica:0/task:1/device:GPU:0",
|
|
||||||
"/job:worker/replica:0/task:1/device:CPU:0"
|
|
||||||
])
|
|
||||||
]
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(
|
|
||||||
mode=["graph"],
|
|
||||||
input_type=["input_fn", "dataset"]))
|
|
||||||
def testOneDevicePerWorker(self, input_type):
|
|
||||||
worker_devices = self._cpu_devices()
|
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
|
||||||
[[0, 0], [1, 1], [2, 2], [3, 3]], sess)
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(
|
|
||||||
mode=["graph"],
|
|
||||||
input_type=["input_fn", "dataset"],
|
|
||||||
required_gpus=1))
|
|
||||||
def testTwoDevicesPerWorker(self, input_type):
|
|
||||||
worker_devices = self._cpu_and_one_gpu_devices()
|
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
|
||||||
[[0, 1, 0, 1], [2, 3, 2, 3]], sess)
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(
|
|
||||||
mode=["graph"],
|
|
||||||
input_type=["input_fn", "dataset"]))
|
|
||||||
def testTupleDataset(self, input_type):
|
|
||||||
worker_devices = self._cpu_devices()
|
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
|
||||||
def dataset_fn():
|
|
||||||
dataset1 = dataset_ops.Dataset.range(4)
|
|
||||||
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
|
|
||||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
|
||||||
|
|
||||||
expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)]
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
|
||||||
expected_values, sess)
|
|
||||||
|
|
||||||
|
|
||||||
class SplitDatasetBatchTest(test.TestCase):
|
|
||||||
|
|
||||||
def testBatchDataset(self):
|
|
||||||
dataset = dataset_ops.Dataset.range(100).batch(20)
|
|
||||||
split_batch_by = 2
|
|
||||||
result_dataset = values._split_dataset_batch(dataset, split_batch_by)
|
|
||||||
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
|
|
||||||
result = [self.evaluate(el) for el in result_dataset]
|
|
||||||
self.assertAllEqual(expected_values, result)
|
|
||||||
|
|
||||||
def testMapAndBatchDataset(self):
|
|
||||||
dataset = dataset_ops.Dataset.range(100)
|
|
||||||
dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20))
|
|
||||||
split_batch_by = 2
|
|
||||||
result_dataset = values._split_dataset_batch(dataset, split_batch_by)
|
|
||||||
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
|
|
||||||
result = [self.evaluate(el) for el in result_dataset]
|
|
||||||
self.assertAllEqual(expected_values, result)
|
|
||||||
|
|
||||||
def testPrefetchDataset(self):
|
|
||||||
dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1)
|
|
||||||
split_batch_by = 2
|
|
||||||
result_dataset = values._split_dataset_batch(dataset, split_batch_by)
|
|
||||||
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
|
|
||||||
result = [self.evaluate(el) for el in result_dataset]
|
|
||||||
self.assertAllEqual(expected_values, result)
|
|
||||||
|
|
||||||
|
|
||||||
class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
config = config_pb2.ConfigProto()
|
config = config_pb2.ConfigProto()
|
||||||
|
@ -219,6 +219,7 @@ py_library(
|
|||||||
":cross_device_ops",
|
":cross_device_ops",
|
||||||
":device_util",
|
":device_util",
|
||||||
":distribute_lib",
|
":distribute_lib",
|
||||||
|
":input_lib",
|
||||||
":multi_worker_util",
|
":multi_worker_util",
|
||||||
":reduce_util",
|
":reduce_util",
|
||||||
":shared_variable_creator",
|
":shared_variable_creator",
|
||||||
@ -253,6 +254,23 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "input_lib",
|
||||||
|
srcs = ["input_lib.py"],
|
||||||
|
deps = [
|
||||||
|
":device_util",
|
||||||
|
":distribute_lib",
|
||||||
|
":input_ops",
|
||||||
|
":values",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:control_flow_ops",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/data/ops:multi_device_iterator_ops",
|
||||||
|
"//tensorflow/python/eager:context",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "input_ops",
|
name = "input_ops",
|
||||||
srcs = ["input_ops.py"],
|
srcs = ["input_ops.py"],
|
||||||
@ -348,14 +366,12 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":device_util",
|
":device_util",
|
||||||
":distribute_lib",
|
":distribute_lib",
|
||||||
":input_ops",
|
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/data/ops:multi_device_iterator_ops",
|
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/training/checkpointable:base",
|
"//tensorflow/python/training/checkpointable:base",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
|
707
tensorflow/python/distribute/input_lib.py
Normal file
707
tensorflow/python/distribute/input_lib.py
Normal file
@ -0,0 +1,707 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Various classes representing distributed inputs."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.ops import batching
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||||
|
from tensorflow.python.distribute import device_util
|
||||||
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
|
from tensorflow.python.distribute import input_ops
|
||||||
|
from tensorflow.python.distribute import values
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import device as tf_device
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
class InputWorkers(object):
|
||||||
|
"""A 1-to-many mapping from input worker devices to compute devices."""
|
||||||
|
|
||||||
|
def __init__(self, device_map, worker_device_pairs=None, logical_device=0):
|
||||||
|
"""Initialize an `InputWorkers` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_map: A `DeviceMap` with the computation devices fed by the
|
||||||
|
input workers.
|
||||||
|
worker_device_pairs: A sequence of pairs:
|
||||||
|
`(input device, a tuple of compute devices fed by that input device)`.
|
||||||
|
logical_device: The logical device of `device_map` to feed.
|
||||||
|
"""
|
||||||
|
self._device_map = device_map
|
||||||
|
self._logical_device = logical_device
|
||||||
|
if worker_device_pairs is None:
|
||||||
|
worker_device_pairs = ((
|
||||||
|
device_util.canonicalize("/device:CPU:0"),
|
||||||
|
device_map.logical_to_actual_devices(logical_device)),)
|
||||||
|
self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
|
||||||
|
self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
|
||||||
|
for _, f in worker_device_pairs)
|
||||||
|
flattened = tuple(d for l in self._fed_devices for d in l)
|
||||||
|
assert (flattened ==
|
||||||
|
device_map.logical_to_actual_devices(logical_device)), (
|
||||||
|
"flattened: %s logical device %d: %s" %
|
||||||
|
(flattened, logical_device,
|
||||||
|
device_map.logical_to_actual_devices(logical_device)))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device_map(self):
|
||||||
|
return self._device_map
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logical_device(self):
|
||||||
|
return self._logical_device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_workers(self):
|
||||||
|
return len(self._input_worker_devices)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def worker_devices(self):
|
||||||
|
return self._input_worker_devices
|
||||||
|
|
||||||
|
def compute_devices_for_worker(self, worker_index):
|
||||||
|
return self._fed_devices[worker_index]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
devices = self.worker_devices
|
||||||
|
debug_repr = ",\n".join(" %d %s: %s" %
|
||||||
|
(i, devices[i], self._fed_devices[i])
|
||||||
|
for i in range(len(devices)))
|
||||||
|
return "%s:{\n%s\n device_map: %s}" % (
|
||||||
|
self.__class__.__name__, debug_repr, self._device_map)
|
||||||
|
|
||||||
|
|
||||||
|
class PerReplicaDataIterator(object):
|
||||||
|
"""An iterator (like `tf.data.Iterator`) into a `PerReplicaDataset`."""
|
||||||
|
|
||||||
|
def __init__(self, iterator, input_workers, worker_index, prefetch_on_device):
|
||||||
|
assert isinstance(input_workers, InputWorkers)
|
||||||
|
self._iterator = iterator
|
||||||
|
self._input_workers = input_workers
|
||||||
|
self._worker_index = worker_index
|
||||||
|
self._prefetch_on_device = prefetch_on_device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def initializer(self):
|
||||||
|
return self._iterator.initializer
|
||||||
|
|
||||||
|
def get_next_as_list(self, name=None):
|
||||||
|
"""Scatter the input across devices."""
|
||||||
|
if self._prefetch_on_device:
|
||||||
|
data_list = self._iterator.get_next()
|
||||||
|
else:
|
||||||
|
batch = self._iterator.get_next(name=name)
|
||||||
|
data_list = []
|
||||||
|
def get_ith(i):
|
||||||
|
return lambda x: x[i]
|
||||||
|
|
||||||
|
devices = self._input_workers.compute_devices_for_worker(
|
||||||
|
self._worker_index)
|
||||||
|
for i, d in enumerate(devices):
|
||||||
|
v = nest.map_structure(get_ith(i), batch)
|
||||||
|
if context.executing_eagerly():
|
||||||
|
with ops.device(d):
|
||||||
|
v = nest.map_structure(array_ops.identity, v)
|
||||||
|
data_list.append(v)
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
def get_next(self, name=None):
|
||||||
|
assert self._input_workers.num_workers == 1
|
||||||
|
data_list = self.get_next_as_list(name)
|
||||||
|
return values.regroup(self._input_workers.device_map, data_list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_classes(self):
|
||||||
|
return self._iterator.output_classes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shapes(self):
|
||||||
|
return self._iterator.output_shapes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return self._iterator.output_types
|
||||||
|
|
||||||
|
|
||||||
|
class PerReplicaDataset(object):
|
||||||
|
"""Like `tf.data.Dataset` split devices, producing `PerReplica` data."""
|
||||||
|
|
||||||
|
def __init__(self, dataset, input_workers, worker_index,
|
||||||
|
prefetch_on_device=None):
|
||||||
|
assert isinstance(input_workers, InputWorkers)
|
||||||
|
assert worker_index is not None
|
||||||
|
assert worker_index is not True # pylint: disable=g-bool-id-comparison
|
||||||
|
assert worker_index is not False # pylint: disable=g-bool-id-comparison
|
||||||
|
self._input_workers = input_workers
|
||||||
|
self._worker_index = worker_index
|
||||||
|
|
||||||
|
# Default to using prefetching, unless specified.
|
||||||
|
self._prefetch_on_device = prefetch_on_device
|
||||||
|
if self._prefetch_on_device is None:
|
||||||
|
self._prefetch_on_device = True
|
||||||
|
|
||||||
|
self._dataset = dataset
|
||||||
|
if not self._prefetch_on_device:
|
||||||
|
# TODO(priyag): If dropping remainder is not appropriate, find another
|
||||||
|
# approach to distributing the dataset when not possible to divide evenly.
|
||||||
|
# Possibly not an issue when we start using PartitionedDataset.
|
||||||
|
num_replicas = len(
|
||||||
|
self._input_workers.compute_devices_for_worker(self._worker_index))
|
||||||
|
self._dataset = self._dataset.batch(num_replicas, drop_remainder=True)
|
||||||
|
else:
|
||||||
|
self._replica_devices = self._input_workers.compute_devices_for_worker(
|
||||||
|
self._worker_index)
|
||||||
|
|
||||||
|
def make_one_shot_iterator(self):
|
||||||
|
"""Get a one time use iterator for the distributed PerReplicaDataset."""
|
||||||
|
# Graph mode with one shot iterator is disabled.
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
raise ValueError("Cannot create a one shot iterator. Please use "
|
||||||
|
"`make_initializable_iterator()` instead.")
|
||||||
|
if self._prefetch_on_device:
|
||||||
|
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
|
self._dataset, self._replica_devices)
|
||||||
|
else:
|
||||||
|
dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
|
||||||
|
return PerReplicaDataIterator(
|
||||||
|
dataset_iterator,
|
||||||
|
self._input_workers,
|
||||||
|
self._worker_index,
|
||||||
|
prefetch_on_device=self._prefetch_on_device)
|
||||||
|
|
||||||
|
def make_initializable_iterator(self):
|
||||||
|
"""Get an initializable iterator for the distributed PerReplicaDataset."""
|
||||||
|
# Eager mode generates already initialized iterators. Hence we cannot create
|
||||||
|
# an initializable iterator.
|
||||||
|
if context.executing_eagerly():
|
||||||
|
raise ValueError("Cannot create initializable iterator in Eager mode. "
|
||||||
|
"Please use `make_one_shot_iterator` instead.")
|
||||||
|
if self._prefetch_on_device:
|
||||||
|
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
|
self._dataset, self._replica_devices)
|
||||||
|
else:
|
||||||
|
dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset)
|
||||||
|
return PerReplicaDataIterator(
|
||||||
|
dataset_iterator, self._input_workers, self._worker_index,
|
||||||
|
prefetch_on_device=self._prefetch_on_device)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiWorkerDataIterator(object):
|
||||||
|
"""An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`."""
|
||||||
|
|
||||||
|
def __init__(self, iterators, input_workers):
|
||||||
|
"""Initialize the `MultiWorkerDataIterator` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iterators: a list of worker, iterator pairs.
|
||||||
|
input_workers: an `InputWorkers` object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if iterators and input_workers are not compatible.
|
||||||
|
"""
|
||||||
|
assert isinstance(input_workers, InputWorkers)
|
||||||
|
workers = tuple(d for d, _ in iterators)
|
||||||
|
if workers != input_workers.worker_devices:
|
||||||
|
raise ValueError("iterators and input_workers are not compatible. "
|
||||||
|
"iterator workers: %r input_workers devices: %r" %
|
||||||
|
(workers, input_workers.worker_devices))
|
||||||
|
self._iterators = tuple(i for _, i in iterators)
|
||||||
|
self._input_workers = input_workers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def initializer(self):
|
||||||
|
return control_flow_ops.group(
|
||||||
|
tuple(iterator.initializer for iterator in self._iterators))
|
||||||
|
|
||||||
|
def get_iterator(self, worker):
|
||||||
|
for i, w in enumerate(self._input_workers.worker_devices):
|
||||||
|
if worker == w:
|
||||||
|
return self._iterators[i]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shapes(self):
|
||||||
|
return self._iterators[0].output_shapes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return self._iterators[0].output_types
|
||||||
|
|
||||||
|
def get_next(self, name=None):
|
||||||
|
"""Scatter the input across hosts and devices."""
|
||||||
|
replicas = []
|
||||||
|
for worker, iterator in zip(self._input_workers.worker_devices,
|
||||||
|
self._iterators):
|
||||||
|
if name is not None:
|
||||||
|
d = tf_device.DeviceSpec.from_string(worker)
|
||||||
|
new_name = "%s_%s_%d" % (name, d.job, d.task)
|
||||||
|
else:
|
||||||
|
new_name = None
|
||||||
|
with ops.device(worker):
|
||||||
|
data_per_worker = iterator.get_next_as_list(name=new_name)
|
||||||
|
# Append to replicas to get a flat list of values indexed by replica.
|
||||||
|
replicas.extend(data_per_worker)
|
||||||
|
|
||||||
|
return values.regroup(self._input_workers.device_map, replicas)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiWorkerDataset(object):
|
||||||
|
"""Like a `tf.data.Dataset` that distributes data to different workers.
|
||||||
|
|
||||||
|
Each worker gets one shard of the input dataset. This currently does not work
|
||||||
|
in eager mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset_fn, input_workers, prefetch_on_device=None,
|
||||||
|
auto_shard=False):
|
||||||
|
"""Initialize the MultiWorkerDataset object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_fn: a function or a list of functions that returns a
|
||||||
|
`tf.data.Dataset`.
|
||||||
|
input_workers: an `InputWorkers` object.
|
||||||
|
prefetch_on_device: whether to prefetch to devices.
|
||||||
|
auto_shard: whether to auto-shard the dataset.
|
||||||
|
"""
|
||||||
|
assert isinstance(input_workers, InputWorkers)
|
||||||
|
if isinstance(dataset_fn, (list, tuple)):
|
||||||
|
if len(dataset_fn) != input_workers.num_workers:
|
||||||
|
raise ValueError("If `dataset_fn` is a list, it must have one entry "
|
||||||
|
"per worker")
|
||||||
|
# TODO(rohanj): b/120673685 to track re-enabling auto sharding.
|
||||||
|
if auto_shard:
|
||||||
|
raise ValueError("Currently autosharding is not supported.")
|
||||||
|
self._input_workers = input_workers
|
||||||
|
self._datasets = []
|
||||||
|
# TODO(yuefengz, priyag): support different set of jobs for input
|
||||||
|
# processing.
|
||||||
|
for i, worker in enumerate(input_workers.worker_devices):
|
||||||
|
with ops.device(worker):
|
||||||
|
if isinstance(dataset_fn, (list, tuple)):
|
||||||
|
worker_input = dataset_fn[i]()
|
||||||
|
else:
|
||||||
|
worker_input = dataset_fn()
|
||||||
|
dataset = PerReplicaDataset(worker_input, input_workers, i,
|
||||||
|
prefetch_on_device=prefetch_on_device)
|
||||||
|
self._datasets.append((worker, dataset))
|
||||||
|
|
||||||
|
def make_one_shot_iterator(self):
|
||||||
|
iterators = []
|
||||||
|
for worker, dataset in self._datasets:
|
||||||
|
with ops.device(worker):
|
||||||
|
iterators.append((worker, dataset_ops.make_one_shot_iterator(dataset)))
|
||||||
|
return MultiWorkerDataIterator(iterators, self._input_workers)
|
||||||
|
|
||||||
|
def make_initializable_iterator(self):
|
||||||
|
iterators = []
|
||||||
|
for worker, dataset in self._datasets:
|
||||||
|
with ops.device(worker):
|
||||||
|
iterators.append(
|
||||||
|
(worker, dataset_ops.make_initializable_iterator(dataset)))
|
||||||
|
return MultiWorkerDataIterator(iterators, self._input_workers)
|
||||||
|
|
||||||
|
|
||||||
|
class InputIterator(object):
|
||||||
|
"""An input iterator, intended to be passed to `DistributionStrategy.run`."""
|
||||||
|
|
||||||
|
def get_next(self):
|
||||||
|
"""Returns the next inputs for all replicas."""
|
||||||
|
raise NotImplementedError("must be implemented in descendants")
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Initialize the underlying input dataset, when applicable.
|
||||||
|
|
||||||
|
In eager mode, this will create a new iterator and return it.
|
||||||
|
In graph mode, this will initialize the same underlying iterator(s).
|
||||||
|
|
||||||
|
Users are required to call this if
|
||||||
|
- This iterator was returned from a call to `make_input_fn_iterator` with an
|
||||||
|
input function that returns a dataset.
|
||||||
|
- Or this iterator was returned from a call to `make_dataset_iterator`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of initialization ops to be executed.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("must be implemented in descendants")
|
||||||
|
|
||||||
|
|
||||||
|
class InputIteratorImpl(InputIterator):
|
||||||
|
"""Common implementation for all input iterators."""
|
||||||
|
|
||||||
|
def __init__(self, input_workers, iterators):
|
||||||
|
assert isinstance(input_workers, InputWorkers)
|
||||||
|
if not input_workers.worker_devices:
|
||||||
|
raise ValueError("Should have at least one worker for input iterator.")
|
||||||
|
|
||||||
|
self._iterators = iterators
|
||||||
|
self._input_workers = input_workers
|
||||||
|
|
||||||
|
def get_next(self, name=None):
|
||||||
|
"""Returns the next input from the iterator for all replicas."""
|
||||||
|
replicas = []
|
||||||
|
for i, worker in enumerate(self._input_workers.worker_devices):
|
||||||
|
if name is not None:
|
||||||
|
d = tf_device.DeviceSpec.from_string(worker)
|
||||||
|
new_name = "%s_%s_%d" % (name, d.job, d.task)
|
||||||
|
else:
|
||||||
|
new_name = None
|
||||||
|
with ops.device(worker):
|
||||||
|
# Make `replicas` a flat list of values across all replicas.
|
||||||
|
replicas.extend(self._iterators[i].get_next_as_list(new_name))
|
||||||
|
|
||||||
|
return values.regroup(self._input_workers.device_map, replicas)
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Initialze underlying iterators.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of any initializer ops that should be run.
|
||||||
|
"""
|
||||||
|
init_ops = []
|
||||||
|
for it in self._iterators:
|
||||||
|
init_ops.extend(it.initialize())
|
||||||
|
return init_ops
|
||||||
|
|
||||||
|
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
||||||
|
@property
|
||||||
|
def output_classes(self):
|
||||||
|
return self._iterators[0].output_classes
|
||||||
|
|
||||||
|
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
||||||
|
@property
|
||||||
|
def output_shapes(self):
|
||||||
|
return self._iterators[0].output_shapes
|
||||||
|
|
||||||
|
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return self._iterators[0].output_types
|
||||||
|
|
||||||
|
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
||||||
|
def get_iterator(self, worker):
|
||||||
|
for i, w in enumerate(self._input_workers.worker_devices):
|
||||||
|
if worker == w:
|
||||||
|
return self._iterators[i]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class InputFunctionIterator(InputIteratorImpl):
|
||||||
|
"""Iterator created from input function."""
|
||||||
|
|
||||||
|
def __init__(self, input_fn, input_workers, input_contexts):
|
||||||
|
"""Make an iterator for input provided via an input function.
|
||||||
|
|
||||||
|
Currently implements PER_WORKER mode, in which the `input_fn` is called
|
||||||
|
once on each worker.
|
||||||
|
|
||||||
|
TODO(priyag): Add other replication modes.
|
||||||
|
TODO(priyag): Allow taking input function that returns a callable that
|
||||||
|
returns nest of tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_fn: Input function that returns a `tf.data.Dataset` object.
|
||||||
|
input_workers: an `InputWorkers` object.
|
||||||
|
input_contexts: A list of `InputContext` instances to be passed to call(s)
|
||||||
|
to `input_fn`. Length and order should match worker order in
|
||||||
|
`worker_device_pairs`.
|
||||||
|
"""
|
||||||
|
assert isinstance(input_workers, InputWorkers)
|
||||||
|
if input_workers.num_workers != len(input_contexts):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of input workers (%d) is not same as number of "
|
||||||
|
"input_contexts (%d)" %
|
||||||
|
(input_workers.num_workers, len(input_contexts)))
|
||||||
|
|
||||||
|
iterators = []
|
||||||
|
for i, ctx in enumerate(input_contexts):
|
||||||
|
worker = input_workers.worker_devices[i]
|
||||||
|
with ops.device(worker):
|
||||||
|
result = input_fn(ctx)
|
||||||
|
if not isinstance(result, dataset_ops.DatasetV2):
|
||||||
|
raise ValueError("input_fn must return a tf.data.Dataset.")
|
||||||
|
devices = input_workers.compute_devices_for_worker(i)
|
||||||
|
iterator = _SingleWorkerDatasetIterator(result, worker, devices)
|
||||||
|
iterators.append(iterator)
|
||||||
|
|
||||||
|
super(InputFunctionIterator, self).__init__(input_workers, iterators)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetIterator(InputIteratorImpl):
|
||||||
|
"""Iterator created from input dataset."""
|
||||||
|
|
||||||
|
def __init__(self, dataset, input_workers, split_batch_by=None):
|
||||||
|
"""Make an iterator for the dataset on given devices.
|
||||||
|
|
||||||
|
If `split_batch_by` is not None, we "split" each batch of the
|
||||||
|
dataset by `split_batch_by` value. To achieve this, we first unbatch the
|
||||||
|
input dataset and then rebatch it with the per replica batch size that is
|
||||||
|
calculated using `global_batch_size // split_batch_by`.
|
||||||
|
The currently supported datasets are as follows:
|
||||||
|
`dataset.batch()` is the last operation on the dataset OR
|
||||||
|
`dataset.apply(map_and_batch)` is the last operation on the dataset OR
|
||||||
|
`dataset.batch().prefetch()` are the last 2 operations on the dataset OR
|
||||||
|
`dataset.apply(map_and_batch).prefetch()` are the last 2 operations.
|
||||||
|
|
||||||
|
TODO(priyag): Support multi worker / host cases properly by cloning
|
||||||
|
and sharding the dataset on each worker. Current setup will only work in
|
||||||
|
some cases, such as in-graph multi worker GPU case. If the input pipeline
|
||||||
|
has random shuffling (with a different seed on each worker), each worker
|
||||||
|
will see random input from the same overall dataset in each step. Otherwise,
|
||||||
|
each worker will see the same input in each step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: `tf.data.Dataset` that will be used as the input source.
|
||||||
|
input_workers: an `InputWorkers` object.
|
||||||
|
split_batch_by: Optional integer. If present, we "split" each batch of the
|
||||||
|
dataset by `split_batch_by` value.
|
||||||
|
"""
|
||||||
|
assert isinstance(input_workers, InputWorkers)
|
||||||
|
if split_batch_by:
|
||||||
|
dataset = _split_dataset_batch(dataset, split_batch_by)
|
||||||
|
|
||||||
|
iterators = []
|
||||||
|
for i, worker in enumerate(input_workers.worker_devices):
|
||||||
|
with ops.device(worker):
|
||||||
|
worker_devices = input_workers.compute_devices_for_worker(i)
|
||||||
|
cloned_dataset = dataset
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
|
||||||
|
iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
|
||||||
|
worker_devices)
|
||||||
|
iterators.append(iterator)
|
||||||
|
|
||||||
|
super(DatasetIterator, self).__init__(input_workers, iterators)
|
||||||
|
|
||||||
|
|
||||||
|
class _SingleWorkerDatasetIterator(object):
|
||||||
|
"""Iterator for a single `tf.data.Dataset`."""
|
||||||
|
|
||||||
|
def __init__(self, dataset, worker, devices):
|
||||||
|
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
|
||||||
|
|
||||||
|
`MultiDeviceIterator` is used to prefetch input to the devices on the
|
||||||
|
given worker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: A `tf.data.Dataset` instance.
|
||||||
|
worker: Worker on which ops should be created.
|
||||||
|
devices: Distribute data from `dataset` to these devices.
|
||||||
|
"""
|
||||||
|
self._dataset = dataset
|
||||||
|
self._worker = worker
|
||||||
|
self._devices = devices
|
||||||
|
self._make_iterator()
|
||||||
|
|
||||||
|
def _make_iterator(self):
|
||||||
|
"""Make appropriate iterator on the dataset."""
|
||||||
|
with ops.device(self._worker):
|
||||||
|
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
|
self._dataset, self._devices)
|
||||||
|
|
||||||
|
def get_next_as_list(self, name=None):
|
||||||
|
"""Get next element from the underlying iterator."""
|
||||||
|
del name
|
||||||
|
with ops.device(self._worker):
|
||||||
|
data_list = self._iterator.get_next()
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Initialze underlying iterator.
|
||||||
|
|
||||||
|
In eager execution, this simply recreates the underlying iterator.
|
||||||
|
In graph execution, it returns the initializer ops for the underlying
|
||||||
|
iterator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of any initializer ops that should be run.
|
||||||
|
"""
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self._make_iterator()
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return [self._iterator.initializer]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_classes(self):
|
||||||
|
return self._iterator.output_classes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shapes(self):
|
||||||
|
return self._iterator.output_shapes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return self._iterator.output_types
|
||||||
|
|
||||||
|
|
||||||
|
def _split_dataset_batch(dataset, split_batch_by):
|
||||||
|
"""Divide a batch-ed dataset's batches into smaller batches."""
|
||||||
|
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def _get_batch_dataset(d):
|
||||||
|
"""Get the underlying batch dataset from the dataset object."""
|
||||||
|
if isinstance(d, dataset_ops.DatasetV1Adapter):
|
||||||
|
d = d._dataset
|
||||||
|
|
||||||
|
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
|
||||||
|
return d
|
||||||
|
elif isinstance(d, dataset_ops.PrefetchDataset):
|
||||||
|
return _get_batch_dataset(d._input_dataset)
|
||||||
|
raise ValueError(
|
||||||
|
"Unable to get batched dataset from the input dataset. `batch` "
|
||||||
|
"`map_and_batch` need to be the last operations on the dataset. "
|
||||||
|
"The batch operations can be followed by a prefetch.")
|
||||||
|
|
||||||
|
batched_dataset = _get_batch_dataset(dataset)
|
||||||
|
if isinstance(batched_dataset, dataset_ops.BatchDataset):
|
||||||
|
batch_size = batched_dataset._batch_size
|
||||||
|
drop_remainder = batched_dataset._drop_remainder
|
||||||
|
elif isinstance(batched_dataset, batching._MapAndBatchDataset):
|
||||||
|
batch_size = batched_dataset._batch_size_t
|
||||||
|
drop_remainder = batched_dataset._drop_remainder_t
|
||||||
|
|
||||||
|
prefetch_buffer = None
|
||||||
|
if isinstance(dataset, dataset_ops.PrefetchDataset):
|
||||||
|
prefetch_buffer = dataset._buffer_size
|
||||||
|
elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
|
||||||
|
and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
|
||||||
|
prefetch_buffer = dataset._dataset._buffer_size
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
if tensor_util.is_tensor(batch_size):
|
||||||
|
batch_size = tensor_util.constant_value(batch_size)
|
||||||
|
|
||||||
|
if tensor_util.is_tensor(drop_remainder):
|
||||||
|
drop_remainder = tensor_util.constant_value(drop_remainder)
|
||||||
|
|
||||||
|
if batch_size % split_batch_by:
|
||||||
|
raise ValueError(
|
||||||
|
"Batch size %s cannot be sharded evenly across replicas %s" % (
|
||||||
|
batch_size, split_batch_by))
|
||||||
|
new_batch_size = batch_size // split_batch_by
|
||||||
|
|
||||||
|
dataset = dataset.apply(batching.unbatch())
|
||||||
|
dataset = dataset.batch(new_batch_size, drop_remainder=drop_remainder)
|
||||||
|
if prefetch_buffer is not None:
|
||||||
|
dataset = dataset.prefetch(prefetch_buffer)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStepContext(object):
|
||||||
|
"""A context object that can be used to capture things when running steps.
|
||||||
|
|
||||||
|
This context object is useful when running multiple steps at a time using the
|
||||||
|
`experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
|
||||||
|
function to specify which outputs to emit at what frequency. Currently it
|
||||||
|
supports capturing output from the last step, as well as capturing non tensor
|
||||||
|
outputs. In the future it will be augmented to support other use cases such
|
||||||
|
as output each N steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize an output context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A context object.
|
||||||
|
"""
|
||||||
|
self._last_step_outputs = {}
|
||||||
|
self._last_step_outputs_reduce_ops = {}
|
||||||
|
self._non_tensor_outputs = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_step_outputs(self):
|
||||||
|
"""A dictionary consisting of outputs to be captured on last step.
|
||||||
|
|
||||||
|
Keys in the dictionary are names of tensors to be captured, as specified
|
||||||
|
when `set_last_step_output` is called.
|
||||||
|
Values in the dictionary are the tensors themselves. If
|
||||||
|
`set_last_step_output` was called with a `reduce_op` for this output,
|
||||||
|
then the value is the reduced value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with last step outputs.
|
||||||
|
"""
|
||||||
|
return self._last_step_outputs
|
||||||
|
|
||||||
|
def _set_last_step_outputs(self, outputs):
|
||||||
|
"""Replace the entire dictionary of last step outputs."""
|
||||||
|
if not isinstance(outputs, dict):
|
||||||
|
raise ValueError("Need a dictionary to set last_step_outputs.")
|
||||||
|
self._last_step_outputs = outputs
|
||||||
|
|
||||||
|
def set_last_step_output(self, name, output, reduce_op=None):
|
||||||
|
"""Set `output` with `name` to be outputted from the last step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: String, name to identify the output. Doesn't need to match tensor
|
||||||
|
name.
|
||||||
|
output: The tensors that should be outputted with `name`. See below for
|
||||||
|
actual types supported.
|
||||||
|
reduce_op: Reduction method to use to reduce outputs from multiple
|
||||||
|
replicas. Required if `set_last_step_output` is called in a replica
|
||||||
|
context. Optional in cross_replica_context.
|
||||||
|
When present, the outputs from all the replicas are reduced using the
|
||||||
|
current distribution strategy's `reduce` method. Hence, the type of
|
||||||
|
`output` must be what's supported by the corresponding `reduce` method.
|
||||||
|
For e.g. if using MirroredStrategy and reduction is set, output
|
||||||
|
must be a `PerReplica` value.
|
||||||
|
The reduce method is also recorded in a dictionary
|
||||||
|
`_last_step_outputs_reduce_ops` for later interpreting of the
|
||||||
|
outputs as already reduced or not.
|
||||||
|
"""
|
||||||
|
if distribution_strategy_context.in_cross_replica_context():
|
||||||
|
self._last_step_outputs_reduce_ops[name] = reduce_op
|
||||||
|
if reduce_op is None:
|
||||||
|
self._last_step_outputs[name] = output
|
||||||
|
else:
|
||||||
|
distribution = distribution_strategy_context.get_distribution_strategy()
|
||||||
|
self._last_step_outputs[name] = distribution.reduce(reduce_op, output)
|
||||||
|
else:
|
||||||
|
assert reduce_op is not None
|
||||||
|
def merge_fn(distribution, value):
|
||||||
|
self._last_step_outputs[name] = distribution.reduce(reduce_op, value)
|
||||||
|
# Setting this inside the `merge_fn` because all replicas share the same
|
||||||
|
# context object, so it's more robust to set it only once (even if all
|
||||||
|
# the replicas are trying to set the same value).
|
||||||
|
self._last_step_outputs_reduce_ops[name] = reduce_op
|
||||||
|
|
||||||
|
distribution_strategy_context.get_replica_context().merge_call(
|
||||||
|
merge_fn, args=(output,))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def non_tensor_outputs(self):
|
||||||
|
"""A dictionary consisting of any non tensor outputs to be captured."""
|
||||||
|
return self._non_tensor_outputs
|
||||||
|
|
||||||
|
def set_non_tensor_output(self, name, output):
|
||||||
|
"""Set `output` with `name` to be captured as a non tensor output."""
|
||||||
|
if distribution_strategy_context.in_cross_replica_context():
|
||||||
|
self._non_tensor_outputs[name] = output
|
||||||
|
else:
|
||||||
|
def merge_fn(distribution, value):
|
||||||
|
# NOTE(priyag): For non tensor outputs, we simply return all the values
|
||||||
|
# in a list as reduction doesn't make sense on non tensors.
|
||||||
|
self._non_tensor_outputs[name] = distribution.unwrap(value)
|
||||||
|
distribution_strategy_context.get_replica_context().merge_call(
|
||||||
|
merge_fn, args=(output,))
|
@ -27,6 +27,7 @@ from tensorflow.python import pywrap_tensorflow
|
|||||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.distribute import shared_variable_creator
|
from tensorflow.python.distribute import shared_variable_creator
|
||||||
@ -456,7 +457,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
"No duplicates allowed in `devices` argument: %s" % devices)
|
"No duplicates allowed in `devices` argument: %s" % devices)
|
||||||
# TODO(josh11b): Require at least 2 devices?
|
# TODO(josh11b): Require at least 2 devices?
|
||||||
self._device_map = values.ReplicaDeviceMap(devices)
|
self._device_map = values.ReplicaDeviceMap(devices)
|
||||||
self._input_workers = values.InputWorkers(self._device_map)
|
self._input_workers = input_lib.InputWorkers(self._device_map)
|
||||||
self._inferred_cross_device_ops = cross_device_ops_lib.choose_the_best(
|
self._inferred_cross_device_ops = cross_device_ops_lib.choose_the_best(
|
||||||
devices)
|
devices)
|
||||||
|
|
||||||
@ -489,7 +490,8 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
self._default_device = workers[0]
|
self._default_device = workers[0]
|
||||||
|
|
||||||
self._device_map = values.ReplicaDeviceMap(devices)
|
self._device_map = values.ReplicaDeviceMap(devices)
|
||||||
self._input_workers = values.InputWorkers(self._device_map, worker_devices)
|
self._input_workers = input_lib.InputWorkers(
|
||||||
|
self._device_map, worker_devices)
|
||||||
self._inferred_cross_device_ops = cross_device_ops_lib.MultiWorkerAllReduce(
|
self._inferred_cross_device_ops = cross_device_ops_lib.MultiWorkerAllReduce(
|
||||||
workers, _infer_num_gpus_per_worker(devices))
|
workers, _infer_num_gpus_per_worker(devices))
|
||||||
|
|
||||||
@ -543,16 +545,16 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
def _distribute_dataset(self, dataset_fn):
|
def _distribute_dataset(self, dataset_fn):
|
||||||
if self._local_mode:
|
if self._local_mode:
|
||||||
worker_index = 0
|
worker_index = 0
|
||||||
return values.PerReplicaDataset(
|
return input_lib.PerReplicaDataset(
|
||||||
self._call_dataset_fn(dataset_fn), self._input_workers, worker_index)
|
self._call_dataset_fn(dataset_fn), self._input_workers, worker_index)
|
||||||
else:
|
else:
|
||||||
return values.MultiWorkerDataset(
|
return input_lib.MultiWorkerDataset(
|
||||||
functools.partial(self._call_dataset_fn, dataset_fn),
|
functools.partial(self._call_dataset_fn, dataset_fn),
|
||||||
self._input_workers,
|
self._input_workers,
|
||||||
auto_shard=False)
|
auto_shard=False)
|
||||||
|
|
||||||
def _make_dataset_iterator(self, dataset):
|
def _make_dataset_iterator(self, dataset):
|
||||||
return values.DatasetIterator(
|
return input_lib.DatasetIterator(
|
||||||
dataset, self._input_workers, self._num_replicas_in_sync)
|
dataset, self._input_workers, self._num_replicas_in_sync)
|
||||||
|
|
||||||
def _make_input_fn_iterator(
|
def _make_input_fn_iterator(
|
||||||
@ -566,7 +568,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
num_input_pipelines=num_workers,
|
num_input_pipelines=num_workers,
|
||||||
input_pipeline_id=i,
|
input_pipeline_id=i,
|
||||||
num_replicas_in_sync=self._num_replicas_in_sync))
|
num_replicas_in_sync=self._num_replicas_in_sync))
|
||||||
return values.InputFunctionIterator(
|
return input_lib.InputFunctionIterator(
|
||||||
input_fn, self._input_workers, input_contexts)
|
input_fn, self._input_workers, input_contexts)
|
||||||
|
|
||||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||||
@ -576,7 +578,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
initial_loop_values = {}
|
initial_loop_values = {}
|
||||||
initial_loop_values = nest.flatten(initial_loop_values)
|
initial_loop_values = nest.flatten(initial_loop_values)
|
||||||
|
|
||||||
ctx = values.MultiStepContext()
|
ctx = input_lib.MultiStepContext()
|
||||||
def body(i, *args):
|
def body(i, *args):
|
||||||
"""A wrapper around `fn` to create the while loop body."""
|
"""A wrapper around `fn` to create the while loop body."""
|
||||||
del args
|
del args
|
||||||
|
@ -23,17 +23,12 @@ import contextlib
|
|||||||
import weakref
|
import weakref
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
|
||||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import input_ops
|
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import device as tf_device
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -1409,679 +1404,6 @@ def update_regroup(extended, device_map, updates, group):
|
|||||||
return nest.pack_sequence_as(regrouped, grouped_flat)
|
return nest.pack_sequence_as(regrouped, grouped_flat)
|
||||||
|
|
||||||
|
|
||||||
class InputWorkers(object):
|
|
||||||
"""A 1-to-many mapping from input worker devices to compute devices."""
|
|
||||||
|
|
||||||
def __init__(self, device_map, worker_device_pairs=None, logical_device=0):
|
|
||||||
"""Initialize an `InputWorkers` object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device_map: A `DeviceMap` with the computation devices fed by the
|
|
||||||
input workers.
|
|
||||||
worker_device_pairs: A sequence of pairs:
|
|
||||||
`(input device, a tuple of compute devices fed by that input device)`.
|
|
||||||
logical_device: The logical device of `device_map` to feed.
|
|
||||||
"""
|
|
||||||
self._device_map = device_map
|
|
||||||
self._logical_device = logical_device
|
|
||||||
if worker_device_pairs is None:
|
|
||||||
worker_device_pairs = ((
|
|
||||||
device_util.canonicalize("/device:CPU:0"),
|
|
||||||
device_map.logical_to_actual_devices(logical_device)),)
|
|
||||||
self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
|
|
||||||
self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
|
|
||||||
for _, f in worker_device_pairs)
|
|
||||||
flattened = tuple(d for l in self._fed_devices for d in l)
|
|
||||||
assert (flattened ==
|
|
||||||
device_map.logical_to_actual_devices(logical_device)), (
|
|
||||||
"flattened: %s logical device %d: %s" %
|
|
||||||
(flattened, logical_device,
|
|
||||||
device_map.logical_to_actual_devices(logical_device)))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device_map(self):
|
|
||||||
return self._device_map
|
|
||||||
|
|
||||||
@property
|
|
||||||
def logical_device(self):
|
|
||||||
return self._logical_device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_workers(self):
|
|
||||||
return len(self._input_worker_devices)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def worker_devices(self):
|
|
||||||
return self._input_worker_devices
|
|
||||||
|
|
||||||
def compute_devices_for_worker(self, worker_index):
|
|
||||||
return self._fed_devices[worker_index]
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
devices = self.worker_devices
|
|
||||||
debug_repr = ",\n".join(" %d %s: %s" %
|
|
||||||
(i, devices[i], self._fed_devices[i])
|
|
||||||
for i in range(len(devices)))
|
|
||||||
return "%s:{\n%s\n device_map: %s}" % (
|
|
||||||
self.__class__.__name__, debug_repr, self._device_map)
|
|
||||||
|
|
||||||
|
|
||||||
class PerReplicaDataIterator(object):
|
|
||||||
"""An iterator (like `tf.data.Iterator`) into a `PerReplicaDataset`."""
|
|
||||||
|
|
||||||
def __init__(self, iterator, input_workers, worker_index, prefetch_on_device):
|
|
||||||
assert isinstance(input_workers, InputWorkers)
|
|
||||||
self._iterator = iterator
|
|
||||||
self._input_workers = input_workers
|
|
||||||
self._worker_index = worker_index
|
|
||||||
self._prefetch_on_device = prefetch_on_device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def initializer(self):
|
|
||||||
return self._iterator.initializer
|
|
||||||
|
|
||||||
def get_next_as_list(self, name=None):
|
|
||||||
"""Scatter the input across devices."""
|
|
||||||
if self._prefetch_on_device:
|
|
||||||
data_list = self._iterator.get_next()
|
|
||||||
else:
|
|
||||||
batch = self._iterator.get_next(name=name)
|
|
||||||
data_list = []
|
|
||||||
def get_ith(i):
|
|
||||||
return lambda x: x[i]
|
|
||||||
|
|
||||||
devices = self._input_workers.compute_devices_for_worker(
|
|
||||||
self._worker_index)
|
|
||||||
for i, d in enumerate(devices):
|
|
||||||
v = nest.map_structure(get_ith(i), batch)
|
|
||||||
if context.executing_eagerly():
|
|
||||||
with ops.device(d):
|
|
||||||
v = nest.map_structure(array_ops.identity, v)
|
|
||||||
data_list.append(v)
|
|
||||||
|
|
||||||
return data_list
|
|
||||||
|
|
||||||
def get_next(self, name=None):
|
|
||||||
assert self._input_workers.num_workers == 1
|
|
||||||
data_list = self.get_next_as_list(name)
|
|
||||||
return regroup(self._input_workers.device_map, data_list)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_classes(self):
|
|
||||||
return self._iterator.output_classes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_shapes(self):
|
|
||||||
return self._iterator.output_shapes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_types(self):
|
|
||||||
return self._iterator.output_types
|
|
||||||
|
|
||||||
|
|
||||||
class PerReplicaDataset(object):
|
|
||||||
"""Like `tf.data.Dataset` split devices, producing `PerReplica` data."""
|
|
||||||
|
|
||||||
def __init__(self, dataset, input_workers, worker_index,
|
|
||||||
prefetch_on_device=None):
|
|
||||||
assert isinstance(input_workers, InputWorkers)
|
|
||||||
assert worker_index is not None
|
|
||||||
assert worker_index is not True
|
|
||||||
assert worker_index is not False
|
|
||||||
self._input_workers = input_workers
|
|
||||||
self._worker_index = worker_index
|
|
||||||
|
|
||||||
# Default to using prefetching, unless specified.
|
|
||||||
self._prefetch_on_device = prefetch_on_device
|
|
||||||
if self._prefetch_on_device is None:
|
|
||||||
self._prefetch_on_device = True
|
|
||||||
|
|
||||||
self._dataset = dataset
|
|
||||||
if not self._prefetch_on_device:
|
|
||||||
# TODO(priyag): If dropping remainder is not appropriate, find another
|
|
||||||
# approach to distributing the dataset when not possible to divide evenly.
|
|
||||||
# Possibly not an issue when we start using PartitionedDataset.
|
|
||||||
num_replicas = len(
|
|
||||||
self._input_workers.compute_devices_for_worker(self._worker_index))
|
|
||||||
self._dataset = self._dataset.batch(num_replicas, drop_remainder=True)
|
|
||||||
else:
|
|
||||||
self._replica_devices = self._input_workers.compute_devices_for_worker(
|
|
||||||
self._worker_index)
|
|
||||||
|
|
||||||
def make_one_shot_iterator(self):
|
|
||||||
"""Get a one time use iterator for the distributed PerReplicaDataset."""
|
|
||||||
# Graph mode with one shot iterator is disabled.
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
raise ValueError("Cannot create a one shot iterator. Please use "
|
|
||||||
"`make_initializable_iterator()` instead.")
|
|
||||||
if self._prefetch_on_device:
|
|
||||||
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
|
||||||
self._dataset, self._replica_devices)
|
|
||||||
else:
|
|
||||||
dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
|
|
||||||
return PerReplicaDataIterator(
|
|
||||||
dataset_iterator,
|
|
||||||
self._input_workers,
|
|
||||||
self._worker_index,
|
|
||||||
prefetch_on_device=self._prefetch_on_device)
|
|
||||||
|
|
||||||
def make_initializable_iterator(self):
|
|
||||||
"""Get an initializable iterator for the distributed PerReplicaDataset."""
|
|
||||||
# Eager mode generates already initialized iterators. Hence we cannot create
|
|
||||||
# an initializable iterator.
|
|
||||||
if context.executing_eagerly():
|
|
||||||
raise ValueError("Cannot create initializable iterator in Eager mode. "
|
|
||||||
"Please use `make_one_shot_iterator` instead.")
|
|
||||||
if self._prefetch_on_device:
|
|
||||||
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
|
||||||
self._dataset, self._replica_devices)
|
|
||||||
else:
|
|
||||||
dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset)
|
|
||||||
return PerReplicaDataIterator(
|
|
||||||
dataset_iterator, self._input_workers, self._worker_index,
|
|
||||||
prefetch_on_device=self._prefetch_on_device)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiWorkerDataIterator(object):
|
|
||||||
"""An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`."""
|
|
||||||
|
|
||||||
def __init__(self, iterators, input_workers):
|
|
||||||
"""Initialize the `MultiWorkerDataIterator` object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
iterators: a list of worker, iterator pairs.
|
|
||||||
input_workers: an `InputWorkers` object.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if iterators and input_workers are not compatible.
|
|
||||||
"""
|
|
||||||
assert isinstance(input_workers, InputWorkers)
|
|
||||||
workers = tuple(d for d, _ in iterators)
|
|
||||||
if workers != input_workers.worker_devices:
|
|
||||||
raise ValueError("iterators and input_workers are not compatible. "
|
|
||||||
"iterator workers: %r input_workers devices: %r" %
|
|
||||||
(workers, input_workers.worker_devices))
|
|
||||||
self._iterators = tuple(i for _, i in iterators)
|
|
||||||
self._input_workers = input_workers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def initializer(self):
|
|
||||||
return control_flow_ops.group(
|
|
||||||
tuple(iterator.initializer for iterator in self._iterators))
|
|
||||||
|
|
||||||
def get_iterator(self, worker):
|
|
||||||
for i, w in enumerate(self._input_workers.worker_devices):
|
|
||||||
if worker == w:
|
|
||||||
return self._iterators[i]
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_shapes(self):
|
|
||||||
return self._iterators[0].output_shapes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_types(self):
|
|
||||||
return self._iterators[0].output_types
|
|
||||||
|
|
||||||
def get_next(self, name=None):
|
|
||||||
"""Scatter the input across hosts and devices."""
|
|
||||||
replicas = []
|
|
||||||
for worker, iterator in zip(self._input_workers.worker_devices,
|
|
||||||
self._iterators):
|
|
||||||
if name is not None:
|
|
||||||
d = tf_device.DeviceSpec.from_string(worker)
|
|
||||||
new_name = "%s_%s_%d" % (name, d.job, d.task)
|
|
||||||
else:
|
|
||||||
new_name = None
|
|
||||||
with ops.device(worker):
|
|
||||||
data_per_worker = iterator.get_next_as_list(name=new_name)
|
|
||||||
# Append to replicas to get a flat list of values indexed by replica.
|
|
||||||
replicas.extend(data_per_worker)
|
|
||||||
|
|
||||||
return regroup(self._input_workers.device_map, replicas)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiWorkerDataset(object):
|
|
||||||
"""Like a `tf.data.Dataset` that distributes data to different workers.
|
|
||||||
|
|
||||||
Each worker gets one shard of the input dataset. This currently does not work
|
|
||||||
in eager mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dataset_fn, input_workers, prefetch_on_device=None,
|
|
||||||
auto_shard=False):
|
|
||||||
"""Initialize the MultiWorkerDataset object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_fn: a function or a list of functions that returns a
|
|
||||||
`tf.data.Dataset`.
|
|
||||||
input_workers: an `InputWorkers` object.
|
|
||||||
prefetch_on_device: whether to prefetch to devices.
|
|
||||||
auto_shard: whether to auto-shard the dataset.
|
|
||||||
"""
|
|
||||||
assert isinstance(input_workers, InputWorkers)
|
|
||||||
if isinstance(dataset_fn, (list, tuple)):
|
|
||||||
if len(dataset_fn) != input_workers.num_workers:
|
|
||||||
raise ValueError("If `dataset_fn` is a list, it must have one entry "
|
|
||||||
"per worker")
|
|
||||||
# TODO(rohanj): b/120673685 to track re-enabling auto sharding.
|
|
||||||
if auto_shard:
|
|
||||||
raise ValueError("Currently autosharding is not supported.")
|
|
||||||
self._input_workers = input_workers
|
|
||||||
self._datasets = []
|
|
||||||
# TODO(yuefengz, priyag): support different set of jobs for input
|
|
||||||
# processing.
|
|
||||||
for i, worker in enumerate(input_workers.worker_devices):
|
|
||||||
with ops.device(worker):
|
|
||||||
if isinstance(dataset_fn, (list, tuple)):
|
|
||||||
worker_input = dataset_fn[i]()
|
|
||||||
else:
|
|
||||||
worker_input = dataset_fn()
|
|
||||||
dataset = PerReplicaDataset(worker_input, input_workers, i,
|
|
||||||
prefetch_on_device=prefetch_on_device)
|
|
||||||
self._datasets.append((worker, dataset))
|
|
||||||
|
|
||||||
def make_one_shot_iterator(self):
|
|
||||||
iterators = []
|
|
||||||
for worker, dataset in self._datasets:
|
|
||||||
with ops.device(worker):
|
|
||||||
iterators.append((worker, dataset_ops.make_one_shot_iterator(dataset)))
|
|
||||||
return MultiWorkerDataIterator(iterators, self._input_workers)
|
|
||||||
|
|
||||||
def make_initializable_iterator(self):
|
|
||||||
iterators = []
|
|
||||||
for worker, dataset in self._datasets:
|
|
||||||
with ops.device(worker):
|
|
||||||
iterators.append(
|
|
||||||
(worker, dataset_ops.make_initializable_iterator(dataset)))
|
|
||||||
return MultiWorkerDataIterator(iterators, self._input_workers)
|
|
||||||
|
|
||||||
|
|
||||||
class InputIterator(object):
|
|
||||||
"""An input iterator, intended to be passed to `DistributionStrategy.run`."""
|
|
||||||
|
|
||||||
def get_next(self):
|
|
||||||
"""Returns the next inputs for all replicas."""
|
|
||||||
raise NotImplementedError("must be implemented in descendants")
|
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
"""Initialize the underlying input dataset, when applicable.
|
|
||||||
|
|
||||||
In eager mode, this will create a new iterator and return it.
|
|
||||||
In graph mode, this will initialize the same underlying iterator(s).
|
|
||||||
|
|
||||||
Users are required to call this if
|
|
||||||
- This iterator was returned from a call to `make_input_fn_iterator` with an
|
|
||||||
input function that returns a dataset.
|
|
||||||
- Or this iterator was returned from a call to `make_dataset_iterator`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of initialization ops to be executed.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("must be implemented in descendants")
|
|
||||||
|
|
||||||
|
|
||||||
class InputIteratorImpl(InputIterator):
|
|
||||||
"""Common implementation for all input iterators."""
|
|
||||||
|
|
||||||
def __init__(self, input_workers, iterators):
|
|
||||||
assert isinstance(input_workers, InputWorkers)
|
|
||||||
if not input_workers.worker_devices:
|
|
||||||
raise ValueError("Should have at least one worker for input iterator.")
|
|
||||||
|
|
||||||
self._iterators = iterators
|
|
||||||
self._input_workers = input_workers
|
|
||||||
|
|
||||||
def get_next(self, name=None):
|
|
||||||
"""Returns the next input from the iterator for all replicas."""
|
|
||||||
replicas = []
|
|
||||||
for i, worker in enumerate(self._input_workers.worker_devices):
|
|
||||||
if name is not None:
|
|
||||||
d = tf_device.DeviceSpec.from_string(worker)
|
|
||||||
new_name = "%s_%s_%d" % (name, d.job, d.task)
|
|
||||||
else:
|
|
||||||
new_name = None
|
|
||||||
with ops.device(worker):
|
|
||||||
# Make `replicas` a flat list of values across all replicas.
|
|
||||||
replicas.extend(self._iterators[i].get_next_as_list(new_name))
|
|
||||||
|
|
||||||
return regroup(self._input_workers.device_map, replicas)
|
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
"""Initialze underlying iterators.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of any initializer ops that should be run.
|
|
||||||
"""
|
|
||||||
init_ops = []
|
|
||||||
for it in self._iterators:
|
|
||||||
init_ops.extend(it.initialize())
|
|
||||||
return init_ops
|
|
||||||
|
|
||||||
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
|
||||||
@property
|
|
||||||
def output_classes(self):
|
|
||||||
return self._iterators[0].output_classes
|
|
||||||
|
|
||||||
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
|
||||||
@property
|
|
||||||
def output_shapes(self):
|
|
||||||
return self._iterators[0].output_shapes
|
|
||||||
|
|
||||||
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
|
||||||
@property
|
|
||||||
def output_types(self):
|
|
||||||
return self._iterators[0].output_types
|
|
||||||
|
|
||||||
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
|
||||||
def get_iterator(self, worker):
|
|
||||||
for i, w in enumerate(self._input_workers.worker_devices):
|
|
||||||
if worker == w:
|
|
||||||
return self._iterators[i]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class InputFunctionIterator(InputIteratorImpl):
|
|
||||||
"""Iterator created from input function."""
|
|
||||||
|
|
||||||
def __init__(self, input_fn, input_workers, input_contexts):
|
|
||||||
"""Make an iterator for input provided via an input function.
|
|
||||||
|
|
||||||
Currently implements PER_WORKER mode, in which the `input_fn` is called
|
|
||||||
once on each worker.
|
|
||||||
|
|
||||||
TODO(priyag): Add other replication modes.
|
|
||||||
TODO(priyag): Allow taking input function that returns a callable that
|
|
||||||
returns nest of tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_fn: Input function that returns a `tf.data.Dataset` object.
|
|
||||||
input_workers: an `InputWorkers` object.
|
|
||||||
input_contexts: A list of `InputContext` instances to be passed to call(s)
|
|
||||||
to `input_fn`. Length and order should match worker order in
|
|
||||||
`worker_device_pairs`.
|
|
||||||
"""
|
|
||||||
assert isinstance(input_workers, InputWorkers)
|
|
||||||
if input_workers.num_workers != len(input_contexts):
|
|
||||||
raise ValueError(
|
|
||||||
"Number of input workers (%d) is not same as number of "
|
|
||||||
"input_contexts (%d)" %
|
|
||||||
(input_workers.num_workers, len(input_contexts)))
|
|
||||||
|
|
||||||
iterators = []
|
|
||||||
for i, ctx in enumerate(input_contexts):
|
|
||||||
worker = input_workers.worker_devices[i]
|
|
||||||
with ops.device(worker):
|
|
||||||
result = input_fn(ctx)
|
|
||||||
if not isinstance(result, dataset_ops.DatasetV2):
|
|
||||||
raise ValueError("input_fn must return a tf.data.Dataset.")
|
|
||||||
devices = input_workers.compute_devices_for_worker(i)
|
|
||||||
iterator = _SingleWorkerDatasetIterator(result, worker, devices)
|
|
||||||
iterators.append(iterator)
|
|
||||||
|
|
||||||
super(InputFunctionIterator, self).__init__(input_workers, iterators)
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetIterator(InputIteratorImpl):
|
|
||||||
"""Iterator created from input dataset."""
|
|
||||||
|
|
||||||
def __init__(self, dataset, input_workers, split_batch_by=None):
|
|
||||||
"""Make an iterator for the dataset on given devices.
|
|
||||||
|
|
||||||
If `split_batch_by` is not None, we "split" each batch of the
|
|
||||||
dataset by `split_batch_by` value. To achieve this, we first unbatch the
|
|
||||||
input dataset and then rebatch it with the per replica batch size that is
|
|
||||||
calculated using `global_batch_size // split_batch_by`.
|
|
||||||
The currently supported datasets are as follows:
|
|
||||||
`dataset.batch()` is the last operation on the dataset OR
|
|
||||||
`dataset.apply(map_and_batch)` is the last operation on the dataset OR
|
|
||||||
`dataset.batch().prefetch()` are the last 2 operations on the dataset OR
|
|
||||||
`dataset.apply(map_and_batch).prefetch()` are the last 2 operations.
|
|
||||||
|
|
||||||
TODO(priyag): Support multi worker / host cases properly by cloning
|
|
||||||
and sharding the dataset on each worker. Current setup will only work in
|
|
||||||
some cases, such as in-graph multi worker GPU case. If the input pipeline
|
|
||||||
has random shuffling (with a different seed on each worker), each worker
|
|
||||||
will see random input from the same overall dataset in each step. Otherwise,
|
|
||||||
each worker will see the same input in each step.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: `tf.data.Dataset` that will be used as the input source.
|
|
||||||
input_workers: an `InputWorkers` object.
|
|
||||||
split_batch_by: Optional integer. If present, we "split" each batch of the
|
|
||||||
dataset by `split_batch_by` value.
|
|
||||||
"""
|
|
||||||
assert isinstance(input_workers, InputWorkers)
|
|
||||||
if split_batch_by:
|
|
||||||
dataset = _split_dataset_batch(dataset, split_batch_by)
|
|
||||||
|
|
||||||
iterators = []
|
|
||||||
for i, worker in enumerate(input_workers.worker_devices):
|
|
||||||
with ops.device(worker):
|
|
||||||
worker_devices = input_workers.compute_devices_for_worker(i)
|
|
||||||
cloned_dataset = dataset
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
|
|
||||||
iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
|
|
||||||
worker_devices)
|
|
||||||
iterators.append(iterator)
|
|
||||||
|
|
||||||
super(DatasetIterator, self).__init__(input_workers, iterators)
|
|
||||||
|
|
||||||
|
|
||||||
class _SingleWorkerDatasetIterator(object):
|
|
||||||
"""Iterator for a single `tf.data.Dataset`."""
|
|
||||||
|
|
||||||
def __init__(self, dataset, worker, devices):
|
|
||||||
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
|
|
||||||
|
|
||||||
`MultiDeviceIterator` is used to prefetch input to the devices on the
|
|
||||||
given worker.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: A `tf.data.Dataset` instance.
|
|
||||||
worker: Worker on which ops should be created.
|
|
||||||
devices: Distribute data from `dataset` to these devices.
|
|
||||||
"""
|
|
||||||
self._dataset = dataset
|
|
||||||
self._worker = worker
|
|
||||||
self._devices = devices
|
|
||||||
self._make_iterator()
|
|
||||||
|
|
||||||
def _make_iterator(self):
|
|
||||||
"""Make appropriate iterator on the dataset."""
|
|
||||||
with ops.device(self._worker):
|
|
||||||
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
|
||||||
self._dataset, self._devices)
|
|
||||||
|
|
||||||
def get_next_as_list(self, name=None):
|
|
||||||
"""Get next element from the underlying iterator."""
|
|
||||||
del name
|
|
||||||
with ops.device(self._worker):
|
|
||||||
data_list = self._iterator.get_next()
|
|
||||||
return data_list
|
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
"""Initialze underlying iterator.
|
|
||||||
|
|
||||||
In eager execution, this simply recreates the underlying iterator.
|
|
||||||
In graph execution, it returns the initializer ops for the underlying
|
|
||||||
iterator.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of any initializer ops that should be run.
|
|
||||||
"""
|
|
||||||
if context.executing_eagerly():
|
|
||||||
self._make_iterator()
|
|
||||||
return []
|
|
||||||
else:
|
|
||||||
return [self._iterator.initializer]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_classes(self):
|
|
||||||
return self._iterator.output_classes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_shapes(self):
|
|
||||||
return self._iterator.output_shapes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_types(self):
|
|
||||||
return self._iterator.output_types
|
|
||||||
|
|
||||||
|
|
||||||
def _split_dataset_batch(dataset, split_batch_by):
|
|
||||||
"""Divide a batch-ed dataset's batches into smaller batches."""
|
|
||||||
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
def _get_batch_dataset(d):
|
|
||||||
"""Get the underlying batch dataset from the dataset object."""
|
|
||||||
if isinstance(d, dataset_ops.DatasetV1Adapter):
|
|
||||||
d = d._dataset
|
|
||||||
|
|
||||||
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
|
|
||||||
return d
|
|
||||||
elif isinstance(d, dataset_ops.PrefetchDataset):
|
|
||||||
return _get_batch_dataset(d._input_dataset)
|
|
||||||
raise ValueError(
|
|
||||||
"Unable to get batched dataset from the input dataset. `batch` "
|
|
||||||
"`map_and_batch` need to be the last operations on the dataset. "
|
|
||||||
"The batch operations can be followed by a prefetch.")
|
|
||||||
|
|
||||||
batched_dataset = _get_batch_dataset(dataset)
|
|
||||||
if isinstance(batched_dataset, dataset_ops.BatchDataset):
|
|
||||||
batch_size = batched_dataset._batch_size
|
|
||||||
drop_remainder = batched_dataset._drop_remainder
|
|
||||||
elif isinstance(batched_dataset, batching._MapAndBatchDataset):
|
|
||||||
batch_size = batched_dataset._batch_size_t
|
|
||||||
drop_remainder = batched_dataset._drop_remainder_t
|
|
||||||
|
|
||||||
prefetch_buffer = None
|
|
||||||
if isinstance(dataset, dataset_ops.PrefetchDataset):
|
|
||||||
prefetch_buffer = dataset._buffer_size
|
|
||||||
elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
|
|
||||||
and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
|
|
||||||
prefetch_buffer = dataset._dataset._buffer_size
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
if tensor_util.is_tensor(batch_size):
|
|
||||||
batch_size = tensor_util.constant_value(batch_size)
|
|
||||||
|
|
||||||
if tensor_util.is_tensor(drop_remainder):
|
|
||||||
drop_remainder = tensor_util.constant_value(drop_remainder)
|
|
||||||
|
|
||||||
if batch_size % split_batch_by:
|
|
||||||
raise ValueError(
|
|
||||||
"Batch size %s cannot be sharded evenly across replicas %s" % (
|
|
||||||
batch_size, split_batch_by))
|
|
||||||
new_batch_size = batch_size // split_batch_by
|
|
||||||
|
|
||||||
dataset = dataset.apply(batching.unbatch())
|
|
||||||
dataset = dataset.batch(new_batch_size, drop_remainder=drop_remainder)
|
|
||||||
if prefetch_buffer is not None:
|
|
||||||
dataset = dataset.prefetch(prefetch_buffer)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
class MultiStepContext(object):
|
|
||||||
"""A context object that can be used to capture things when running steps.
|
|
||||||
|
|
||||||
This context object is useful when running multiple steps at a time using the
|
|
||||||
`experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
|
|
||||||
function to specify which outputs to emit at what frequency. Currently it
|
|
||||||
supports capturing output from the last step, as well as capturing non tensor
|
|
||||||
outputs. In the future it will be augmented to support other use cases such
|
|
||||||
as output each N steps.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize an output context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A context object.
|
|
||||||
"""
|
|
||||||
self._last_step_outputs = {}
|
|
||||||
self._last_step_outputs_reduce_ops = {}
|
|
||||||
self._non_tensor_outputs = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def last_step_outputs(self):
|
|
||||||
"""A dictionary consisting of outputs to be captured on last step.
|
|
||||||
|
|
||||||
Keys in the dictionary are names of tensors to be captured, as specified
|
|
||||||
when `set_last_step_output` is called.
|
|
||||||
Values in the dictionary are the tensors themselves. If
|
|
||||||
`set_last_step_output` was called with a `reduce_op` for this output,
|
|
||||||
then the value is the reduced value.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary with last step outputs.
|
|
||||||
"""
|
|
||||||
return self._last_step_outputs
|
|
||||||
|
|
||||||
def _set_last_step_outputs(self, outputs):
|
|
||||||
"""Replace the entire dictionary of last step outputs."""
|
|
||||||
if not isinstance(outputs, dict):
|
|
||||||
raise ValueError("Need a dictionary to set last_step_outputs.")
|
|
||||||
self._last_step_outputs = outputs
|
|
||||||
|
|
||||||
def set_last_step_output(self, name, output, reduce_op=None):
|
|
||||||
"""Set `output` with `name` to be outputted from the last step.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: String, name to identify the output. Doesn't need to match tensor
|
|
||||||
name.
|
|
||||||
output: The tensors that should be outputted with `name`. See below for
|
|
||||||
actual types supported.
|
|
||||||
reduce_op: Reduction method to use to reduce outputs from multiple
|
|
||||||
replicas. Required if `set_last_step_output` is called in a replica
|
|
||||||
context. Optional in cross_replica_context.
|
|
||||||
When present, the outputs from all the replicas are reduced using the
|
|
||||||
current distribution strategy's `reduce` method. Hence, the type of
|
|
||||||
`output` must be what's supported by the corresponding `reduce` method.
|
|
||||||
For e.g. if using MirroredStrategy and reduction is set, output
|
|
||||||
must be a `PerReplica` value.
|
|
||||||
The reduce method is also recorded in a dictionary
|
|
||||||
`_last_step_outputs_reduce_ops` for later interpreting of the
|
|
||||||
outputs as already reduced or not.
|
|
||||||
"""
|
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
|
||||||
self._last_step_outputs_reduce_ops[name] = reduce_op
|
|
||||||
if reduce_op is None:
|
|
||||||
self._last_step_outputs[name] = output
|
|
||||||
else:
|
|
||||||
distribution = distribution_strategy_context.get_distribution_strategy()
|
|
||||||
self._last_step_outputs[name] = distribution.reduce(reduce_op, output)
|
|
||||||
else:
|
|
||||||
assert reduce_op is not None
|
|
||||||
def merge_fn(distribution, value):
|
|
||||||
self._last_step_outputs[name] = distribution.reduce(reduce_op, value)
|
|
||||||
# Setting this inside the `merge_fn` because all replicas share the same
|
|
||||||
# context object, so it's more robust to set it only once (even if all
|
|
||||||
# the replicas are trying to set the same value).
|
|
||||||
self._last_step_outputs_reduce_ops[name] = reduce_op
|
|
||||||
|
|
||||||
distribution_strategy_context.get_replica_context().merge_call(
|
|
||||||
merge_fn, args=(output,))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def non_tensor_outputs(self):
|
|
||||||
"""A dictionary consisting of any non tensor outputs to be captured."""
|
|
||||||
return self._non_tensor_outputs
|
|
||||||
|
|
||||||
def set_non_tensor_output(self, name, output):
|
|
||||||
"""Set `output` with `name` to be captured as a non tensor output."""
|
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
|
||||||
self._non_tensor_outputs[name] = output
|
|
||||||
else:
|
|
||||||
def merge_fn(distribution, value):
|
|
||||||
# NOTE(priyag): For non tensor outputs, we simply return all the values
|
|
||||||
# in a list as reduction doesn't make sense on non tensors.
|
|
||||||
self._non_tensor_outputs[name] = distribution.unwrap(value)
|
|
||||||
distribution_strategy_context.get_replica_context().merge_call(
|
|
||||||
merge_fn, args=(output,))
|
|
||||||
|
|
||||||
|
|
||||||
def value_container(val):
|
def value_container(val):
|
||||||
"""Returns the container that this per-replica `value` belongs to.
|
"""Returns the container that this per-replica `value` belongs to.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user