Enable complex types for TanhGrad legalization

PiperOrigin-RevId: 351335238
Change-Id: Ic424040882b8807eca1e01c7fac678ae6c5d30b1
This commit is contained in:
Smit Hinsu 2021-01-12 03:11:28 -08:00 committed by TensorFlower Gardener
parent 36485830be
commit 8f0920289a
3 changed files with 2 additions and 5 deletions

View File

@ -508,7 +508,7 @@ func @tanhgrad_float(%y : tensor<*xf32>, %dy: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @tanhgrad_complex
// CHECK-SAME: (%[[Y:.*]]: tensor<*xcomplex<f32>>, %[[DY:.*]]: tensor<*xcomplex<f32>>)
func @tanhgrad_complex(%y : tensor<*xcomplex<f32>>, %dy: tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>> {
// CHECK: tf.TanhGrad
// CHECK-NOT: tf.TanhGrad
%0 = "tf.TanhGrad"(%y, %dy) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
return %0 : tensor<*xcomplex<f32>>

View File

@ -259,9 +259,8 @@ def : Pat<(TF_RsqrtGradOp $lhs, $rhs),
// grad = dy * (1 - y**2)
// TODO(hinsu): Support complex input types.
def LowerTanhGradOp :
Pat<(TF_TanhGradOp TF_FloatTensor:$y, TF_FloatTensor:$dy),
Pat<(TF_TanhGradOp $y, $dy),
(TF_MulOp $dy,
(TF_SubOp (TF_ConstOp (GetScalarOfType<1> $y)),
(TF_SquareOp $y)))>;

View File

@ -514,8 +514,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=expected,
equality_test=NextAfterEqualityTest)
@test_util.disable_mlir_bridge(
"Complex types not supported in CreateDenseElementsAttrFromLiteral")
def testComplexOps(self):
for dtype in self.complex_types:
ctypes = {np.complex64: np.float32, np.complex128: np.float64}