Split input-related classes out of python/distribute/values.py into new

file .../input_lib.py.

PiperOrigin-RevId: 227227637
This commit is contained in:
A. Unique TensorFlower 2018-12-29 08:46:45 -08:00 committed by TensorFlower Gardener
parent e0aa938725
commit d2dd369f9d
12 changed files with 1272 additions and 1162 deletions

View File

@ -23,17 +23,14 @@ cuda_py_test(
additional_deps = [
":combinations",
":mirrored_strategy",
":multi_worker_test_base",
"@absl_py//absl/testing:parameterized",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:device_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
@ -45,14 +42,36 @@ cuda_py_test(
],
)
cuda_py_test(
name = "input_lib_test",
srcs = ["input_lib_test.py"],
additional_deps = [
":combinations",
":mirrored_strategy",
":multi_worker_test_base",
"@absl_py//absl/testing:parameterized",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
tags = [
"no_pip",
],
)
py_library(
name = "mirrored_strategy",
srcs = ["mirrored_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:values",
],
)
@ -69,6 +88,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/distribute:cross_device_ops",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values",
@ -119,6 +139,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
@ -139,6 +160,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python/distribute:cross_device_ops",
"//tensorflow/python/distribute:cross_device_utils",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
@ -289,6 +311,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values",
],

View File

@ -26,6 +26,7 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
@ -130,7 +131,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._collective_keys = cross_device_utils.CollectiveKeys()
self._initialize_local(local_devices)
self._input_workers = values.InputWorkers(
self._input_workers = input_lib.InputWorkers(
self._device_map, [(self._worker_device, self.worker_devices)])
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
num_workers=self._num_workers,
@ -229,13 +230,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
"""Distributes the dataset to each local GPU."""
# TODO(yuefengz): shard the dataset.
worker_index = 0
return values.PerReplicaDataset(
return input_lib.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._input_workers, worker_index,
prefetch_on_device=True)
def _make_dataset_iterator(self, dataset):
return values.DatasetIterator(dataset, self._input_workers,
self._num_replicas_in_sync)
return input_lib.DatasetIterator(dataset, self._input_workers,
self._num_replicas_in_sync)
def _make_input_fn_iterator(
self,
@ -252,7 +253,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
input_pipeline_id=input_pipeline_id,
num_replicas_in_sync=self._num_replicas_in_sync)
return values.InputFunctionIterator(
return input_lib.InputFunctionIterator(
input_fn, self._input_workers, [input_context])
def _configure(self,

View File

@ -0,0 +1,480 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the input_lib library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.util import nest
class PerReplicaDatasetTest(test.TestCase):
config = config_pb2.ConfigProto()
config.allow_soft_placement = True
def _test_iterator(self, devices, dataset, expected_values):
device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map)
per_replica_dataset = input_lib.PerReplicaDataset(dataset, input_workers, 0)
if context.executing_eagerly():
iterator = per_replica_dataset.make_one_shot_iterator()
else:
iterator = per_replica_dataset.make_initializable_iterator()
self.evaluate([iterator.initializer])
for expected_value in expected_values:
next_element = iterator.get_next_as_list()
computed_value = self.evaluate(next_element)
self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next_as_list()
self.evaluate(next_element)
@test_util.run_in_graph_and_eager_modes
def testOneDevice(self):
devices = ["/device:CPU:0"]
dataset = dataset_ops.Dataset.range(10)
expected_values = [[i] for i in range(10)]
self._test_iterator(devices, dataset, expected_values)
@test_util.run_in_graph_and_eager_modes(config=config)
def testMultipleDevices(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
devices = ["/device:CPU:0", "/device:GPU:0"]
dataset = dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
self._test_iterator(devices, dataset, expected_values)
@test_util.run_in_graph_and_eager_modes(config=config)
def testTupleDataset(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
devices = ["/device:CPU:0", "/device:GPU:0"]
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
self._test_iterator(devices, dataset, expected_values)
@test_util.run_in_graph_and_eager_modes(config=config)
def testUnevenDatasetBatches(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
devices = ["/device:CPU:0", "/device:GPU:0"]
dataset = dataset_ops.Dataset.range(11)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
self._test_iterator(devices, dataset, expected_values)
def testInitializableIterator(self):
with context.graph_mode():
devices = ["/device:CPU:0"]
# Using random input since that is only allowed with initializable
# iterator.
dataset = dataset_ops.Dataset.from_tensor_slices(
random_ops.random_uniform((10,)))
device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map)
per_replica_dataset = input_lib.PerReplicaDataset(
dataset, input_workers, 0)
iterator = per_replica_dataset.make_initializable_iterator()
self.evaluate(iterator.initializer)
next_element = iterator.get_next_as_list()
for _ in range(10):
self.evaluate(next_element)
# Should fail after the input is finished.
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
# After re-initializing the iterator, should be able to iterate again.
self.evaluate(iterator.initializer)
for _ in range(10):
self.evaluate(next_element)
class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
def _test_iterator(self, sess, iterator, devices, expected_values):
next_element = iterator.get_next()
for r, device in enumerate(devices):
v = values.select_replica(r, next_element)
# The `v` here can be a tuple.
for element in nest.flatten(v):
self.assertTrue(element.device in device)
for expected_value in expected_values:
t = [values.select_replica(r, next_element) for r in range(len(devices))]
actual = sess.run(t)
self.assertEqual(expected_value, actual)
with self.assertRaises(errors.OutOfRangeError):
sess.run([values.select_replica(r, next_element)
for r in range(len(devices))])
def _test_dataset(self, dataset_fn, worker_devices, devices,
expected_values):
device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map, worker_devices)
multi_worker_dataset = input_lib.MultiWorkerDataset(
dataset_fn, input_workers)
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
with self.cached_session() as sess:
sess.run(multi_worker_iterator.initializer)
self._test_iterator(sess, multi_worker_iterator, devices, expected_values)
def _cpu_devices(self):
worker_devices = (
("/job:worker/replica:0/task:0",
["/job:worker/replica:0/task:0/device:CPU:0"]),
("/job:worker/replica:0/task:1",
["/job:worker/replica:0/task:1/device:CPU:0"])
)
devices = [
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:1/device:CPU:0"
]
return worker_devices, devices
def _cpu_and_one_gpu_devices(self):
worker_devices = (
("/job:worker/replica:0/task:0", (
"/job:worker/replica:0/task:0/device:GPU:0",
"/job:worker/replica:0/task:0/device:CPU:0"
)),
("/job:worker/replica:0/task:1", (
"/job:worker/replica:0/task:1/device:GPU:0",
"/job:worker/replica:0/task:1/device:CPU:0"
))
)
devices = [
"/job:worker/replica:0/task:0/device:GPU:0",
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:1/device:GPU:0",
"/job:worker/replica:0/task:1/device:CPU:0"
]
return worker_devices, devices
def testDataDistributionOneDevicePerWorker(self):
worker_devices, devices = self._cpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
self._test_dataset(
dataset_fn, worker_devices, devices,
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
def testDataDistributionTwoDevicePerWorker(self):
if context.num_gpus() < 1:
self.skipTest("A GPU is not available for this test.")
worker_devices, devices = self._cpu_and_one_gpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
self._test_dataset(
dataset_fn, worker_devices, devices,
[[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]])
def testTupleDataset(self):
worker_devices, devices = self._cpu_devices()
with context.graph_mode():
def dataset_fn():
dataset1 = dataset_ops.Dataset.range(8)
dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
expected_values = [[(i, i**2), (i, i**2)] for i in range(8)]
self._test_dataset(dataset_fn, worker_devices, devices,
expected_values)
def testInitializableIterator(self):
worker_devices, devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda: dataset_ops.Dataset.range(8)
device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map, worker_devices)
multi_worker_dataset = input_lib.MultiWorkerDataset(
dataset_fn, input_workers)
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
sess.run(multi_worker_iterator.initializer)
self._test_iterator(
sess, multi_worker_iterator, devices,
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
# After re-initializing the iterator, should be able to iterate again.
sess.run(multi_worker_iterator.initializer)
self._test_iterator(
sess, multi_worker_iterator, devices,
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
def testValueErrorForIterator(self):
# Incompatiable arguments.
d1 = "/device:GPU:0"
d2 = "/device:GPU:1"
device_map = values.ReplicaDeviceMap([d1, d2])
input_workers = input_lib.InputWorkers(
device_map, (("w1", (d1,)), ("w2", (d2,))))
with self.assertRaises(ValueError):
input_lib.MultiWorkerDataIterator([("w1", None)], input_workers)
def testDuplicateDevices(self):
_, devices = self._cpu_devices()
devices.append("/job:worker/replica:0/task:0/device:CPU:0")
with self.assertRaises(ValueError):
_ = values.ReplicaDeviceMap(devices)
class InputIteratorTestBase(test.TestCase):
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
expected_values, sess=None, split_batch_by=None):
devices = nest.flatten([ds for _, ds in worker_device_pairs])
device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
if input_type == "input_fn":
input_contexts = [
distribute_lib.InputContext() for _ in worker_device_pairs]
input_fn = lambda _: dataset_fn()
iterator = input_lib.InputFunctionIterator(
input_fn, input_workers, input_contexts)
else:
iterator = input_lib.DatasetIterator(
dataset_fn(), input_workers, split_batch_by)
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
evaluate(control_flow_ops.group(iterator.initialize()))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertAllEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate([values.select_replica(r, next_element)
for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
evaluate(control_flow_ops.group(iterator.initialize()))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertAllEqual(expected_value, computed_value)
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
parameterized.TestCase):
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"]))
def testOneDeviceCPU(self, input_type):
worker_device_pairs = [("", ["/device:CPU:0"])]
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i] for i in range(10)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
required_gpus=1))
def testTwoDevicesOneGPUOneCPU(self, input_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
required_gpus=1))
def testTupleDataset(self, input_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
def dataset_fn():
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
required_gpus=1))
def testUnevenDatasetBatches(self, input_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
dataset_fn = lambda: dataset_ops.Dataset.range(11)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["dataset"],
split_batch_by=[None, 2],
required_gpus=1))
def testBatchSplitting(self, input_type, split_batch_by):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
batch_size = 10
dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size)
updated_batch_size = (
batch_size // split_batch_by if split_batch_by else batch_size)
expected_values = [[range(i, i+updated_batch_size),
range(i+updated_batch_size, i+2*updated_batch_size)]
for i in range(0, 100, updated_batch_size*2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values, sess=None,
split_batch_by=split_batch_by)
class InputIteratorMultiWorkerTest(
multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase,
parameterized.TestCase):
def _cpu_devices(self):
return [
("/job:worker/replica:0/task:0",
["/job:worker/replica:0/task:0/device:CPU:0"]),
("/job:worker/replica:0/task:1",
["/job:worker/replica:0/task:1/device:CPU:0"])]
def _cpu_and_one_gpu_devices(self):
return [
("/job:worker/replica:0/task:0", [
"/job:worker/replica:0/task:0/device:GPU:0",
"/job:worker/replica:0/task:0/device:CPU:0"
]),
("/job:worker/replica:0/task:1", [
"/job:worker/replica:0/task:1/device:GPU:0",
"/job:worker/replica:0/task:1/device:CPU:0"
])
]
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["input_fn", "dataset"]))
def testOneDevicePerWorker(self, input_type):
worker_devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda: dataset_ops.Dataset.range(4)
self._test_iterator(input_type, dataset_fn, worker_devices,
[[0, 0], [1, 1], [2, 2], [3, 3]], sess)
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["input_fn", "dataset"],
required_gpus=1))
def testTwoDevicesPerWorker(self, input_type):
worker_devices = self._cpu_and_one_gpu_devices()
with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda: dataset_ops.Dataset.range(4)
self._test_iterator(input_type, dataset_fn, worker_devices,
[[0, 1, 0, 1], [2, 3, 2, 3]], sess)
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["input_fn", "dataset"]))
def testTupleDataset(self, input_type):
worker_devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess:
def dataset_fn():
dataset1 = dataset_ops.Dataset.range(4)
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)]
self._test_iterator(input_type, dataset_fn, worker_devices,
expected_values, sess)
class SplitDatasetBatchTest(test.TestCase):
def testBatchDataset(self):
dataset = dataset_ops.Dataset.range(100).batch(20)
split_batch_by = 2
result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by)
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
result = [self.evaluate(el) for el in result_dataset]
self.assertAllEqual(expected_values, result)
def testMapAndBatchDataset(self):
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20))
split_batch_by = 2
result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by)
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
result = [self.evaluate(el) for el in result_dataset]
self.assertAllEqual(expected_values, result)
def testPrefetchDataset(self):
dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1)
split_batch_by = 2
result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by)
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
result = [self.evaluate(el) for el in result_dataset]
self.assertAllEqual(expected_values, result)
if __name__ == "__main__":
test.main()

