From 720648504039e267b984ab179d9c2b239cdf8eef Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Wed, 2 Oct 2019 19:12:38 -0700 Subject: [PATCH] Allow dynamic shape tensors in XLA HLO dialect This way XLA static shape restrictions can be kept limited to MLIR translation to HLOModule and other use-cases of this dialect can have dynamic shapes. Also, update TensorFlow Relu and Relu6 op lowering to XLA to restrict it to static shape as the shape is used to construct the splat constant. PiperOrigin-RevId: 272570959 --- tensorflow/compiler/mlir/xla/ir/hlo_ops.td | 10 +++--- .../compiler/mlir/xla/tests/legalize-tf.mlir | 34 +++++++++---------- tensorflow/compiler/mlir/xla/tests/ops.mlir | 5 ++- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 1a626c2502c..2c2707f60b4 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -49,17 +49,17 @@ class HLO_Op traits> : //===----------------------------------------------------------------------===// // Any integer tensor types -def HLO_IntTensor : StaticShapeTensorOf<[HLO_Int]>; +def HLO_IntTensor : TensorOf<[HLO_Int]>; // Any floating-point tensor types -def HLO_FpTensor : StaticShapeTensorOf<[AnyFloat]>; +def HLO_FpTensor : TensorOf<[AnyFloat]>; -def HLO_PredTensor : StaticShapeTensorOf<[HLO_Pred]>; +def HLO_PredTensor : TensorOf<[HLO_Pred]>; // Any integer or floating-point tensor types -def HLO_IntOrFpTensor : StaticShapeTensorOf<[HLO_Int, AnyFloat]>; +def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>; -def HLO_Tensor : StaticShapeTensorOf<[AnyFloat, AnyInteger]>; +def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger]>; def HLO_Tuple : NestedTupleOf<[HLO_Tensor]>; diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index e7c770fffef..b8f24e82071 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -534,21 +534,21 @@ func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @abs_dynamic func @abs_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Abs"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.abs"(%arg0) : (tensor) -> tensor %0 = "tf.Abs"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @abs_rankless func @abs_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @cast_dynamic_i2f func @cast_dynamic_i2f(%arg0: tensor) -> tensor { - // CHECK: "tf.Cast"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.convert"(%arg0) : (tensor) -> tensor %0 = "tf.Cast"(%arg0) : (tensor) -> tensor return %0 : tensor } @@ -576,14 +576,14 @@ func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @ceil_dynamic func @ceil_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Ceil"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.ceil"(%arg0) : (tensor) -> tensor %0 = "tf.Ceil"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @ceil_rankless func @ceil_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -597,14 +597,14 @@ func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @cos_dynamic func @cos_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Cos"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.cos"(%arg0) : (tensor) -> tensor %0 = "tf.Cos"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @cos_rankless func @cos_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -618,14 +618,14 @@ func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @exp_dynamic func @exp_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Exp"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.exp"(%arg0) : (tensor) -> tensor %0 = "tf.Exp"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @exp_rankless func @exp_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -639,14 +639,14 @@ func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @floor_dynamic func @floor_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Floor"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.floor"(%arg0) : (tensor) -> tensor %0 = "tf.Floor"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @floor_rankless func @floor_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -660,14 +660,14 @@ func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @neg_dynamic func @neg_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Neg"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.neg"(%arg0) : (tensor) -> tensor %0 = "tf.Neg"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @neg_rankless func @neg_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -693,14 +693,14 @@ func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @rsqrt_dynamic func @rsqrt_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Rsqrt"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor) -> tensor %0 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @rsqrt_rankless func @rsqrt_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -714,14 +714,14 @@ func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @tanh_dynamic func @tanh_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Tanh"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.tanh"(%arg0) : (tensor) -> tensor %0 = "tf.Tanh"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @tanh_rankless func @tanh_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 4520f7615ca..915a2c1a034 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -2,8 +2,7 @@ // ----- -func @enforce_static_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // expected-error@+1 {{op operand #0 must be statically shaped tensor}} +func @dynamic_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0: tensor<*xf32> } @@ -347,7 +346,7 @@ func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tenso // ----- func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - // expected-error@+1 {{must be statically shaped tensor of pred (AKA boolean or 1-bit integer)}} + // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) values}} %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> }