[TF:XLA] Implement SqrtGrad.

PiperOrigin-RevId: 167000454
This commit is contained in:
Peter Hawkins 2017-08-30 08:57:22 -07:00 committed by TensorFlower Gardener
parent 96b8526273
commit f9c5e921dd
3 changed files with 20 additions and 0 deletions

View File

@ -94,6 +94,12 @@ class BinaryOpsTest(XLATestCase):
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([-160, -81, -28, -4], dtype=dtype))
self._testBinary(
gen_math_ops._sqrt_grad,
np.array([4, 3, 2, 1], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([0.625, 1, 1.75, 4], dtype=dtype))
self._testBinary(
gen_nn_ops._softplus_grad,
np.array([4, 3, 2, 1], dtype=dtype),

View File

@ -2496,6 +2496,16 @@ TEST_F(OpTest, Sqrt) {
});
}
TEST_F(OpTest, SqrtGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, SquaredDifference) {
Repeatedly([this]() {
auto dims = BroadcastableDims();

View File

@ -107,6 +107,10 @@ XLA_MAKE_BINARY(
b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
extend_dimensions));
XLA_MAKE_BINARY(SqrtGrad,
b->Div(b->Mul(rhs,
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
lhs, extend_dimensions));
static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder,
const xla::ComputationDataHandle& x) {