Expose XlaOptions as an argument to tpu.rewrite

Right now, it contains an single option whether requested XLA partitioning should be done in SPMD or MPMD.

XLA/TPU has two implementations of partitioning. SPMD: ideal for per-op partitioning; MPMD: required for graph partitioning (e.g., GPipe).

PiperOrigin-RevId: 320282466
Change-Id: I8b90394333a15da9a3f2d93ac79e56eb2764f466
This commit is contained in:
Yuanzhong Xu 2020-07-08 15:51:43 -07:00 committed by TensorFlower Gardener
parent 486ac1e10a
commit 353af3ea66
5 changed files with 86 additions and 16 deletions

View File

@ -11,6 +11,7 @@ py_library(
srcs = ["xla_sharding.py"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/compiler/xla:xla_data_proto_py",
"//tensorflow/compiler/xla/python_api:types",
"//tensorflow/compiler/xla/python_api:xla_shape",

View File

@ -19,9 +19,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import collections
import enum
from absl import logging
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@ -812,6 +813,23 @@ class PaddingSpec(enum.IntEnum):
POWER_OF_TWO = 1
@tf_export(v1=["tpu.XLAOptions"])
class XLAOptions(
collections.namedtuple("XLAOptions", [
"use_spmd_for_xla_partitioning",
])):
"""XLA compilation options.
Attributes:
use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD
partitioner instead of MPMD partitioner when compiler partitioning is
requested.
"""
def __new__(cls, use_spmd_for_xla_partitioning=False):
return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning)
@tf_export(v1=["tpu.replicate"])
def replicate(computation,
inputs=None,
@ -819,7 +837,8 @@ def replicate(computation,
device_assignment=None,
name=None,
maximum_shapes=None,
padding_spec=None):
padding_spec=None,
xla_options=None):
"""Builds a graph operator that runs a replicated TPU computation.
Example for the basic usage that `inputs` has static shape:
@ -884,6 +903,8 @@ def replicate(computation,
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns:
A list of outputs, indexed by `[replica_num]` each output can be a nested
structure same as what computation() returns with a few exceptions.
@ -912,7 +933,8 @@ def replicate(computation,
device_assignment,
name,
maximum_shapes=maximum_shapes,
padding_spec=padding_spec)[1]
padding_spec=padding_spec,
xla_options=xla_options)[1]
def _ceil_to_pow_of_n(x, n):
@ -1104,7 +1126,8 @@ def split_compile_and_replicate(computation,
name=None,
use_tpu=True,
maximum_shapes=None,
padding_spec=None):
padding_spec=None,
xla_options=None):
"""Builds graph operators that runs compilation and replicated computation.
This is a lower level interface than replicate that returns a separate compile
@ -1146,6 +1169,8 @@ def split_compile_and_replicate(computation,
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns:
A list of lists with the first list corresponding to the compile op and the
@ -1161,6 +1186,7 @@ def split_compile_and_replicate(computation,
"""
del name
inputs = [[]] if inputs is None else inputs
xla_options = xla_options or XLAOptions()
metadata_kwargs = {}
if device_assignment is not None:
@ -1284,6 +1310,8 @@ def split_compile_and_replicate(computation,
metadata_kwargs["step_marker_location"] = getattr(
computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
metadata_kwargs["use_spmd_for_xla_partitioning"] = \
xla_options.use_spmd_for_xla_partitioning
graph = ops.get_default_graph()
@ -1569,7 +1597,8 @@ def split_compile_and_shard(computation,
output_shard_axes=None,
infeed_queue=None,
device_assignment=None,
name=None):
name=None,
xla_options=None):
"""Shards `computation` for parallel execution.
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
@ -1619,6 +1648,8 @@ def split_compile_and_shard(computation,
only one core, and there is either only one shard, or the number of shards
is equal to the number of cores in the TPU system.
name: (Deprecated) Does nothing.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns:
A tuple of (compile op, [output tensors]).
Raises:
@ -1662,7 +1693,8 @@ def split_compile_and_shard(computation,
transposed_inputs,
infeed_queue=infeed_queue,
device_assignment=device_assignment,
name=name)
name=name,
xla_options=xla_options)
# There must be at least one shard since num_shards > 0.
# TODO(b/36647078) remove disable when pylint bug is fixed.
@ -1720,7 +1752,8 @@ def shard(computation,
output_shard_axes=None,
infeed_queue=None,
device_assignment=None,
name=None):
name=None,
xla_options=None):
"""Shards `computation` for parallel execution.
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
@ -1773,6 +1806,8 @@ def shard(computation,
only one core, and there is either only one shard, or the number of shards
is equal to the number of cores in the TPU system.
name: (Deprecated) Does nothing.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns:
A list of output tensors.
Raises:
@ -1789,7 +1824,8 @@ def shard(computation,
output_shard_axes=output_shard_axes,
infeed_queue=infeed_queue,
device_assignment=device_assignment,
name=name)[1]
name=name,
xla_options=xla_options)[1]
@tf_export(v1=["tpu.batch_parallel"])
@ -1798,7 +1834,8 @@ def batch_parallel(computation,
num_shards=1,
infeed_queue=None,
device_assignment=None,
name=None):
name=None,
xla_options=None):
"""Shards `computation` along the batch dimension for parallel execution.
Convenience wrapper around shard().
@ -1835,6 +1872,8 @@ def batch_parallel(computation,
only one core, and there is either only one shard, or the number of shards
is equal to the number of cores in the TPU system.
name: (Deprecated) Does nothing.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns:
A list of output tensors.
Raises:
@ -1846,7 +1885,8 @@ def batch_parallel(computation,
num_shards=num_shards,
infeed_queue=infeed_queue,
device_assignment=device_assignment,
name=name)
name=name,
xla_options=xla_options)
@tf_export(v1=["tpu.rewrite"])
@ -1854,7 +1894,8 @@ def rewrite(computation,
inputs=None,
infeed_queue=None,
device_assignment=None,
name=None):
name=None,
xla_options=None):
"""Rewrites `computation` for execution on a TPU system.
Args:
@ -1882,6 +1923,8 @@ def rewrite(computation,
the TPU topology. May be omitted for a single-core computation, in which
case the core attached to task 0, TPU device 0 is used.
name: (Deprecated) Does nothing.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns:
Same data structure as if computation(*inputs) is called directly with some
exceptions for correctness. Exceptions include:
@ -1899,7 +1942,8 @@ def rewrite(computation,
None if inputs is None else [inputs],
infeed_queue=infeed_queue,
device_assignment=device_assignment,
name=name)[0]
name=name,
xla_options=xla_options)[0]
# pylint: enable=indexing-exception
# Operations that indicate some error in the user's inference graph.

View File

@ -0,0 +1,19 @@
path: "tensorflow.tpu.XLAOptions"
tf_class {
is_instance: "<class \'tensorflow.python.tpu.tpu.XLAOptions\'>"
is_instance: "<class \'tensorflow.python.tpu.tpu.XLAOptions\'>"
is_instance: "<type \'tuple\'>"
member {
name: "use_spmd_for_xla_partitioning"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -8,13 +8,17 @@ tf_module {
name: "PaddingSpec"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "XLAOptions"
mtype: "<type \'type\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
}
member_method {
name: "batch_parallel"
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "bfloat16_scope"
@ -38,15 +42,15 @@ tf_module {
}
member_method {
name: "replicate"
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "rewrite"
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'input_shard_axes\', \'outputs_from_all_shards\', \'output_shard_axes\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'input_shard_axes\', \'outputs_from_all_shards\', \'output_shard_axes\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "shutdown_system"

View File

@ -1372,6 +1372,8 @@ renames = {
'tf.compat.v1.tpu.shard',
'tf.tpu.shutdown_system':
'tf.compat.v1.tpu.shutdown_system',
'tf.tpu.XLAOptions':
'tf.compat.v1.tpu.XLAOptions',
'tf.trace':
'tf.linalg.trace',
'tf.train.AdadeltaOptimizer':