[TF:XLA] Implement SqrtGrad.
PiperOrigin-RevId: 167000454
This commit is contained in:
parent
96b8526273
commit
f9c5e921dd
@ -94,6 +94,12 @@ class BinaryOpsTest(XLATestCase):
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array([-160, -81, -28, -4], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._sqrt_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array([0.625, 1, 1.75, 4], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops._softplus_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
|
@ -2496,6 +2496,16 @@ TEST_F(OpTest, Sqrt) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, SqrtGrad) {
|
||||
Repeatedly([this]() {
|
||||
auto dims = RandomDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, SquaredDifference) {
|
||||
Repeatedly([this]() {
|
||||
auto dims = BroadcastableDims();
|
||||
|
@ -107,6 +107,10 @@ XLA_MAKE_BINARY(
|
||||
b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
|
||||
b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
|
||||
extend_dimensions));
|
||||
XLA_MAKE_BINARY(SqrtGrad,
|
||||
b->Div(b->Mul(rhs,
|
||||
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
|
||||
lhs, extend_dimensions));
|
||||
|
||||
static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder,
|
||||
const xla::ComputationDataHandle& x) {
|
||||
|
Loading…
Reference in New Issue
Block a user