diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index cf7ffc5af19..e349aefd4cb 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -94,6 +94,12 @@ class BinaryOpsTest(XLATestCase): np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-160, -81, -28, -4], dtype=dtype)) + self._testBinary( + gen_math_ops._sqrt_grad, + np.array([4, 3, 2, 1], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([0.625, 1, 1.75, 4], dtype=dtype)) + self._testBinary( gen_nn_ops._softplus_grad, np.array([4, 3, 2, 1], dtype=dtype), diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 6a0bdf1ed15..e380fdd7f4e 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -2496,6 +2496,16 @@ TEST_F(OpTest, Sqrt) { }); } +TEST_F(OpTest, SqrtGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, SquaredDifference) { Repeatedly([this]() { auto dims = BroadcastableDims(); diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index ded20a9a3ce..f9bb1e2fb1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -107,6 +107,10 @@ XLA_MAKE_BINARY( b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), extend_dimensions)); +XLA_MAKE_BINARY(SqrtGrad, + b->Div(b->Mul(rhs, + XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), + lhs, extend_dimensions)); static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) {