Allow dynamic shape tensors in XLA HLO dialect
This way XLA static shape restrictions can be kept limited to MLIR translation to HLOModule and other use-cases of this dialect can have dynamic shapes. Also, update TensorFlow Relu and Relu6 op lowering to XLA to restrict it to static shape as the shape is used to construct the splat constant. PiperOrigin-RevId: 272570959
This commit is contained in:
parent
2a8e8706a4
commit
7206485040
@ -49,17 +49,17 @@ class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Any integer tensor types
|
||||
def HLO_IntTensor : StaticShapeTensorOf<[HLO_Int]>;
|
||||
def HLO_IntTensor : TensorOf<[HLO_Int]>;
|
||||
|
||||
// Any floating-point tensor types
|
||||
def HLO_FpTensor : StaticShapeTensorOf<[AnyFloat]>;
|
||||
def HLO_FpTensor : TensorOf<[AnyFloat]>;
|
||||
|
||||
def HLO_PredTensor : StaticShapeTensorOf<[HLO_Pred]>;
|
||||
def HLO_PredTensor : TensorOf<[HLO_Pred]>;
|
||||
|
||||
// Any integer or floating-point tensor types
|
||||
def HLO_IntOrFpTensor : StaticShapeTensorOf<[HLO_Int, AnyFloat]>;
|
||||
def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>;
|
||||
|
||||
def HLO_Tensor : StaticShapeTensorOf<[AnyFloat, AnyInteger]>;
|
||||
def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger]>;
|
||||
|
||||
def HLO_Tuple : NestedTupleOf<[HLO_Tensor]>;
|
||||
|
||||
|
@ -534,21 +534,21 @@ func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @abs_dynamic
|
||||
func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_hlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @abs_rankless
|
||||
func @abs_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cast_dynamic_i2f
|
||||
func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_hlo.convert"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -576,14 +576,14 @@ func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @ceil_dynamic
|
||||
func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_hlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ceil_rankless
|
||||
func @ceil_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -597,14 +597,14 @@ func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @cos_dynamic
|
||||
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_hlo.cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cos_rankless
|
||||
func @cos_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -618,14 +618,14 @@ func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @exp_dynamic
|
||||
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_hlo.exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @exp_rankless
|
||||
func @exp_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -639,14 +639,14 @@ func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @floor_dynamic
|
||||
func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_hlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @floor_rankless
|
||||
func @floor_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -660,14 +660,14 @@ func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @neg_dynamic
|
||||
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_hlo.neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @neg_rankless
|
||||
func @neg_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -693,14 +693,14 @@ func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @rsqrt_dynamic
|
||||
func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @rsqrt_rankless
|
||||
func @rsqrt_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -714,14 +714,14 @@ func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @tanh_dynamic
|
||||
func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_hlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tanh_rankless
|
||||
func @tanh_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
// -----
|
||||
|
||||
func @enforce_static_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// expected-error@+1 {{op operand #0 must be statically shaped tensor}}
|
||||
func @dynamic_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
}
|
||||
@ -347,7 +346,7 @@ func @select_scalar_pred(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tenso
|
||||
// -----
|
||||
|
||||
func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// expected-error@+1 {{must be statically shaped tensor of pred (AKA boolean or 1-bit integer)}}
|
||||
// expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) values}}
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user