Add sharding op to TF and HLO dialect

PiperOrigin-RevId: 292534463
Change-Id: Ibae20d6bd108b3f3dc2bd6085f6f5ad712f3d20e
This commit is contained in:
HyoukJoong Lee 2020-01-31 06:29:42 -08:00 committed by TensorFlower Gardener
parent ed87850352
commit 6b0350800f
14 changed files with 171 additions and 2 deletions

View File

@ -175,9 +175,11 @@ cc_library(
":tensorflow_ops_inc_gen",
":tf_saved_model_inc_gen",
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:protobuf",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CallOpInterfacesIncGen",

View File

@ -7367,6 +7367,25 @@ def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape]>,
let hasCanonicalizer = 1;
}
def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> {
let summary = [{
An op which shards the input based on the given sharding attribute.
}];
let description = [{
}];
let arguments = (ins
TF_Tensor:$input
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns a tensor of zeros with the same shape and type as x.";

View File

@ -58,7 +58,9 @@ limitations under the License.
#include "mlir/Support/STLExtras.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/tensor_format.h"
namespace mlir {
@ -2932,5 +2934,22 @@ Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder,
return builder.create<ConstOp>(loc, type, value);
}
LogicalResult TensorFlowDialect::verifyOperationAttribute(
Operation *op, NamedAttribute attribute) {
// Check the _XlaSharding attribute is a valid serialized bytes of OpSharding.
if (attribute.first.is("tf._XlaSharding")) {
auto sharding = attribute.second.dyn_cast<mlir::StringAttr>();
if (!sharding) {
return op->emitError() << "tf._XlaSharding must be a string attribute";
}
::xla::OpSharding sharding_proto;
if (!sharding_proto.ParseFromString(sharding.getValue().str())) {
return op->emitError() << "Invalid sharding: " << sharding.getValue();
}
}
return success();
}
} // namespace TF
} // namespace mlir

View File

@ -72,6 +72,10 @@ class TensorFlowDialect : public Dialect {
// value with the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
// Verify an attribute from this dialect on the given operation.
LogicalResult verifyOperationAttribute(Operation *op,
NamedAttribute attribute) override;
};
// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose

View File

@ -2301,3 +2301,11 @@ func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf
%result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %empty_str_vector, %ragged_keys) {dense_shapes = [], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 0, 2, 1]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>) -> (tensor<?xf32>, tensor<?x!tf.string>, tensor<?xi32>)
return %result#0 : tensor<?xf32>
}
// -----
func @testInvalidXlaSharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
// expected-error @+1 {{Invalid sharding: some-invalid-sharding}}
%0 = "tf.XlaSharding"(%arg0) {tf._XlaSharding = "some-invalid-sharding"} : (tensor<4x16xf32>) -> tensor<4x16xf32>
return %0 : tensor<4x16xf32>
}

View File