View File

@ -21,8 +21,8 @@ from __future__ import print_function
import functools
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import values
# pylint: disable=protected-access,invalid-name
@ -135,14 +135,14 @@ class MirroredExtended(CoreMirroredExtended):
Returns:
An `InputIterator` which returns inputs for each step of the computation.
"""
return values.DatasetIterator(dataset, self._input_workers)
return input_lib.DatasetIterator(dataset, self._input_workers)
def _distribute_dataset(self, dataset_fn):
if self._local_mode:
return values.PerReplicaDataset(
return input_lib.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._input_workers, 0)
else:
return values.MultiWorkerDataset(
return input_lib.MultiWorkerDataset(
functools.partial(self._call_dataset_fn, dataset_fn),
self._input_workers,
auto_shard=self._auto_shard_dataset)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import values
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -52,7 +53,8 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
worker = device_util.canonicalize("/device:CPU:0")
worker_device_pairs = [(worker, [self._device])]
device_map = values.SingleDeviceMap(device)
self._input_workers = values.InputWorkers(device_map, worker_device_pairs)
self._input_workers = input_lib.InputWorkers(
device_map, worker_device_pairs)
def _create_variable(self, next_creator, *args, **kwargs):
colocate_with = kwargs.pop("colocate_with", None)
@ -67,17 +69,17 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
def _make_dataset_iterator(self, dataset):
"""Make iterator from dataset without splitting the batch."""
return values.DatasetIterator(dataset, self._input_workers)
return input_lib.DatasetIterator(dataset, self._input_workers)
def _distribute_dataset(self, dataset_fn):
return values.PerReplicaDataset(
return input_lib.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._input_workers, 0)
def _make_input_fn_iterator(
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
return values.InputFunctionIterator(
return input_lib.InputFunctionIterator(
input_fn, self._input_workers, [distribute_lib.InputContext()])
def _broadcast_to(self, tensor, destinations):
@ -91,7 +93,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
initial_loop_values = {}
initial_loop_values = nest.flatten(initial_loop_values)
ctx = values.MultiStepContext()
ctx = input_lib.MultiStepContext()
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args

View File

@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
@ -153,7 +154,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
compute_devices = (worker_device,)
self._device_map = values.ReplicaDeviceMap(compute_devices)
self._input_workers = values.InputWorkers(
self._input_workers = input_lib.InputWorkers(
self._device_map, [(worker_device, compute_devices)])
# In distributed mode, place variables on ps jobs in a round-robin fashion.
@ -210,7 +211,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
compute_devices = (_LOCAL_CPU,)
self._device_map = values.ReplicaDeviceMap(compute_devices)
self._input_workers = values.InputWorkers(
self._input_workers = input_lib.InputWorkers(
self._device_map, [(worker_device, compute_devices)])
# If there is only one GPU, put everything on that GPU. Otherwise, place
@ -237,13 +238,13 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
def _distribute_dataset(self, dataset_fn):
"""Distributes the dataset to each local GPU."""
return values.PerReplicaDataset(
return input_lib.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._input_workers, 0,
prefetch_on_device=True)
def _make_dataset_iterator(self, dataset):
return values.DatasetIterator(dataset, self._input_workers,
self._num_replicas_in_sync)
return input_lib.DatasetIterator(dataset, self._input_workers,
self._num_replicas_in_sync)
def _make_input_fn_iterator(
self,
@ -262,7 +263,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
num_input_pipelines=num_input_pipelines,
input_pipeline_id=input_pipeline_id,
num_replicas_in_sync=self._num_replicas_in_sync)
return values.InputFunctionIterator(
return input_lib.InputFunctionIterator(
input_fn, self._input_workers, [input_context])
def _broadcast_to(self, tensor, destinations):

View File

@ -33,6 +33,7 @@ from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver_lib
@ -204,7 +205,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
(self.get_host(hid), [self.get_host_cpu_device(hid)])
for hid in range(self.num_hosts)
]
self._input_workers = values.InputWorkers(input_device_map, worker_devices)
self._input_workers = input_lib.InputWorkers(
input_device_map, worker_devices)
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
@ -304,11 +306,11 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
def _make_dataset_iterator(self, dataset):
"""Make iterators for each of the TPU hosts."""
return values.DatasetIterator(dataset, self._input_workers,
self._num_replicas_in_sync)
return input_lib.DatasetIterator(dataset, self._input_workers,
self._num_replicas_in_sync)
def _distribute_dataset(self, dataset_fn):
return values.MultiWorkerDataset(
return input_lib.MultiWorkerDataset(
functools.partial(self._call_dataset_fn, dataset_fn),
self._input_workers)
@ -339,7 +341,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
if initial_loop_values is None:
initial_loop_values = {}
initial_loop_values = nest.flatten(initial_loop_values)
ctx = values.MultiStepContext()
ctx = input_lib.MultiStepContext()
def run_fn(*args, **kwargs):
"""Single step on the TPU device."""

View File

@ -22,28 +22,20 @@ import os
from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import nest
class DistributedValuesTest(test.TestCase):
@ -354,444 +346,6 @@ class RegroupAndSelectDeviceTest(test.TestCase):
merged_estimator_spec))
class PerReplicaDatasetTest(test.TestCase):
config = config_pb2.ConfigProto()
config.allow_soft_placement = True
def _test_iterator(self, devices, dataset, expected_values):
device_map = values.ReplicaDeviceMap(devices)
input_workers = values.InputWorkers(device_map)
per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0)
if context.executing_eagerly():
iterator = per_replica_dataset.make_one_shot_iterator()
else:
iterator = per_replica_dataset.make_initializable_iterator()
self.evaluate([iterator.initializer])
for expected_value in expected_values:
next_element = iterator.get_next_as_list()
computed_value = self.evaluate(next_element)
self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next_as_list()
self.evaluate(next_element)
@test_util.run_in_graph_and_eager_modes
def testOneDevice(self):
devices = ["/device:CPU:0"]
dataset = dataset_ops.Dataset.range(10)
expected_values = [[i] for i in range(10)]
self._test_iterator(devices, dataset, expected_values)
@test_util.run_in_graph_and_eager_modes(config=config)
def testMultipleDevices(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
devices = ["/device:CPU:0", "/device:GPU:0"]
dataset = dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
self._test_iterator(devices, dataset, expected_values)
@test_util.run_in_graph_and_eager_modes(config=config)
def testTupleDataset(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
devices = ["/device:CPU:0", "/device:GPU:0"]
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
self._test_iterator(devices, dataset, expected_values)
@test_util.run_in_graph_and_eager_modes(config=config)
def testUnevenDatasetBatches(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
devices = ["/device:CPU:0", "/device:GPU:0"]
dataset = dataset_ops.Dataset.range(11)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
self._test_iterator(devices, dataset, expected_values)
def testInitializableIterator(self):
with context.graph_mode():
devices = ["/device:CPU:0"]
# Using random input since that is only allowed with initializable
# iterator.
dataset = dataset_ops.Dataset.from_tensor_slices(
random_ops.random_uniform((10,)))
device_map = values.ReplicaDeviceMap(devices)
input_workers = values.InputWorkers(device_map)
per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0)
iterator = per_replica_dataset.make_initializable_iterator()
self.evaluate(iterator.initializer)
next_element = iterator.get_next_as_list()
for _ in range(10):
self.evaluate(next_element)
# Should fail after the input is finished.
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
# After re-initializing the iterator, should be able to iterate again.
self.evaluate(iterator.initializer)
for _ in range(10):
self.evaluate(next_element)
class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
def _test_iterator(self, sess, iterator, devices, expected_values):
next_element = iterator.get_next()
for r, device in enumerate(devices):
v = values.select_replica(r, next_element)
# The `v` here can be a tuple.
for element in nest.flatten(v):
self.assertTrue(element.device in device)
for expected_value in expected_values:
t = [values.select_replica(r, next_element) for r in range(len(devices))]
actual = sess.run(t)
self.assertEqual(expected_value, actual)
with self.assertRaises(errors.OutOfRangeError):
sess.run([values.select_replica(r, next_element)
for r in range(len(devices))])
def _test_dataset(self, dataset_fn, worker_devices, devices,
expected_values):
device_map = values.ReplicaDeviceMap(devices)
input_workers = values.InputWorkers(device_map, worker_devices)
multi_worker_dataset = values.MultiWorkerDataset(
dataset_fn, input_workers)
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
with self.cached_session() as sess:
sess.run(multi_worker_iterator.initializer)
self._test_iterator(sess, multi_worker_iterator, devices, expected_values)
def _cpu_devices(self):
worker_devices = (
("/job:worker/replica:0/task:0",
["/job:worker/replica:0/task:0/device:CPU:0"]),
("/job:worker/replica:0/task:1",
["/job:worker/replica:0/task:1/device:CPU:0"])
)
devices = [
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:1/device:CPU:0"
]
return worker_devices, devices
def _cpu_and_one_gpu_devices(self):
worker_devices = (
("/job:worker/replica:0/task:0", (
"/job:worker/replica:0/task:0/device:GPU:0",
"/job:worker/replica:0/task:0/device:CPU:0"
)),
("/job:worker/replica:0/task:1", (
"/job:worker/replica:0/task:1/device:GPU:0",
"/job:worker/replica:0/task:1/device:CPU:0"
))
)
devices = [
"/job:worker/replica:0/task:0/device:GPU:0",
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:1/device:GPU:0",
"/job:worker/replica:0/task:1/device:CPU:0"
]
return worker_devices, devices
def testDataDistributionOneDevicePerWorker(self):
worker_devices, devices = self._cpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
self._test_dataset(
dataset_fn, worker_devices, devices,
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
def testDataDistributionTwoDevicePerWorker(self):
if context.num_gpus() < 1:
self.skipTest("A GPU is not available for this test.")
worker_devices, devices = self._cpu_and_one_gpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
self._test_dataset(
dataset_fn, worker_devices, devices,
[[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]])
def testTupleDataset(self):
worker_devices, devices = self._cpu_devices()
with context.graph_mode():
def dataset_fn():
dataset1 = dataset_ops.Dataset.range(8)
dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
expected_values = [[(i, i**2), (i, i**2)] for i in range(8)]
self._test_dataset(dataset_fn, worker_devices, devices,
expected_values)
def testInitializableIterator(self):
worker_devices, devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda: dataset_ops.Dataset.range(8)
device_map = values.ReplicaDeviceMap(devices)
input_workers = values.InputWorkers(device_map, worker_devices)
multi_worker_dataset = values.MultiWorkerDataset(
dataset_fn, input_workers)
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
sess.run(multi_worker_iterator.initializer)
self._test_iterator(
sess, multi_worker_iterator, devices,
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
# After re-initializing the iterator, should be able to iterate again.
sess.run(multi_worker_iterator.initializer)
self._test_iterator(
sess, multi_worker_iterator, devices,
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
def testValueErrorForIterator(self):
# Incompatiable arguments.
d1 = "/device:GPU:0"
d2 = "/device:GPU:1"
device_map = values.ReplicaDeviceMap([d1, d2])
input_workers = values.InputWorkers(
device_map, (("w1", (d1,)), ("w2", (d2,))))
with self.assertRaises(ValueError):
values.MultiWorkerDataIterator([("w1", None)], input_workers)
def testDuplicateDevices(self):
_, devices = self._cpu_devices()
devices.append("/job:worker/replica:0/task:0/device:CPU:0")
with self.assertRaises(ValueError):
_ = values.ReplicaDeviceMap(devices)
class InputIteratorTestBase(test.TestCase):
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
expected_values, sess=None, split_batch_by=None):
devices = nest.flatten([ds for _, ds in worker_device_pairs])
device_map = values.ReplicaDeviceMap(devices)
input_workers = values.InputWorkers(device_map, worker_device_pairs)
if input_type == "input_fn":
input_contexts = [
distribute_lib.InputContext() for _ in worker_device_pairs]
input_fn = lambda _: dataset_fn()
iterator = values.InputFunctionIterator(
input_fn, input_workers, input_contexts)
else:
iterator = values.DatasetIterator(
dataset_fn(), input_workers, split_batch_by)
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
evaluate(control_flow_ops.group(iterator.initialize()))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertAllEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate([values.select_replica(r, next_element)
for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
evaluate(control_flow_ops.group(iterator.initialize()))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertAllEqual(expected_value, computed_value)
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
parameterized.TestCase):
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"]))
def testOneDeviceCPU(self, input_type):
worker_device_pairs = [("", ["/device:CPU:0"])]
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i] for i in range(10)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
required_gpus=1))
def testTwoDevicesOneGPUOneCPU(self, input_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
required_gpus=1))
def testTupleDataset(self, input_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
def dataset_fn():
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
required_gpus=1))
def testUnevenDatasetBatches(self, input_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
dataset_fn = lambda: dataset_ops.Dataset.range(11)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["dataset"],
split_batch_by=[None, 2],
required_gpus=1))
def testBatchSplitting(self, input_type, split_batch_by):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
batch_size = 10
dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size)
updated_batch_size = (
batch_size // split_batch_by if split_batch_by else batch_size)
expected_values = [[range(i, i+updated_batch_size),
range(i+updated_batch_size, i+2*updated_batch_size)]
for i in range(0, 100, updated_batch_size*2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values, sess=None,
split_batch_by=split_batch_by)
class InputIteratorMultiWorkerTest(
multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase,
parameterized.TestCase):
def _cpu_devices(self):
return [
("/job:worker/replica:0/task:0",
["/job:worker/replica:0/task:0/device:CPU:0"]),
("/job:worker/replica:0/task:1",
["/job:worker/replica:0/task:1/device:CPU:0"])]
def _cpu_and_one_gpu_devices(self):
return [
("/job:worker/replica:0/task:0", [
"/job:worker/replica:0/task:0/device:GPU:0",
"/job:worker/replica:0/task:0/device:CPU:0"
]),
("/job:worker/replica:0/task:1", [
"/job:worker/replica:0/task:1/device:GPU:0",
"/job:worker/replica:0/task:1/device:CPU:0"
])
]
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["input_fn", "dataset"]))
def testOneDevicePerWorker(self, input_type):
worker_devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda: dataset_ops.Dataset.range(4)
self._test_iterator(input_type, dataset_fn, worker_devices,
[[0, 0], [1, 1], [2, 2], [3, 3]], sess)
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["input_fn", "dataset"],
required_gpus=1))
def testTwoDevicesPerWorker(self, input_type):
worker_devices = self._cpu_and_one_gpu_devices()
with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda: dataset_ops.Dataset.range(4)
self._test_iterator(input_type, dataset_fn, worker_devices,
[[0, 1, 0, 1], [2, 3, 2, 3]], sess)
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["input_fn", "dataset"]))
def testTupleDataset(self, input_type):
worker_devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess:
def dataset_fn():
dataset1 = dataset_ops.Dataset.range(4)
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)]
self._test_iterator(input_type, dataset_fn, worker_devices,
expected_values, sess)
class SplitDatasetBatchTest(test.TestCase):
def testBatchDataset(self):
dataset = dataset_ops.Dataset.range(100).batch(20)
split_batch_by = 2
result_dataset = values._split_dataset_batch(dataset, split_batch_by)
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
result = [self.evaluate(el) for el in result_dataset]
self.assertAllEqual(expected_values, result)
def testMapAndBatchDataset(self):
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20))
split_batch_by = 2
result_dataset = values._split_dataset_batch(dataset, split_batch_by)
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
result = [self.evaluate(el) for el in result_dataset]
self.assertAllEqual(expected_values, result)
def testPrefetchDataset(self):
dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1)
split_batch_by = 2
result_dataset = values._split_dataset_batch(dataset, split_batch_by)
expected_values = [range(i, i+10) for i in range(0, 100, 10)]
result = [self.evaluate(el) for el in result_dataset]
self.assertAllEqual(expected_values, result)
class MirroredVariableTest(test.TestCase, parameterized.TestCase):
config = config_pb2.ConfigProto()

View File

@ -219,6 +219,7 @@ py_library(
":cross_device_ops",
":device_util",
":distribute_lib",
":input_lib",
":multi_worker_util",
":reduce_util",
":shared_variable_creator",
@ -253,6 +254,23 @@ py_library(
],
)
py_library(
name = "input_lib",
srcs = ["input_lib.py"],
deps = [
":device_util",
":distribute_lib",
":input_ops",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
],
)
py_library(
name = "input_ops",
srcs = ["input_ops.py"],
@ -348,14 +366,12 @@ py_library(
deps = [
":device_util",
":distribute_lib",
":input_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",

View File

@ -0,0 +1,707 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various classes representing distributed inputs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.util import nest
class InputWorkers(object):
"""A 1-to-many mapping from input worker devices to compute devices."""
def __init__(self, device_map, worker_device_pairs=None, logical_device=0):
"""Initialize an `InputWorkers` object.
Args:
device_map: A `DeviceMap` with the computation devices fed by the
input workers.
worker_device_pairs: A sequence of pairs:
`(input device, a tuple of compute devices fed by that input device)`.
logical_device: The logical device of `device_map` to feed.
"""
self._device_map = device_map
self._logical_device = logical_device
if worker_device_pairs is None:
worker_device_pairs = ((
device_util.canonicalize("/device:CPU:0"),
device_map.logical_to_actual_devices(logical_device)),)
self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
for _, f in worker_device_pairs)
flattened = tuple(d for l in self._fed_devices for d in l)
assert (flattened ==
device_map.logical_to_actual_devices(logical_device)), (
"flattened: %s logical device %d: %s" %
(flattened, logical_device,
device_map.logical_to_actual_devices(logical_device)))
@property
def device_map(self):
return self._device_map
@property
def logical_device(self):
return self._logical_device
@property
def num_workers(self):
return len(self._input_worker_devices)
@property
def worker_devices(self):
return self._input_worker_devices
def compute_devices_for_worker(self, worker_index):
return self._fed_devices[worker_index]
def __repr__(self):
devices = self.worker_devices
debug_repr = ",\n".join(" %d %s: %s" %
(i, devices[i], self._fed_devices[i])
for i in range(len(devices)))
return "%s:{\n%s\n device_map: %s}" % (
self.__class__.__name__, debug_repr, self._device_map)
class PerReplicaDataIterator(object):
"""An iterator (like `tf.data.Iterator`) into a `PerReplicaDataset`."""
def __init__(self, iterator, input_workers, worker_index, prefetch_on_device):
assert isinstance(input_workers, InputWorkers)
self._iterator = iterator
self._input_workers = input_workers
self._worker_index = worker_index
self._prefetch_on_device = prefetch_on_device
@property
def initializer(self):
return self._iterator.initializer
def get_next_as_list(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
data_list = self._iterator.get_next()
else:
batch = self._iterator.get_next(name=name)
data_list = []
def get_ith(i):
return lambda x: x[i]
devices = self._input_workers.compute_devices_for_worker(
self._worker_index)
for i, d in enumerate(devices):
v = nest.map_structure(get_ith(i), batch)
if context.executing_eagerly():
with ops.device(d):
v = nest.map_structure(array_ops.identity, v)
data_list.append(v)
return data_list
def get_next(self, name=None):
assert self._input_workers.num_workers == 1
data_list = self.get_next_as_list(name)
return values.regroup(self._input_workers.device_map, data_list)
@property
def output_classes(self):
return self._iterator.output_classes
@property
def output_shapes(self):
return self._iterator.output_shapes
@property
def output_types(self):
return self._iterator.output_types
class PerReplicaDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerReplica` data."""
def __init__(self, dataset, input_workers, worker_index,
prefetch_on_device=None):
assert isinstance(input_workers, InputWorkers)
assert worker_index is not None
assert worker_index is not True # pylint: disable=g-bool-id-comparison
assert worker_index is not False # pylint: disable=g-bool-id-comparison
self._input_workers = input_workers
self._worker_index = worker_index
# Default to using prefetching, unless specified.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = True
self._dataset = dataset
if not self._prefetch_on_device:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
num_replicas = len(
self._input_workers.compute_devices_for_worker(self._worker_index))
self._dataset = self._dataset.batch(num_replicas, drop_remainder=True)
else:
self._replica_devices = self._input_workers.compute_devices_for_worker(
self._worker_index)
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerReplicaDataset."""
# Graph mode with one shot iterator is disabled.
if not context.executing_eagerly():
raise ValueError("Cannot create a one shot iterator. Please use "
"`make_initializable_iterator()` instead.")
if self._prefetch_on_device:
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._replica_devices)
else:
dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
return PerReplicaDataIterator(
dataset_iterator,
self._input_workers,
self._worker_index,
prefetch_on_device=self._prefetch_on_device)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerReplicaDataset."""
# Eager mode generates already initialized iterators. Hence we cannot create
# an initializable iterator.
if context.executing_eagerly():
raise ValueError("Cannot create initializable iterator in Eager mode. "
"Please use `make_one_shot_iterator` instead.")
if self._prefetch_on_device:
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._replica_devices)
else:
dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset)
return PerReplicaDataIterator(
dataset_iterator, self._input_workers, self._worker_index,
prefetch_on_device=self._prefetch_on_device)
class MultiWorkerDataIterator(object):
"""An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`."""
def __init__(self, iterators, input_workers):
"""Initialize the `MultiWorkerDataIterator` object.
Args:
iterators: a list of worker, iterator pairs.
input_workers: an `InputWorkers` object.
Raises:
ValueError: if iterators and input_workers are not compatible.
"""
assert isinstance(input_workers, InputWorkers)
workers = tuple(d for d, _ in iterators)
if workers != input_workers.worker_devices:
raise ValueError("iterators and input_workers are not compatible. "
"iterator workers: %r input_workers devices: %r" %
(workers, input_workers.worker_devices))
self._iterators = tuple(i for _, i in iterators)
self._input_workers = input_workers
@property
def initializer(self):
return control_flow_ops.group(
tuple(iterator.initializer for iterator in self._iterators))
def get_iterator(self, worker):
for i, w in enumerate(self._input_workers.worker_devices):
if worker == w:
return self._iterators[i]
return None
@property
def output_shapes(self):
return self._iterators[0].output_shapes
@property
def output_types(self):
return self._iterators[0].output_types
def get_next(self, name=None):
"""Scatter the input across hosts and devices."""
replicas = []
for worker, iterator in zip(self._input_workers.worker_devices,
self._iterators):
if name is not None:
d = tf_device.DeviceSpec.from_string(worker)
new_name = "%s_%s_%d" % (name, d.job, d.task)
else:
new_name = None
with ops.device(worker):
data_per_worker = iterator.get_next_as_list(name=new_name)
# Append to replicas to get a flat list of values indexed by replica.
replicas.extend(data_per_worker)
return values.regroup(self._input_workers.device_map, replicas)
class MultiWorkerDataset(object):
"""Like a `tf.data.Dataset` that distributes data to different workers.
Each worker gets one shard of the input dataset. This currently does not work
in eager mode.
"""
def __init__(self, dataset_fn, input_workers, prefetch_on_device=None,
auto_shard=False):
"""Initialize the MultiWorkerDataset object.
Args:
dataset_fn: a function or a list of functions that returns a
`tf.data.Dataset`.
input_workers: an `InputWorkers` object.
prefetch_on_device: whether to prefetch to devices.
auto_shard: whether to auto-shard the dataset.
"""
assert isinstance(input_workers, InputWorkers)
if isinstance(dataset_fn, (list, tuple)):
if len(dataset_fn) != input_workers.num_workers:
raise ValueError("If `dataset_fn` is a list, it must have one entry "
"per worker")
# TODO(rohanj): b/120673685 to track re-enabling auto sharding.
if auto_shard:
raise ValueError("Currently autosharding is not supported.")
self._input_workers = input_workers
self._datasets = []
# TODO(yuefengz, priyag): support different set of jobs for input
# processing.
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
if isinstance(dataset_fn, (list, tuple)):
worker_input = dataset_fn[i]()
else:
worker_input = dataset_fn()
dataset = PerReplicaDataset(worker_input, input_workers, i,
prefetch_on_device=prefetch_on_device)
self._datasets.append((worker, dataset))
def make_one_shot_iterator(self):
iterators = []
for worker, dataset in self._datasets:
with ops.device(worker):
iterators.append((worker, dataset_ops.make_one_shot_iterator(dataset)))
return MultiWorkerDataIterator(iterators, self._input_workers)
def make_initializable_iterator(self):
iterators = []
for worker, dataset in self._datasets:
with ops.device(worker):
iterators.append(
(worker, dataset_ops.make_initializable_iterator(dataset)))
return MultiWorkerDataIterator(iterators, self._input_workers)
class InputIterator(object):
"""An input iterator, intended to be passed to `DistributionStrategy.run`."""
def get_next(self):
"""Returns the next inputs for all replicas."""
raise NotImplementedError("must be implemented in descendants")
def initialize(self):
"""Initialize the underlying input dataset, when applicable.
In eager mode, this will create a new iterator and return it.
In graph mode, this will initialize the same underlying iterator(s).
Users are required to call this if
- This iterator was returned from a call to `make_input_fn_iterator` with an
input function that returns a dataset.
- Or this iterator was returned from a call to `make_dataset_iterator`.
Returns:
A list of initialization ops to be executed.
"""
raise NotImplementedError("must be implemented in descendants")
class InputIteratorImpl(InputIterator):
"""Common implementation for all input iterators."""
def __init__(self, input_workers, iterators):
assert isinstance(input_workers, InputWorkers)
if not input_workers.worker_devices:
raise ValueError("Should have at least one worker for input iterator.")
self._iterators = iterators
self._input_workers = input_workers
def get_next(self, name=None):
"""Returns the next input from the iterator for all replicas."""
replicas = []
for i, worker in enumerate(self._input_workers.worker_devices):
if name is not None:
d = tf_device.DeviceSpec.from_string(worker)
new_name = "%s_%s_%d" % (name, d.job, d.task)
else:
new_name = None
with ops.device(worker):
# Make `replicas` a flat list of values across all replicas.
replicas.extend(self._iterators[i].get_next_as_list(new_name))
return values.regroup(self._input_workers.device_map, replicas)
def initialize(self):
"""Initialze underlying iterators.
Returns:
A list of any initializer ops that should be run.
"""
init_ops = []
for it in self._iterators:
init_ops.extend(it.initialize())
return init_ops
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_classes(self):
return self._iterators[0].output_classes
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_shapes(self):
return self._iterators[0].output_shapes
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_types(self):
return self._iterators[0].output_types
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
def get_iterator(self, worker):
for i, w in enumerate(self._input_workers.worker_devices):
if worker == w:
return self._iterators[i]
return None
class InputFunctionIterator(InputIteratorImpl):
"""Iterator created from input function."""
def __init__(self, input_fn, input_workers, input_contexts):
"""Make an iterator for input provided via an input function.
Currently implements PER_WORKER mode, in which the `input_fn` is called
once on each worker.
TODO(priyag): Add other replication modes.
TODO(priyag): Allow taking input function that returns a callable that
returns nest of tensors.
Args:
input_fn: Input function that returns a `tf.data.Dataset` object.
input_workers: an `InputWorkers` object.
input_contexts: A list of `InputContext` instances to be passed to call(s)
to `input_fn`. Length and order should match worker order in
`worker_device_pairs`.
"""
assert isinstance(input_workers, InputWorkers)
if input_workers.num_workers != len(input_contexts):
raise ValueError(
"Number of input workers (%d) is not same as number of "
"input_contexts (%d)" %
(input_workers.num_workers, len(input_contexts)))
iterators = []
for i, ctx in enumerate(input_contexts):
worker = input_workers.worker_devices[i]
with ops.device(worker):
result = input_fn(ctx)
if not isinstance(result, dataset_ops.DatasetV2):
raise ValueError("input_fn must return a tf.data.Dataset.")
devices = input_workers.compute_devices_for_worker(i)
iterator = _SingleWorkerDatasetIterator(result, worker, devices)
iterators.append(iterator)
super(InputFunctionIterator, self).__init__(input_workers, iterators)
class DatasetIterator(InputIteratorImpl):
"""Iterator created from input dataset."""
def __init__(self, dataset, input_workers, split_batch_by=None):
"""Make an iterator for the dataset on given devices.
If `split_batch_by` is not None, we "split" each batch of the
dataset by `split_batch_by` value. To achieve this, we first unbatch the
input dataset and then rebatch it with the per replica batch size that is
calculated using `global_batch_size // split_batch_by`.
The currently supported datasets are as follows:
`dataset.batch()` is the last operation on the dataset OR
`dataset.apply(map_and_batch)` is the last operation on the dataset OR
`dataset.batch().prefetch()` are the last 2 operations on the dataset OR
`dataset.apply(map_and_batch).prefetch()` are the last 2 operations.
TODO(priyag): Support multi worker / host cases properly by cloning
and sharding the dataset on each worker. Current setup will only work in
some cases, such as in-graph multi worker GPU case. If the input pipeline
has random shuffling (with a different seed on each worker), each worker
will see random input from the same overall dataset in each step. Otherwise,
each worker will see the same input in each step.
Args:
dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
"""
assert isinstance(input_workers, InputWorkers)
if split_batch_by:
dataset = _split_dataset_batch(dataset, split_batch_by)
iterators = []
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i)
cloned_dataset = dataset
if not context.executing_eagerly():
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
worker_devices)
iterators.append(iterator)
super(DatasetIterator, self).__init__(input_workers, iterators)
class _SingleWorkerDatasetIterator(object):
"""Iterator for a single `tf.data.Dataset`."""
def __init__(self, dataset, worker, devices):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
`MultiDeviceIterator` is used to prefetch input to the devices on the
given worker.
Args:
dataset: A `tf.data.Dataset` instance.
worker: Worker on which ops should be created.
devices: Distribute data from `dataset` to these devices.
"""
self._dataset = dataset
self._worker = worker
self._devices = devices
self._make_iterator()
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._devices)
def get_next_as_list(self, name=None):
"""Get next element from the underlying iterator."""
del name
with ops.device(self._worker):
data_list = self._iterator.get_next()
return data_list
def initialize(self):
"""Initialze underlying iterator.
In eager execution, this simply recreates the underlying iterator.
In graph execution, it returns the initializer ops for the underlying
iterator.
Returns:
A list of any initializer ops that should be run.
"""
if context.executing_eagerly():
self._make_iterator()
return []
else:
return [self._iterator.initializer]
@property
def output_classes(self):
return self._iterator.output_classes
@property
def output_shapes(self):
return self._iterator.output_shapes
@property
def output_types(self):
return self._iterator.output_types
def _split_dataset_batch(dataset, split_batch_by):
"""Divide a batch-ed dataset's batches into smaller batches."""
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
# pylint: disable=protected-access
def _get_batch_dataset(d):
"""Get the underlying batch dataset from the dataset object."""
if isinstance(d, dataset_ops.DatasetV1Adapter):
d = d._dataset
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
return d
elif isinstance(d, dataset_ops.PrefetchDataset):
return _get_batch_dataset(d._input_dataset)
raise ValueError(
"Unable to get batched dataset from the input dataset. `batch` "
"`map_and_batch` need to be the last operations on the dataset. "
"The batch operations can be followed by a prefetch.")
batched_dataset = _get_batch_dataset(dataset)
if isinstance(batched_dataset, dataset_ops.BatchDataset):
batch_size = batched_dataset._batch_size
drop_remainder = batched_dataset._drop_remainder
elif isinstance(batched_dataset, batching._MapAndBatchDataset):
batch_size = batched_dataset._batch_size_t
drop_remainder = batched_dataset._drop_remainder_t
prefetch_buffer = None
if isinstance(dataset, dataset_ops.PrefetchDataset):
prefetch_buffer = dataset._buffer_size
elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
prefetch_buffer = dataset._dataset._buffer_size
# pylint: enable=protected-access
if tensor_util.is_tensor(batch_size):
batch_size = tensor_util.constant_value(batch_size)
if tensor_util.is_tensor(drop_remainder):
drop_remainder = tensor_util.constant_value(drop_remainder)
if batch_size % split_batch_by:
raise ValueError(
"Batch size %s cannot be sharded evenly across replicas %s" % (
batch_size, split_batch_by))
new_batch_size = batch_size // split_batch_by
dataset = dataset.apply(batching.unbatch())
dataset = dataset.batch(new_batch_size, drop_remainder=drop_remainder)
if prefetch_buffer is not None:
dataset = dataset.prefetch(prefetch_buffer)
return dataset
class MultiStepContext(object):
"""A context object that can be used to capture things when running steps.
This context object is useful when running multiple steps at a time using the
`experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
function to specify which outputs to emit at what frequency. Currently it
supports capturing output from the last step, as well as capturing non tensor
outputs. In the future it will be augmented to support other use cases such
as output each N steps.
"""
def __init__(self):
"""Initialize an output context.
Returns:
A context object.
"""
self._last_step_outputs = {}
self._last_step_outputs_reduce_ops = {}
self._non_tensor_outputs = {}
@property
def last_step_outputs(self):
"""A dictionary consisting of outputs to be captured on last step.
Keys in the dictionary are names of tensors to be captured, as specified
when `set_last_step_output` is called.
Values in the dictionary are the tensors themselves. If
`set_last_step_output` was called with a `reduce_op` for this output,
then the value is the reduced value.
Returns:
A dictionary with last step outputs.
"""
return self._last_step_outputs
def _set_last_step_outputs(self, outputs):
"""Replace the entire dictionary of last step outputs."""
if not isinstance(outputs, dict):
raise ValueError("Need a dictionary to set last_step_outputs.")
self._last_step_outputs = outputs
def set_last_step_output(self, name, output, reduce_op=None):
"""Set `output` with `name` to be outputted from the last step.
Args:
name: String, name to identify the output. Doesn't need to match tensor
name.
output: The tensors that should be outputted with `name`. See below for
actual types supported.
reduce_op: Reduction method to use to reduce outputs from multiple
replicas. Required if `set_last_step_output` is called in a replica
context. Optional in cross_replica_context.
When present, the outputs from all the replicas are reduced using the
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and reduction is set, output
must be a `PerReplica` value.
The reduce method is also recorded in a dictionary
`_last_step_outputs_reduce_ops` for later interpreting of the
outputs as already reduced or not.
"""
if distribution_strategy_context.in_cross_replica_context():
self._last_step_outputs_reduce_ops[name] = reduce_op
if reduce_op is None:
self._last_step_outputs[name] = output
else:
distribution = distribution_strategy_context.get_distribution_strategy()
self._last_step_outputs[name] = distribution.reduce(reduce_op, output)
else:
assert reduce_op is not None
def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce(reduce_op, value)
# Setting this inside the `merge_fn` because all replicas share the same
# context object, so it's more robust to set it only once (even if all
# the replicas are trying to set the same value).
self._last_step_outputs_reduce_ops[name] = reduce_op
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
@property
def non_tensor_outputs(self):
"""A dictionary consisting of any non tensor outputs to be captured."""
return self._non_tensor_outputs
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
if distribution_strategy_context.in_cross_replica_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as reduction doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))

