Add xla_lhlo.dynamic_broadcast_in_dim operation.
Also change the type of the dynamic dimensions operand to vector of Integer, as index type is not supported in vectors. PiperOrigin-RevId: 295141631 Change-Id: Ie8b6d5adec65d70243a3b132ffc807cafd212b42
This commit is contained in:
parent
33c5c0b880
commit
77deb9292c
tensorflow/compiler/mlir/xla
@ -60,10 +60,6 @@ def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
|
||||
|
||||
def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
|
||||
|
||||
def HLO_DimensionTensor : ShapedContainerType<
|
||||
[Index], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
|
||||
"a 1D tensor of dimensions">;
|
||||
|
||||
// In general, static shaped tensor constraints should be avoided unless
|
||||
// it is for a legacy op which is only correct with static shapes.
|
||||
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
|
||||
@ -778,7 +774,7 @@ def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
|
||||
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
HLO_DimensionTensor:$output_dimensions,
|
||||
HLO_BASE_DimensionTensor:$output_dimensions,
|
||||
BroadcastDimAttr:$broadcast_dimensions
|
||||
);
|
||||
|
||||
|
@ -21,6 +21,19 @@ include "mlir/IR/OpBase.td"
|
||||
def HLO_Int : IntOfWidths<[8, 16, 32, 64]>;
|
||||
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
|
||||
|
||||
// The broadcasting dimensions correspond to a tuple that describes how a
|
||||
// smaller rank shape is broadcast into a larger rank shape. For example,
|
||||
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
|
||||
// matching the matrix to dimensions 1 and 2 of the cuboid.
|
||||
def BroadcastDimAttr : OptionalAttr<I64ElementsAttr>;
|
||||
|
||||
// Dynamic representation of a shape vector as a tensor. Ideally this would be
|
||||
// an index type (as it stores indices) but that is currently disallowed in
|
||||
// MLIR.
|
||||
def HLO_BASE_DimensionTensor : ShapedContainerType<
|
||||
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
|
||||
"a 1D tensor of dimensions">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA nullary op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -318,12 +331,6 @@ class BASE_HLO_TanhOp {
|
||||
// XLA binary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// The broadcasting dimensions correspond to a tuple that describes how a
|
||||
// smaller rank shape is broadcast into a larger rank shape. For example,
|
||||
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
|
||||
// matching the matrix to dimensions 1 and 2 of the cuboid.
|
||||
def BroadcastDimAttr : OptionalAttr<I64ElementsAttr>;
|
||||
|
||||
class BASE_HLO_AddOp {
|
||||
string summary = "Addition operator";
|
||||
|
||||
|
@ -242,6 +242,16 @@ def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
|
||||
);
|
||||
}
|
||||
|
||||
def HLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim",
|
||||
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
|
||||
let arguments = (ins
|
||||
LHLO_Buffer:$operand,
|
||||
HLO_BASE_DimensionTensor:$output_dimensions,
|
||||
LHLO_Buffer:$output,
|
||||
BroadcastDimAttr:$broadcast_dimensions
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
|
||||
let arguments = (ins
|
||||
LHLO_Buffer:$min,
|
||||
|
@ -136,6 +136,30 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_memref
|
||||
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
|
||||
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
|
||||
func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () {
|
||||
"xla_lhlo.broadcast_in_dim"(%arg0, %out) : (memref<i32>, memref<1x2x3xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_memref
|
||||
func @dynamic_broadcast_in_dim_memref(%arg0: memref<?x?xi32>, %out: memref<?x?x?xi32>, %shape: tensor<3xi64>) -> () {
|
||||
"xla_lhlo.dynamic_broadcast_in_dim"(%arg0, %shape, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xi32>, tensor<3xi64>, memref<?x?x?xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce_memref
|
||||
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.reduce"(%input, %init, %out) ( {
|
||||
|
@ -108,6 +108,14 @@ func @broadcast_in_dim_zero_rank(%arg0: tensor<i32>) -> tensor<1x2x3xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim
|
||||
func @dynamic_broadcast_in_dim(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<?x?x?xi32> {
|
||||
%0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<?x?x?xi32>
|
||||
return %0 : tensor<?x?x?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
|
||||
// expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}}
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
|
||||
|
Loading…
Reference in New Issue
Block a user