diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 9c23d5b3332..ffd000a668e 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -2064,6 +2064,17 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } +// CHECK-LABEL: @sigmoid_grad +func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xf32> + // CHECK-DAG: [[ONE:%.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<2xf32> + // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ONE]], %arg0 : tensor<2xf32> + // CHECK-DAG: [[MUL1:%.+]] = xla_hlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32> + // CHECK: return [[MUL1]] + %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + // CHECK-LABEL: @sin func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index bb505b6d3d6..6a36f3e7f2b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -610,3 +610,14 @@ def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2), (CastValueToI64 $old, $shape)), [(IsShapedTensor $shape)]>; } + +//===----------------------------------------------------------------------===// +// Sigmoid grad op. +//===----------------------------------------------------------------------===// +def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), + (HLO_MulOp + (HLO_MulOp $r, $l, (NullDenseIntElementsAttr)), + (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l, + (NullDenseIntElementsAttr)), + (NullDenseIntElementsAttr)), + [(IEEEFloatTensor $l)]>;