[XLA:SPMD] Add API for the experimental mechanism to mix auto and manual partitioning
There are two ops added: XlaSpmdFullToShardShape: casts full shape (to be auto partitioned) to shard shape, which can then be consumed by manually partitioned code. XlaSpmdShardToFullShape: casts shard shape (manually partitioned) to full shape, which will be consumed by ops auto-partitioned by SPMD. PiperOrigin-RevId: 309845623 Change-Id: Ic056f2965c04a2357c2bf63d642eeafdbaa19b18
This commit is contained in:
parent
262997ce0b
commit
19624f9650
@ -2078,6 +2078,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"XlaSend",
|
||||
"XlaSharding",
|
||||
"XlaSort",
|
||||
"XlaSpmdFullToShardShape",
|
||||
"XlaSpmdShardToFullShape",
|
||||
"XlaSvd",
|
||||
"XlaWhile",
|
||||
"_Arg",
|
||||
|
@ -103,6 +103,7 @@ tf_kernel_library(
|
||||
"spacetodepth_op.cc",
|
||||
"sparse_to_dense_op.cc",
|
||||
"split_op.cc",
|
||||
"spmd_manual_sharding_ops.cc",
|
||||
"stack_ops.cc",
|
||||
"stateful_random_ops.cc",
|
||||
"stateless_random_ops.cc",
|
||||
|
147
tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc
Normal file
147
tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc
Normal file
@ -0,0 +1,147 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class XlaSpmdFullToShardShapeOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit XlaSpmdFullToShardShapeOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("manual_sharding", &manual_sharding_str_));
|
||||
}
|
||||
|
||||
~XlaSpmdFullToShardShapeOp() override = default;
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
auto input_shape_or = ctx->InputXlaShape(0);
|
||||
OP_REQUIRES_OK(ctx, input_shape_or.status());
|
||||
xla::OpSharding sharding;
|
||||
if (!sharding.ParseFromString(manual_sharding_str_)) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
xla::InvalidArgument("manual_sharding attribute was not a "
|
||||
"valid encoded xla::OpSharding "
|
||||
"proto."));
|
||||
}
|
||||
auto output_shape = input_shape_or.ValueOrDie();
|
||||
int64 rank = output_shape.rank();
|
||||
if (sharding.type() == xla::OpSharding::OTHER) {
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
int64 partitions_i = sharding.tile_assignment_dimensions(i);
|
||||
if (partitions_i == 1) continue;
|
||||
int64 dim_size =
|
||||
xla::CeilOfRatio(output_shape.dimensions(i), partitions_i);
|
||||
output_shape.set_dimensions(i, dim_size);
|
||||
}
|
||||
}
|
||||
xla::XlaOp input_annotation;
|
||||
{
|
||||
// Annotate the full-shape input with the manual sharding.
|
||||
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
|
||||
sharding);
|
||||
input_annotation =
|
||||
xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding",
|
||||
{input}, input_shape_or.ValueOrDie());
|
||||
}
|
||||
|
||||
{
|
||||
// Annotate the shard-shape output with replicated sharding, so that the
|
||||
// partitioner will leave it as is.
|
||||
xla::OpSharding replicated;
|
||||
replicated.set_type(xla::OpSharding::REPLICATED);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
|
||||
replicated);
|
||||
auto output = xla::CustomCall(ctx->builder(),
|
||||
/*call_target_name=*/"SPMDFullToShardShape",
|
||||
{input_annotation}, output_shape);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
string manual_sharding_str_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaSpmdFullToShardShapeOp);
|
||||
};
|
||||
|
||||
class XlaSpmdShardToFullShapeOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit XlaSpmdShardToFullShapeOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_shape", &full_shape_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("manual_sharding", &manual_sharding_str_));
|
||||
}
|
||||
|
||||
~XlaSpmdShardToFullShapeOp() override = default;
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
auto input_shape_or = ctx->InputXlaShape(0);
|
||||
OP_REQUIRES_OK(ctx, input_shape_or.status());
|
||||
auto output_shape = TensorShapeToXLAShape(
|
||||
input_shape_or.ValueOrDie().element_type(), full_shape_);
|
||||
|
||||
xla::OpSharding sharding;
|
||||
if (!sharding.ParseFromString(manual_sharding_str_)) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
xla::InvalidArgument("manual_sharding attribute was not a "
|
||||
"valid encoded xla::OpSharding "
|
||||
"proto."));
|
||||
}
|
||||
xla::XlaOp input_annotation;
|
||||
{
|
||||
// Annotate the shard-shape input with replicated sharding, so that the
|
||||
// partitioner will leave it as is.
|
||||
xla::OpSharding replicated;
|
||||
replicated.set_type(xla::OpSharding::REPLICATED);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
|
||||
replicated);
|
||||
input_annotation =
|
||||
xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding",
|
||||
{input}, input_shape_or.ValueOrDie());
|
||||
}
|
||||
|
||||
{
|
||||
// Annotate the full-shape output with the manual sharding.
|
||||
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
|
||||
sharding);
|
||||
ctx->SetOutput(
|
||||
0, xla::CustomCall(ctx->builder(),
|
||||
/*call_target_name=*/"SPMDShardToFullShape",
|
||||
{input_annotation}, output_shape));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
TensorShape full_shape_;
|
||||
string manual_sharding_str_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaSpmdShardToFullShapeOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaSpmdFullToShardShape"), XlaSpmdFullToShardShapeOp);
|
||||
REGISTER_XLA_OP(Name("XlaSpmdShardToFullShape"), XlaSpmdShardToFullShapeOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -648,6 +648,62 @@ This op has better TPU performance since it doesn't have explicitly reshape and
|
||||
transpose operations as tf.einsum does.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaSpmdFullToShardShape")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("manual_sharding: string")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
auto input_handle = c->input(0);
|
||||
if (!c->RankKnown(input_handle)) {
|
||||
return shape_inference::UnknownShape(c);
|
||||
}
|
||||
string sharding_attr;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
|
||||
std::vector<shape_inference::DimensionHandle> dims;
|
||||
for (int64 i = 0; i < c->Rank(input_handle); ++i) {
|
||||
auto dim = c->Value(c->Dim(input_handle, i));
|
||||
xla::OpSharding sharding;
|
||||
sharding.ParseFromString(sharding_attr);
|
||||
int64 partitions_i = sharding.tile_assignment_dimensions(i);
|
||||
if (dim != shape_inference::InferenceContext::kUnknownDim &&
|
||||
sharding.type() == xla::OpSharding::OTHER && partitions_i != 1) {
|
||||
dim = (dim + partitions_i - 1) / partitions_i;
|
||||
}
|
||||
dims.push_back(c->MakeDim(dim));
|
||||
}
|
||||
c->set_output(0, c->MakeShape(dims));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
An op used by XLA SPMD partitioner to switch from automatic partitioning to
|
||||
manual partitioning. It annotates the input (full-shape, to be automatically
|
||||
partitioned) with the same sharding used by manual partitioning, and outputs a
|
||||
shard-shaped tensor to be consumed by later manually-partitioned ops. If the
|
||||
shape is not evenly partitionable, the padding region will be masked with 0s.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaSpmdShardToFullShape")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("manual_sharding: string")
|
||||
.Attr("full_shape: shape")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
TensorShape shape_attr;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr));
|
||||
shape_inference::ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
|
||||
c->set_output(0, s);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
An op used by XLA SPMD partitioner to switch from manual partitioning to
|
||||
automatic partitioning. It converts the shard-shaped, manually partitioned input
|
||||
into full-shaped tensor to be partitioned automatically with the same sharding
|
||||
used by manual partitioning.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaSharding")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
|
@ -418,6 +418,26 @@ def _sharding_grad(op, grad):
|
||||
return [grad]
|
||||
|
||||
|
||||
spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
|
||||
spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape
|
||||
|
||||
|
||||
@ops.RegisterGradient("XlaSpmdFullToShardShape")
|
||||
def _spmd_full_to_shard_shape_grad(op, grad):
|
||||
s2f = gen_xla_ops.xla_spmd_shard_to_full_shape(
|
||||
grad,
|
||||
manual_sharding=op.get_attr("manual_sharding"),
|
||||
full_shape=op.inputs[0].shape.as_list())
|
||||
return [s2f]
|
||||
|
||||
|
||||
@ops.RegisterGradient("XlaSpmdShardToFullShape")
|
||||
def _spmd_shard_to_full_shape_grad(op, grad):
|
||||
f2s = gen_xla_ops.xla_spmd_full_to_shard_shape(
|
||||
grad, manual_sharding=op.get_attr("manual_sharding"))
|
||||
return [f2s]
|
||||
|
||||
|
||||
sort = gen_xla_ops.xla_sort
|
||||
key_value_sort = gen_xla_ops.xla_key_value_sort
|
||||
while_loop = gen_xla_ops.xla_while
|
||||
|
@ -243,3 +243,54 @@ def split(tensor,
|
||||
tensor, split_dimension, num_devices, input_shape).apply_to_tensor(
|
||||
tensor, assign_tuple_sharding=assign_tuple_sharding)
|
||||
return tensor
|
||||
|
||||
|
||||
def get_op_sharding(op):
|
||||
"""Returns sharding attribute of an op.
|
||||
|
||||
Args:
|
||||
op: a TensorFlow op.
|
||||
|
||||
Returns:
|
||||
The attribute representing XLA sharding on this op.
|
||||
"""
|
||||
return op.get_attr('_XlaSharding')
|
||||
|
||||
|
||||
def auto_to_manual_spmd_partition(tensor, manual_sharding):
|
||||
"""Switches from automatic SPMD partitioning to manual partitioning.
|
||||
|
||||
Converts a full-shaped tensor (to be automatically partitioned by SPMD
|
||||
partitioner) to a shard-shaped tensor to be consumed by manually partitioned
|
||||
ops.
|
||||
|
||||
Args:
|
||||
tensor: A tf.Tensor in full shape.
|
||||
manual_sharding: a serialized string of OpSharding to be used in manual
|
||||
partitioning.
|
||||
|
||||
Returns:
|
||||
A shard-shaped tensor to be consumed by manually partitioned ops.
|
||||
"""
|
||||
return tf2xla.spmd_full_to_shard_shape(
|
||||
tensor, manual_sharding=manual_sharding)
|
||||
|
||||
|
||||
def manual_to_auto_spmd_partition(tensor, manual_sharding, full_shape):
|
||||
"""Switches from manual partitioning to automatic SPMD partitioning.
|
||||
|
||||
Converts a shard-shaped tensor (manually partitioned in SPMD-style) to a
|
||||
full-shaped tensor to be partitioned automatically by the SPMD partitioner.
|
||||
|
||||
Args:
|
||||
tensor: A tf.Tensor in shard shape.
|
||||
manual_sharding: a serialized string of OpSharding to be used in manual
|
||||
partitioning.
|
||||
full_shape: the shape of tensor before partitioning.
|
||||
|
||||
Returns:
|
||||
A full-shaped tensor to be partitioned automatically by the SPMD
|
||||
partitioner.
|
||||
"""
|
||||
return tf2xla.spmd_shard_to_full_shape(
|
||||
tensor, manual_sharding=manual_sharding, full_shape=full_shape)
|
||||
|
@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 347> a = {{
|
||||
static std::array<OpIndexInfo, 348> a = {{
|
||||
{"Acosh"},
|
||||
{"AllToAll", 1, {0}},
|
||||
{"ApproximateEqual"},
|
||||
@ -396,6 +396,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
{"WholeFileReader"},
|
||||
{"XlaClusterOutput"},
|
||||
{"XlaSharding"},
|
||||
{"XlaSpmdShardToFullShape"},
|
||||
{"ZerosLike"},
|
||||
{"VarHandleOp"},
|
||||
}};
|
||||
@ -410,7 +411,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 459> a = {{
|
||||
static std::array<OpIndexInfo, 461> a = {{
|
||||
{"Abs"},
|
||||
{"AccumulateNV2"},
|
||||
{"Acos"},
|
||||
@ -865,6 +866,8 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
{"XlaClusterOutput"},
|
||||
{"XlaEinsum"},
|
||||
{"XlaSharding"},
|
||||
{"XlaSpmdFullToShardShape"},
|
||||
{"XlaSpmdShardToFullShape"},
|
||||
{"Xlog1py"},
|
||||
{"Xlogy"},
|
||||
{"ZerosLike"},
|
||||
|
Loading…
Reference in New Issue
Block a user