Allow dynamic shape tensors in XLA HLO dialect

This way XLA static shape restrictions can be kept limited to MLIR translation to HLOModule and other use-cases of this dialect can have dynamic shapes.

Also, update TensorFlow Relu and Relu6 op lowering to XLA to restrict it to static shape as the shape is used to construct the splat constant.

PiperOrigin-RevId: 272570959
This commit is contained in:
Smit Hinsu 2019-10-02 19:12:38 -07:00 committed by TensorFlower Gardener
parent 2a8e8706a4
commit 7206485040
3 changed files with 24 additions and 25 deletions

View File

@ -49,17 +49,17 @@ class HLO_Op<string mnemonic, list<OpTrait> traits> :
//===----------------------------------------------------------------------===//
// Any integer tensor types
def HLO_IntTensor : StaticShapeTensorOf<[HLO_Int]>;
def HLO_IntTensor : TensorOf<[HLO_Int]>;
// Any floating-point tensor types
def HLO_FpTensor : StaticShapeTensorOf<[AnyFloat]>;
def HLO_FpTensor : TensorOf<[AnyFloat]>;
def HLO_PredTensor : StaticShapeTensorOf<[HLO_Pred]>;
def HLO_PredTensor : TensorOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def HLO_IntOrFpTensor : StaticShapeTensorOf<[HLO_Int, AnyFloat]>;
def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>;
def HLO_Tensor : StaticShapeTensorOf<[AnyFloat, AnyInteger]>;
def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger]>;
def HLO_Tuple : NestedTupleOf<[HLO_Tensor]>;

View File

@ -534,21 +534,21 @@ func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @abs_dynamic
func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_hlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @abs_rankless
func @abs_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @cast_dynamic_i2f
func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> {
// CHECK: "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
// CHECK: "xla_hlo.convert"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
%0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -576,14 +576,14 @@ func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @ceil_dynamic
func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_hlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @ceil_rankless
func @ceil_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -597,14 +597,14 @@ func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @cos_dynamic
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_hlo.cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @cos_rankless
func @cos_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -618,14 +618,14 @@ func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @exp_dynamic
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_hlo.exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @exp_rankless
func @exp_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -639,14 +639,14 @@ func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @floor_dynamic
func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_hlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @floor_rankless
func @floor_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -660,14 +660,14 @@ func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @neg_dynamic
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_hlo.neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @neg_rankless
func @neg_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -693,14 +693,14 @@ func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @rsqrt_dynamic
func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @rsqrt_rankless
func @rsqrt_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -714,14 +714,14 @@ func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @tanh_dynamic
func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_hlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @tanh_rankless
func @tanh_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -2,8 +2,7 @@
// -----
func @enforce_static_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// expected-error@+1 {{op operand #0 must be statically shaped tensor}}
func @dynamic_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0: tensor<*xf32>
}
@ -347,7 +346,7 @@ func @select_scalar_pred(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tenso
// -----
func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
// expected-error@+1 {{must be statically shaped tensor of pred (AKA boolean or 1-bit integer)}}
// expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) values}}
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
}