Allow dynamic shape tensors in XLA HLO dialect
This way XLA static shape restrictions can be kept limited to MLIR translation to HLOModule and other use-cases of this dialect can have dynamic shapes. Also, update TensorFlow Relu and Relu6 op lowering to XLA to restrict it to static shape as the shape is used to construct the splat constant. PiperOrigin-RevId: 272570959
This commit is contained in:
parent
2a8e8706a4
commit
7206485040
@ -49,17 +49,17 @@ class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Any integer tensor types
|
// Any integer tensor types
|
||||||
def HLO_IntTensor : StaticShapeTensorOf<[HLO_Int]>;
|
def HLO_IntTensor : TensorOf<[HLO_Int]>;
|
||||||
|
|
||||||
// Any floating-point tensor types
|
// Any floating-point tensor types
|
||||||
def HLO_FpTensor : StaticShapeTensorOf<[AnyFloat]>;
|
def HLO_FpTensor : TensorOf<[AnyFloat]>;
|
||||||
|
|
||||||
def HLO_PredTensor : StaticShapeTensorOf<[HLO_Pred]>;
|
def HLO_PredTensor : TensorOf<[HLO_Pred]>;
|
||||||
|
|
||||||
// Any integer or floating-point tensor types
|
// Any integer or floating-point tensor types
|
||||||
def HLO_IntOrFpTensor : StaticShapeTensorOf<[HLO_Int, AnyFloat]>;
|
def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>;
|
||||||
|
|
||||||
def HLO_Tensor : StaticShapeTensorOf<[AnyFloat, AnyInteger]>;
|
def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger]>;
|
||||||
|
|
||||||
def HLO_Tuple : NestedTupleOf<[HLO_Tensor]>;
|
def HLO_Tuple : NestedTupleOf<[HLO_Tensor]>;
|
||||||
|
|
||||||
|
@ -534,21 +534,21 @@ func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
|
|
||||||
// CHECK-LABEL: func @abs_dynamic
|
// CHECK-LABEL: func @abs_dynamic
|
||||||
func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: "xla_hlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @abs_rankless
|
// CHECK-LABEL: func @abs_rankless
|
||||||
func @abs_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @abs_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
// CHECK: "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @cast_dynamic_i2f
|
// CHECK-LABEL: func @cast_dynamic_i2f
|
||||||
func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> {
|
func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> {
|
||||||
// CHECK: "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
|
// CHECK: "xla_hlo.convert"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
|
||||||
%0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
|
%0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
@ -576,14 +576,14 @@ func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
|
|
||||||
// CHECK-LABEL: func @ceil_dynamic
|
// CHECK-LABEL: func @ceil_dynamic
|
||||||
func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: "xla_hlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
%0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @ceil_rankless
|
// CHECK-LABEL: func @ceil_rankless
|
||||||
func @ceil_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @ceil_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
// CHECK: "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
%0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
@ -597,14 +597,14 @@ func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
|
|
||||||
// CHECK-LABEL: func @cos_dynamic
|
// CHECK-LABEL: func @cos_dynamic
|
||||||
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: "xla_hlo.cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
%0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @cos_rankless
|
// CHECK-LABEL: func @cos_rankless
|
||||||
func @cos_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @cos_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
// CHECK: "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
%0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
@ -618,14 +618,14 @@ func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
|
|
||||||
// CHECK-LABEL: func @exp_dynamic
|
// CHECK-LABEL: func @exp_dynamic
|
||||||
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: "xla_hlo.exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
%0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @exp_rankless
|
// CHECK-LABEL: func @exp_rankless
|
||||||
func @exp_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @exp_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
// CHECK: "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
%0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
@ -639,14 +639,14 @@ func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
|
|
||||||
// CHECK-LABEL: func @floor_dynamic
|
// CHECK-LABEL: func @floor_dynamic
|
||||||
func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: "xla_hlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
%0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @floor_rankless
|
// CHECK-LABEL: func @floor_rankless
|
||||||
func @floor_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @floor_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
// CHECK: "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
%0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
@ -660,14 +660,14 @@ func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
|
|
||||||
// CHECK-LABEL: func @neg_dynamic
|
// CHECK-LABEL: func @neg_dynamic
|
||||||
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: "xla_hlo.neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
%0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @neg_rankless
|
// CHECK-LABEL: func @neg_rankless
|
||||||
func @neg_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @neg_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
// CHECK: "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
%0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
@ -693,14 +693,14 @@ func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
|
|
||||||
// CHECK-LABEL: func @rsqrt_dynamic
|
// CHECK-LABEL: func @rsqrt_dynamic
|
||||||
func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
%0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @rsqrt_rankless
|
// CHECK-LABEL: func @rsqrt_rankless
|
||||||
func @rsqrt_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @rsqrt_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
// CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
%0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
@ -714,14 +714,14 @@ func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
|
|
||||||
// CHECK-LABEL: func @tanh_dynamic
|
// CHECK-LABEL: func @tanh_dynamic
|
||||||
func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: "xla_hlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
%0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @tanh_rankless
|
// CHECK-LABEL: func @tanh_rankless
|
||||||
func @tanh_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @tanh_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
// CHECK: "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
%0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @enforce_static_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @dynamic_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// expected-error@+1 {{op operand #0 must be statically shaped tensor}}
|
|
||||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %0: tensor<*xf32>
|
return %0: tensor<*xf32>
|
||||||
}
|
}
|
||||||
@ -347,7 +346,7 @@ func @select_scalar_pred(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tenso
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||||
// expected-error@+1 {{must be statically shaped tensor of pred (AKA boolean or 1-bit integer)}}
|
// expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) values}}
|
||||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||||
return %0 : tensor<2x3xi32>
|
return %0 : tensor<2x3xi32>
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user