Add sharding op to TF and HLO dialect
PiperOrigin-RevId: 292534463 Change-Id: Ibae20d6bd108b3f3dc2bd6085f6f5ad712f3d20e
This commit is contained in:
parent
ed87850352
commit
6b0350800f
@ -175,9 +175,11 @@ cc_library(
|
||||
":tensorflow_ops_inc_gen",
|
||||
":tf_saved_model_inc_gen",
|
||||
"//tensorflow/compiler/mlir/lite:validators",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:protobuf",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:CallOpInterfacesIncGen",
|
||||
|
@ -7367,6 +7367,25 @@ def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
An op which shards the input based on the given sharding attribute.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns a tensor of zeros with the same shape and type as x.";
|
||||
|
||||
|
@ -58,7 +58,9 @@ limitations under the License.
|
||||
#include "mlir/Support/STLExtras.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -2932,5 +2934,22 @@ Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder,
|
||||
return builder.create<ConstOp>(loc, type, value);
|
||||
}
|
||||
|
||||
LogicalResult TensorFlowDialect::verifyOperationAttribute(
|
||||
Operation *op, NamedAttribute attribute) {
|
||||
// Check the _XlaSharding attribute is a valid serialized bytes of OpSharding.
|
||||
if (attribute.first.is("tf._XlaSharding")) {
|
||||
auto sharding = attribute.second.dyn_cast<mlir::StringAttr>();
|
||||
if (!sharding) {
|
||||
return op->emitError() << "tf._XlaSharding must be a string attribute";
|
||||
}
|
||||
|
||||
::xla::OpSharding sharding_proto;
|
||||
if (!sharding_proto.ParseFromString(sharding.getValue().str())) {
|
||||
return op->emitError() << "Invalid sharding: " << sharding.getValue();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -72,6 +72,10 @@ class TensorFlowDialect : public Dialect {
|
||||
// value with the desired resultant type.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
Location loc) override;
|
||||
|
||||
// Verify an attribute from this dialect on the given operation.
|
||||
LogicalResult verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attribute) override;
|
||||
};
|
||||
|
||||
// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose
|
||||
|
@ -2301,3 +2301,11 @@ func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf
|
||||
%result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %empty_str_vector, %ragged_keys) {dense_shapes = [], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 0, 2, 1]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>) -> (tensor<?xf32>, tensor<?x!tf.string>, tensor<?xi32>)
|
||||
return %result#0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testInvalidXlaSharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
|
||||
// expected-error @+1 {{Invalid sharding: some-invalid-sharding}}
|
||||
%0 = "tf.XlaSharding"(%arg0) {tf._XlaSharding = "some-invalid-sharding"} : (tensor<4x16xf32>) -> tensor<4x16xf32>
|
||||
return %0 : tensor<4x16xf32>
|
||||
}
|
||||
|
@ -114,6 +114,7 @@ cc_library(
|
||||
":hlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:padding",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/kernels:conv_grad_shape_utils",
|
||||
@ -358,6 +359,8 @@ cc_library(
|
||||
":hlo_ops_base_inc_gen",
|
||||
":hlo_ops_inc_gen",
|
||||
":xla_canonicalize_inc_gen",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core/platform:protobuf",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
|
@ -49,6 +49,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc"
|
||||
@ -70,6 +72,24 @@ static LogicalResult Verify(T op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult XlaHloDialect::verifyOperationAttribute(
|
||||
Operation* op, NamedAttribute attribute) {
|
||||
// Check the sharding attribute is a valid sharding text string.
|
||||
if (attribute.first.is("xla_hlo.sharding")) {
|
||||
auto sharding = attribute.second.dyn_cast<mlir::StringAttr>();
|
||||
if (!sharding) {
|
||||
return op->emitError() << "xla_hlo.sharding must be a string attribute";
|
||||
}
|
||||
|
||||
::xla::OpSharding sharding_proto;
|
||||
if (sharding && !::tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
sharding.getValue().str(), &sharding_proto)) {
|
||||
return op->emitError() << "Invalid sharding: " << sharding.getValue();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -52,6 +52,10 @@ class XlaHloDialect : public Dialect {
|
||||
|
||||
// Prints a type registered to this dialect.
|
||||
void printType(Type type, DialectAsmPrinter &os) const override;
|
||||
|
||||
// Verify an attribute from this dialect on the given operation.
|
||||
LogicalResult verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attribute) override;
|
||||
};
|
||||
|
||||
namespace HLOTypes {
|
||||
|
@ -371,6 +371,23 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
|
||||
return output;
|
||||
}
|
||||
|
||||
// Returns an OpSharding proto from the "sharding" attribute of the op. If the
|
||||
// op doesn't have a sharding attribute or the sharding attribute is invalid,
|
||||
// returns absl::nullopt.
|
||||
static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
|
||||
mlir::Operation* op) {
|
||||
auto sharding = op->getAttrOfType<mlir::StringAttr>("xla_hlo.sharding");
|
||||
if (!sharding) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
::xla::OpSharding sharding_proto;
|
||||
if (!::tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
sharding.getValue().str(), &sharding_proto)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
return sharding_proto;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
class ConvertToHloModule {
|
||||
|
@ -131,7 +131,12 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
|
||||
// Emit a function to generate an XLA operation for the operations with
|
||||
// auto-generated builders.
|
||||
os << "mlir::LogicalResult ExportXlaOperator(\n"
|
||||
"mlir::Operation* op, OpLoweringContext lowering_context) {\n";
|
||||
"mlir::Operation* op, OpLoweringContext lowering_context) {\n\n";
|
||||
|
||||
// Create a scoped object to assign sharding to generated XLA ops. Any HLO
|
||||
// can have an attribute of "sharding".
|
||||
os << " xla::XlaScopedShardingAssignment sharding(lowering_context.builder, "
|
||||
"CreateOpShardingFromAttribute(op));\n\n";
|
||||
|
||||
// Retrieve all the definitions derived from HLO_Op and sort by record name.
|
||||
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
|
||||
|
@ -3221,3 +3221,10 @@ func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32>
|
||||
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32>
|
||||
return %0 : tensor<2x4x7x7xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: xla_sharding
|
||||
func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
|
||||
// CHECK-NEXT: "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, xla_hlo.sharding = ""}
|
||||
%0 = "tf.XlaSharding"(%arg0) {_XlaSharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32>
|
||||
return %0 : tensor<4x16xf32>
|
||||
}
|
||||
|
@ -837,3 +837,11 @@ func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x1
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @custom_call_invalid_sharding(%input0: tensor<16x16xf32>) -> tensor<16x16xf32> {
|
||||
// expected-error @+1 {{Invalid sharding: some-invalid-sharding}}
|
||||
%0 = "xla_hlo.custom_call"(%input0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "some-invalid-sharding"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||
return %0 : tensor<16x16xf32>
|
||||
}
|
||||
|
@ -910,6 +910,18 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
|
||||
%0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "type: OTHER\ntile_assignment_dimensions: 1\ntile_assignment_dimensions: 2\ntile_assignment_devices: 0\ntile_assignment_devices: 1"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||
return %0 : tensor<16x16xf32>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] custom-call(f32[16,16] %[[ARG0]]), custom_call_target="Sharding", sharding={devices=[1,2]0,1}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the exported HLO module keeps parameter replication annotation.
|
||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/xla/client/padding.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
@ -3403,6 +3404,45 @@ class ConvertVariableShapeOp : public OpRewritePattern<TF::VariableShapeOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
|
||||
class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TF::XlaShardingOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(b/148313088): define sharding attribute struct in MLIR intead of
|
||||
// using a string.
|
||||
auto sharding = op.getAttrOfType<StringAttr>("_XlaSharding");
|
||||
if (!sharding) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// _XlaSharding attribute in TF is a serialized string of the OpSharding
|
||||
// proto, so convert to a text form here.
|
||||
::xla::OpSharding sharding_proto;
|
||||
std::string sharding_str;
|
||||
if (!sharding_proto.ParseFromString(sharding.getValue().str())) {
|
||||
return matchFailure();
|
||||
}
|
||||
if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
|
||||
&sharding_str)) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
auto custom_call = rewriter.create<xla_hlo::CustomCallOp>(
|
||||
op.getLoc(), op.getType(), op.input(),
|
||||
/*call_target_name=*/rewriter.getStringAttr("Sharding"),
|
||||
/*has_side_effect=*/rewriter.getBoolAttr(false),
|
||||
/*backend_config=*/rewriter.getStringAttr(""));
|
||||
custom_call.setAttr("xla_hlo.sharding",
|
||||
rewriter.getStringAttr(sharding_str));
|
||||
rewriter.replaceOp(op, custom_call.getResult());
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
|
||||
|
||||
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
@ -3432,7 +3472,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
|
||||
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
|
||||
ConvertRandomShuffleOp, ConvertVariableShapeOp>(op->getContext());
|
||||
ConvertRandomShuffleOp, ConvertVariableShapeOp, ConvertXlaShardingOp>(
|
||||
op->getContext());
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<XlaHloDialect>();
|
||||
|
Loading…
Reference in New Issue
Block a user