diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 1a626c2502c..2c2707f60b4 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -49,17 +49,17 @@ class HLO_Op traits> : //===----------------------------------------------------------------------===// // Any integer tensor types -def HLO_IntTensor : StaticShapeTensorOf<[HLO_Int]>; +def HLO_IntTensor : TensorOf<[HLO_Int]>; // Any floating-point tensor types -def HLO_FpTensor : StaticShapeTensorOf<[AnyFloat]>; +def HLO_FpTensor : TensorOf<[AnyFloat]>; -def HLO_PredTensor : StaticShapeTensorOf<[HLO_Pred]>; +def HLO_PredTensor : TensorOf<[HLO_Pred]>; // Any integer or floating-point tensor types -def HLO_IntOrFpTensor : StaticShapeTensorOf<[HLO_Int, AnyFloat]>; +def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>; -def HLO_Tensor : StaticShapeTensorOf<[AnyFloat, AnyInteger]>; +def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger]>; def HLO_Tuple : NestedTupleOf<[HLO_Tensor]>; diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index e7c770fffef..b8f24e82071 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -534,21 +534,21 @@ func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @abs_dynamic func @abs_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Abs"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.abs"(%arg0) : (tensor) -> tensor %0 = "tf.Abs"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @abs_rankless func @abs_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @cast_dynamic_i2f func @cast_dynamic_i2f(%arg0: tensor) -> tensor { - // CHECK: "tf.Cast"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.convert"(%arg0) : (tensor) -> tensor %0 = "tf.Cast"(%arg0) : (tensor) -> tensor return %0 : tensor } @@ -576,14 +576,14 @@ func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @ceil_dynamic func @ceil_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Ceil"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.ceil"(%arg0) : (tensor) -> tensor %0 = "tf.Ceil"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @ceil_rankless func @ceil_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -597,14 +597,14 @@ func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @cos_dynamic func @cos_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Cos"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.cos"(%arg0) : (tensor) -> tensor %0 = "tf.Cos"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @cos_rankless func @cos_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -618,14 +618,14 @@ func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @exp_dynamic func @exp_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Exp"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.exp"(%arg0) : (tensor) -> tensor %0 = "tf.Exp"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @exp_rankless func @exp_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -639,14 +639,14 @@ func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @floor_dynamic func @floor_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Floor"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.floor"(%arg0) : (tensor) -> tensor %0 = "tf.Floor"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @floor_rankless func @floor_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -660,14 +660,14 @@ func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @neg_dynamic func @neg_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Neg"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.neg"(%arg0) : (tensor) -> tensor %0 = "tf.Neg"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @neg_rankless func @neg_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -693,14 +693,14 @@ func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @rsqrt_dynamic func @rsqrt_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Rsqrt"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor) -> tensor %0 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @rsqrt_rankless func @rsqrt_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -714,14 +714,14 @@ func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @tanh_dynamic func @tanh_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Tanh"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.tanh"(%arg0) : (tensor) -> tensor %0 = "tf.Tanh"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @tanh_rankless func @tanh_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 4520f7615ca..915a2c1a034 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -2,8 +2,7 @@ // ----- -func @enforce_static_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // expected-error@+1 {{op operand #0 must be statically shaped tensor}} +func @dynamic_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0: tensor<*xf32> } @@ -347,7 +346,7 @@ func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tenso // ----- func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - // expected-error@+1 {{must be statically shaped tensor of pred (AKA boolean or 1-bit integer)}} + // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) values}} %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> }