diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 6b2a4f1d92b..42fb73f17e2 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -97,6 +97,20 @@ func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x return %0: tensor<1x2xi32> } +// CHECK-LABEL: func @maximum +func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.max %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @minimum +func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.min %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + // CHECK-LABEL: func @mul func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = xla_hlo.mul %arg0, %arg0 : tensor<2xi32> diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 6d78f37e598..f2893507b3b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -80,8 +80,10 @@ class DirectBinaryPat (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; foreach fromToBinPair = [[TF_AddOp, HLO_AddOp], - [TF_AddV2Op, HLO_AddOp], + [TF_AddV2Op, HLO_AddOp], [TF_DivOp, HLO_DivOp], + [TF_MaximumOp, HLO_MaxOp], + [TF_MinimumOp, HLO_MinOp], [TF_MulOp, HLO_MulOp], [TF_RealDivOp, HLO_DivOp], [TF_SubOp, HLO_SubOp]] in