Internal change

PiperOrigin-RevId: 191024677
This commit is contained in:
Yuefeng Zhou 2018-03-29 19:56:47 -07:00 committed by TensorFlower Gardener
parent 566f9041e1
commit 9451c12d62
11 changed files with 587 additions and 127 deletions

View File

@ -360,8 +360,10 @@ py_library(
":cross_tower_utils",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:device_lib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python/eager:context",
"@six_archive//:six",

View File

@ -22,6 +22,7 @@ import six
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.client import device_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -231,7 +232,7 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps):
def _group_value_by_device(per_device_values):
"""Group values into sublists by their devices.
This grouping is needed to call the allreduce library.
This grouping is needed to call the all-reduce library.
Args:
per_device_values: a list of PerDevice obejcts.
@ -251,10 +252,20 @@ def _group_value_by_device(per_device_values):
def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
"""Ungroup results from allreduce and make Mirrored objects.
"""Ungroup results from all-reduce and make Mirrored objects.
Each allreduce result would be divided by the number of destinations before
Each all-reduce result will be divided by the number of destinations before
Mirrored objects are created if method_string is "mean".
Args:
grouped_reduced: a list of lists, each sublist has components for each
device, paired with a None. It is the result from
cross_tower_utils.aggregate_gradients_using*.
destinations: a list of device strings for returned Mirrored objects.
method_string: "mean" or "sum".
Returns:
a list of Mirrored objects.
"""
index = [{} for _ in range(len(grouped_reduced[0]))]
for d, per_device_reduced in enumerate(grouped_reduced):
@ -266,23 +277,171 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
return [value_lib.Mirrored(v) for v in index]
class ConcatAndSplitPacker(object):
"""Concatenate and split tensors for reduction."""
def __init__(self, num_packs=1):
"""Initialize the ConcatAndSplitPacker object.
Args:
num_packs: specifies the number of split packs that will be
formed.
Raises:
ValueError: if num_packs is not greater than 0.
"""
if num_packs <= 0:
raise ValueError("num_packs must be greater than zero.")
self.num_packs = num_packs
def pack(self, grouped_grads_and_vars):
"""Pack tensors."""
self.grouped_grads_and_vars = grouped_grads_and_vars
self.all_tower_shapes = []
self.all_tower_sizes = []
device_grad_packs = []
for tower_grads_and_vars in grouped_grads_and_vars:
with ops.colocate_with(tower_grads_and_vars[0][0]):
# Flatten all the grads.
flat_grads = [
array_ops.reshape(g, [-1]) for g, _ in tower_grads_and_vars
]
# Remember the original shape of all the grads.
tower_shapes = [array_ops.shape(g) for g, _ in tower_grads_and_vars]
# Remember the original sizes of all the grads.
tower_sizes = [array_ops.size(g) for g, _ in tower_grads_and_vars]
# Concat all the flat grads into a big flat tensor.
concat_grads = array_ops.concat(flat_grads, 0)
# Split the big tensor into num_splits packs. In cases where the
# total size is not divisible num_splits, the last pack gets
# more elements.
# TODO(zhengxq): it is also possible to optimize away all the concat
# as well.
num_splits = self.num_packs
total_grad_size = array_ops.size(concat_grads)
split_size = total_grad_size // num_splits
split_size_last = total_grad_size - split_size * (num_splits - 1)
split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
grad_packs = array_ops.split(concat_grads, split_sizes)
# Ready to aggregate the repacked gradients, with fake variables.
# TODO(zhengxq): It is hacky to have to use fake variables.
# We should remove the need for variables in
# aggregate_gradients_using*.
device_grad_packs.append(zip(grad_packs, [None] * num_splits))
self.all_tower_shapes.append(tower_shapes)
self.all_tower_sizes.append(tower_sizes)
return device_grad_packs
def unpack(self, summed_device_grad_packs):
"""Reverse the pack."""
aggregated_device_grads = []
for (summed_tower_grad_packs,
tower_grads_and_vars, tower_shapes, tower_sizes) in zip(
summed_device_grad_packs, self.grouped_grads_and_vars,
self.all_tower_shapes, self.all_tower_sizes):
# pylint: enable=line-too-long
# Reverse the packing operations in the previous steps. Form the
# summed gradients back into their original shapes.
with ops.colocate_with(summed_tower_grad_packs[0][0]):
# Form a list of the summed grad packs.
device_grad_packs = [g for g, _ in summed_tower_grad_packs]
# Concat them back into a big flat tensor.
device_grads_concat = array_ops.concat(device_grad_packs, 0)
# Split the tensors back into their original sizes.
grads_with_sizes = array_ops.split(device_grads_concat, tower_sizes)
# Reshape the tensors back into their original shapes.
grads_with_shapes = [
array_ops.reshape(grad, shape)
for shape, grad in zip(tower_shapes, grads_with_sizes)
]
# Form the list with the original list of variables.
summed_tower_grads = [
(g, v) for g, (_, v) in zip(grads_with_shapes, tower_grads_and_vars)
]
aggregated_device_grads.append(summed_tower_grads)
return aggregated_device_grads
class AggregateSmallTensorPacker(object):
"""Concatenate small gradient tensors together for reduction."""
def __init__(self,
agg_small_grads_max_bytes=1048576,
agg_small_grads_max_group=16):
"""Initialize the AggregateSmallTensorPacker object.
Args:
agg_small_grads_max_bytes: largest tensor eligible for aggregation,
in number of bytes.
agg_small_grads_max_group: largest permitted aggregation of small
tensors.
Raises:
ValueError: if `agg_small_grads_max_bytes` or `agg_small_grads_max_group`
is not greater than 0.
"""
if agg_small_grads_max_bytes <= 0 or agg_small_grads_max_group <= 0:
raise ValueError("agg_small_grads_max_bytes and agg_small_grads_max_group"
" should both be greater than zero.")
self.agg_small_grads_max_bytes = agg_small_grads_max_bytes
self.agg_small_grads_max_group = agg_small_grads_max_group
def pack(self, grouped_grads_and_vars):
"""Aggregate small tensors."""
if (self.agg_small_grads_max_bytes > 0 and
self.agg_small_grads_max_group > 0):
tower_grads, self.packing = cross_tower_utils.pack_small_tensors(
grouped_grads_and_vars,
max_bytes=self.agg_small_grads_max_bytes,
max_group=self.agg_small_grads_max_group)
return tower_grads
def unpack(self, summed_device_grad_packs):
"""Reverse the aggregation process."""
return cross_tower_utils.unpack_small_tensors(summed_device_grad_packs,
self.packing)
class AllReduceCrossTowerOps(CrossTowerOps):
"""Reduction using all reduce."""
def __init__(self, all_reduce_alg="nccl", gradient_repacking=1):
"""Initialize this subclass of CrossTowerOps with allreduce.
def __init__(self,
all_reduce_alg="nccl",
num_packs=1,
agg_small_grads_max_bytes=0,
agg_small_grads_max_group=10):
"""All-reduce implementation of CrossTowerOps.
Gradients would be repacked for more efficient cross-device transportation.
Before performing all-reduce, tensors will be repacked or aggregated for
more efficient cross-device transportation:
1) If `num_packs` is non-zero, pack values into
`num_packs` splits.
2) Otherwise, if `agg_small_grads_max_bytes` > 0 and
`agg_small_grads_max_group` > 0, aggregate values smaller than
`agg_small_grads_max_bytes` into groups with at most
`agg_small_grads_max_group` values.
3) Otherwise, no repacking or grouping will happen.
Args:
all_reduce_alg: the allreduce algorithm to use, currently only "nccl" or
all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
"hierarchical_copy" are supported.
gradient_repacking: If zero, no gradient repacking would be done. If
non-zero value it specifies the number of split packs that will be
formed.
num_packs: see above.
agg_small_grads_max_bytes: see above.
agg_small_grads_max_group: see above.
tensors.
"""
self.all_reduce_alg = all_reduce_alg
self.gradient_repacking = gradient_repacking
self.num_packs = num_packs
self.agg_small_grads_max_bytes = agg_small_grads_max_bytes
self.agg_small_grads_max_group = agg_small_grads_max_group
super(AllReduceCrossTowerOps, self).__init__()
def _reduce(self, method_string, per_device_value, destinations):
@ -312,99 +471,115 @@ class AllReduceCrossTowerOps(CrossTowerOps):
def _batch_all_reduce(self, method_string, per_device_values):
"""All reduce algorithm in a batch."""
logging.info("batch_all_reduce invoked for batches size = %d with algorithm"
" = %s and gradient repacking = %d", len(per_device_values),
self.all_reduce_alg, self.gradient_repacking)
destinations = per_device_values[0].devices
grouped = _group_value_by_device(per_device_values)
if self.gradient_repacking == 0:
if self.all_reduce_alg == "nccl":
reduced = cross_tower_utils.aggregate_gradients_using_nccl(grouped)
else:
# TODO(yuefengz): check that gpu ids in `destinations` are in ascending
# order.
reduced = (
cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
destinations, grouped))
if self.num_packs > 0:
logging.info(
"batch_all_reduce invoked for batches size = %d with "
"algorithm = %s and num_packs = %d", len(per_device_values),
self.all_reduce_alg, self.num_packs)
tensor_packer = ConcatAndSplitPacker(self.num_packs)
device_grad_packs = tensor_packer.pack(grouped)
elif (self.agg_small_grads_max_bytes > 0 and
self.agg_small_grads_max_group > 0):
logging.info(
"batch_all_reduce invoked for batches size = %d with "
"algorithm = %s, agg_small_grads_max_bytes = %d and "
"agg_small_grads_max_group = %d", len(per_device_values),
self.all_reduce_alg, self.agg_small_grads_max_bytes,
self.agg_small_grads_max_group)
tensor_packer = AggregateSmallTensorPacker(100, 10)
device_grad_packs = tensor_packer.pack(grouped)
else:
device_grad_packs = []
all_tower_shapes = []
all_tower_sizes = []
for tower_grads_and_vars in grouped:
with ops.colocate_with(tower_grads_and_vars[0][0]):
# Flatten all the grads.
flat_grads = [
array_ops.reshape(g, [-1]) for g, _ in tower_grads_and_vars
]
# Remember the original shape of all the grads.
tower_shapes = [array_ops.shape(g) for g, _ in tower_grads_and_vars]
# Remember the original sizes of all the grads.
tower_sizes = [array_ops.size(g) for g, _ in tower_grads_and_vars]
# Concat all the flat grads into a big flat tensor.
concat_grads = array_ops.concat(flat_grads, 0)
logging.info(
"batch_all_reduce invoked for batches size = %d with algorithm = %s",
len(per_device_values), self.all_reduce_alg)
tensor_packer = None
device_grad_packs = grouped
# Split the big tensor into num_splits packs. In cases where the
# total size is not divisible num_splits, the last pack gets
# more elements.
# TODO(zhengxq): it is possible to optimize away the additional
# data movement by copying along the original variable boundary.
# TODO(zhengxq): it is also possible to optimize away all the concat
# as well.
num_splits = self.gradient_repacking
total_grad_size = array_ops.size(concat_grads)
split_size = total_grad_size // num_splits
split_size_last = total_grad_size - split_size * (num_splits - 1)
split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
grad_packs = array_ops.split(concat_grads, split_sizes)
# The actual aggregation of the repacked gradients. Note that they are
# sharded among different aggregation trees. So it is important to strike
# the balance on num_splits.
if self.all_reduce_alg == "nccl":
reduced = cross_tower_utils.aggregate_gradients_using_nccl(
device_grad_packs)
else:
# TODO(yuefengz): check that gpu ids in `destinations` are in ascending
# order.
reduced = (
cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
destinations, device_grad_packs))
# Ready to aggregate the repacked gradients, with fake variables.
# TODO(zhengxq): It is hacky to have to use fake variables.
# We should remove the need for variables in
# aggregate_gradients_using*.
device_grad_packs.append(zip(grad_packs, [None] * num_splits))
all_tower_shapes.append(tower_shapes)
all_tower_sizes.append(tower_sizes)
if tensor_packer:
reduced = tensor_packer.unpack(reduced)
# The actual aggregation of the repacked gradients. Note that they are
# sharded among different aggregation trees. So it is important to
# strike the balance on num_splits.
if self.all_reduce_alg == "nccl":
summed_device_grad_packs = (
cross_tower_utils.aggregate_gradients_using_nccl(device_grad_packs))
else:
summed_device_grad_packs = (
cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
destinations, device_grad_packs))
aggregated_device_grads = []
for (summed_tower_grad_packs, tower_grads_and_vars, tower_shapes,
tower_sizes) in zip(summed_device_grad_packs, grouped,
all_tower_shapes, all_tower_sizes):
# pylint: enable=line-too-long
# Reverse the packing operations in the previous steps. Form the
# summed gradients back into their original shapes.
with ops.colocate_with(summed_tower_grad_packs[0][0]):
# Form a list of the summed grad packs.
device_grad_packs = [g for g, _ in summed_tower_grad_packs]
# Concat them back into a big flat tensor.
device_grads_concat = array_ops.concat(device_grad_packs, 0)
# Split the tensors back into their original sizes.
grads_with_sizes = array_ops.split(device_grads_concat, tower_sizes)
# Reshape the tensors back into their original shapes.
grads_with_shapes = [
array_ops.reshape(grad, shape)
for shape, grad in zip(tower_shapes, grads_with_sizes)
]
# Form the list with the original list of variables.
summed_tower_grads = [
(g, v)
for g, (_, v) in zip(grads_with_shapes, tower_grads_and_vars)
]
aggregated_device_grads.append(summed_tower_grads)
reduced = aggregated_device_grads
return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices,
method_string)
_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
def _has_dgx1_like_links(gpu_links):
if not gpu_links:
return False
# TODO(yuefengz): figure out the right topology for hierarchial copy if
# number of gpus are less than 8.
if len(gpu_links) < 8:
return False
for i, (gpu_link, dgx1_link) in enumerate(zip(gpu_links, _dgx1_links)):
if (set(gpu_link) != set(dgx1_link) and
set(gpu_link) != set(dgx1_link + [i])):
return False
return True
def _choose_all_reduce_algorithm(device_links):
if _has_dgx1_like_links(device_links):
logging.info("Configured hierarchical_copy with num_packs=%d",
len(device_links))
return AllReduceCrossTowerOps(
"hierarchical_copy", num_packs=len(device_links))
else:
logging.info("Configured nccl all-reduce.")
return AllReduceCrossTowerOps("nccl", num_packs=1)
def choose_the_best(devices, session_config=None):
"""Find the best subclass of CrossTowerOps given a tensorflow session.
Args:
devices: a list of devices passed for distribute strategy.
session_config: a tensorflow session config or None. If None, it will make
deciesion based on all local devices.
Returns:
a subclass of CrossTowerOps.
"""
requested_devices = set([device_util.canonicalize(d) for d in devices])
machine_devices = device_lib.list_local_devices(session_config=session_config)
using_devices = []
for d in machine_devices:
if device_util.canonicalize(d.name) in requested_devices:
using_devices.append(d)
else:
logging.info(
"Device is available but not used by distribute strategy: %s", d.name)
if len(using_devices) != len(requested_devices):
logging.warning("Not all devices in distribute strategy are visible by "
"TensorFlow sessions.")
return ReductionToOneDeviceCrossTowerOps()
if any([d.device_type.lower() != "gpu" for d in using_devices]):
logging.warning("Not all devices in DistributionStrategy are visible to "
"TensorFlow session.")
return ReductionToOneDeviceCrossTowerOps()
device_links = [[] for _ in range(len(using_devices))]
for i, device in enumerate(using_devices):
for link in device.locality.links.link:
device_links[i].append(link.device_id)
return _choose_all_reduce_algorithm(device_links)

View File

@ -101,22 +101,22 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
mode=["graph", "eager"])
allreduce_combinations = combinations.combine(
cross_tower_ops=[
combinations.NamedObject("AllReduce",
cross_tower_ops_lib.AllReduceCrossTowerOps(
"nccl", 1)),
combinations.NamedObject("HierarchicalCopy",
cross_tower_ops_lib.AllReduceCrossTowerOps(
"hierarchical_copy", 8)),
combinations.NamedObject("AllReduceNoGradientRepacking",
cross_tower_ops_lib.AllReduceCrossTowerOps(
"nccl", 0)),
combinations.NamedObject("HierarchicalCopyNoGradientRepacking",
cross_tower_ops_lib.AllReduceCrossTowerOps(
"hierarchical_copy", 0))
],
distribution=[
combinations.mirrored_strategy_with_two_gpus
combinations.NamedObject(
"AllReduce",
cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)),
combinations.NamedObject(
"HierarchicalCopy",
cross_tower_ops_lib.AllReduceCrossTowerOps(
"hierarchical_copy", 8, 0, 0)),
combinations.NamedObject(
"AllReduceNoGradientRepacking",
cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)),
combinations.NamedObject(
"HierarchicalCopyAggregateSmallTensors",
cross_tower_ops_lib.AllReduceCrossTowerOps(
"hierarchical_copy", 0, 100, 10))
],
distribution=[combinations.mirrored_strategy_with_two_gpus],
mode=["graph", "eager"])
@combinations.generate(reduction_to_one_combinations + allreduce_combinations)
@ -180,6 +180,42 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
_fake_mirrored(1., destinations))
def testChooseAlgorithm(self):
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertTrue(
isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
self.assertEqual(result.num_packs, 8)
# if there are only 4 devices
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertTrue(
isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
self.assertEqual(result.all_reduce_alg, "nccl")
self.assertEqual(result.num_packs, 1)
# if devices links contain each device itself
device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6],
[0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7],
[2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertTrue(
isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
self.assertEqual(result.num_packs, 8)
# if not dgx1-like links
device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
[1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertTrue(
isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
self.assertEqual(result.all_reduce_alg, "nccl")
self.assertEqual(result.num_packs, 1)
if __name__ == "__main__":
test.main()

View File

@ -18,7 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections as pycoll
from tensorflow.contrib import nccl
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -151,3 +154,186 @@ def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
return (grad, v), has_nan_or_inf
else:
return (grad, v), None
def extract_ranges(index_list, range_size_limit=32):
"""Extract consecutive ranges and singles from index_list.
Args:
index_list: List of monotone increasing non-negative integers.
range_size_limit: Largest size range to return. If a larger
consecutive range exists, it will be returned as multiple
ranges.
Returns:
(ranges, singles) where ranges is a list of [first, last] pairs of
consecutive elements in index_list, and singles is all of the
other elements, in original order.
"""
if not index_list:
return [], []
first = index_list[0]
last = first
ranges = []
singles = []
for i in index_list[1:]:
if i == last + 1 and (last - first) <= range_size_limit:
last = i
else:
if last > first:
ranges.append([first, last])
else:
singles.append(first)
first = i
last = i
if last > first:
ranges.append([first, last])
else:
singles.append(first)
return ranges, singles
GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes')
def pack_range(key, packing, grad_vars, rng):
"""Form the concatenation of a specified range of gradient tensors.
Args:
key: Value under which to store meta-data in packing that will be used
later to restore the grad_var list structure.
packing: Dict holding data describing packed ranges of small tensors.
grad_vars: List of (grad, var) pairs for one tower.
rng: A pair of integers giving the first, last indices of a consecutive
range of tensors to be packed.
Returns:
A tensor that is the concatenation of all the specified small tensors.
"""
to_pack = grad_vars[rng[0]:rng[1] + 1]
members = []
variables = []
restore_shapes = []
with ops.name_scope('pack'):
for g, v in to_pack:
variables.append(v)
restore_shapes.append(g.shape)
with ops.device(g.device):
members.append(array_ops.reshape(g, [-1]))
packing[key] = GradPackTuple(
indices=range(rng[0], rng[1] + 1),
vars=variables,
shapes=restore_shapes)
with ops.device(members[0].device):
return array_ops.concat(members, 0)
def unpack_grad_tuple(gv, gpt):
"""Unpack a previously packed collection of gradient tensors.
Args:
gv: A (grad, var) pair to be unpacked.
gpt: A GradPackTuple describing the packing operation that produced gv.
Returns:
A list of (grad, var) pairs corresponding to the values that were
originally packed into gv, maybe following subsequent operations like
reduction.
"""
elt_widths = [x.num_elements() for x in gpt.shapes]
with ops.device(gv[0][0].device):
with ops.name_scope('unpack'):
splits = array_ops.split(gv[0], elt_widths)
unpacked_gv = []
for idx, s in enumerate(splits):
unpacked_gv.append((array_ops.reshape(s, gpt.shapes[idx]),
gpt.vars[idx]))
return unpacked_gv
def pack_small_tensors(tower_grads, max_bytes=0, max_group=0):
"""Concatenate small gradient tensors together for reduction.
Args:
tower_grads: List of lists of (gradient, variable) tuples.
max_bytes: Int giving max number of bytes in a tensor that
may be considered small.
max_group: Int giving max number of small tensors that may be
concatenated into one new tensor.
Returns:
new_tower_grads, packing where new_tower_grads is identical to
tower_grads except that all feasible small_tensors have been removed
from their places and concatenated into larger tensors that are
now in the front of the list for each tower, and packing contains
the data necessary to restore the tower_grads structure.
Look through the first tower for gradients of the same type (float),
and small size, that are all sequential. For each such group,
replace by a new tensor that is a flattened concatenation. Note
that the corresponding variable will be absent, which doesn't matter
because it isn't used during all-reduce.
Requires:
Every gv_list in towers must have isomorphic structure including identical
tensor sizes and types.
"""
small_indices = []
large_indices = []
for idx, (g, _) in enumerate(tower_grads[0]):
if g.dtype == dtypes.float32 and (4 * g.shape.num_elements()) <= max_bytes:
small_indices.append(idx)
else:
large_indices.append(idx)
small_ranges, small_singles = extract_ranges(
small_indices, range_size_limit=max_group)
large_indices = sorted(large_indices + small_singles)
num_gv = len(tower_grads[0])
packing = {}
if small_ranges:
new_tower_grads = []
for dev_idx, gv_list in enumerate(tower_grads):
assert len(gv_list) == num_gv
new_gv_list = []
for r in small_ranges:
key = '%d:%d' % (dev_idx, len(new_gv_list))
new_gv_list.append((pack_range(key, packing, gv_list, r),
'packing_var_placeholder'))
for i in large_indices:
new_gv_list.append(gv_list[i])
new_tower_grads.append(new_gv_list)
return new_tower_grads, packing
else:
return tower_grads, None
def unpack_small_tensors(tower_grads, packing):
"""Undo the structure alterations to tower_grads done by pack_small_tensors.
Args:
tower_grads: List of List of (grad, var) tuples.
packing: A dict generated by pack_small_tensors describing the changes
it made to tower_grads.
Returns:
new_tower_grads: identical to tower_grads except that concatentations
of small tensors have been split apart and returned to their original
positions, paired with their original variables.
"""
if not packing:
return tower_grads
new_tower_grads = []
num_devices = len(tower_grads)
num_packed = len(packing.keys()) // num_devices
for dev_idx, gv_list in enumerate(tower_grads):
gv_list = list(gv_list)
new_gv_list = gv_list[num_packed:]
for i in xrange(0, num_packed):
k = '%d:%d' % (dev_idx, i)
gpt = packing[k]
gv = unpack_grad_tuple(gv_list[i], gpt)
for gi, idx in enumerate(gpt.indices):
assert idx == gpt.indices[gi]
new_gv_list.insert(idx, gv[gi])
new_tower_grads.append(new_gv_list)
return new_tower_grads

View File

@ -78,9 +78,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
[device_util.canonicalize(d) for d in devices])
self._device_index = values.PerDevice(
dict((d, i) for i, d in enumerate(devices)))
self.cross_tower_ops = (
cross_tower_ops or
cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps())
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
def _create_variable(self, next_creator, *args, **kwargs):
@ -149,7 +147,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _broadcast(self, tensor, destinations):
# TODO(josh11b): In eager mode, use one thread per device, or async mode.
return self.cross_tower_ops.broadcast(tensor, destinations or self._devices)
return self._get_cross_tower_ops().broadcast(tensor, destinations or
self._devices)
def _call_for_each_tower(self, fn, *args, **kwargs):
"""Run `fn` in separate threads, once per tower/worker device.
@ -272,16 +271,28 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# in addition to PerDevice data.
return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
def configure(self, session_config=None):
if self._cross_tower_ops is None:
self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
self._devices, session_config=session_config)
def _get_cross_tower_ops(self):
if self._cross_tower_ops is None:
self._cross_tower_ops = (
cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps())
return self._cross_tower_ops
def _reduce(self, method_string, value, destinations):
if len(self._devices) == 1 and not isinstance(value, values.PerDevice):
value = values.PerDevice({self._devices[0]: value})
assert isinstance(value, values.PerDevice)
return self.cross_tower_ops.reduce(
return self._get_cross_tower_ops().reduce(
method_string, value, destinations=destinations)
def _batch_reduce(self, method_string, value_destination_pairs):
return self.cross_tower_ops.batch_reduce(method_string,
value_destination_pairs)
return self._get_cross_tower_ops().batch_reduce(method_string,
value_destination_pairs)
def _update(self, var, fn, *args, **kwargs):
# TODO(josh11b): Also support TowerLocalVariables here? If so, args and

View File

@ -3253,6 +3253,7 @@ cuda_py_tests(
":client_testlib",
":framework_test_lib",
":platform_test",
"//tensorflow/core:protos_all_py",
],
)

