Enable complex types for TanhGrad legalization
PiperOrigin-RevId: 351335238 Change-Id: Ic424040882b8807eca1e01c7fac678ae6c5d30b1
This commit is contained in:
parent
36485830be
commit
8f0920289a
@ -508,7 +508,7 @@ func @tanhgrad_float(%y : tensor<*xf32>, %dy: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-LABEL: func @tanhgrad_complex
|
||||
// CHECK-SAME: (%[[Y:.*]]: tensor<*xcomplex<f32>>, %[[DY:.*]]: tensor<*xcomplex<f32>>)
|
||||
func @tanhgrad_complex(%y : tensor<*xcomplex<f32>>, %dy: tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>> {
|
||||
// CHECK: tf.TanhGrad
|
||||
// CHECK-NOT: tf.TanhGrad
|
||||
%0 = "tf.TanhGrad"(%y, %dy) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
|
||||
|
||||
return %0 : tensor<*xcomplex<f32>>
|
||||
|
@ -259,9 +259,8 @@ def : Pat<(TF_RsqrtGradOp $lhs, $rhs),
|
||||
|
||||
// grad = dy * (1 - y**2)
|
||||
|
||||
// TODO(hinsu): Support complex input types.
|
||||
def LowerTanhGradOp :
|
||||
Pat<(TF_TanhGradOp TF_FloatTensor:$y, TF_FloatTensor:$dy),
|
||||
Pat<(TF_TanhGradOp $y, $dy),
|
||||
(TF_MulOp $dy,
|
||||
(TF_SubOp (TF_ConstOp (GetScalarOfType<1> $y)),
|
||||
(TF_SquareOp $y)))>;
|
||||
|
@ -514,8 +514,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
expected=expected,
|
||||
equality_test=NextAfterEqualityTest)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Complex types not supported in CreateDenseElementsAttrFromLiteral")
|
||||
def testComplexOps(self):
|
||||
for dtype in self.complex_types:
|
||||
ctypes = {np.complex64: np.float32, np.complex128: np.float64}
|
||||
|
Loading…
Reference in New Issue
Block a user