[MLIR][XLA] Remove output_dimensions
arg from LHLO DynamicBroadcastInDimOp.
It is not needed since we have access to the output buffer. PiperOrigin-RevId: 295802211 Change-Id: I078c7b91f837e80131a8dde5bb735a8ca72ee876
This commit is contained in:
parent
60fb12820e
commit
49a83c96b0
@ -60,6 +60,13 @@ def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
|
||||
|
||||
def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
|
||||
|
||||
// Dynamic representation of a shape vector as a tensor. Ideally this would be
|
||||
// an index type (as it stores indices) but that is currently disallowed in
|
||||
// MLIR.
|
||||
def HLO_DimensionTensor : ShapedContainerType<
|
||||
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
|
||||
"a 1D tensor of dimensions">;
|
||||
|
||||
// In general, static shaped tensor constraints should be avoided unless
|
||||
// it is for a legacy op which is only correct with static shapes.
|
||||
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
|
||||
@ -771,10 +778,22 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim",
|
||||
}
|
||||
|
||||
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
|
||||
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
|
||||
[NoSideEffect]> {
|
||||
string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
|
||||
string description = [{
|
||||
This is a generalization of the BroadcastInDimOp which accepts its output
|
||||
dimensions as an argument. It should eventually supercede the statically
|
||||
shaped original, but is being phased as a separate op in order to support
|
||||
compatibility with lowerings and translations that precede dynamic
|
||||
shapes.
|
||||
|
||||
Note that the `broadcast_dimensions` attribute is optional and if omitted,
|
||||
it is assumed to be an ordered, right-aligned mapping from input to
|
||||
output dimensions.
|
||||
}];
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
HLO_BASE_DimensionTensor:$output_dimensions,
|
||||
HLO_DimensionTensor:$output_dimensions,
|
||||
BroadcastDimAttr:$broadcast_dimensions
|
||||
);
|
||||
|
||||
|
@ -27,13 +27,6 @@ def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
|
||||
// matching the matrix to dimensions 1 and 2 of the cuboid.
|
||||
def BroadcastDimAttr : OptionalAttr<I64ElementsAttr>;
|
||||
|
||||
// Dynamic representation of a shape vector as a tensor. Ideally this would be
|
||||
// an index type (as it stores indices) but that is currently disallowed in
|
||||
// MLIR.
|
||||
def HLO_BASE_DimensionTensor : ShapedContainerType<
|
||||
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
|
||||
"a 1D tensor of dimensions">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA nullary op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -817,22 +810,6 @@ class BASE_HLO_BroadcastInDimOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_DynamicBroadcastInDimOp {
|
||||
string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
|
||||
|
||||
string description = [{
|
||||
This is a generalization of the BroadcastInDimOp which accepts its output
|
||||
dimensions as an argument. It should eventually supercede the statically
|
||||
shaped original, but is being phased as a separate op in order to support
|
||||
compatibility with lowerings and translations that precede dynamic
|
||||
shapes.
|
||||
|
||||
Note that the `broadcast_dimensions` attribute is optional and if omitted,
|
||||
it is assumed to be an ordered, right-aligned mapping from input to
|
||||
output dimensions.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_CholeskyOp {
|
||||
string summary = "Cholesky operator";
|
||||
|
||||
|
@ -242,16 +242,6 @@ def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
|
||||
);
|
||||
}
|
||||
|
||||
def HLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim",
|
||||
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
|
||||
let arguments = (ins
|
||||
LHLO_Buffer:$operand,
|
||||
HLO_BASE_DimensionTensor:$output_dimensions,
|
||||
LHLO_Buffer:$output,
|
||||
BroadcastDimAttr:$broadcast_dimensions
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
|
||||
let arguments = (ins
|
||||
LHLO_Buffer:$min,
|
||||
|
@ -152,13 +152,6 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_memref
|
||||
func @dynamic_broadcast_in_dim_memref(%arg0: memref<?x?xi32>, %out: memref<?x?x?xi32>, %shape: tensor<3xi64>) -> () {
|
||||
"xla_lhlo.dynamic_broadcast_in_dim"(%arg0, %shape, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xi32>, tensor<3xi64>, memref<?x?x?xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce_memref
|
||||
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
|
||||
|
Loading…
x
Reference in New Issue
Block a user