View File

@ -15,19 +15,39 @@ limitations under the License.
%include "tensorflow/python/platform/base.i"
%typemap(in) const tensorflow::ConfigProto& (tensorflow::ConfigProto temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The ConfigProto could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%{
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace swig {
static std::vector<string> ListDevices(TF_Status* out_status) {
static std::vector<string> ListDevicesWithSessionConfig(
const tensorflow::ConfigProto& config, TF_Status* out_status) {
std::vector<string> output;
SessionOptions options;
options.config = config;
std::vector<Device*> devices;
Status status = DeviceFactory::AddDevices(
options, "" /* name_prefix */, &devices);
@ -35,7 +55,8 @@ static std::vector<string> ListDevices(TF_Status* out_status) {
Set_TF_Status_from_Status(out_status, status);
}
std::vector<std::unique_ptr<Device>> device_holder(devices.begin(), devices.end());
std::vector<std::unique_ptr<Device>> device_holder(devices.begin(),
devices.end());
for (const Device* device : devices) {
const DeviceAttributes& attr = device->attributes();
@ -53,6 +74,11 @@ static std::vector<string> ListDevices(TF_Status* out_status) {
return output;
}
std::vector<string> ListDevices(TF_Status* out_status) {
tensorflow::ConfigProto session_config;
return ListDevicesWithSessionConfig(session_config, out_status);
}
} // namespace swig
} // namespace tensorflow
@ -62,21 +88,28 @@ static std::vector<string> ListDevices(TF_Status* out_status) {
%unignore tensorflow;
%unignore tensorflow::swig;
%unignore tensorflow::swig::ListDevicesWithSessionConfig;
%unignore tensorflow::swig::ListDevices;
// Wrap this function
namespace tensorflow {
namespace swig {
std::vector<string> ListDevices(TF_Status* out_status);
static std::vector<string> ListDevicesWithSessionConfig(
const tensorflow::ConfigProto& config, TF_Status* out_status);
} // namespace swig
} // namespace tensorflow
%insert("python") %{
def list_devices():
def list_devices(session_config=None):
from tensorflow.python.framework import errors
with errors.raise_exception_on_not_ok_status() as status:
return ListDevices(status)
if session_config:
return ListDevicesWithSessionConfig(session_config.SerializeToString(),
status)
else:
return ListDevices(status)
%}
%unignoreall

View File

@ -22,9 +22,12 @@ from tensorflow.core.framework import device_attributes_pb2
from tensorflow.python import pywrap_tensorflow
def list_local_devices():
def list_local_devices(session_config=None):
"""List the available devices available in the local process.
Args:
session_config: a session config proto or None to use the default config.
Returns:
A list of `DeviceAttribute` protocol buffers.
"""
@ -33,4 +36,7 @@ def list_local_devices():
m.ParseFromString(pb_str)
return m
return [_convert(s) for s in pywrap_tensorflow.list_devices()]
return [
_convert(s)
for s in pywrap_tensorflow.list_devices(session_config=session_config)
]

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import device_lib
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
@ -31,6 +32,10 @@ class DeviceLibTest(test_util.TensorFlowTestCase):
self.assertGreater(len(devices), 0)
self.assertEqual(devices[0].device_type, "CPU")
devices = device_lib.list_local_devices(config_pb2.ConfigProto())
self.assertGreater(len(devices), 0)
self.assertEqual(devices[0].device_type, "CPU")
# GPU test
if test.is_gpu_available():
self.assertGreater(len(devices), 1)

View File

@ -859,6 +859,7 @@ class Estimator(object):
saving_listeners)
def _train_model_distributed(self, input_fn, hooks, saving_listeners):
self._distribution.configure(self._session_config)
worker_hooks = []
with ops.Graph().as_default() as g:
with self._distribution.scope():

View File

@ -1105,6 +1105,10 @@ class _DefaultDistributionStrategy(DistributionStrategy):
# in contrib.
return dataset.make_one_shot_iterator()
def configure(self, session_config=None):
"""Find the best configuration given a tensorflow session config."""
del session_config
def _broadcast(self, tensor, destinations):
if destinations is None:
return tensor