Expose XlaOptions as an argument to tpu.rewrite

Right now, it contains an single option whether requested XLA partitioning should be done in SPMD or MPMD.

XLA/TPU has two implementations of partitioning. SPMD: ideal for per-op partitioning; MPMD: required for graph partitioning (e.g., GPipe).

PiperOrigin-RevId: 320282466
Change-Id: I8b90394333a15da9a3f2d93ac79e56eb2764f466
This commit is contained in:
Yuanzhong Xu 2020-07-08 15:51:43 -07:00 committed by TensorFlower Gardener
parent 486ac1e10a
commit 353af3ea66
5 changed files with 86 additions and 16 deletions

View File

@ -11,6 +11,7 @@ py_library(
srcs = ["xla_sharding.py"], srcs = ["xla_sharding.py"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/compiler/xla:xla_data_proto_py", "//tensorflow/compiler/xla:xla_data_proto_py",
"//tensorflow/compiler/xla/python_api:types", "//tensorflow/compiler/xla/python_api:types",
"//tensorflow/compiler/xla/python_api:xla_shape", "//tensorflow/compiler/xla/python_api:xla_shape",

View File

@ -19,9 +19,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import logging import collections
import enum import enum
from absl import logging
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
@ -812,6 +813,23 @@ class PaddingSpec(enum.IntEnum):
POWER_OF_TWO = 1 POWER_OF_TWO = 1
@tf_export(v1=["tpu.XLAOptions"])
class XLAOptions(
collections.namedtuple("XLAOptions", [
"use_spmd_for_xla_partitioning",
])):
"""XLA compilation options.
Attributes:
use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD
partitioner instead of MPMD partitioner when compiler partitioning is
requested.
"""
def __new__(cls, use_spmd_for_xla_partitioning=False):
return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning)
@tf_export(v1=["tpu.replicate"]) @tf_export(v1=["tpu.replicate"])
def replicate(computation, def replicate(computation,
inputs=None, inputs=None,
@ -819,7 +837,8 @@ def replicate(computation,
device_assignment=None, device_assignment=None,
name=None, name=None,
maximum_shapes=None, maximum_shapes=None,
padding_spec=None): padding_spec=None,
xla_options=None):
"""Builds a graph operator that runs a replicated TPU computation. """Builds a graph operator that runs a replicated TPU computation.
Example for the basic usage that `inputs` has static shape: Example for the basic usage that `inputs` has static shape:
@ -884,6 +903,8 @@ def replicate(computation,
One usage is to enable automatic bucketizing on the inputs by setting the One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side. recompilation in the XLA side.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns: Returns:
A list of outputs, indexed by `[replica_num]` each output can be a nested A list of outputs, indexed by `[replica_num]` each output can be a nested
structure same as what computation() returns with a few exceptions. structure same as what computation() returns with a few exceptions.
@ -912,7 +933,8 @@ def replicate(computation,
device_assignment, device_assignment,
name, name,
maximum_shapes=maximum_shapes, maximum_shapes=maximum_shapes,
padding_spec=padding_spec)[1] padding_spec=padding_spec,
xla_options=xla_options)[1]
def _ceil_to_pow_of_n(x, n): def _ceil_to_pow_of_n(x, n):
@ -1104,7 +1126,8 @@ def split_compile_and_replicate(computation,
name=None, name=None,
use_tpu=True, use_tpu=True,
maximum_shapes=None, maximum_shapes=None,
padding_spec=None): padding_spec=None,
xla_options=None):
"""Builds graph operators that runs compilation and replicated computation. """Builds graph operators that runs compilation and replicated computation.
This is a lower level interface than replicate that returns a separate compile This is a lower level interface than replicate that returns a separate compile
@ -1146,6 +1169,8 @@ def split_compile_and_replicate(computation,
One usage is to enable automatic bucketizing on the inputs by setting the One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side. recompilation in the XLA side.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns: Returns:
A list of lists with the first list corresponding to the compile op and the A list of lists with the first list corresponding to the compile op and the
@ -1161,6 +1186,7 @@ def split_compile_and_replicate(computation,
""" """
del name del name
inputs = [[]] if inputs is None else inputs inputs = [[]] if inputs is None else inputs
xla_options = xla_options or XLAOptions()
metadata_kwargs = {} metadata_kwargs = {}
if device_assignment is not None: if device_assignment is not None:
@ -1284,6 +1310,8 @@ def split_compile_and_replicate(computation,
metadata_kwargs["step_marker_location"] = getattr( metadata_kwargs["step_marker_location"] = getattr(
computation, "step_marker_location", "STEP_MARK_AT_ENTRY") computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
metadata_kwargs["use_spmd_for_xla_partitioning"] = \
xla_options.use_spmd_for_xla_partitioning
graph = ops.get_default_graph() graph = ops.get_default_graph()
@ -1569,7 +1597,8 @@ def split_compile_and_shard(computation,
output_shard_axes=None, output_shard_axes=None,
infeed_queue=None, infeed_queue=None,
device_assignment=None, device_assignment=None,
name=None): name=None,
xla_options=None):
"""Shards `computation` for parallel execution. """Shards `computation` for parallel execution.
`inputs` must be a list of Tensors or None (equivalent to an empty list), each `inputs` must be a list of Tensors or None (equivalent to an empty list), each
@ -1619,6 +1648,8 @@ def split_compile_and_shard(computation,
only one core, and there is either only one shard, or the number of shards only one core, and there is either only one shard, or the number of shards
is equal to the number of cores in the TPU system. is equal to the number of cores in the TPU system.
name: (Deprecated) Does nothing. name: (Deprecated) Does nothing.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns: Returns:
A tuple of (compile op, [output tensors]). A tuple of (compile op, [output tensors]).
Raises: Raises:
@ -1662,7 +1693,8 @@ def split_compile_and_shard(computation,
transposed_inputs, transposed_inputs,
infeed_queue=infeed_queue, infeed_queue=infeed_queue,
device_assignment=device_assignment, device_assignment=device_assignment,
name=name) name=name,
xla_options=xla_options)
# There must be at least one shard since num_shards > 0. # There must be at least one shard since num_shards > 0.
# TODO(b/36647078) remove disable when pylint bug is fixed. # TODO(b/36647078) remove disable when pylint bug is fixed.
@ -1720,7 +1752,8 @@ def shard(computation,
output_shard_axes=None, output_shard_axes=None,
infeed_queue=None, infeed_queue=None,
device_assignment=None, device_assignment=None,
name=None): name=None,
xla_options=None):
"""Shards `computation` for parallel execution. """Shards `computation` for parallel execution.
`inputs` must be a list of Tensors or None (equivalent to an empty list), each `inputs` must be a list of Tensors or None (equivalent to an empty list), each
@ -1773,6 +1806,8 @@ def shard(computation,
only one core, and there is either only one shard, or the number of shards only one core, and there is either only one shard, or the number of shards
is equal to the number of cores in the TPU system. is equal to the number of cores in the TPU system.
name: (Deprecated) Does nothing. name: (Deprecated) Does nothing.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns: Returns:
A list of output tensors. A list of output tensors.
Raises: Raises:
@ -1789,7 +1824,8 @@ def shard(computation,
output_shard_axes=output_shard_axes, output_shard_axes=output_shard_axes,
infeed_queue=infeed_queue, infeed_queue=infeed_queue,
device_assignment=device_assignment, device_assignment=device_assignment,
name=name)[1] name=name,
xla_options=xla_options)[1]
@tf_export(v1=["tpu.batch_parallel"]) @tf_export(v1=["tpu.batch_parallel"])
@ -1798,7 +1834,8 @@ def batch_parallel(computation,
num_shards=1, num_shards=1,
infeed_queue=None, infeed_queue=None,
device_assignment=None, device_assignment=None,
name=None): name=None,
xla_options=None):
"""Shards `computation` along the batch dimension for parallel execution. """Shards `computation` along the batch dimension for parallel execution.
Convenience wrapper around shard(). Convenience wrapper around shard().
@ -1835,6 +1872,8 @@ def batch_parallel(computation,
only one core, and there is either only one shard, or the number of shards only one core, and there is either only one shard, or the number of shards
is equal to the number of cores in the TPU system. is equal to the number of cores in the TPU system.
name: (Deprecated) Does nothing. name: (Deprecated) Does nothing.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns: Returns:
A list of output tensors. A list of output tensors.
Raises: Raises:
@ -1846,7 +1885,8 @@ def batch_parallel(computation,
num_shards=num_shards, num_shards=num_shards,
infeed_queue=infeed_queue, infeed_queue=infeed_queue,
device_assignment=device_assignment, device_assignment=device_assignment,
name=name) name=name,
xla_options=xla_options)
@tf_export(v1=["tpu.rewrite"]) @tf_export(v1=["tpu.rewrite"])
@ -1854,7 +1894,8 @@ def rewrite(computation,
inputs=None, inputs=None,
infeed_queue=None, infeed_queue=None,
device_assignment=None, device_assignment=None,
name=None): name=None,
xla_options=None):
"""Rewrites `computation` for execution on a TPU system. """Rewrites `computation` for execution on a TPU system.
Args: Args:
@ -1882,6 +1923,8 @@ def rewrite(computation,
the TPU topology. May be omitted for a single-core computation, in which the TPU topology. May be omitted for a single-core computation, in which
case the core attached to task 0, TPU device 0 is used. case the core attached to task 0, TPU device 0 is used.
name: (Deprecated) Does nothing. name: (Deprecated) Does nothing.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns: Returns:
Same data structure as if computation(*inputs) is called directly with some Same data structure as if computation(*inputs) is called directly with some
exceptions for correctness. Exceptions include: exceptions for correctness. Exceptions include:
@ -1899,7 +1942,8 @@ def rewrite(computation,
None if inputs is None else [inputs], None if inputs is None else [inputs],
infeed_queue=infeed_queue, infeed_queue=infeed_queue,
device_assignment=device_assignment, device_assignment=device_assignment,
name=name)[0] name=name,
xla_options=xla_options)[0]
# pylint: enable=indexing-exception # pylint: enable=indexing-exception
# Operations that indicate some error in the user's inference graph. # Operations that indicate some error in the user's inference graph.

View File

@ -0,0 +1,19 @@
path: "tensorflow.tpu.XLAOptions"
tf_class {
is_instance: "<class \'tensorflow.python.tpu.tpu.XLAOptions\'>"
is_instance: "<class \'tensorflow.python.tpu.tpu.XLAOptions\'>"
is_instance: "<type \'tuple\'>"
member {
name: "use_spmd_for_xla_partitioning"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -8,13 +8,17 @@ tf_module {
name: "PaddingSpec" name: "PaddingSpec"
mtype: "<class \'enum.EnumMeta\'>" mtype: "<class \'enum.EnumMeta\'>"
} }
member {
name: "XLAOptions"
mtype: "<type \'type\'>"
}
member { member {
name: "experimental" name: "experimental"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"
} }
member_method { member_method {
name: "batch_parallel" name: "batch_parallel"
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \'None\'], " argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "bfloat16_scope" name: "bfloat16_scope"
@ -38,15 +42,15 @@ tf_module {
} }
member_method { member_method {
name: "replicate" name: "replicate"
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "rewrite" name: "rewrite"
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "shard" name: "shard"
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'input_shard_axes\', \'outputs_from_all_shards\', \'output_shard_axes\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'input_shard_axes\', \'outputs_from_all_shards\', \'output_shard_axes\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "shutdown_system" name: "shutdown_system"

View File

@ -1372,6 +1372,8 @@ renames = {
'tf.compat.v1.tpu.shard', 'tf.compat.v1.tpu.shard',
'tf.tpu.shutdown_system': 'tf.tpu.shutdown_system':
'tf.compat.v1.tpu.shutdown_system', 'tf.compat.v1.tpu.shutdown_system',
'tf.tpu.XLAOptions':
'tf.compat.v1.tpu.XLAOptions',
'tf.trace': 'tf.trace':
'tf.linalg.trace', 'tf.linalg.trace',
'tf.train.AdadeltaOptimizer': 'tf.train.AdadeltaOptimizer':