Add xla_lhlo.dynamic_broadcast_in_dim operation.

Also change the type of the dynamic dimensions operand to vector of Integer, as index type is not supported in vectors.

PiperOrigin-RevId: 295141631
Change-Id: Ie8b6d5adec65d70243a3b132ffc807cafd212b42
This commit is contained in:
Stephan Herhut 2020-02-14 07:24:54 -08:00 committed by TensorFlower Gardener
parent 33c5c0b880
commit 77deb9292c
5 changed files with 56 additions and 11 deletions

View File

@ -60,10 +60,6 @@ def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
def HLO_DimensionTensor : ShapedContainerType<
[Index], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
"a 1D tensor of dimensions">;
// In general, static shaped tensor constraints should be avoided unless
// it is for a legacy op which is only correct with static shapes.
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
@ -778,7 +774,7 @@ def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
let arguments = (ins
HLO_Tensor:$operand,
HLO_DimensionTensor:$output_dimensions,
HLO_BASE_DimensionTensor:$output_dimensions,
BroadcastDimAttr:$broadcast_dimensions
);

View File

@ -21,6 +21,19 @@ include "mlir/IR/OpBase.td"
def HLO_Int : IntOfWidths<[8, 16, 32, 64]>;
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
// The broadcasting dimensions correspond to a tuple that describes how a
// smaller rank shape is broadcast into a larger rank shape. For example,
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
// matching the matrix to dimensions 1 and 2 of the cuboid.
def BroadcastDimAttr : OptionalAttr<I64ElementsAttr>;
// Dynamic representation of a shape vector as a tensor. Ideally this would be
// an index type (as it stores indices) but that is currently disallowed in
// MLIR.
def HLO_BASE_DimensionTensor : ShapedContainerType<
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
"a 1D tensor of dimensions">;
//===----------------------------------------------------------------------===//
// XLA nullary op definitions.
//===----------------------------------------------------------------------===//
@ -318,12 +331,6 @@ class BASE_HLO_TanhOp {
// XLA binary elementwise op definitions.
//===----------------------------------------------------------------------===//
// The broadcasting dimensions correspond to a tuple that describes how a
// smaller rank shape is broadcast into a larger rank shape. For example,
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
// matching the matrix to dimensions 1 and 2 of the cuboid.
def BroadcastDimAttr : OptionalAttr<I64ElementsAttr>;
class BASE_HLO_AddOp {
string summary = "Addition operator";

View File

@ -242,6 +242,16 @@ def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
);
}
def HLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim",
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
let arguments = (ins
LHLO_Buffer:$operand,
HLO_BASE_DimensionTensor:$output_dimensions,
LHLO_Buffer:$output,
BroadcastDimAttr:$broadcast_dimensions
);
}
def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
let arguments = (ins
LHLO_Buffer:$min,

View File

@ -136,6 +136,30 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// -----
// CHECK-LABEL: func @broadcast_in_dim_memref
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) : (memref<i32>, memref<1x2x3xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @dynamic_broadcast_in_dim_memref
func @dynamic_broadcast_in_dim_memref(%arg0: memref<?x?xi32>, %out: memref<?x?x?xi32>, %shape: tensor<3xi64>) -> () {
"xla_lhlo.dynamic_broadcast_in_dim"(%arg0, %shape, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xi32>, tensor<3xi64>, memref<?x?x?xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @reduce_memref
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
"xla_lhlo.reduce"(%input, %init, %out) ( {

View File

@ -108,6 +108,14 @@ func @broadcast_in_dim_zero_rank(%arg0: tensor<i32>) -> tensor<1x2x3xi32> {
// -----
// CHECK-LABEL: func @dynamic_broadcast_in_dim
func @dynamic_broadcast_in_dim(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<?x?x?xi32> {
%0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<?x?x?xi32>
return %0 : tensor<?x?x?xi32>
}
// -----
func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}}
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>