Add CreateTokenOp to HLO dialect.
This is to separate out the use of AfterAll op to create a HLO token out of thin air and to wait on some existing tokens. The idea is to at some point only allow the main function to create tokens which would be passed to other computations. This would force the creators of side effecting ops to consider tokens and ordering while creating the op. This op is exported as an AfterAll HLO with no operands. Replace the generation of AfterAll op in tf.InfeedDequeueTuple and tf.OutfeedEnqueueTuple legalization with this CreateTokenOp. PiperOrigin-RevId: 299857684 Change-Id: I9a358f5537115ea2afac0274a38ed75f61d66b9c
This commit is contained in:
parent
ac26a80b0f
commit
8ba7e2f386
@ -127,6 +127,18 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp {
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
|
||||
string summary = "Create Token operator";
|
||||
|
||||
string description = [{
|
||||
Produces a HLO token. Tokens are used for ordering side-effecting perations.
|
||||
This is exported to HLO as an AfterAll operation with no operands to
|
||||
generate a token.
|
||||
}];
|
||||
|
||||
let results = (outs HLO_Token:$output);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA unary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1133,8 +1133,8 @@ func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
|
||||
// CHECK-LABEL: func @infeed_dequeue_tuple
|
||||
func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) {
|
||||
// CHECK: [[AFTER_ALL:%.*]] = "xla_hlo.after_all"() : () -> !xla_hlo.token
|
||||
// CHECK: [[INFEED:%.*]] = "xla_hlo.infeed"([[AFTER_ALL]]) {infeed_config = ""} : (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token>
|
||||
// CHECK: [[TOKEN:%.*]] = "xla_hlo.create_token"() : () -> !xla_hlo.token
|
||||
// CHECK: [[INFEED:%.*]] = "xla_hlo.infeed"([[TOKEN]]) {infeed_config = ""} : (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token>
|
||||
// CHECK: [[INFEED_VAL:%.*]] = "xla_hlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token>) -> tuple<tensor<3xi32>, tensor<4xf32>>
|
||||
// CHECK: [[RES_1:%.*]] = "xla_hlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple<tensor<3xi32>, tensor<4xf32>>) -> tensor<3xi32>
|
||||
// CHECK: [[RES_2:%.*]] = "xla_hlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple<tensor<3xi32>, tensor<4xf32>>) -> tensor<4xf32>
|
||||
@ -1393,8 +1393,8 @@ func @one_hot(%indices: tensor<3xi32>, %on_value: tensor<f32>, %off_value: tenso
|
||||
// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>)
|
||||
func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () {
|
||||
// CHECK: [[TUPLE:%.*]] = "xla_hlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple<tensor<3xi32>, tensor<4xf32>>
|
||||
// CHECK: [[AFTER_ALL:%.*]] = "xla_hlo.after_all"() : () -> !xla_hlo.token
|
||||
// CHECK: "xla_hlo.outfeed"([[TUPLE]], [[AFTER_ALL]]) {outfeed_config = ""} : (tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token
|
||||
// CHECK: [[TOKEN:%.*]] = "xla_hlo.create_token"() : () -> !xla_hlo.token
|
||||
// CHECK: "xla_hlo.outfeed"([[TUPLE]], [[TOKEN]]) {outfeed_config = ""} : (tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token
|
||||
"tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -160,6 +160,16 @@ func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main() -> !xla_hlo.token {
|
||||
%0 = "xla_hlo.create_token"() : () -> !xla_hlo.token
|
||||
return %0 : !xla_hlo.token
|
||||
}
|
||||
|
||||
// CHECK: ROOT [[TOKEN:%.*]] = token[] after-all()
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
@ -2928,22 +2928,22 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts InfeedEnqueueTuple to XLA HLO after_all, infeed and
|
||||
// Converts InfeedDequeueTuple to XLA HLO create_token, infeed and
|
||||
// get_tuple_element ops.
|
||||
//
|
||||
// All HLO infeed ops expect a HLO token type operand and produce a tuple
|
||||
// containing a token. This HLO token type is used to order multiple infeed
|
||||
// operations within a computation. The token type can come from other
|
||||
// infeed/outfeed/send/recv ops or can be generated using an after_all op with
|
||||
// no operands. Here we emit an after_all op to generate the token type operand
|
||||
// of infeed.
|
||||
// infeed/outfeed/send/recv ops or can be generated using create_token op with
|
||||
// no operands. Here we emit a create_token op to generate the token type
|
||||
// operand of infeed.
|
||||
//
|
||||
// For example the following IR:
|
||||
// %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>)
|
||||
//
|
||||
// would be lowered to
|
||||
//
|
||||
// %token = "xla_hlo.after_all"() : () -> !xla_hlo.token
|
||||
// %token = "xla_hlo.create_token"() : () -> !xla_hlo.token
|
||||
// %data_and_token = "xla_hlo.infeed"(%token) {infeed_config = ""} :
|
||||
// (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>,
|
||||
// !xla_hlo.token>
|
||||
@ -2962,21 +2962,20 @@ class ConvertInfeedDequeueTupleOp
|
||||
for (auto idx_and_output : llvm::enumerate(op.outputs())) {
|
||||
result_types[idx_and_output.index()] = (idx_and_output.value().getType());
|
||||
}
|
||||
// Infeed takes a single token operand. Generate the token using after_all
|
||||
// op to pass to the infeed op.
|
||||
auto afterall = rewriter.create<AfterAllOp>(
|
||||
op.getLoc(), xla_hlo::TokenType::get(rewriter.getContext()),
|
||||
ValueRange());
|
||||
// Infeed takes a single token operand. Generate the token using
|
||||
// create_token op to pass to the infeed op.
|
||||
auto token = rewriter.create<CreateTokenOp>(
|
||||
op.getLoc(), xla_hlo::TokenType::get(rewriter.getContext()));
|
||||
|
||||
// Emit infeed op.
|
||||
// The result type of infeed is a tuple(tuple(result types), token type).
|
||||
auto data_tuple_type =
|
||||
mlir::TupleType::get(result_types, rewriter.getContext());
|
||||
auto data_and_token_type = mlir::TupleType::get(
|
||||
{data_tuple_type, afterall.getType()}, rewriter.getContext());
|
||||
{data_tuple_type, token.getType()}, rewriter.getContext());
|
||||
|
||||
auto data_and_token =
|
||||
rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, afterall,
|
||||
rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,
|
||||
/*infeed_config=*/rewriter.getStringAttr(""));
|
||||
|
||||
// The infeed instruction produces a tuple of the infeed data and a token
|
||||
@ -2998,10 +2997,11 @@ class ConvertInfeedDequeueTupleOp
|
||||
}
|
||||
};
|
||||
|
||||
// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, after_all and outfeed ops.
|
||||
// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed
|
||||
// ops.
|
||||
//
|
||||
// XLA HLO outfeed op expects a token, which we generate by emitting an
|
||||
// after_all op.
|
||||
// create_token op.
|
||||
//
|
||||
// For example the following IR:
|
||||
// "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
|
||||
@ -3011,7 +3011,7 @@ class ConvertInfeedDequeueTupleOp
|
||||
//
|
||||
// %tuple = "xla_hlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
|
||||
// tuple<tensor<3xi32>, tensor<4xf32>>
|
||||
// %token = "xla_hlo.after_all"() : () -> !xla_hlo.token
|
||||
// %token = "xla_hlo.create_token"() : () -> !xla_hlo.token
|
||||
// %outfeed_token = "xla_hlo.outfeed"(%tuple, %token) {outfeed_config = ""} :
|
||||
// (tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token
|
||||
//
|
||||
@ -3024,9 +3024,8 @@ class ConvertOutfeedEnqueueTupleOp
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto token_type = xla_hlo::TokenType::get(rewriter.getContext());
|
||||
auto tuple = rewriter.create<TupleOp>(op.getLoc(), op.inputs());
|
||||
auto afterall =
|
||||
rewriter.create<AfterAllOp>(op.getLoc(), token_type, ValueRange());
|
||||
rewriter.create<OutfeedOp>(op.getLoc(), token_type, tuple, afterall,
|
||||
auto token = rewriter.create<CreateTokenOp>(op.getLoc(), token_type);
|
||||
rewriter.create<OutfeedOp>(op.getLoc(), token_type, tuple, token,
|
||||
/*outfeed_config=*/rewriter.getStringAttr(""));
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
|
Loading…
Reference in New Issue
Block a user