@ -114,6 +114,7 @@ cc_library(
":hlo",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_grad_shape_utils",
@ -358,6 +359,8 @@ cc_library(
":hlo_ops_base_inc_gen",
":hlo_ops_inc_gen",
":xla_canonicalize_inc_gen",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core/platform:protobuf",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",

View File

@ -49,6 +49,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc"
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/protobuf.h"
namespace mlir {
#include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc"
@ -70,6 +72,24 @@ static LogicalResult Verify(T op) {
return success();
}
LogicalResult XlaHloDialect::verifyOperationAttribute(
Operation* op, NamedAttribute attribute) {
// Check the sharding attribute is a valid sharding text string.
if (attribute.first.is("xla_hlo.sharding")) {
auto sharding = attribute.second.dyn_cast<mlir::StringAttr>();
if (!sharding) {
return op->emitError() << "xla_hlo.sharding must be a string attribute";
}
::xla::OpSharding sharding_proto;
if (sharding && !::tensorflow::protobuf::TextFormat::ParseFromString(
sharding.getValue().str(), &sharding_proto)) {
return op->emitError() << "Invalid sharding: " << sharding.getValue();
}
}
return success();
}
namespace {
//===----------------------------------------------------------------------===//

View File

@ -52,6 +52,10 @@ class XlaHloDialect : public Dialect {
// Prints a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
// Verify an attribute from this dialect on the given operation.
LogicalResult verifyOperationAttribute(Operation *op,
NamedAttribute attribute) override;
};
namespace HLOTypes {

View File

@ -371,6 +371,23 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
return output;
}
// Returns an OpSharding proto from the "sharding" attribute of the op. If the
// op doesn't have a sharding attribute or the sharding attribute is invalid,
// returns absl::nullopt.
static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
mlir::Operation* op) {
auto sharding = op->getAttrOfType<mlir::StringAttr>("xla_hlo.sharding");
if (!sharding) {
return absl::nullopt;
}
::xla::OpSharding sharding_proto;
if (!::tensorflow::protobuf::TextFormat::ParseFromString(
sharding.getValue().str(), &sharding_proto)) {
return absl::nullopt;
}
return sharding_proto;
}
namespace mlir {
namespace {
class ConvertToHloModule {

View File

@ -131,7 +131,12 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
// Emit a function to generate an XLA operation for the operations with
// auto-generated builders.
os << "mlir::LogicalResult ExportXlaOperator(\n"
"mlir::Operation* op, OpLoweringContext lowering_context) {\n";
"mlir::Operation* op, OpLoweringContext lowering_context) {\n\n";
// Create a scoped object to assign sharding to generated XLA ops. Any HLO
// can have an attribute of "sharding".
os << " xla::XlaScopedShardingAssignment sharding(lowering_context.builder, "
"CreateOpShardingFromAttribute(op));\n\n";
// Retrieve all the definitions derived from HLO_Op and sort by record name.
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {

View File

@ -3221,3 +3221,10 @@ func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32>
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32>
return %0 : tensor<2x4x7x7xf32>
}
// CHECK-LABEL: xla_sharding
func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
// CHECK-NEXT: "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, xla_hlo.sharding = ""}
%0 = "tf.XlaSharding"(%arg0) {_XlaSharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32>
return %0 : tensor<4x16xf32>
}

View File

@ -837,3 +837,11 @@ func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x1
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @custom_call_invalid_sharding(%input0: tensor<16x16xf32>) -> tensor<16x16xf32> {
// expected-error @+1 {{Invalid sharding: some-invalid-sharding}}
%0 = "xla_hlo.custom_call"(%input0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "some-invalid-sharding"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}

View File

@ -910,6 +910,18 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
// -----
// CHECK: HloModule
func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
%0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "type: OTHER\ntile_assignment_dimensions: 1\ntile_assignment_dimensions: 2\ntile_assignment_devices: 0\ntile_assignment_devices: 1"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}
// CHECK: ENTRY
// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] custom-call(f32[16,16] %[[ARG0]]), custom_call_target="Sharding", sharding={devices=[1,2]0,1}
// -----
// Tests that the exported HLO module keeps parameter replication annotation.

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include "tensorflow/core/util/padding.h"
@ -3403,6 +3404,45 @@ class ConvertVariableShapeOp : public OpRewritePattern<TF::VariableShapeOp> {
}
};
// Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
public:
using OpRewritePattern::OpRewritePattern;
PatternMatchResult matchAndRewrite(TF::XlaShardingOp op,
PatternRewriter &rewriter) const override {
// TODO(b/148313088): define sharding attribute struct in MLIR intead of
// using a string.
auto sharding = op.getAttrOfType<StringAttr>("_XlaSharding");
if (!sharding) {
return matchFailure();
}
// _XlaSharding attribute in TF is a serialized string of the OpSharding
// proto, so convert to a text form here.
::xla::OpSharding sharding_proto;
std::string sharding_str;
if (!sharding_proto.ParseFromString(sharding.getValue().str())) {
return matchFailure();
}
if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
&sharding_str)) {
return matchFailure();
}
auto custom_call = rewriter.create<xla_hlo::CustomCallOp>(
op.getLoc(), op.getType(), op.input(),
/*call_target_name=*/rewriter.getStringAttr("Sharding"),
/*has_side_effect=*/rewriter.getBoolAttr(false),
/*backend_config=*/rewriter.getStringAttr(""));
custom_call.setAttr("xla_hlo.sharding",
rewriter.getStringAttr(sharding_str));
rewriter.replaceOp(op, custom_call.getResult());
return matchSuccess();
}
};
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
@ -3432,7 +3472,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
ConvertRandomShuffleOp, ConvertVariableShapeOp>(op->getContext());
ConvertRandomShuffleOp, ConvertVariableShapeOp, ConvertXlaShardingOp>(
op->getContext());
ConversionTarget target(*context);
target.addLegalDialect<XlaHloDialect>();