[MLIR:TF/XLA] Lower SigmoidGrad op to HLO
PiperOrigin-RevId: 306537968 Change-Id: Ie37d58ee0671131c1c0906882705b98247496ddd
This commit is contained in:
parent
d9666eb32b
commit
76079c00ac
@ -2064,6 +2064,17 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
return %0 : tensor<2xf32>
|
return %0 : tensor<2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @sigmoid_grad
|
||||||
|
func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
// CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xf32>
|
||||||
|
// CHECK-DAG: [[ONE:%.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<2xf32>
|
||||||
|
// CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ONE]], %arg0 : tensor<2xf32>
|
||||||
|
// CHECK-DAG: [[MUL1:%.+]] = xla_hlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32>
|
||||||
|
// CHECK: return [[MUL1]]
|
||||||
|
%0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @sin
|
// CHECK-LABEL: @sin
|
||||||
func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
// CHECK: "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
// CHECK: "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
@ -610,3 +610,14 @@ def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2),
|
|||||||
(CastValueToI64 $old, $shape)),
|
(CastValueToI64 $old, $shape)),
|
||||||
[(IsShapedTensor $shape)]>;
|
[(IsShapedTensor $shape)]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Sigmoid grad op.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_MulOp $r, $l, (NullDenseIntElementsAttr)),
|
||||||
|
(HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l,
|
||||||
|
(NullDenseIntElementsAttr)),
|
||||||
|
(NullDenseIntElementsAttr)),
|
||||||
|
[(IEEEFloatTensor $l)]>;
|
||||||
|
Loading…
Reference in New Issue
Block a user