diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 44f95a4c8d4..3e6bf11e7ff 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -127,6 +127,18 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp { let hasCustomHLOConverter = 1; } +def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { + string summary = "Create Token operator"; + + string description = [{ + Produces a HLO token. Tokens are used for ordering side-effecting perations. + This is exported to HLO as an AfterAll operation with no operands to + generate a token. + }]; + + let results = (outs HLO_Token:$output); +} + //===----------------------------------------------------------------------===// // XLA unary elementwise op definitions. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 11287ad3f2d..c2f0e65f52f 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1133,8 +1133,8 @@ func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @infeed_dequeue_tuple func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) { -// CHECK: [[AFTER_ALL:%.*]] = "xla_hlo.after_all"() : () -> !xla_hlo.token -// CHECK: [[INFEED:%.*]] = "xla_hlo.infeed"([[AFTER_ALL]]) {infeed_config = ""} : (!xla_hlo.token) -> tuple, tensor<4xf32>>, !xla_hlo.token> +// CHECK: [[TOKEN:%.*]] = "xla_hlo.create_token"() : () -> !xla_hlo.token +// CHECK: [[INFEED:%.*]] = "xla_hlo.infeed"([[TOKEN]]) {infeed_config = ""} : (!xla_hlo.token) -> tuple, tensor<4xf32>>, !xla_hlo.token> // CHECK: [[INFEED_VAL:%.*]] = "xla_hlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple, tensor<4xf32>>, !xla_hlo.token>) -> tuple, tensor<4xf32>> // CHECK: [[RES_1:%.*]] = "xla_hlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple, tensor<4xf32>>) -> tensor<3xi32> // CHECK: [[RES_2:%.*]] = "xla_hlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple, tensor<4xf32>>) -> tensor<4xf32> @@ -1393,8 +1393,8 @@ func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tenso // CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { // CHECK: [[TUPLE:%.*]] = "xla_hlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple, tensor<4xf32>> -// CHECK: [[AFTER_ALL:%.*]] = "xla_hlo.after_all"() : () -> !xla_hlo.token -// CHECK: "xla_hlo.outfeed"([[TUPLE]], [[AFTER_ALL]]) {outfeed_config = ""} : (tuple, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token +// CHECK: [[TOKEN:%.*]] = "xla_hlo.create_token"() : () -> !xla_hlo.token +// CHECK: "xla_hlo.outfeed"([[TUPLE]], [[TOKEN]]) {outfeed_config = ""} : (tuple, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () return } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index aac4e613358..8af27bb586a 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -160,6 +160,16 @@ func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> { // ----- +// CHECK: HloModule +func @main() -> !xla_hlo.token { + %0 = "xla_hlo.create_token"() : () -> !xla_hlo.token + return %0 : !xla_hlo.token +} + +// CHECK: ROOT [[TOKEN:%.*]] = token[] after-all() + +// ----- + // CHECK: HloModule func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> { %0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 486cac5a408..95b29366917 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -2928,22 +2928,22 @@ class ConvertOneHotOp : public OpRewritePattern { } }; -// Converts InfeedEnqueueTuple to XLA HLO after_all, infeed and +// Converts InfeedDequeueTuple to XLA HLO create_token, infeed and // get_tuple_element ops. // // All HLO infeed ops expect a HLO token type operand and produce a tuple // containing a token. This HLO token type is used to order multiple infeed // operations within a computation. The token type can come from other -// infeed/outfeed/send/recv ops or can be generated using an after_all op with -// no operands. Here we emit an after_all op to generate the token type operand -// of infeed. +// infeed/outfeed/send/recv ops or can be generated using create_token op with +// no operands. Here we emit a create_token op to generate the token type +// operand of infeed. // // For example the following IR: // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>) // // would be lowered to // -// %token = "xla_hlo.after_all"() : () -> !xla_hlo.token +// %token = "xla_hlo.create_token"() : () -> !xla_hlo.token // %data_and_token = "xla_hlo.infeed"(%token) {infeed_config = ""} : // (!xla_hlo.token) -> tuple, tensor<4xf32>>, // !xla_hlo.token> @@ -2962,21 +2962,20 @@ class ConvertInfeedDequeueTupleOp for (auto idx_and_output : llvm::enumerate(op.outputs())) { result_types[idx_and_output.index()] = (idx_and_output.value().getType()); } - // Infeed takes a single token operand. Generate the token using after_all - // op to pass to the infeed op. - auto afterall = rewriter.create( - op.getLoc(), xla_hlo::TokenType::get(rewriter.getContext()), - ValueRange()); + // Infeed takes a single token operand. Generate the token using + // create_token op to pass to the infeed op. + auto token = rewriter.create( + op.getLoc(), xla_hlo::TokenType::get(rewriter.getContext())); // Emit infeed op. // The result type of infeed is a tuple(tuple(result types), token type). auto data_tuple_type = mlir::TupleType::get(result_types, rewriter.getContext()); auto data_and_token_type = mlir::TupleType::get( - {data_tuple_type, afterall.getType()}, rewriter.getContext()); + {data_tuple_type, token.getType()}, rewriter.getContext()); auto data_and_token = - rewriter.create(op.getLoc(), data_and_token_type, afterall, + rewriter.create(op.getLoc(), data_and_token_type, token, /*infeed_config=*/rewriter.getStringAttr("")); // The infeed instruction produces a tuple of the infeed data and a token @@ -2998,10 +2997,11 @@ class ConvertInfeedDequeueTupleOp } }; -// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, after_all and outfeed ops. +// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed +// ops. // // XLA HLO outfeed op expects a token, which we generate by emitting an -// after_all op. +// create_token op. // // For example the following IR: // "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> @@ -3011,7 +3011,7 @@ class ConvertInfeedDequeueTupleOp // // %tuple = "xla_hlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> // tuple, tensor<4xf32>> -// %token = "xla_hlo.after_all"() : () -> !xla_hlo.token +// %token = "xla_hlo.create_token"() : () -> !xla_hlo.token // %outfeed_token = "xla_hlo.outfeed"(%tuple, %token) {outfeed_config = ""} : // (tuple, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token // @@ -3024,9 +3024,8 @@ class ConvertOutfeedEnqueueTupleOp PatternRewriter &rewriter) const override { auto token_type = xla_hlo::TokenType::get(rewriter.getContext()); auto tuple = rewriter.create(op.getLoc(), op.inputs()); - auto afterall = - rewriter.create(op.getLoc(), token_type, ValueRange()); - rewriter.create(op.getLoc(), token_type, tuple, afterall, + auto token = rewriter.create(op.getLoc(), token_type); + rewriter.create(op.getLoc(), token_type, tuple, token, /*outfeed_config=*/rewriter.getStringAttr("")); rewriter.eraseOp(op); return matchSuccess();