Legalize XlaReplicaId to HLO replica-id op
Also, define shape inference function for HLO replica-id op. PiperOrigin-RevId: 345714342 Change-Id: Ida5019487412fbd31d1b852a47b67a6d912263b4
This commit is contained in:
parent
68190eb0c3
commit
46cf2ef65b
@ -490,7 +490,8 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
|
||||
// MHLO parallelism related op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
|
||||
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
|
||||
BASE_HLO_ReplicaIdOp {
|
||||
let results = (outs TensorOf<[UI32]>);
|
||||
}
|
||||
|
@ -1945,6 +1945,18 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReplicaId Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ReplicaIdOp::inferReturnTypes(
|
||||
MLIRContext* context, Optional<Location>, ValueRange operands,
|
||||
DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(RankedTensorType::get(
|
||||
/*shape=*/{}, IntegerType::get(32, IntegerType::Unsigned, context)));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Case Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -5120,3 +5120,11 @@ func @stridedslice_with_i32(%arg0: tensor<i32>) -> tensor<4xf32> attributes {tf.
|
||||
%6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32>
|
||||
return %6 : tensor<4xf32>
|
||||
}
|
||||
|
||||
func @replica_id() -> tensor<i32> {
|
||||
// CHECK: %[[ID:.*]] = "mhlo.replica_id"() : () -> tensor<ui32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.convert"(%0) : (tensor<ui32>) -> tensor<i32>
|
||||
%0 = "tf.XlaReplicaId"() : () -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
|
@ -706,6 +706,13 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features),
|
||||
(replaceWithValue $output)
|
||||
]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XlaReplicaId op.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(TF_XlaReplicaIdOp),
|
||||
(TF_CastOp (HLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XlaGather op.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
x
Reference in New Issue
Block a user