Expose XlaOptions as an argument to tpu.rewrite
Right now, it contains an single option whether requested XLA partitioning should be done in SPMD or MPMD. XLA/TPU has two implementations of partitioning. SPMD: ideal for per-op partitioning; MPMD: required for graph partitioning (e.g., GPipe). PiperOrigin-RevId: 320282466 Change-Id: I8b90394333a15da9a3f2d93ac79e56eb2764f466
This commit is contained in:
parent
486ac1e10a
commit
353af3ea66
@ -11,6 +11,7 @@ py_library(
|
||||
srcs = ["xla_sharding.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_py",
|
||||
"//tensorflow/compiler/xla/python_api:types",
|
||||
"//tensorflow/compiler/xla/python_api:xla_shape",
|
||||
|
@ -19,9 +19,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import logging
|
||||
import collections
|
||||
import enum
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
@ -812,6 +813,23 @@ class PaddingSpec(enum.IntEnum):
|
||||
POWER_OF_TWO = 1
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.XLAOptions"])
|
||||
class XLAOptions(
|
||||
collections.namedtuple("XLAOptions", [
|
||||
"use_spmd_for_xla_partitioning",
|
||||
])):
|
||||
"""XLA compilation options.
|
||||
|
||||
Attributes:
|
||||
use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD
|
||||
partitioner instead of MPMD partitioner when compiler partitioning is
|
||||
requested.
|
||||
"""
|
||||
|
||||
def __new__(cls, use_spmd_for_xla_partitioning=False):
|
||||
return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning)
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.replicate"])
|
||||
def replicate(computation,
|
||||
inputs=None,
|
||||
@ -819,7 +837,8 @@ def replicate(computation,
|
||||
device_assignment=None,
|
||||
name=None,
|
||||
maximum_shapes=None,
|
||||
padding_spec=None):
|
||||
padding_spec=None,
|
||||
xla_options=None):
|
||||
"""Builds a graph operator that runs a replicated TPU computation.
|
||||
|
||||
Example for the basic usage that `inputs` has static shape:
|
||||
@ -884,6 +903,8 @@ def replicate(computation,
|
||||
One usage is to enable automatic bucketizing on the inputs by setting the
|
||||
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
|
||||
recompilation in the XLA side.
|
||||
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||
passed to XLA compiler. Use `None` for default options.
|
||||
Returns:
|
||||
A list of outputs, indexed by `[replica_num]` each output can be a nested
|
||||
structure same as what computation() returns with a few exceptions.
|
||||
@ -912,7 +933,8 @@ def replicate(computation,
|
||||
device_assignment,
|
||||
name,
|
||||
maximum_shapes=maximum_shapes,
|
||||
padding_spec=padding_spec)[1]
|
||||
padding_spec=padding_spec,
|
||||
xla_options=xla_options)[1]
|
||||
|
||||
|
||||
def _ceil_to_pow_of_n(x, n):
|
||||
@ -1104,7 +1126,8 @@ def split_compile_and_replicate(computation,
|
||||
name=None,
|
||||
use_tpu=True,
|
||||
maximum_shapes=None,
|
||||
padding_spec=None):
|
||||
padding_spec=None,
|
||||
xla_options=None):
|
||||
"""Builds graph operators that runs compilation and replicated computation.
|
||||
|
||||
This is a lower level interface than replicate that returns a separate compile
|
||||
@ -1146,6 +1169,8 @@ def split_compile_and_replicate(computation,
|
||||
One usage is to enable automatic bucketizing on the inputs by setting the
|
||||
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
|
||||
recompilation in the XLA side.
|
||||
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||
passed to XLA compiler. Use `None` for default options.
|
||||
|
||||
Returns:
|
||||
A list of lists with the first list corresponding to the compile op and the
|
||||
@ -1161,6 +1186,7 @@ def split_compile_and_replicate(computation,
|
||||
"""
|
||||
del name
|
||||
inputs = [[]] if inputs is None else inputs
|
||||
xla_options = xla_options or XLAOptions()
|
||||
|
||||
metadata_kwargs = {}
|
||||
if device_assignment is not None:
|
||||
@ -1284,6 +1310,8 @@ def split_compile_and_replicate(computation,
|
||||
|
||||
metadata_kwargs["step_marker_location"] = getattr(
|
||||
computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
|
||||
metadata_kwargs["use_spmd_for_xla_partitioning"] = \
|
||||
xla_options.use_spmd_for_xla_partitioning
|
||||
|
||||
graph = ops.get_default_graph()
|
||||
|
||||
@ -1569,7 +1597,8 @@ def split_compile_and_shard(computation,
|
||||
output_shard_axes=None,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None):
|
||||
name=None,
|
||||
xla_options=None):
|
||||
"""Shards `computation` for parallel execution.
|
||||
|
||||
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
|
||||
@ -1619,6 +1648,8 @@ def split_compile_and_shard(computation,
|
||||
only one core, and there is either only one shard, or the number of shards
|
||||
is equal to the number of cores in the TPU system.
|
||||
name: (Deprecated) Does nothing.
|
||||
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||
passed to XLA compiler. Use `None` for default options.
|
||||
Returns:
|
||||
A tuple of (compile op, [output tensors]).
|
||||
Raises:
|
||||
@ -1662,7 +1693,8 @@ def split_compile_and_shard(computation,
|
||||
transposed_inputs,
|
||||
infeed_queue=infeed_queue,
|
||||
device_assignment=device_assignment,
|
||||
name=name)
|
||||
name=name,
|
||||
xla_options=xla_options)
|
||||
|
||||
# There must be at least one shard since num_shards > 0.
|
||||
# TODO(b/36647078) remove disable when pylint bug is fixed.
|
||||
@ -1720,7 +1752,8 @@ def shard(computation,
|
||||
output_shard_axes=None,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None):
|
||||
name=None,
|
||||
xla_options=None):
|
||||
"""Shards `computation` for parallel execution.
|
||||
|
||||
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
|
||||
@ -1773,6 +1806,8 @@ def shard(computation,
|
||||
only one core, and there is either only one shard, or the number of shards
|
||||
is equal to the number of cores in the TPU system.
|
||||
name: (Deprecated) Does nothing.
|
||||
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||
passed to XLA compiler. Use `None` for default options.
|
||||
Returns:
|
||||
A list of output tensors.
|
||||
Raises:
|
||||
@ -1789,7 +1824,8 @@ def shard(computation,
|
||||
output_shard_axes=output_shard_axes,
|
||||
infeed_queue=infeed_queue,
|
||||
device_assignment=device_assignment,
|
||||
name=name)[1]
|
||||
name=name,
|
||||
xla_options=xla_options)[1]
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.batch_parallel"])
|
||||
@ -1798,7 +1834,8 @@ def batch_parallel(computation,
|
||||
num_shards=1,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None):
|
||||
name=None,
|
||||
xla_options=None):
|
||||
"""Shards `computation` along the batch dimension for parallel execution.
|
||||
|
||||
Convenience wrapper around shard().
|
||||
@ -1835,6 +1872,8 @@ def batch_parallel(computation,
|
||||
only one core, and there is either only one shard, or the number of shards
|
||||
is equal to the number of cores in the TPU system.
|
||||
name: (Deprecated) Does nothing.
|
||||
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||
passed to XLA compiler. Use `None` for default options.
|
||||
Returns:
|
||||
A list of output tensors.
|
||||
Raises:
|
||||
@ -1846,7 +1885,8 @@ def batch_parallel(computation,
|
||||
num_shards=num_shards,
|
||||
infeed_queue=infeed_queue,
|
||||
device_assignment=device_assignment,
|
||||
name=name)
|
||||
name=name,
|
||||
xla_options=xla_options)
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.rewrite"])
|
||||
@ -1854,7 +1894,8 @@ def rewrite(computation,
|
||||
inputs=None,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None):
|
||||
name=None,
|
||||
xla_options=None):
|
||||
"""Rewrites `computation` for execution on a TPU system.
|
||||
|
||||
Args:
|
||||
@ -1882,6 +1923,8 @@ def rewrite(computation,
|
||||
the TPU topology. May be omitted for a single-core computation, in which
|
||||
case the core attached to task 0, TPU device 0 is used.
|
||||
name: (Deprecated) Does nothing.
|
||||
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||
passed to XLA compiler. Use `None` for default options.
|
||||
Returns:
|
||||
Same data structure as if computation(*inputs) is called directly with some
|
||||
exceptions for correctness. Exceptions include:
|
||||
@ -1899,7 +1942,8 @@ def rewrite(computation,
|
||||
None if inputs is None else [inputs],
|
||||
infeed_queue=infeed_queue,
|
||||
device_assignment=device_assignment,
|
||||
name=name)[0]
|
||||
name=name,
|
||||
xla_options=xla_options)[0]
|
||||
# pylint: enable=indexing-exception
|
||||
|
||||
# Operations that indicate some error in the user's inference graph.
|
||||
|
@ -0,0 +1,19 @@
|
||||
path: "tensorflow.tpu.XLAOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.tpu.tpu.XLAOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.tpu.tpu.XLAOptions\'>"
|
||||
is_instance: "<type \'tuple\'>"
|
||||
member {
|
||||
name: "use_spmd_for_xla_partitioning"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "count"
|
||||
}
|
||||
member_method {
|
||||
name: "index"
|
||||
}
|
||||
}
|
@ -8,13 +8,17 @@ tf_module {
|
||||
name: "PaddingSpec"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "XLAOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "batch_parallel"
|
||||
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "bfloat16_scope"
|
||||
@ -38,15 +42,15 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "replicate"
|
||||
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "rewrite"
|
||||
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "shard"
|
||||
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'input_shard_axes\', \'outputs_from_all_shards\', \'output_shard_axes\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'input_shard_axes\', \'outputs_from_all_shards\', \'output_shard_axes\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "shutdown_system"
|
||||
|
@ -1372,6 +1372,8 @@ renames = {
|
||||
'tf.compat.v1.tpu.shard',
|
||||
'tf.tpu.shutdown_system':
|
||||
'tf.compat.v1.tpu.shutdown_system',
|
||||
'tf.tpu.XLAOptions':
|
||||
'tf.compat.v1.tpu.XLAOptions',
|
||||
'tf.trace':
|
||||
'tf.linalg.trace',
|
||||
'tf.train.AdadeltaOptimizer':
|
||||
|
Loading…
Reference in New Issue
Block a user