[XLA:SPMD] Add API for the experimental mechanism to mix auto and manual partitioning

There are two ops added:
XlaSpmdFullToShardShape: casts full shape (to be auto partitioned) to shard shape, which can then be consumed by manually partitioned code.

XlaSpmdShardToFullShape: casts shard shape (manually partitioned) to full shape, which will be consumed by ops auto-partitioned by SPMD.
PiperOrigin-RevId: 309845623
Change-Id: Ic056f2965c04a2357c2bf63d642eeafdbaa19b18
This commit is contained in:
Yuanzhong Xu 2020-05-04 16:54:11 -07:00 committed by TensorFlower Gardener
parent 262997ce0b
commit 19624f9650
7 changed files with 282 additions and 2 deletions

View File

@ -2078,6 +2078,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"XlaSend",
"XlaSharding",
"XlaSort",
"XlaSpmdFullToShardShape",
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaWhile",
"_Arg",

View File

@ -103,6 +103,7 @@ tf_kernel_library(
"spacetodepth_op.cc",
"sparse_to_dense_op.cc",
"split_op.cc",
"spmd_manual_sharding_ops.cc",
"stack_ops.cc",
"stateful_random_ops.cc",
"stateless_random_ops.cc",

View File

@ -0,0 +1,147 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace {
class XlaSpmdFullToShardShapeOp : public XlaOpKernel {
public:
explicit XlaSpmdFullToShardShapeOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("manual_sharding", &manual_sharding_str_));
}
~XlaSpmdFullToShardShapeOp() override = default;
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp input = ctx->Input(0);
auto input_shape_or = ctx->InputXlaShape(0);
OP_REQUIRES_OK(ctx, input_shape_or.status());
xla::OpSharding sharding;
if (!sharding.ParseFromString(manual_sharding_str_)) {
OP_REQUIRES_OK(ctx,
xla::InvalidArgument("manual_sharding attribute was not a "
"valid encoded xla::OpSharding "
"proto."));
}
auto output_shape = input_shape_or.ValueOrDie();
int64 rank = output_shape.rank();
if (sharding.type() == xla::OpSharding::OTHER) {
for (int64 i = 0; i < rank; ++i) {
int64 partitions_i = sharding.tile_assignment_dimensions(i);
if (partitions_i == 1) continue;
int64 dim_size =
xla::CeilOfRatio(output_shape.dimensions(i), partitions_i);
output_shape.set_dimensions(i, dim_size);
}
}
xla::XlaOp input_annotation;
{
// Annotate the full-shape input with the manual sharding.
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
sharding);
input_annotation =
xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding",
{input}, input_shape_or.ValueOrDie());
}
{
// Annotate the shard-shape output with replicated sharding, so that the
// partitioner will leave it as is.
xla::OpSharding replicated;
replicated.set_type(xla::OpSharding::REPLICATED);
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
replicated);
auto output = xla::CustomCall(ctx->builder(),
/*call_target_name=*/"SPMDFullToShardShape",
{input_annotation}, output_shape);
ctx->SetOutput(0, output);
}
}
private:
string manual_sharding_str_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaSpmdFullToShardShapeOp);
};
class XlaSpmdShardToFullShapeOp : public XlaOpKernel {
public:
explicit XlaSpmdShardToFullShapeOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_shape", &full_shape_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("manual_sharding", &manual_sharding_str_));
}
~XlaSpmdShardToFullShapeOp() override = default;
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp input = ctx->Input(0);
auto input_shape_or = ctx->InputXlaShape(0);
OP_REQUIRES_OK(ctx, input_shape_or.status());
auto output_shape = TensorShapeToXLAShape(
input_shape_or.ValueOrDie().element_type(), full_shape_);
xla::OpSharding sharding;
if (!sharding.ParseFromString(manual_sharding_str_)) {
OP_REQUIRES_OK(ctx,
xla::InvalidArgument("manual_sharding attribute was not a "
"valid encoded xla::OpSharding "
"proto."));
}
xla::XlaOp input_annotation;
{
// Annotate the shard-shape input with replicated sharding, so that the
// partitioner will leave it as is.
xla::OpSharding replicated;
replicated.set_type(xla::OpSharding::REPLICATED);
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
replicated);
input_annotation =
xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding",
{input}, input_shape_or.ValueOrDie());
}
{
// Annotate the full-shape output with the manual sharding.
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
sharding);
ctx->SetOutput(
0, xla::CustomCall(ctx->builder(),
/*call_target_name=*/"SPMDShardToFullShape",
{input_annotation}, output_shape));
}
}
private:
TensorShape full_shape_;
string manual_sharding_str_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaSpmdShardToFullShapeOp);
};
REGISTER_XLA_OP(Name("XlaSpmdFullToShardShape"), XlaSpmdFullToShardShapeOp);
REGISTER_XLA_OP(Name("XlaSpmdShardToFullShape"), XlaSpmdShardToFullShapeOp);
} // namespace
} // namespace tensorflow

View File

