diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index a38a3ceb344..5204370772b 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -175,9 +175,11 @@ cc_library( ":tensorflow_ops_inc_gen", ":tf_saved_model_inc_gen", "//tensorflow/compiler/mlir/lite:validators", + "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:protobuf", "@llvm-project//llvm:support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:CallOpInterfacesIncGen", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 071605b3745..d535790d5fb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -7367,6 +7367,25 @@ def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape]>, let hasCanonicalizer = 1; } +def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> { + let summary = [{ +An op which shards the input based on the given sharding attribute. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_Tensor:$input + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns a tensor of zeros with the same shape and type as x."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index a473c5442f0..81e327cd023 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -58,7 +58,9 @@ limitations under the License. #include "mlir/Support/STLExtras.h" // TF:llvm-project #include "mlir/Transforms/InliningUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/util/tensor_format.h" namespace mlir { @@ -2932,5 +2934,22 @@ Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder, return builder.create(loc, type, value); } +LogicalResult TensorFlowDialect::verifyOperationAttribute( + Operation *op, NamedAttribute attribute) { + // Check the _XlaSharding attribute is a valid serialized bytes of OpSharding. + if (attribute.first.is("tf._XlaSharding")) { + auto sharding = attribute.second.dyn_cast(); + if (!sharding) { + return op->emitError() << "tf._XlaSharding must be a string attribute"; + } + + ::xla::OpSharding sharding_proto; + if (!sharding_proto.ParseFromString(sharding.getValue().str())) { + return op->emitError() << "Invalid sharding: " << sharding.getValue(); + } + } + return success(); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index b6f1f76782f..f6f37b5d389 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -72,6 +72,10 @@ class TensorFlowDialect : public Dialect { // value with the desired resultant type. Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) override; + + // Verify an attribute from this dialect on the given operation. + LogicalResult verifyOperationAttribute(Operation *op, + NamedAttribute attribute) override; }; // TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index e734d3d7c89..22dff942669 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -2301,3 +2301,11 @@ func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf %result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %empty_str_vector, %ragged_keys) {dense_shapes = [], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 0, 2, 1]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>) -> (tensor, tensor, tensor) return %result#0 : tensor } + +// ----- + +func @testInvalidXlaSharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { + // expected-error @+1 {{Invalid sharding: some-invalid-sharding}} + %0 = "tf.XlaSharding"(%arg0) {tf._XlaSharding = "some-invalid-sharding"} : (tensor<4x16xf32>) -> tensor<4x16xf32> + return %0 : tensor<4x16xf32> +} diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index e66f31702e4..9dbe9be32db 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -114,6 +114,7 @@ cc_library( ":hlo", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:padding", "//tensorflow/core:framework", "//tensorflow/core/kernels:conv_grad_shape_utils", @@ -358,6 +359,8 @@ cc_library( ":hlo_ops_base_inc_gen", ":hlo_ops_inc_gen", ":xla_canonicalize_inc_gen", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core/platform:protobuf", "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:support", "@llvm-project//mlir:Analysis", diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 351e3bdfa7d..f9bfb59d055 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -49,6 +49,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc" #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/protobuf.h" namespace mlir { #include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc" @@ -70,6 +72,24 @@ static LogicalResult Verify(T op) { return success(); } +LogicalResult XlaHloDialect::verifyOperationAttribute( + Operation* op, NamedAttribute attribute) { + // Check the sharding attribute is a valid sharding text string. + if (attribute.first.is("xla_hlo.sharding")) { + auto sharding = attribute.second.dyn_cast(); + if (!sharding) { + return op->emitError() << "xla_hlo.sharding must be a string attribute"; + } + + ::xla::OpSharding sharding_proto; + if (sharding && !::tensorflow::protobuf::TextFormat::ParseFromString( + sharding.getValue().str(), &sharding_proto)) { + return op->emitError() << "Invalid sharding: " << sharding.getValue(); + } + } + return success(); +} + namespace { //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h index d0bc9619db9..399dc5c664e 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h @@ -52,6 +52,10 @@ class XlaHloDialect : public Dialect { // Prints a type registered to this dialect. void printType(Type type, DialectAsmPrinter &os) const override; + + // Verify an attribute from this dialect on the given operation. + LogicalResult verifyOperationAttribute(Operation *op, + NamedAttribute attribute) override; }; namespace HLOTypes { diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 08612cf16ee..5c5956285d9 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -371,6 +371,23 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( return output; } +// Returns an OpSharding proto from the "sharding" attribute of the op. If the +// op doesn't have a sharding attribute or the sharding attribute is invalid, +// returns absl::nullopt. +static absl::optional CreateOpShardingFromAttribute( + mlir::Operation* op) { + auto sharding = op->getAttrOfType("xla_hlo.sharding"); + if (!sharding) { + return absl::nullopt; + } + ::xla::OpSharding sharding_proto; + if (!::tensorflow::protobuf::TextFormat::ParseFromString( + sharding.getValue().str(), &sharding_proto)) { + return absl::nullopt; + } + return sharding_proto; +} + namespace mlir { namespace { class ConvertToHloModule { diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index e61c8fc9724..34c4c2221ca 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -131,7 +131,12 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { // Emit a function to generate an XLA operation for the operations with // auto-generated builders. os << "mlir::LogicalResult ExportXlaOperator(\n" - "mlir::Operation* op, OpLoweringContext lowering_context) {\n"; + "mlir::Operation* op, OpLoweringContext lowering_context) {\n\n"; + + // Create a scoped object to assign sharding to generated XLA ops. Any HLO + // can have an attribute of "sharding". + os << " xla::XlaScopedShardingAssignment sharding(lowering_context.builder, " + "CreateOpShardingFromAttribute(op));\n\n"; // Retrieve all the definitions derived from HLO_Op and sort by record name. for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) { diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 5d7bc6d29be..27ed9205a91 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -3221,3 +3221,10 @@ func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> return %0 : tensor<2x4x7x7xf32> } + +// CHECK-LABEL: xla_sharding +func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { + // CHECK-NEXT: "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, xla_hlo.sharding = ""} + %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32> + return %0 : tensor<4x16xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 9227695191e..014788b49c9 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -837,3 +837,11 @@ func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x1 }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } + +// ----- + +func @custom_call_invalid_sharding(%input0: tensor<16x16xf32>) -> tensor<16x16xf32> { + // expected-error @+1 {{Invalid sharding: some-invalid-sharding}} + %0 = "xla_hlo.custom_call"(%input0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "some-invalid-sharding"} : (tensor<16x16xf32>) -> tensor<16x16xf32> + return %0 : tensor<16x16xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index ac62bc9880c..608732032fc 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -910,6 +910,18 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] +// ----- + +// CHECK: HloModule +func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "type: OTHER\ntile_assignment_dimensions: 1\ntile_assignment_dimensions: 2\ntile_assignment_devices: 0\ntile_assignment_devices: 1"} : (tensor<16x16xf32>) -> tensor<16x16xf32> + return %0 : tensor<16x16xf32> +} + +// CHECK: ENTRY +// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] custom-call(f32[16,16] %[[ARG0]]), custom_call_target="Sharding", sharding={devices=[1,2]0,1} + // ----- // Tests that the exported HLO module keeps parameter replication annotation. diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index e0cd0e03b11..110cfe2bdd2 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" #include "tensorflow/core/util/padding.h" @@ -3403,6 +3404,45 @@ class ConvertVariableShapeOp : public OpRewritePattern { } }; +// Converts an XlaSharding op to a XLA HLO shard op with sharding attributes. +class ConvertXlaShardingOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::XlaShardingOp op, + PatternRewriter &rewriter) const override { + // TODO(b/148313088): define sharding attribute struct in MLIR intead of + // using a string. + auto sharding = op.getAttrOfType("_XlaSharding"); + if (!sharding) { + return matchFailure(); + } + + // _XlaSharding attribute in TF is a serialized string of the OpSharding + // proto, so convert to a text form here. + ::xla::OpSharding sharding_proto; + std::string sharding_str; + if (!sharding_proto.ParseFromString(sharding.getValue().str())) { + return matchFailure(); + } + if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto, + &sharding_str)) { + return matchFailure(); + } + + auto custom_call = rewriter.create( + op.getLoc(), op.getType(), op.input(), + /*call_target_name=*/rewriter.getStringAttr("Sharding"), + /*has_side_effect=*/rewriter.getBoolAttr(false), + /*backend_config=*/rewriter.getStringAttr("")); + custom_call.setAttr("xla_hlo.sharding", + rewriter.getStringAttr(sharding_str)); + rewriter.replaceOp(op, custom_call.getResult()); + + return matchSuccess(); + } +}; + #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { @@ -3432,7 +3472,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp, ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp, - ConvertRandomShuffleOp, ConvertVariableShapeOp>(op->getContext()); + ConvertRandomShuffleOp, ConvertVariableShapeOp, ConvertXlaShardingOp>( + op->getContext()); ConversionTarget target(*context); target.addLegalDialect();