[MLIR:TF/XLA] Lower SigmoidGrad op to HLO
PiperOrigin-RevId: 306537968 Change-Id: Ie37d58ee0671131c1c0906882705b98247496ddd
This commit is contained in:
parent
d9666eb32b
commit
76079c00ac
@ -2064,6 +2064,17 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sigmoid_grad
|
||||
func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xf32>
|
||||
// CHECK-DAG: [[ONE:%.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<2xf32>
|
||||
// CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ONE]], %arg0 : tensor<2xf32>
|
||||
// CHECK-DAG: [[MUL1:%.+]] = xla_hlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32>
|
||||
// CHECK: return [[MUL1]]
|
||||
%0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sin
|
||||
func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
@ -610,3 +610,14 @@ def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2),
|
||||
(CastValueToI64 $old, $shape)),
|
||||
[(IsShapedTensor $shape)]>;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sigmoid grad op.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
(HLO_MulOp
|
||||
(HLO_MulOp $r, $l, (NullDenseIntElementsAttr)),
|
||||
(HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l,
|
||||
(NullDenseIntElementsAttr)),
|
||||
(NullDenseIntElementsAttr)),
|
||||
[(IEEEFloatTensor $l)]>;
|
||||
|
Loading…
Reference in New Issue
Block a user