936 lines
37 KiB
Python
936 lines
37 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ===================================================================
|
|
|
|
"""Helper library for handling infeed between hosts and TPUs.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import itertools
|
|
|
|
import numpy as np
|
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
|
|
|
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.tpu import tpu
|
|
from tensorflow.python.tpu import tpu_sharding
|
|
from tensorflow.python.tpu.ops import tpu_ops
|
|
|
|
from tensorflow.python.util import nest
|
|
|
|
|
|
def partition_or_replicate_on_host(tensor, dims):
|
|
"""Partitions or replicates the input tensor.
|
|
|
|
The ops inside this function are placed on the host side.
|
|
|
|
Args:
|
|
tensor: The input tensor which will be partitioned or replicated.
|
|
dims: A list of integer describes how to partition the input tensor.
|
|
|
|
Returns:
|
|
An iterator of `Tensor`s or a list of partitioned tensors.
|
|
"""
|
|
if dims is None:
|
|
return itertools.repeat(tensor)
|
|
dims = np.array(dims)
|
|
output = [tensor]
|
|
shape_list = np.array(tensor.shape.as_list())
|
|
quotients, remainders = np.divmod(shape_list, dims)
|
|
for axis, (quotient, remainder, dim, original_size) in enumerate(
|
|
zip(quotients, remainders, dims, shape_list)):
|
|
if dim <= 1:
|
|
continue
|
|
if remainder > 0:
|
|
# For each dimension, when it cannot be evenly partitioned, XLA assumes
|
|
# tensors are partitioned in a greedy manner by using
|
|
# ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims
|
|
# are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] =>
|
|
# [[(3, 4), (3, 4), (2, 4), (2, 2)],
|
|
# [(2, 4), (2, 4), (2, 4), (2, 2)]]
|
|
ceil_ratio = quotient + 1
|
|
num_full_slots, left_over = np.divmod(original_size, ceil_ratio)
|
|
num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over]
|
|
if len(num_or_size_splits) < dim:
|
|
num_or_size_splits += [0] * (dim - len(num_or_size_splits))
|
|
new_output = []
|
|
for x in output:
|
|
new_output.append(
|
|
array_ops.split(
|
|
x, num_or_size_splits=num_or_size_splits, axis=axis))
|
|
output = new_output
|
|
else:
|
|
output = [array_ops.split(x, int(dim), axis=axis) for x in output]
|
|
output = nest.flatten(output)
|
|
return output
|
|
|
|
|
|
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
|
|
"""Tags appropriate XLA sharding attribute to the dequeued tensor.
|
|
|
|
The sharding attribute of the dequeued tensor will be a tuple.
|
|
|
|
Args:
|
|
tensor: The dequeued tensor on TPU.
|
|
dims: A list of integer describes how the tensor is partitioned.
|
|
|
|
Returns:
|
|
The same tensor with the xla_sharding attribute.
|
|
"""
|
|
if dims is None:
|
|
return xla_sharding.replicate(tensor, assign_tuple_sharding=True)
|
|
elif np.prod(dims) == 1:
|
|
return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True)
|
|
else:
|
|
tile_assignment = np.arange(np.prod(dims)).reshape(dims)
|
|
return xla_sharding.tile(
|
|
tensor=tensor,
|
|
tile_assignment=tile_assignment,
|
|
assign_tuple_sharding=True)
|
|
|
|
|
|
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
|
|
"""Tags appropriate XLA sharding attribute to the dequeued tensors.
|
|
|
|
Args:
|
|
dequeues: A list of dequeued tensors on TPU.
|
|
dims: A list of integer describes how the tensor is partitioned.
|
|
|
|
Returns:
|
|
The same dequeues with appropriate xla_sharding attribute.
|
|
"""
|
|
nest.assert_shallow_structure(dequeues, dims)
|
|
return nest.map_structure_up_to(
|
|
dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
|
|
|
|
|
|
class InfeedQueue(object):
|
|
"""A helper object to build a device infeed queue.
|
|
|
|
The InfeedQueue builds the host-side and device-side Ops to enqueue and
|
|
dequeue elements, respectively, and ensures that their types and
|
|
shapes match.
|
|
"""
|
|
|
|
def __init__(self,
|
|
number_of_tuple_elements=None,
|
|
tuple_types=None,
|
|
tuple_shapes=None,
|
|
shard_dimensions=None,
|
|
name=None):
|
|
"""Creates a new InfeedQueue with the given configuration.
|
|
|
|
The configuration need not be fully specified at creation since it
|
|
can be modified subsequently by methods that set the values
|
|
explicitly or infer them from the shapes of inputs.
|
|
|
|
Args:
|
|
number_of_tuple_elements: the number of Tensors fed atomically through the
|
|
queue, must be present unless it can be inferred from other arguments.
|
|
tuple_types: if not None, a list of types of the elements of the queue.
|
|
tuple_shapes: if not None, a list of shapes of the elements of the queue.
|
|
shard_dimensions: if not None, a list of dimensions on which the
|
|
elements of the queue should be sharded during automatic
|
|
parallelization.
|
|
name: the name of the queue.
|
|
|
|
Raises:
|
|
ValueError: if number_of_tuple_elements <= 0; or
|
|
number_of_tuple_arguments, tuple_types, tuple_shapes, and
|
|
shard_dimensions are all None; or the length of tuple_types,
|
|
tuple_shapes, or shard_dimensions is not equal to
|
|
number_of_tuple_elements; or any element of shard_dimensions
|
|
can't be converted to a Dimension.
|
|
TypeError: if any element of tuple_types or tuple_shapes can't
|
|
be converted to a dtype or TensorShape, respectively.
|
|
"""
|
|
self._frozen = False
|
|
self._generated_enqueue_ops = False
|
|
self._generated_dequeue_op = False
|
|
self._name = "InfeedQueue" if name is None else name
|
|
if number_of_tuple_elements is None:
|
|
if tuple_types is not None:
|
|
number_of_tuple_elements = len(tuple_types)
|
|
elif tuple_shapes is not None:
|
|
number_of_tuple_elements = len(tuple_shapes)
|
|
elif shard_dimensions is not None:
|
|
number_of_tuple_elements = len(shard_dimensions)
|
|
else:
|
|
raise ValueError(
|
|
"number of tuple elements cannot be inferred from InfeedQueue "
|
|
"constructor")
|
|
if number_of_tuple_elements <= 0:
|
|
raise ValueError("number_of_tuple_elements %d must be > 0" %
|
|
number_of_tuple_elements)
|
|
# Make an empty sharding policy for each tuple element.
|
|
self._sharding_policies = [
|
|
tpu_sharding.ShardingPolicy()
|
|
for _ in xrange(number_of_tuple_elements)
|
|
]
|
|
if tuple_types is not None:
|
|
self.set_tuple_types(tuple_types)
|
|
else:
|
|
self._tuple_types = None
|
|
if tuple_shapes is not None:
|
|
self.set_tuple_shapes(tuple_shapes)
|
|
else:
|
|
self._tuple_shapes = None
|
|
if shard_dimensions is not None:
|
|
self.set_shard_dimensions(shard_dimensions)
|
|
self._validate()
|
|
|
|
def _validate(self):
|
|
"""Checks that the configuration is self-consistent.
|
|
|
|
Raises:
|
|
ValueError: if the shapes and sharding policies don't match.
|
|
"""
|
|
if self.tuple_shapes is not None:
|
|
for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
|
|
# Raise an error if the policy is incompatible with the shape.
|
|
_ = policy.get_sharded_shape(shape)
|
|
|
|
@property
|
|
def number_of_tuple_elements(self):
|
|
"""Returns the number of InfeedQueue tuple elements."""
|
|
return len(self._sharding_policies)
|
|
|
|
@property
|
|
def tuple_types(self):
|
|
"""Returns the types of the InfeedQueue tuple elements."""
|
|
return self._tuple_types
|
|
|
|
def set_tuple_types(self, tuple_types):
|
|
"""Sets the type of each element of the queue.
|
|
|
|
tuple_types must be a list of length
|
|
self.number_of_tuple_elements, and each element must be
|
|
convertible to a dtype.
|
|
|
|
Args:
|
|
tuple_types: the types of each queue element.
|
|
|
|
Raises:
|
|
ValueError: if tuple_types is not of length
|
|
self.number_of_tuple_elements.
|
|
TypeError: if an element of tuple_types cannot be converted to a
|
|
dtype.
|
|
"""
|
|
if len(tuple_types) != self.number_of_tuple_elements:
|
|
raise ValueError("tuple_types is %s, but must be a list of length %d" %
|
|
(str(tuple_types), self.number_of_tuple_elements))
|
|
if self._frozen:
|
|
for (frozen, updated) in zip(self._tuple_types, tuple_types):
|
|
if frozen != updated:
|
|
raise ValueError(
|
|
"Trying to update InfeedQueue with frozen configuration with an "
|
|
"incompatible type. Frozen types are %s, updated types are %s" % (
|
|
str(self._tuple_types), str(tuple_types)))
|
|
else:
|
|
try:
|
|
self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
|
|
except (TypeError) as e:
|
|
raise TypeError(
|
|
"tuple_types is %s, but must be a list of elements each "
|
|
"convertible to dtype: got error %s" % (str(tuple_types), str(e)))
|
|
|
|
@property
|
|
def tuple_shapes(self):
|
|
"""Returns the shapes of the InfeedQueue tuple elements."""
|
|
return self._tuple_shapes
|
|
|
|
def set_tuple_shapes(self, tuple_shapes):
|
|
"""Sets the shape of each element of the queue.
|
|
|
|
tuple_shapes must be a list of length
|
|
self.number_of_tuple_elements, and each element must be
|
|
convertible to a TensorShape.
|
|
|
|
Args:
|
|
tuple_shapes: the shapes of each queue element.
|
|
|
|
Raises:
|
|
ValueError: if tuple_shapes is not of length
|
|
self.number_of_tuple_elements.
|
|
TypeError: if an element of tuple_shapes cannot be converted to
|
|
a TensorShape.
|
|
"""
|
|
if len(tuple_shapes) != self.number_of_tuple_elements:
|
|
raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
|
|
(str(tuple_shapes), self.number_of_tuple_elements))
|
|
try:
|
|
tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
|
|
except (ValueError, TypeError) as e:
|
|
raise TypeError(
|
|
"tuple_shapes is %s, but must be a list of elements each "
|
|
"convertible to TensorShape: got error %s" % (str(tuple_shapes),
|
|
str(e)))
|
|
if self._frozen:
|
|
for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
|
|
if frozen != updated:
|
|
raise ValueError(
|
|
"Trying to update InfeedQueue with frozen configuration with an "
|
|
"incompatible shape. Frozen shapes are %s, updated shapes are %s"
|
|
% (str(self._tuple_shapes), str(tuple_shapes)))
|
|
else:
|
|
self._tuple_shapes = tuple_shapes
|
|
self._validate()
|
|
|
|
@property
|
|
def sharding_policies(self):
|
|
"""Returns the sharding policies of the InfeedQueue tuple elements."""
|
|
return self._sharding_policies
|
|
|
|
@property
|
|
def shard_dimensions(self):
|
|
"""Gets the shard dimension of each tuple element.
|
|
|
|
Returns:
|
|
A list of length number_of_tuple_elements, where each list entry
|
|
is the shard dimension of that tuple element or None if the
|
|
shard dimension has not been set.
|
|
"""
|
|
# The number of shards is always the same for all the policies.
|
|
return [policy.shard_dimension for policy in self._sharding_policies]
|
|
|
|
def set_shard_dimensions(self, shard_dimensions):
|
|
"""Sets the shard_dimension of each element of the queue.
|
|
|
|
shard_dimensions must be a list of length
|
|
self.number_of_tuple_elements, and each element must be
|
|
convertible to a Dimension compatible with self.tuple_shapes.
|
|
|
|
Args:
|
|
shard_dimensions: the dimensions of each queue element.
|
|
|
|
Raises:
|
|
ValueError: if shard_dimensions is not of length
|
|
self.number_of_tuple_elements; or an element of
|
|
shard_dimensions cannot be converted to a Dimension; or an
|
|
element of shard_dimensions is a Dimension that is out of
|
|
range for the corresponding tuple element shape.
|
|
"""
|
|
if len(shard_dimensions) != self.number_of_tuple_elements:
|
|
raise ValueError("shard_dimensions is %s, but must be a list of length %d"
|
|
% (str(shard_dimensions),
|
|
self.number_of_tuple_elements))
|
|
for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
|
|
policy.set_shard_dimension(dimension)
|
|
self._validate()
|
|
|
|
@property
|
|
def number_of_shards(self):
|
|
"""Gets the number of shards to use for the InfeedQueue.
|
|
|
|
Returns:
|
|
Number of shards or None if the number of shards has not been set.
|
|
"""
|
|
# The number of shards is always the same for all the policies.
|
|
return self._sharding_policies[0].number_of_shards
|
|
|
|
def set_number_of_shards(self, number_of_shards):
|
|
"""Sets the number of shards to use for the InfeedQueue.
|
|
|
|
Args:
|
|
number_of_shards: number of ways to shard the InfeedQueue.
|
|
|
|
Raises:
|
|
ValueError: if number_of_shards is not > 0; or the policies have
|
|
been frozen and number_of_shards was already set to something
|
|
else.
|
|
"""
|
|
for policy in self._sharding_policies:
|
|
policy.set_number_of_shards(number_of_shards)
|
|
self._validate()
|
|
|
|
def set_configuration_from_input_tensors(self, input_tensors):
|
|
"""Sets the shapes and types of the queue tuple elements.
|
|
|
|
input_tensors is a list of Tensors whose types and shapes are used
|
|
to set the queue configuration.
|
|
|
|
Args:
|
|
input_tensors: list of Tensors of the same types and shapes as
|
|
the desired queue Tuple.
|
|
|
|
Raises:
|
|
ValueError: if input_tensors is not a list of length
|
|
self.number_of_tuple_elements
|
|
"""
|
|
if len(input_tensors) != self.number_of_tuple_elements:
|
|
raise ValueError("input_tensors is %s, but should be a list of %d Tensors"
|
|
% (str(input_tensors), self.number_of_tuple_elements))
|
|
self.set_tuple_shapes([t.shape for t in input_tensors])
|
|
self.set_tuple_types([t.dtype for t in input_tensors])
|
|
|
|
def set_configuration_from_sharded_input_tensors(self, input_tensors):
|
|
"""Sets the shapes and types of the queue tuple elements.
|
|
|
|
input_tensors is a list of lists of Tensors whose types and shapes are used
|
|
to set the queue configuration. The length of the outer list is the number
|
|
of shards required, and each inner list is the tuple of Tensors to use to
|
|
determine the types and shapes of the corresponding shard. This method
|
|
depends on the shard dimension, and calling it freezes the shard policy.
|
|
|
|
Args:
|
|
input_tensors: list of lists of Tensors. The outer list length corresponds
|
|
to the desired number of shards, and each inner list is the size
|
|
and shape of the desired configuration of the corresponding shard.
|
|
|
|
Raises:
|
|
ValueError: if any inner list is not a list of length
|
|
self.number_of_tuple_elements; or the inner lists do not combine to
|
|
form a consistent unsharded shape.
|
|
TypeError: if the types of the Tensors in the inner lists do not match.
|
|
"""
|
|
if not self._frozen:
|
|
# Unset the tuple shapes in case the configuration becomes
|
|
# transiently inconsistent.
|
|
self._tuple_shapes = None
|
|
number_of_shards = len(input_tensors)
|
|
self.set_number_of_shards(number_of_shards)
|
|
for t in input_tensors:
|
|
if len(t) != self.number_of_tuple_elements:
|
|
raise ValueError(
|
|
"input_tensors is %s but must be a list of lists, where each inner"
|
|
" list has length number_of_tuple_elements=%d" % (
|
|
str(input_tensors), self.number_of_tuple_elements))
|
|
# Transpose the inputs to make a list of shard shapes for each tuple
|
|
# element.
|
|
sharded_shapes = [[t[i].shape for t in input_tensors]
|
|
for i in xrange(self.number_of_tuple_elements)]
|
|
# For each tuple, get the unsharded shape using that tuple's policy.
|
|
unsharded_shapes = [
|
|
policy.get_unsharded_shape(s)
|
|
for (policy, s) in zip(self._sharding_policies, sharded_shapes)
|
|
]
|
|
self.set_tuple_shapes(unsharded_shapes)
|
|
for i in xrange(1, self.number_of_shards):
|
|
for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
|
|
if t1.dtype != t2.dtype:
|
|
raise TypeError(
|
|
"types of the tuple elements of input_tensors %s are not "
|
|
"consistent" % str(input_tensors))
|
|
self.set_tuple_types([t.dtype for t in input_tensors[0]])
|
|
|
|
def freeze(self):
|
|
"""Freezes the InfeedQueue so it can no longer be modified.
|
|
|
|
The configuration is implicitly frozen before any host-side or
|
|
device-side Ops are generated. The configuration cannot be frozen
|
|
until the types and shapes of the tuple elements have been set.
|
|
|
|
Raises:
|
|
ValueError: if the types or shapes of the tuple elements have not been
|
|
set.
|
|
"""
|
|
self._frozen = True
|
|
if self._tuple_types is None:
|
|
raise ValueError(
|
|
"Can't freeze an InfeedQueue without setting all tuple types.")
|
|
if self._tuple_shapes is None:
|
|
raise ValueError(
|
|
"Can't freeze an InfeedQueue without setting all tuple shapes.")
|
|
for shape in self._tuple_shapes:
|
|
if shape.dims is None:
|
|
raise ValueError(
|
|
"Can't freeze an InfeedQueue without setting all tuple shapes.")
|
|
for policy in self._sharding_policies:
|
|
policy.freeze()
|
|
self._validate()
|
|
|
|
def generate_dequeue_op(self, tpu_device=0):
|
|
"""Generates the device-side Op to dequeue a tuple from the queue.
|
|
|
|
Implicitly freezes the queue configuration if it is not already
|
|
frozen, which will raise errors if the shapes and types have not
|
|
been fully specified.
|
|
|
|
Args:
|
|
tpu_device: The TPU device ordinal where the infeed instruction should be
|
|
placed. If None, no explicit placement will be performed, and it is up
|
|
to the user to call this API from within a proper TPU device scope.
|
|
The XLA code will fail if the TPU dequeue instruction is not bound to
|
|
any device.
|
|
|
|
Returns:
|
|
A list of Outputs corresponding to a shard of infeed dequeued
|
|
into XLA, suitable for use within a replicated block.
|
|
|
|
Raises:
|
|
ValueError: if the types or shapes of the tuple elements have not been
|
|
set; or if a dequeue op has already been generated.
|
|
"""
|
|
self.freeze()
|
|
if self._generated_dequeue_op:
|
|
raise ValueError("Can't generate two dequeue Ops from the same queue")
|
|
self._generated_dequeue_op = True
|
|
full_name = "%s/dequeue" % self._name
|
|
sharded_shapes = [
|
|
policy.get_sharded_shape(shape)
|
|
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
|
]
|
|
if tpu_device is not None:
|
|
with ops.device(tpu.core(tpu_device)):
|
|
return tpu_ops.infeed_dequeue_tuple(
|
|
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
|
else:
|
|
return tpu_ops.infeed_dequeue_tuple(
|
|
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
|
|
|
def _generate_enqueue_op(self,
|
|
inputs,
|
|
name_prefix,
|
|
index,
|
|
device=None,
|
|
tpu_ordinal=-1):
|
|
"""Generate a host-side Op to enqueue a tuple to the queue.
|
|
|
|
If device is None the inputs are all required to have the same
|
|
device specification, and the enqueue Op is colocated with
|
|
inputs[0]. Otherwise the enqueue Op is placed on 'device'.
|
|
|
|
Args:
|
|
inputs: a list of Tensors with the types and shapes of the tuple elements.
|
|
name_prefix: the base name for the Op.
|
|
index: the shard index, used to uniquify the Op name.
|
|
device: device to place the Op on, or None if it should be
|
|
colocated with the inputs.
|
|
tpu_ordinal: ordinal of the TPU device on the host to use for
|
|
infeed if device is a CPU device. Should be set to -1 if device
|
|
is a TPU device.
|
|
|
|
Returns:
|
|
An Op corresponding to a shard of infeed enqueued at the host,
|
|
suitable for use within a replicated block.
|
|
|
|
Raises:
|
|
ValueError: if device is None and inputs do not all have the
|
|
same device specification.
|
|
"""
|
|
full_name = "%s/%d" % (name_prefix, index)
|
|
shapes = [t.shape for t in inputs]
|
|
if device is None:
|
|
devices = [t.device for t in inputs]
|
|
for i in xrange(1, self.number_of_tuple_elements):
|
|
if devices[0] != devices[i]:
|
|
raise ValueError(
|
|
"input devices for shard %d are %s, but should all be the same" %
|
|
(index, str(devices)))
|
|
with ops.colocate_with(inputs[0]):
|
|
return tpu_ops.infeed_enqueue_tuple(
|
|
inputs=inputs,
|
|
shapes=shapes,
|
|
name=full_name,
|
|
device_ordinal=tpu_ordinal)
|
|
else:
|
|
with ops.device(device):
|
|
return tpu_ops.infeed_enqueue_tuple(
|
|
inputs=inputs,
|
|
shapes=shapes,
|
|
name=full_name,
|
|
device_ordinal=tpu_ordinal)
|
|
|
|
def generate_enqueue_ops(self,
|
|
sharded_inputs,
|
|
tpu_ordinal_function=None,
|
|
placement_function=None):
|
|
"""Generates the host-side Ops to enqueue the shards of a tuple.
|
|
|
|
sharded_inputs is a list, one for each shard, of lists of
|
|
Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
|
|
shard i of the queue. Returns the host-side Ops that must be run to
|
|
enqueue the sharded tuple. The Op for shard i is colocated with the inputs
|
|
for shard i.
|
|
|
|
Implicitly freezes the queue configuration if it is not already
|
|
frozen. If the configuration has already been frozen, and is not
|
|
compatible with the types and shapes of sharded_inputs, an error
|
|
will be raised.
|
|
|
|
Args:
|
|
sharded_inputs: a list of lists of Tensors. The length of the outer list
|
|
determines the number of shards. Each inner list indicates the types
|
|
and shapes of the tuples in the corresponding shard.
|
|
tpu_ordinal_function: if not None, a function that takes the
|
|
shard index as input and returns the ordinal of the TPU device
|
|
the shard's infeed should be placed on. tpu_ordinal_function must be
|
|
set if the inputs are placed on CPU devices.
|
|
placement_function: if not None, a function that takes the shard index as
|
|
input and returns the host device where the enqueue op should be placed
|
|
on.
|
|
|
|
Returns:
|
|
A list of host-side Ops, one for each shard, that when executed together
|
|
will enqueue a full-size element of infeed.
|
|
|
|
Raises:
|
|
ValueError: if the queue configuration has previously been frozen and the
|
|
shapes of the elements of sharded_inputs are not compatible with the
|
|
frozen configuration; or if the shapes of the elements of sharded_inputs
|
|
don't form a consistent unsharded tuple; or if the elements of a tuple
|
|
have different device constraints.
|
|
TypeError: if the queue configuration has previously been frozen and the
|
|
types of the elements of sharded_inputs are not compatible with the
|
|
frozen configuration; or if the types of the elements of sharded_inputs
|
|
don't form a consistent unsharded tuple.
|
|
"""
|
|
self.set_configuration_from_sharded_input_tensors(sharded_inputs)
|
|
self.freeze()
|
|
if self._generated_enqueue_ops:
|
|
raise ValueError("Can't generate two enqueue Ops from the same queue")
|
|
self._generated_enqueue_ops = True
|
|
if tpu_ordinal_function is None:
|
|
tpu_ordinal_function = lambda index: -1
|
|
name_prefix = "%s/enqueue" % self._name
|
|
return [
|
|
self._generate_enqueue_op(
|
|
shard,
|
|
name_prefix,
|
|
index,
|
|
tpu_ordinal=tpu_ordinal_function(index),
|
|
device=placement_function(index) if placement_function else None)
|
|
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
|
|
]
|
|
|
|
# TODO(misard) Generalize this to the case of systems that don't
|
|
# have 8 devices per host, and figure out what to do with
|
|
# model-parallelism.
|
|
def _default_placement_function(self, index):
|
|
return "/task:%d/device:CPU:0" % (index / 8)
|
|
|
|
def _default_ordinal_function(self, index):
|
|
return index % 8
|
|
|
|
# TODO(b/36470756) remove this from tutorials once we have a better story
|
|
# for automatic placement of input pipelines.
|
|
def split_inputs_and_generate_enqueue_ops(self,
|
|
inputs,
|
|
device_assignment=None,
|
|
placement_function=None,
|
|
tpu_ordinal_function=None):
|
|
"""POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
|
|
|
|
Generates the host-side Ops to enqueue a tuple.
|
|
|
|
This method performs poorly because it takes an entire input on a single
|
|
host, splits it, and distributes it to all of the cores. It is present only
|
|
to simplify tutorial examples.
|
|
|
|
inputs is a list of Tensors to use to feed the queue. Each input is split
|
|
into self.number_of_shards shards. Returns an Op for each shard to enqueue
|
|
the shard. The Op for shard i is placed on device placement_function(i).
|
|
|
|
Implicitly freezes the queue configuration if it is not already
|
|
frozen. If the configuration has already been frozen, and is not
|
|
compatible with the types and shapes of inputs, an error
|
|
will be raised.
|
|
|
|
Args:
|
|
inputs: a list of Tensors which indicates the types and shapes of the
|
|
queue tuple.
|
|
device_assignment: if not `None`, a TPU `DeviceAssignment`. If
|
|
device_assignment is not `None`, but `placement_function` and
|
|
`ordinal_function` are None, then `device_assignment` will be used to
|
|
place infeeds on the first k TPU shards, where k is the number of shards
|
|
in the queue. If all three are `None`, then default placement and
|
|
ordinal functions are used.
|
|
placement_function: if not None, a function that takes the shard
|
|
index as input and returns a device string indicating which
|
|
device the shard's infeed should be placed on. If placement_function
|
|
and tpu_ordinal_function are None, inputs are sharded round-robin
|
|
across the devices in the system.
|
|
tpu_ordinal_function: if not None, a function that takes the
|
|
shard index as input and returns the ordinal of the TPU device
|
|
the shard's infeed should be placed on. If placement_function
|
|
and tpu_ordinal_function are None, inputs are sharded round-robin
|
|
across the devices in the system.
|
|
|
|
Returns:
|
|
A list of host-side Ops, one for each shard, that when executed together
|
|
will enqueue a full-size element of infeed.
|
|
|
|
Raises:
|
|
ValueError: if the queue configuration has previously been frozen and the
|
|
shapes of the elements of inputs are not compatible with the frozen
|
|
configuration.
|
|
TypeError: if the queue configuration has previously been frozen and the
|
|
types of the elements of inputs are not compatible with the frozen
|
|
configuration.
|
|
"""
|
|
if device_assignment is None:
|
|
if placement_function is None:
|
|
placement_function = self._default_placement_function
|
|
if tpu_ordinal_function is None:
|
|
tpu_ordinal_function = self._default_ordinal_function
|
|
else:
|
|
|
|
def _placement_function_from_map(index):
|
|
return device_assignment.host_device(replica=index)
|
|
|
|
def _ordinal_function_from_map(index):
|
|
return device_assignment.tpu_ordinal(replica=index)
|
|
|
|
if placement_function is None:
|
|
placement_function = _placement_function_from_map
|
|
if tpu_ordinal_function is None:
|
|
tpu_ordinal_function = _ordinal_function_from_map
|
|
self.set_configuration_from_input_tensors(inputs)
|
|
self.freeze()
|
|
if self._generated_enqueue_ops:
|
|
raise ValueError("Can't generate two enqueue Ops from the same queue")
|
|
self._generated_enqueue_ops = True
|
|
split_name_prefix = "%s/split" % self._name
|
|
if self.number_of_shards == 1:
|
|
transposed_sharded_inputs = [[inp] for inp in inputs]
|
|
else:
|
|
|
|
def split_fn(inp, num_shards, axis, name):
|
|
with ops.colocate_with(inp):
|
|
return array_ops.split(inp, num_shards, axis=axis, name=name)
|
|
|
|
transposed_sharded_inputs = [
|
|
split_fn(
|
|
inp,
|
|
self.number_of_shards,
|
|
axis=policy.shard_dimension,
|
|
name="%s/%d" % (split_name_prefix, index))
|
|
for (inp, policy, index) in zip(inputs, self._sharding_policies,
|
|
xrange(self.number_of_tuple_elements))
|
|
]
|
|
sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
|
|
for i in xrange(self.number_of_shards)]
|
|
name_prefix = "%s/enqueue" % self._name
|
|
return [
|
|
self._generate_enqueue_op(
|
|
shard,
|
|
name_prefix,
|
|
index,
|
|
device=placement_function(index),
|
|
tpu_ordinal=tpu_ordinal_function(index))
|
|
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
|
|
]
|
|
|
|
|
|
class _PartitionedInfeedQueue(InfeedQueue):
|
|
"""A helper object to build a device infeed queue with input partition.
|
|
|
|
Args:
|
|
number_of_tuple_elements: the number of Tensors fed atomically through the
|
|
queue, must be present unless it can be inferred from other arguments.
|
|
device_assignment: A TPU `DeviceAssignment` which is used to place all the
|
|
partitions to different TPU infeed queues.
|
|
host_id: The id of the host machine.
|
|
input_partition_dims: A nested list/tuple of integers. Each inner
|
|
list/tuple describes how to partition the corresponding input tensor.
|
|
tuple_types: If not None, a list of types of the elements of the queue.
|
|
tuple_shapes: If not None, a list of shapes of the elements of the queue.
|
|
name: The name of the queue.
|
|
"""
|
|
|
|
def __init__(self,
|
|
number_of_tuple_elements,
|
|
device_assignment,
|
|
host_id,
|
|
input_partition_dims=None,
|
|
tuple_types=None,
|
|
tuple_shapes=None,
|
|
name=None):
|
|
super(_PartitionedInfeedQueue, self).__init__(
|
|
number_of_tuple_elements=number_of_tuple_elements,
|
|
tuple_types=tuple_types,
|
|
tuple_shapes=None,
|
|
shard_dimensions=None,
|
|
name="PartitionedInfeedQueue" if name is None else name)
|
|
self._input_partition_dims = input_partition_dims
|
|
self._host_id = host_id
|
|
self._device_assignment = device_assignment
|
|
|
|
def generate_dequeue_op(self, tpu_device=0):
|
|
"""Generate TPU dequeue ops.
|
|
|
|
Args:
|
|
tpu_device: The TPU device ordinal where the infeed instruction should be
|
|
placed.
|
|
|
|
Returns:
|
|
A list of Outputs corresponding to a partition of infeed dequeued
|
|
into XLA, suitable for use within a replicated block.
|
|
|
|
Raises:
|
|
ValueError: if the types or shapes of the tuple elements have not been
|
|
set; or if a dequeue op has already been generated.
|
|
"""
|
|
self.freeze()
|
|
if self._generated_dequeue_op:
|
|
raise ValueError("Can't generate two dequeue Ops from the same queue")
|
|
self._generated_dequeue_op = True
|
|
full_name = "%s/dequeue" % self._name
|
|
sharded_shapes = [
|
|
policy.get_sharded_shape(shape)
|
|
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
|
]
|
|
with ops.device(tpu.core(tpu_device)):
|
|
values = tpu_ops.infeed_dequeue_tuple(
|
|
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
|
return tag_sharding_attribute_for_dequeued_tensors(
|
|
values, self._input_partition_dims)
|
|
|
|
def generate_enqueue_ops(self, sharded_inputs):
|
|
"""Generates the host-side Ops to enqueue the partitioned inputs.
|
|
|
|
sharded_inputs is a list, one for each replica, of lists of
|
|
Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
|
|
replica i.
|
|
sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].
|
|
|
|
For example, if sharded_inputs[i][j] is a 2-D Tensor:
|
|
[[A, B, C, D],
|
|
[E ,F, G, H]]
|
|
self._input_partition_dims[j] is [2, 4].
|
|
|
|
sharded_inputs[i][j] will be partitioned and flattened into:
|
|
[A, B, C, D, E, F, G, H] and fed into the logical core ids:
|
|
[0, 1, 2, 3, 4, 5, 6, 7] respectively.
|
|
|
|
Args:
|
|
sharded_inputs: a list of lists of Tensors. The length of the
|
|
outer list determines the number of shards. Each inner list indicates
|
|
the types and shapes of the tuples in the corresponding shard.
|
|
|
|
Returns:
|
|
A list of host-side Ops, one for each shard, that when executed together
|
|
will enqueue a full-size element of infeed.
|
|
|
|
Raises:
|
|
ValueError: if the queue configuration has previously been frozen and the
|
|
shapes of the elements of sharded_inputs are not compatible with the
|
|
frozen configuration; or if the shapes of the elements of sharded_inputs
|
|
don't form a consistent unsharded tuple; or if the elements of a tuple
|
|
have different device constraints; or if the partition dims are invalid.
|
|
TypeError: if the queue configuration has previously been frozen and the
|
|
types of the elements of sharded_inputs are not compatible with the
|
|
frozen configuration; or if the types of the elements of sharded_inputs
|
|
don't form a consistent unsharded tuple.
|
|
"""
|
|
self.set_configuration_from_sharded_input_tensors(sharded_inputs)
|
|
number_of_replicas = len(sharded_inputs)
|
|
number_of_tuple_elements = len(sharded_inputs[0])
|
|
|
|
assert len(self._input_partition_dims) == number_of_tuple_elements
|
|
enqueue_ops = []
|
|
|
|
for replica_index in range(number_of_replicas):
|
|
flattened_inputs = sharded_inputs[replica_index]
|
|
inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
|
|
self._input_partition_dims)
|
|
inputs_parted_iters = [
|
|
iter(self._check_dims_and_partition_or_replicate_on_host(x, dims))
|
|
for x, dims in zip(sharded_inputs[replica_index],
|
|
inputs_part_dims_flat)
|
|
]
|
|
|
|
# Find the replica_id of the host's logical core 0.
|
|
# The self._host_id is guaranteed to contain the logical core 0,
|
|
# even when num_cores_per_replica > num_cores_per_host -- the function
|
|
# caller makes sure that this host_id will must be receiving data (calls
|
|
# input_fn).
|
|
replica_id = self._device_assignment.lookup_replicas(
|
|
task_id=self._host_id, logical_core=0)[replica_index]
|
|
for logical_core in xrange(self._device_assignment.num_cores_per_replica):
|
|
# Places different partitions to different logic cores.
|
|
# Since there can be multiple hosts per replica, we need to find
|
|
# the actual host (device) of this logical core.
|
|
device = self._device_assignment.host_device(
|
|
replica=replica_id, logical_core=logical_core)
|
|
|
|
with ops.device(device):
|
|
ordinal = self._device_assignment.tpu_ordinal(
|
|
replica=replica_id, logical_core=logical_core)
|
|
infeed_inputs = []
|
|
for it in inputs_parted_iters:
|
|
input_for_device = next(it, None)
|
|
if input_for_device is not None:
|
|
infeed_inputs.append(input_for_device)
|
|
|
|
if infeed_inputs:
|
|
enqueue_ops.append(
|
|
tpu_ops.infeed_enqueue_tuple(
|
|
inputs=infeed_inputs,
|
|
shapes=[x.shape for x in infeed_inputs],
|
|
name="enqueue/replica_{0}/input_{1}".format(
|
|
replica_index, logical_core),
|
|
device_ordinal=ordinal))
|
|
return enqueue_ops
|
|
|
|
def _check_input_partition_dims(self, tensor, dims):
|
|
"""Checks that input partition dims are valid for the `Tensor`.
|
|
|
|
Args:
|
|
tensor: Input tensor for partitioning.
|
|
dims: A list of integer describes how to partition the input tensor.
|
|
|
|
Raises:
|
|
ValueError: If the tensor can't be partitioned by dims or the
|
|
num_cores_per_replica doesn't match the number of
|
|
partitions(dims.prod()).
|
|
"""
|
|
# No partitioning specified, so don't perform further checks.
|
|
if dims is None:
|
|
return
|
|
|
|
dims = np.array(dims)
|
|
|
|
if (dims < 1).any():
|
|
raise ValueError("All input partition dims must be >= 1.")
|
|
|
|
# No partitioning, so don't perform further checks.
|
|
if dims.prod() == 1:
|
|
return
|
|
|
|
if dims.prod() != self._device_assignment.num_cores_per_replica:
|
|
raise ValueError(
|
|
"The product of each input partition dim should equal to "
|
|
"num_cores_per_replica. (dim = {}, num_cores_per_replica "
|
|
"= {})".format(dims, self._device_assignment.num_cores_per_replica))
|
|
if dims.shape[0] != tensor.shape.ndims:
|
|
raise ValueError(
|
|
"Input partition dims must have the same number of dimensions "
|
|
"as the `Tensor` to be partitioned. (tensor shape = {}, input "
|
|
"partition dims = {}).".format(tensor.shape.as_list(), dims))
|
|
|
|
tensor.shape.assert_is_fully_defined()
|
|
|
|
def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims):
|
|
"""Checks dims and partitions or replicates the input tensor.
|
|
|
|
The ops inside this function are placed on the host side.
|
|
|
|
Args:
|
|
tensor: The input tensor which will be partitioned or replicated.
|
|
dims: A list of integer describes how to partition the input tensor.
|
|
|
|
Returns:
|
|
An iterator of `Tensor`s or a list of partitioned tensors.
|
|
"""
|
|
self._check_input_partition_dims(tensor, dims)
|
|
return partition_or_replicate_on_host(tensor, dims)
|