Expose XlaOptions as an argument to tpu.rewrite
Right now, it contains an single option whether requested XLA partitioning should be done in SPMD or MPMD. XLA/TPU has two implementations of partitioning. SPMD: ideal for per-op partitioning; MPMD: required for graph partitioning (e.g., GPipe). PiperOrigin-RevId: 320282466 Change-Id: I8b90394333a15da9a3f2d93ac79e56eb2764f466
This commit is contained in:
parent
486ac1e10a
commit
353af3ea66
@ -11,6 +11,7 @@ py_library(
|
|||||||
srcs = ["xla_sharding.py"],
|
srcs = ["xla_sharding.py"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/compiler/tf2xla/python:xla",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_py",
|
"//tensorflow/compiler/xla:xla_data_proto_py",
|
||||||
"//tensorflow/compiler/xla/python_api:types",
|
"//tensorflow/compiler/xla/python_api:types",
|
||||||
"//tensorflow/compiler/xla/python_api:xla_shape",
|
"//tensorflow/compiler/xla/python_api:xla_shape",
|
||||||
|
@ -19,9 +19,10 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from absl import logging
|
import collections
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
@ -812,6 +813,23 @@ class PaddingSpec(enum.IntEnum):
|
|||||||
POWER_OF_TWO = 1
|
POWER_OF_TWO = 1
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export(v1=["tpu.XLAOptions"])
|
||||||
|
class XLAOptions(
|
||||||
|
collections.namedtuple("XLAOptions", [
|
||||||
|
"use_spmd_for_xla_partitioning",
|
||||||
|
])):
|
||||||
|
"""XLA compilation options.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD
|
||||||
|
partitioner instead of MPMD partitioner when compiler partitioning is
|
||||||
|
requested.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, use_spmd_for_xla_partitioning=False):
|
||||||
|
return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning)
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["tpu.replicate"])
|
@tf_export(v1=["tpu.replicate"])
|
||||||
def replicate(computation,
|
def replicate(computation,
|
||||||
inputs=None,
|
inputs=None,
|
||||||
@ -819,7 +837,8 @@ def replicate(computation,
|
|||||||
device_assignment=None,
|
device_assignment=None,
|
||||||
name=None,
|
name=None,
|
||||||
maximum_shapes=None,
|
maximum_shapes=None,
|
||||||
padding_spec=None):
|
padding_spec=None,
|
||||||
|
xla_options=None):
|
||||||
"""Builds a graph operator that runs a replicated TPU computation.
|
"""Builds a graph operator that runs a replicated TPU computation.
|
||||||
|
|
||||||
Example for the basic usage that `inputs` has static shape:
|
Example for the basic usage that `inputs` has static shape:
|
||||||
@ -884,6 +903,8 @@ def replicate(computation,
|
|||||||
One usage is to enable automatic bucketizing on the inputs by setting the
|
One usage is to enable automatic bucketizing on the inputs by setting the
|
||||||
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
|
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
|
||||||
recompilation in the XLA side.
|
recompilation in the XLA side.
|
||||||
|
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||||
|
passed to XLA compiler. Use `None` for default options.
|
||||||
Returns:
|
Returns:
|
||||||
A list of outputs, indexed by `[replica_num]` each output can be a nested
|
A list of outputs, indexed by `[replica_num]` each output can be a nested
|
||||||
structure same as what computation() returns with a few exceptions.
|
structure same as what computation() returns with a few exceptions.
|
||||||
@ -912,7 +933,8 @@ def replicate(computation,
|
|||||||
device_assignment,
|
device_assignment,
|
||||||
name,
|
name,
|
||||||
maximum_shapes=maximum_shapes,
|
maximum_shapes=maximum_shapes,
|
||||||
padding_spec=padding_spec)[1]
|
padding_spec=padding_spec,
|
||||||
|
xla_options=xla_options)[1]
|
||||||
|
|
||||||
|
|
||||||
def _ceil_to_pow_of_n(x, n):
|
def _ceil_to_pow_of_n(x, n):
|
||||||
@ -1104,7 +1126,8 @@ def split_compile_and_replicate(computation,
|
|||||||
name=None,
|
name=None,
|
||||||
use_tpu=True,
|
use_tpu=True,
|
||||||
maximum_shapes=None,
|
maximum_shapes=None,
|
||||||
padding_spec=None):
|
padding_spec=None,
|
||||||
|
xla_options=None):
|
||||||
"""Builds graph operators that runs compilation and replicated computation.
|
"""Builds graph operators that runs compilation and replicated computation.
|
||||||
|
|
||||||
This is a lower level interface than replicate that returns a separate compile
|
This is a lower level interface than replicate that returns a separate compile
|
||||||
@ -1146,6 +1169,8 @@ def split_compile_and_replicate(computation,
|
|||||||
One usage is to enable automatic bucketizing on the inputs by setting the
|
One usage is to enable automatic bucketizing on the inputs by setting the
|
||||||
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
|
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
|
||||||
recompilation in the XLA side.
|
recompilation in the XLA side.
|
||||||
|
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||||
|
passed to XLA compiler. Use `None` for default options.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of lists with the first list corresponding to the compile op and the
|
A list of lists with the first list corresponding to the compile op and the
|
||||||
@ -1161,6 +1186,7 @@ def split_compile_and_replicate(computation,
|
|||||||
"""
|
"""
|
||||||
del name
|
del name
|
||||||
inputs = [[]] if inputs is None else inputs
|
inputs = [[]] if inputs is None else inputs
|
||||||
|
xla_options = xla_options or XLAOptions()
|
||||||
|
|
||||||
metadata_kwargs = {}
|
metadata_kwargs = {}
|
||||||
if device_assignment is not None:
|
if device_assignment is not None:
|
||||||
@ -1284,6 +1310,8 @@ def split_compile_and_replicate(computation,
|
|||||||
|
|
||||||
metadata_kwargs["step_marker_location"] = getattr(
|
metadata_kwargs["step_marker_location"] = getattr(
|
||||||
computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
|
computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
|
||||||
|
metadata_kwargs["use_spmd_for_xla_partitioning"] = \
|
||||||
|
xla_options.use_spmd_for_xla_partitioning
|
||||||
|
|
||||||
graph = ops.get_default_graph()
|
graph = ops.get_default_graph()
|
||||||
|
|
||||||
@ -1569,7 +1597,8 @@ def split_compile_and_shard(computation,
|
|||||||
output_shard_axes=None,
|
output_shard_axes=None,
|
||||||
infeed_queue=None,
|
infeed_queue=None,
|
||||||
device_assignment=None,
|
device_assignment=None,
|
||||||
name=None):
|
name=None,
|
||||||
|
xla_options=None):
|
||||||
"""Shards `computation` for parallel execution.
|
"""Shards `computation` for parallel execution.
|
||||||
|
|
||||||
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
|
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
|
||||||
@ -1619,6 +1648,8 @@ def split_compile_and_shard(computation,
|
|||||||
only one core, and there is either only one shard, or the number of shards
|
only one core, and there is either only one shard, or the number of shards
|
||||||
is equal to the number of cores in the TPU system.
|
is equal to the number of cores in the TPU system.
|
||||||
name: (Deprecated) Does nothing.
|
name: (Deprecated) Does nothing.
|
||||||
|
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||||
|
passed to XLA compiler. Use `None` for default options.
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (compile op, [output tensors]).
|
A tuple of (compile op, [output tensors]).
|
||||||
Raises:
|
Raises:
|
||||||
@ -1662,7 +1693,8 @@ def split_compile_and_shard(computation,
|
|||||||
transposed_inputs,
|
transposed_inputs,
|
||||||
infeed_queue=infeed_queue,
|
infeed_queue=infeed_queue,
|
||||||
device_assignment=device_assignment,
|
device_assignment=device_assignment,
|
||||||
name=name)
|
name=name,
|
||||||
|
xla_options=xla_options)
|
||||||
|
|
||||||
# There must be at least one shard since num_shards > 0.
|
# There must be at least one shard since num_shards > 0.
|
||||||
# TODO(b/36647078) remove disable when pylint bug is fixed.
|
# TODO(b/36647078) remove disable when pylint bug is fixed.
|
||||||
@ -1720,7 +1752,8 @@ def shard(computation,
|
|||||||
output_shard_axes=None,
|
output_shard_axes=None,
|
||||||
infeed_queue=None,
|
infeed_queue=None,
|
||||||
device_assignment=None,
|
device_assignment=None,
|
||||||
name=None):
|
name=None,
|
||||||
|
xla_options=None):
|
||||||
"""Shards `computation` for parallel execution.
|
"""Shards `computation` for parallel execution.
|
||||||
|
|
||||||
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
|
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
|
||||||
@ -1773,6 +1806,8 @@ def shard(computation,
|
|||||||
only one core, and there is either only one shard, or the number of shards
|
only one core, and there is either only one shard, or the number of shards
|
||||||
is equal to the number of cores in the TPU system.
|
is equal to the number of cores in the TPU system.
|
||||||
name: (Deprecated) Does nothing.
|
name: (Deprecated) Does nothing.
|
||||||
|
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||||
|
passed to XLA compiler. Use `None` for default options.
|
||||||
Returns:
|
Returns:
|
||||||
A list of output tensors.
|
A list of output tensors.
|
||||||
Raises:
|
Raises:
|
||||||
@ -1789,7 +1824,8 @@ def shard(computation,
|
|||||||
output_shard_axes=output_shard_axes,
|
output_shard_axes=output_shard_axes,
|
||||||
infeed_queue=infeed_queue,
|
infeed_queue=infeed_queue,
|
||||||
device_assignment=device_assignment,
|
device_assignment=device_assignment,
|
||||||
name=name)[1]
|
name=name,
|
||||||
|
xla_options=xla_options)[1]
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["tpu.batch_parallel"])
|
@tf_export(v1=["tpu.batch_parallel"])
|
||||||
@ -1798,7 +1834,8 @@ def batch_parallel(computation,
|
|||||||
num_shards=1,
|
num_shards=1,
|
||||||
infeed_queue=None,
|
infeed_queue=None,
|
||||||
device_assignment=None,
|
device_assignment=None,
|
||||||
name=None):
|
name=None,
|
||||||
|
xla_options=None):
|
||||||
"""Shards `computation` along the batch dimension for parallel execution.
|
"""Shards `computation` along the batch dimension for parallel execution.
|
||||||
|
|
||||||
Convenience wrapper around shard().
|
Convenience wrapper around shard().
|
||||||
@ -1835,6 +1872,8 @@ def batch_parallel(computation,
|
|||||||
only one core, and there is either only one shard, or the number of shards
|
only one core, and there is either only one shard, or the number of shards
|
||||||
is equal to the number of cores in the TPU system.
|
is equal to the number of cores in the TPU system.
|
||||||
name: (Deprecated) Does nothing.
|
name: (Deprecated) Does nothing.
|
||||||
|
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||||
|
passed to XLA compiler. Use `None` for default options.
|
||||||
Returns:
|
Returns:
|
||||||
A list of output tensors.
|
A list of output tensors.
|
||||||
Raises:
|
Raises:
|
||||||
@ -1846,7 +1885,8 @@ def batch_parallel(computation,
|
|||||||
num_shards=num_shards,
|
num_shards=num_shards,
|
||||||
infeed_queue=infeed_queue,
|
infeed_queue=infeed_queue,
|
||||||
device_assignment=device_assignment,
|
device_assignment=device_assignment,
|
||||||
name=name)
|
name=name,
|
||||||
|
xla_options=xla_options)
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["tpu.rewrite"])
|
@tf_export(v1=["tpu.rewrite"])
|
||||||
@ -1854,7 +1894,8 @@ def rewrite(computation,
|
|||||||
inputs=None,
|
inputs=None,
|
||||||
infeed_queue=None,
|
infeed_queue=None,
|
||||||
device_assignment=None,
|
device_assignment=None,
|
||||||
name=None):
|
name=None,
|
||||||
|
xla_options=None):
|
||||||
"""Rewrites `computation` for execution on a TPU system.
|
"""Rewrites `computation` for execution on a TPU system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1882,6 +1923,8 @@ def rewrite(computation,
|
|||||||
the TPU topology. May be omitted for a single-core computation, in which
|
the TPU topology. May be omitted for a single-core computation, in which
|
||||||
case the core attached to task 0, TPU device 0 is used.
|
case the core attached to task 0, TPU device 0 is used.
|
||||||
name: (Deprecated) Does nothing.
|
name: (Deprecated) Does nothing.
|
||||||
|
xla_options: An instance of `tpu.XLAOptions` which indicates the options
|
||||||
|
passed to XLA compiler. Use `None` for default options.
|
||||||
Returns:
|
Returns:
|
||||||
Same data structure as if computation(*inputs) is called directly with some
|
Same data structure as if computation(*inputs) is called directly with some
|
||||||
exceptions for correctness. Exceptions include:
|
exceptions for correctness. Exceptions include:
|
||||||
@ -1899,7 +1942,8 @@ def rewrite(computation,
|
|||||||
None if inputs is None else [inputs],
|
None if inputs is None else [inputs],
|
||||||
infeed_queue=infeed_queue,
|
infeed_queue=infeed_queue,
|
||||||
device_assignment=device_assignment,
|
device_assignment=device_assignment,
|
||||||
name=name)[0]
|
name=name,
|
||||||
|
xla_options=xla_options)[0]
|
||||||
# pylint: enable=indexing-exception
|
# pylint: enable=indexing-exception
|
||||||
|
|
||||||
# Operations that indicate some error in the user's inference graph.
|
# Operations that indicate some error in the user's inference graph.
|
||||||
|
@ -0,0 +1,19 @@
|
|||||||
|
path: "tensorflow.tpu.XLAOptions"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.tpu.tpu.XLAOptions\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.tpu.tpu.XLAOptions\'>"
|
||||||
|
is_instance: "<type \'tuple\'>"
|
||||||
|
member {
|
||||||
|
name: "use_spmd_for_xla_partitioning"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "count"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "index"
|
||||||
|
}
|
||||||
|
}
|
@ -8,13 +8,17 @@ tf_module {
|
|||||||
name: "PaddingSpec"
|
name: "PaddingSpec"
|
||||||
mtype: "<class \'enum.EnumMeta\'>"
|
mtype: "<class \'enum.EnumMeta\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "XLAOptions"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "experimental"
|
name: "experimental"
|
||||||
mtype: "<type \'module\'>"
|
mtype: "<type \'module\'>"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "batch_parallel"
|
name: "batch_parallel"
|
||||||
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "bfloat16_scope"
|
name: "bfloat16_scope"
|
||||||
@ -38,15 +42,15 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "replicate"
|
name: "replicate"
|
||||||
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "rewrite"
|
name: "rewrite"
|
||||||
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "shard"
|
name: "shard"
|
||||||
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'input_shard_axes\', \'outputs_from_all_shards\', \'output_shard_axes\', \'infeed_queue\', \'device_assignment\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'computation\', \'inputs\', \'num_shards\', \'input_shard_axes\', \'outputs_from_all_shards\', \'output_shard_axes\', \'infeed_queue\', \'device_assignment\', \'name\', \'xla_options\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "shutdown_system"
|
name: "shutdown_system"
|
||||||
|
@ -1372,6 +1372,8 @@ renames = {
|
|||||||
'tf.compat.v1.tpu.shard',
|
'tf.compat.v1.tpu.shard',
|
||||||
'tf.tpu.shutdown_system':
|
'tf.tpu.shutdown_system':
|
||||||
'tf.compat.v1.tpu.shutdown_system',
|
'tf.compat.v1.tpu.shutdown_system',
|
||||||
|
'tf.tpu.XLAOptions':
|
||||||
|
'tf.compat.v1.tpu.XLAOptions',
|
||||||
'tf.trace':
|
'tf.trace':
|
||||||
'tf.linalg.trace',
|
'tf.linalg.trace',
|
||||||
'tf.train.AdadeltaOptimizer':
|
'tf.train.AdadeltaOptimizer':
|
||||||
|
Loading…
Reference in New Issue
Block a user