Legalize XlaReplicaId to HLO replica-id op

Also, define shape inference function for HLO replica-id op.

PiperOrigin-RevId: 345714342
Change-Id: Ida5019487412fbd31d1b852a47b67a6d912263b4
This commit is contained in:
Smit Hinsu 2020-12-04 11:04:02 -08:00 committed by TensorFlower Gardener
parent 68190eb0c3
commit 46cf2ef65b
4 changed files with 29 additions and 1 deletions

View File

@ -490,7 +490,8 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
// MHLO parallelism related op definitions.
//===----------------------------------------------------------------------===//
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
BASE_HLO_ReplicaIdOp {
let results = (outs TensorOf<[UI32]>);
}

View File

@ -1945,6 +1945,18 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
context);
}
//===----------------------------------------------------------------------===//
// ReplicaId Op
//===----------------------------------------------------------------------===//
LogicalResult ReplicaIdOp::inferReturnTypes(
MLIRContext* context, Optional<Location>, ValueRange operands,
DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(RankedTensorType::get(
/*shape=*/{}, IntegerType::get(32, IntegerType::Unsigned, context)));
return success();
}
//===----------------------------------------------------------------------===//
// Case Op
//===----------------------------------------------------------------------===//

View File

@ -5120,3 +5120,11 @@ func @stridedslice_with_i32(%arg0: tensor<i32>) -> tensor<4xf32> attributes {tf.
%6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32>
return %6 : tensor<4xf32>
}
func @replica_id() -> tensor<i32> {
// CHECK: %[[ID:.*]] = "mhlo.replica_id"() : () -> tensor<ui32>
// CHECK: %[[RESULT:.*]] = "mhlo.convert"(%0) : (tensor<ui32>) -> tensor<i32>
%0 = "tf.XlaReplicaId"() : () -> tensor<i32>
return %0 : tensor<i32>
}

View File

@ -706,6 +706,13 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features),
(replaceWithValue $output)
]>;
//===----------------------------------------------------------------------===//
// XlaReplicaId op.
//===----------------------------------------------------------------------===//
def : Pat<(TF_XlaReplicaIdOp),
(TF_CastOp (HLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>;
//===----------------------------------------------------------------------===//
// XlaGather op.
//===----------------------------------------------------------------------===//