@ -648,6 +648,62 @@ This op has better TPU performance since it doesn't have explicitly reshape and
transpose operations as tf.einsum does.
)doc");
REGISTER_OP("XlaSpmdFullToShardShape")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.Attr("manual_sharding: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
auto input_handle = c->input(0);
if (!c->RankKnown(input_handle)) {
return shape_inference::UnknownShape(c);
}
string sharding_attr;
TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
std::vector<shape_inference::DimensionHandle> dims;
for (int64 i = 0; i < c->Rank(input_handle); ++i) {
auto dim = c->Value(c->Dim(input_handle, i));
xla::OpSharding sharding;
sharding.ParseFromString(sharding_attr);
int64 partitions_i = sharding.tile_assignment_dimensions(i);
if (dim != shape_inference::InferenceContext::kUnknownDim &&
sharding.type() == xla::OpSharding::OTHER && partitions_i != 1) {
dim = (dim + partitions_i - 1) / partitions_i;
}
dims.push_back(c->MakeDim(dim));
}
c->set_output(0, c->MakeShape(dims));
return Status::OK();
})
.Doc(R"doc(
An op used by XLA SPMD partitioner to switch from automatic partitioning to
manual partitioning. It annotates the input (full-shape, to be automatically
partitioned) with the same sharding used by manual partitioning, and outputs a
shard-shaped tensor to be consumed by later manually-partitioned ops. If the
shape is not evenly partitionable, the padding region will be masked with 0s.
)doc");
REGISTER_OP("XlaSpmdShardToFullShape")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.Attr("manual_sharding: string")
.Attr("full_shape: shape")
.SetShapeFn([](shape_inference::InferenceContext* c) {
TensorShape shape_attr;
TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
c->set_output(0, s);
return Status::OK();
})
.Doc(R"doc(
An op used by XLA SPMD partitioner to switch from manual partitioning to
automatic partitioning. It converts the shard-shaped, manually partitioned input
into full-shaped tensor to be partitioned automatically with the same sharding
used by manual partitioning.
)doc");
REGISTER_OP("XlaSharding")
.Input("input: T")
.Output("output: T")

View File

@ -418,6 +418,26 @@ def _sharding_grad(op, grad):
return [grad]
spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape
@ops.RegisterGradient("XlaSpmdFullToShardShape")
def _spmd_full_to_shard_shape_grad(op, grad):
s2f = gen_xla_ops.xla_spmd_shard_to_full_shape(
grad,
manual_sharding=op.get_attr("manual_sharding"),
full_shape=op.inputs[0].shape.as_list())
return [s2f]
@ops.RegisterGradient("XlaSpmdShardToFullShape")
def _spmd_shard_to_full_shape_grad(op, grad):
f2s = gen_xla_ops.xla_spmd_full_to_shard_shape(
grad, manual_sharding=op.get_attr("manual_sharding"))
return [f2s]
sort = gen_xla_ops.xla_sort
key_value_sort = gen_xla_ops.xla_key_value_sort
while_loop = gen_xla_ops.xla_while

View File

@ -243,3 +243,54 @@ def split(tensor,
tensor, split_dimension, num_devices, input_shape).apply_to_tensor(
tensor, assign_tuple_sharding=assign_tuple_sharding)
return tensor
def get_op_sharding(op):
"""Returns sharding attribute of an op.
Args:
op: a TensorFlow op.
Returns:
The attribute representing XLA sharding on this op.
"""
return op.get_attr('_XlaSharding')
def auto_to_manual_spmd_partition(tensor, manual_sharding):
"""Switches from automatic SPMD partitioning to manual partitioning.
Converts a full-shaped tensor (to be automatically partitioned by SPMD
partitioner) to a shard-shaped tensor to be consumed by manually partitioned
ops.
Args:
tensor: A tf.Tensor in full shape.
manual_sharding: a serialized string of OpSharding to be used in manual
partitioning.
Returns:
A shard-shaped tensor to be consumed by manually partitioned ops.
"""
return tf2xla.spmd_full_to_shard_shape(
tensor, manual_sharding=manual_sharding)
def manual_to_auto_spmd_partition(tensor, manual_sharding, full_shape):
"""Switches from manual partitioning to automatic SPMD partitioning.
Converts a shard-shaped tensor (manually partitioned in SPMD-style) to a
full-shaped tensor to be partitioned automatically by the SPMD partitioner.
Args:
tensor: A tf.Tensor in shard shape.
manual_sharding: a serialized string of OpSharding to be used in manual
partitioning.
full_shape: the shape of tensor before partitioning.
Returns:
A full-shaped tensor to be partitioned automatically by the SPMD
partitioner.
"""
return tf2xla.spmd_shard_to_full_shape(
tensor, manual_sharding=manual_sharding, full_shape=full_shape)

View File

@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 347> a = {{
static std::array<OpIndexInfo, 348> a = {{
{"Acosh"},
{"AllToAll", 1, {0}},
{"ApproximateEqual"},
@ -396,6 +396,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
{"WholeFileReader"},
{"XlaClusterOutput"},
{"XlaSharding"},
{"XlaSpmdShardToFullShape"},
{"ZerosLike"},
{"VarHandleOp"},
}};
@ -410,7 +411,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 459> a = {{
static std::array<OpIndexInfo, 461> a = {{
{"Abs"},
{"AccumulateNV2"},
{"Acos"},
@ -865,6 +866,8 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
{"XlaClusterOutput"},
{"XlaEinsum"},
{"XlaSharding"},
{"XlaSpmdFullToShardShape"},
{"XlaSpmdShardToFullShape"},
{"Xlog1py"},
{"Xlogy"},
{"ZerosLike"},