View File

@ -27,6 +27,7 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import shared_variable_creator
@ -456,7 +457,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
"No duplicates allowed in `devices` argument: %s" % devices)
# TODO(josh11b): Require at least 2 devices?
self._device_map = values.ReplicaDeviceMap(devices)
self._input_workers = values.InputWorkers(self._device_map)
self._input_workers = input_lib.InputWorkers(self._device_map)
self._inferred_cross_device_ops = cross_device_ops_lib.choose_the_best(
devices)
@ -489,7 +490,8 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
self._default_device = workers[0]
self._device_map = values.ReplicaDeviceMap(devices)
self._input_workers = values.InputWorkers(self._device_map, worker_devices)
self._input_workers = input_lib.InputWorkers(
self._device_map, worker_devices)
self._inferred_cross_device_ops = cross_device_ops_lib.MultiWorkerAllReduce(
workers, _infer_num_gpus_per_worker(devices))
@ -543,16 +545,16 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
def _distribute_dataset(self, dataset_fn):
if self._local_mode:
worker_index = 0
return values.PerReplicaDataset(
return input_lib.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._input_workers, worker_index)
else:
return values.MultiWorkerDataset(
return input_lib.MultiWorkerDataset(
functools.partial(self._call_dataset_fn, dataset_fn),
self._input_workers,
auto_shard=False)
def _make_dataset_iterator(self, dataset):
return values.DatasetIterator(
return input_lib.DatasetIterator(
dataset, self._input_workers, self._num_replicas_in_sync)
def _make_input_fn_iterator(
@ -566,7 +568,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
num_input_pipelines=num_workers,
input_pipeline_id=i,
num_replicas_in_sync=self._num_replicas_in_sync))
return values.InputFunctionIterator(
return input_lib.InputFunctionIterator(
input_fn, self._input_workers, input_contexts)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
@ -576,7 +578,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
initial_loop_values = {}
initial_loop_values = nest.flatten(initial_loop_values)
ctx = values.MultiStepContext()
ctx = input_lib.MultiStepContext()
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args

View File

@ -23,17 +23,12 @@ import contextlib
import weakref
import six
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
@ -1409,679 +1404,6 @@ def update_regroup(extended, device_map, updates, group):
return nest.pack_sequence_as(regrouped, grouped_flat)
class InputWorkers(object):
"""A 1-to-many mapping from input worker devices to compute devices."""
def __init__(self, device_map, worker_device_pairs=None, logical_device=0):
"""Initialize an `InputWorkers` object.
Args:
device_map: A `DeviceMap` with the computation devices fed by the
input workers.
worker_device_pairs: A sequence of pairs:
`(input device, a tuple of compute devices fed by that input device)`.
logical_device: The logical device of `device_map` to feed.
"""
self._device_map = device_map
self._logical_device = logical_device
if worker_device_pairs is None:
worker_device_pairs = ((
device_util.canonicalize("/device:CPU:0"),
device_map.logical_to_actual_devices(logical_device)),)
self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
for _, f in worker_device_pairs)
flattened = tuple(d for l in self._fed_devices for d in l)
assert (flattened ==
device_map.logical_to_actual_devices(logical_device)), (
"flattened: %s logical device %d: %s" %
(flattened, logical_device,
device_map.logical_to_actual_devices(logical_device)))
@property
def device_map(self):
return self._device_map
@property
def logical_device(self):
return self._logical_device
@property
def num_workers(self):
return len(self._input_worker_devices)
@property
def worker_devices(self):
return self._input_worker_devices
def compute_devices_for_worker(self, worker_index):
return self._fed_devices[worker_index]
def __repr__(self):
devices = self.worker_devices
debug_repr = ",\n".join(" %d %s: %s" %
(i, devices[i], self._fed_devices[i])
for i in range(len(devices)))
return "%s:{\n%s\n device_map: %s}" % (
self.__class__.__name__, debug_repr, self._device_map)
class PerReplicaDataIterator(object):
"""An iterator (like `tf.data.Iterator`) into a `PerReplicaDataset`."""
def __init__(self, iterator, input_workers, worker_index, prefetch_on_device):
assert isinstance(input_workers, InputWorkers)
self._iterator = iterator
self._input_workers = input_workers
self._worker_index = worker_index
self._prefetch_on_device = prefetch_on_device
@property
def initializer(self):
return self._iterator.initializer
def get_next_as_list(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
data_list = self._iterator.get_next()
else:
batch = self._iterator.get_next(name=name)
data_list = []
def get_ith(i):
return lambda x: x[i]
devices = self._input_workers.compute_devices_for_worker(
self._worker_index)
for i, d in enumerate(devices):
v = nest.map_structure(get_ith(i), batch)
if context.executing_eagerly():
with ops.device(d):
v = nest.map_structure(array_ops.identity, v)
data_list.append(v)
return data_list
def get_next(self, name=None):
assert self._input_workers.num_workers == 1
data_list = self.get_next_as_list(name)
return regroup(self._input_workers.device_map, data_list)
@property
def output_classes(self):
return self._iterator.output_classes
@property
def output_shapes(self):
return self._iterator.output_shapes
@property
def output_types(self):
return self._iterator.output_types
class PerReplicaDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerReplica` data."""
def __init__(self, dataset, input_workers, worker_index,
prefetch_on_device=None):
assert isinstance(input_workers, InputWorkers)
assert worker_index is not None
assert worker_index is not True
assert worker_index is not False
self._input_workers = input_workers
self._worker_index = worker_index
# Default to using prefetching, unless specified.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = True
self._dataset = dataset
if not self._prefetch_on_device:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
num_replicas = len(
self._input_workers.compute_devices_for_worker(self._worker_index))
self._dataset = self._dataset.batch(num_replicas, drop_remainder=True)
else:
self._replica_devices = self._input_workers.compute_devices_for_worker(
self._worker_index)
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerReplicaDataset."""
# Graph mode with one shot iterator is disabled.
if not context.executing_eagerly():
raise ValueError("Cannot create a one shot iterator. Please use "
"`make_initializable_iterator()` instead.")
if self._prefetch_on_device:
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._replica_devices)
else:
dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
return PerReplicaDataIterator(
dataset_iterator,
self._input_workers,
self._worker_index,
prefetch_on_device=self._prefetch_on_device)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerReplicaDataset."""
# Eager mode generates already initialized iterators. Hence we cannot create
# an initializable iterator.
if context.executing_eagerly():
raise ValueError("Cannot create initializable iterator in Eager mode. "
"Please use `make_one_shot_iterator` instead.")
if self._prefetch_on_device:
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._replica_devices)
else:
dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset)
return PerReplicaDataIterator(
dataset_iterator, self._input_workers, self._worker_index,
prefetch_on_device=self._prefetch_on_device)
class MultiWorkerDataIterator(object):
"""An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`."""
def __init__(self, iterators, input_workers):
"""Initialize the `MultiWorkerDataIterator` object.
Args:
iterators: a list of worker, iterator pairs.
input_workers: an `InputWorkers` object.
Raises:
ValueError: if iterators and input_workers are not compatible.
"""
assert isinstance(input_workers, InputWorkers)
workers = tuple(d for d, _ in iterators)
if workers != input_workers.worker_devices:
raise ValueError("iterators and input_workers are not compatible. "
"iterator workers: %r input_workers devices: %r" %
(workers, input_workers.worker_devices))
self._iterators = tuple(i for _, i in iterators)
self._input_workers = input_workers
@property
def initializer(self):
return control_flow_ops.group(
tuple(iterator.initializer for iterator in self._iterators))
def get_iterator(self, worker):
for i, w in enumerate(self._input_workers.worker_devices):
if worker == w:
return self._iterators[i]
return None
@property
def output_shapes(self):
return self._iterators[0].output_shapes
@property
def output_types(self):
return self._iterators[0].output_types
def get_next(self, name=None):
"""Scatter the input across hosts and devices."""
replicas = []
for worker, iterator in zip(self._input_workers.worker_devices,
self._iterators):
if name is not None:
d = tf_device.DeviceSpec.from_string(worker)
new_name = "%s_%s_%d" % (name, d.job, d.task)
else:
new_name = None
with ops.device(worker):
data_per_worker = iterator.get_next_as_list(name=new_name)
# Append to replicas to get a flat list of values indexed by replica.
replicas.extend(data_per_worker)
return regroup(self._input_workers.device_map, replicas)
class MultiWorkerDataset(object):
"""Like a `tf.data.Dataset` that distributes data to different workers.
Each worker gets one shard of the input dataset. This currently does not work
in eager mode.
"""
def __init__(self, dataset_fn, input_workers, prefetch_on_device=None,
auto_shard=False):
"""Initialize the MultiWorkerDataset object.
Args:
dataset_fn: a function or a list of functions that returns a
`tf.data.Dataset`.
input_workers: an `InputWorkers` object.
prefetch_on_device: whether to prefetch to devices.
auto_shard: whether to auto-shard the dataset.
"""
assert isinstance(input_workers, InputWorkers)
if isinstance(dataset_fn, (list, tuple)):
if len(dataset_fn) != input_workers.num_workers:
raise ValueError("If `dataset_fn` is a list, it must have one entry "
"per worker")
# TODO(rohanj): b/120673685 to track re-enabling auto sharding.
if auto_shard:
raise ValueError("Currently autosharding is not supported.")
self._input_workers = input_workers
self._datasets = []
# TODO(yuefengz, priyag): support different set of jobs for input
# processing.
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
if isinstance(dataset_fn, (list, tuple)):
worker_input = dataset_fn[i]()
else:
worker_input = dataset_fn()
dataset = PerReplicaDataset(worker_input, input_workers, i,
prefetch_on_device=prefetch_on_device)
self._datasets.append((worker, dataset))
def make_one_shot_iterator(self):
iterators = []
for worker, dataset in self._datasets:
with ops.device(worker):
iterators.append((worker, dataset_ops.make_one_shot_iterator(dataset)))
return MultiWorkerDataIterator(iterators, self._input_workers)
def make_initializable_iterator(self):
iterators = []
for worker, dataset in self._datasets:
with ops.device(worker):
iterators.append(
(worker, dataset_ops.make_initializable_iterator(dataset)))
return MultiWorkerDataIterator(iterators, self._input_workers)
class InputIterator(object):
"""An input iterator, intended to be passed to `DistributionStrategy.run`."""
def get_next(self):
"""Returns the next inputs for all replicas."""
raise NotImplementedError("must be implemented in descendants")
def initialize(self):
"""Initialize the underlying input dataset, when applicable.
In eager mode, this will create a new iterator and return it.
In graph mode, this will initialize the same underlying iterator(s).
Users are required to call this if
- This iterator was returned from a call to `make_input_fn_iterator` with an
input function that returns a dataset.
- Or this iterator was returned from a call to `make_dataset_iterator`.
Returns:
A list of initialization ops to be executed.
"""
raise NotImplementedError("must be implemented in descendants")
class InputIteratorImpl(InputIterator):
"""Common implementation for all input iterators."""
def __init__(self, input_workers, iterators):
assert isinstance(input_workers, InputWorkers)
if not input_workers.worker_devices:
raise ValueError("Should have at least one worker for input iterator.")
self._iterators = iterators
self._input_workers = input_workers
def get_next(self, name=None):
"""Returns the next input from the iterator for all replicas."""
replicas = []
for i, worker in enumerate(self._input_workers.worker_devices):
if name is not None:
d = tf_device.DeviceSpec.from_string(worker)
new_name = "%s_%s_%d" % (name, d.job, d.task)
else:
new_name = None
with ops.device(worker):
# Make `replicas` a flat list of values across all replicas.
replicas.extend(self._iterators[i].get_next_as_list(new_name))
return regroup(self._input_workers.device_map, replicas)
def initialize(self):
"""Initialze underlying iterators.
Returns:
A list of any initializer ops that should be run.
"""
init_ops = []
for it in self._iterators:
init_ops.extend(it.initialize())
return init_ops
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_classes(self):
return self._iterators[0].output_classes
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_shapes(self):
return self._iterators[0].output_shapes
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_types(self):
return self._iterators[0].output_types
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
def get_iterator(self, worker):
for i, w in enumerate(self._input_workers.worker_devices):
if worker == w:
return self._iterators[i]
return None
class InputFunctionIterator(InputIteratorImpl):
"""Iterator created from input function."""
def __init__(self, input_fn, input_workers, input_contexts):
"""Make an iterator for input provided via an input function.
Currently implements PER_WORKER mode, in which the `input_fn` is called
once on each worker.
TODO(priyag): Add other replication modes.
TODO(priyag): Allow taking input function that returns a callable that
returns nest of tensors.
Args:
input_fn: Input function that returns a `tf.data.Dataset` object.
input_workers: an `InputWorkers` object.
input_contexts: A list of `InputContext` instances to be passed to call(s)
to `input_fn`. Length and order should match worker order in
`worker_device_pairs`.
"""
assert isinstance(input_workers, InputWorkers)
if input_workers.num_workers != len(input_contexts):
raise ValueError(
"Number of input workers (%d) is not same as number of "
"input_contexts (%d)" %
(input_workers.num_workers, len(input_contexts)))
iterators = []
for i, ctx in enumerate(input_contexts):
worker = input_workers.worker_devices[i]
with ops.device(worker):
result = input_fn(ctx)
if not isinstance(result, dataset_ops.DatasetV2):
raise ValueError("input_fn must return a tf.data.Dataset.")
devices = input_workers.compute_devices_for_worker(i)
iterator = _SingleWorkerDatasetIterator(result, worker, devices)
iterators.append(iterator)
super(InputFunctionIterator, self).__init__(input_workers, iterators)
class DatasetIterator(InputIteratorImpl):
"""Iterator created from input dataset."""
def __init__(self, dataset, input_workers, split_batch_by=None):
"""Make an iterator for the dataset on given devices.
If `split_batch_by` is not None, we "split" each batch of the
dataset by `split_batch_by` value. To achieve this, we first unbatch the
input dataset and then rebatch it with the per replica batch size that is
calculated using `global_batch_size // split_batch_by`.
The currently supported datasets are as follows:
`dataset.batch()` is the last operation on the dataset OR
`dataset.apply(map_and_batch)` is the last operation on the dataset OR
`dataset.batch().prefetch()` are the last 2 operations on the dataset OR
`dataset.apply(map_and_batch).prefetch()` are the last 2 operations.
TODO(priyag): Support multi worker / host cases properly by cloning
and sharding the dataset on each worker. Current setup will only work in
some cases, such as in-graph multi worker GPU case. If the input pipeline
has random shuffling (with a different seed on each worker), each worker
will see random input from the same overall dataset in each step. Otherwise,
each worker will see the same input in each step.
Args:
dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
"""
assert isinstance(input_workers, InputWorkers)
if split_batch_by:
dataset = _split_dataset_batch(dataset, split_batch_by)
iterators = []
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i)
cloned_dataset = dataset
if not context.executing_eagerly():
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
worker_devices)
iterators.append(iterator)
super(DatasetIterator, self).__init__(input_workers, iterators)
class _SingleWorkerDatasetIterator(object):
"""Iterator for a single `tf.data.Dataset`."""
def __init__(self, dataset, worker, devices):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
`MultiDeviceIterator` is used to prefetch input to the devices on the
given worker.
Args:
dataset: A `tf.data.Dataset` instance.
worker: Worker on which ops should be created.
devices: Distribute data from `dataset` to these devices.
"""
self._dataset = dataset
self._worker = worker
self._devices = devices
self._make_iterator()
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._devices)
def get_next_as_list(self, name=None):
"""Get next element from the underlying iterator."""
del name
with ops.device(self._worker):
data_list = self._iterator.get_next()
return data_list
def initialize(self):
"""Initialze underlying iterator.
In eager execution, this simply recreates the underlying iterator.
In graph execution, it returns the initializer ops for the underlying
iterator.
Returns:
A list of any initializer ops that should be run.
"""
if context.executing_eagerly():
self._make_iterator()
return []
else:
return [self._iterator.initializer]
@property
def output_classes(self):
return self._iterator.output_classes
@property
def output_shapes(self):
return self._iterator.output_shapes
@property
def output_types(self):
return self._iterator.output_types
def _split_dataset_batch(dataset, split_batch_by):
"""Divide a batch-ed dataset's batches into smaller batches."""
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
# pylint: disable=protected-access
def _get_batch_dataset(d):
"""Get the underlying batch dataset from the dataset object."""
if isinstance(d, dataset_ops.DatasetV1Adapter):
d = d._dataset
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
return d
elif isinstance(d, dataset_ops.PrefetchDataset):
return _get_batch_dataset(d._input_dataset)
raise ValueError(
"Unable to get batched dataset from the input dataset. `batch` "
"`map_and_batch` need to be the last operations on the dataset. "
"The batch operations can be followed by a prefetch.")
batched_dataset = _get_batch_dataset(dataset)
if isinstance(batched_dataset, dataset_ops.BatchDataset):
batch_size = batched_dataset._batch_size
drop_remainder = batched_dataset._drop_remainder
elif isinstance(batched_dataset, batching._MapAndBatchDataset):
batch_size = batched_dataset._batch_size_t
drop_remainder = batched_dataset._drop_remainder_t
prefetch_buffer = None
if isinstance(dataset, dataset_ops.PrefetchDataset):
prefetch_buffer = dataset._buffer_size
elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
prefetch_buffer = dataset._dataset._buffer_size
# pylint: enable=protected-access
if tensor_util.is_tensor(batch_size):
batch_size = tensor_util.constant_value(batch_size)
if tensor_util.is_tensor(drop_remainder):
drop_remainder = tensor_util.constant_value(drop_remainder)
if batch_size % split_batch_by:
raise ValueError(
"Batch size %s cannot be sharded evenly across replicas %s" % (
batch_size, split_batch_by))
new_batch_size = batch_size // split_batch_by
dataset = dataset.apply(batching.unbatch())
dataset = dataset.batch(new_batch_size, drop_remainder=drop_remainder)
if prefetch_buffer is not None:
dataset = dataset.prefetch(prefetch_buffer)
return dataset
class MultiStepContext(object):
"""A context object that can be used to capture things when running steps.
This context object is useful when running multiple steps at a time using the
`experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
function to specify which outputs to emit at what frequency. Currently it
supports capturing output from the last step, as well as capturing non tensor
outputs. In the future it will be augmented to support other use cases such
as output each N steps.
"""
def __init__(self):
"""Initialize an output context.
Returns:
A context object.
"""
self._last_step_outputs = {}
self._last_step_outputs_reduce_ops = {}
self._non_tensor_outputs = {}
@property
def last_step_outputs(self):
"""A dictionary consisting of outputs to be captured on last step.
Keys in the dictionary are names of tensors to be captured, as specified
when `set_last_step_output` is called.
Values in the dictionary are the tensors themselves. If
`set_last_step_output` was called with a `reduce_op` for this output,
then the value is the reduced value.
Returns:
A dictionary with last step outputs.
"""
return self._last_step_outputs
def _set_last_step_outputs(self, outputs):
"""Replace the entire dictionary of last step outputs."""
if not isinstance(outputs, dict):
raise ValueError("Need a dictionary to set last_step_outputs.")
self._last_step_outputs = outputs
def set_last_step_output(self, name, output, reduce_op=None):
"""Set `output` with `name` to be outputted from the last step.
Args:
name: String, name to identify the output. Doesn't need to match tensor
name.
output: The tensors that should be outputted with `name`. See below for
actual types supported.
reduce_op: Reduction method to use to reduce outputs from multiple
replicas. Required if `set_last_step_output` is called in a replica
context. Optional in cross_replica_context.
When present, the outputs from all the replicas are reduced using the
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and reduction is set, output
must be a `PerReplica` value.
The reduce method is also recorded in a dictionary
`_last_step_outputs_reduce_ops` for later interpreting of the
outputs as already reduced or not.
"""
if distribution_strategy_context.in_cross_replica_context():
self._last_step_outputs_reduce_ops[name] = reduce_op
if reduce_op is None:
self._last_step_outputs[name] = output
else:
distribution = distribution_strategy_context.get_distribution_strategy()
self._last_step_outputs[name] = distribution.reduce(reduce_op, output)
else:
assert reduce_op is not None
def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce(reduce_op, value)
# Setting this inside the `merge_fn` because all replicas share the same
# context object, so it's more robust to set it only once (even if all
# the replicas are trying to set the same value).
self._last_step_outputs_reduce_ops[name] = reduce_op
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
@property
def non_tensor_outputs(self):
"""A dictionary consisting of any non tensor outputs to be captured."""
return self._non_tensor_outputs
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
if distribution_strategy_context.in_cross_replica_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as reduction doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
def value_container(val):
"""Returns the container that this per-replica `value` belongs to.