[TF:XLA] Implement SqrtGrad.
PiperOrigin-RevId: 167000454
This commit is contained in:
parent
96b8526273
commit
f9c5e921dd
@ -94,6 +94,12 @@ class BinaryOpsTest(XLATestCase):
|
|||||||
np.array([5, 6, 7, 8], dtype=dtype),
|
np.array([5, 6, 7, 8], dtype=dtype),
|
||||||
expected=np.array([-160, -81, -28, -4], dtype=dtype))
|
expected=np.array([-160, -81, -28, -4], dtype=dtype))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
gen_math_ops._sqrt_grad,
|
||||||
|
np.array([4, 3, 2, 1], dtype=dtype),
|
||||||
|
np.array([5, 6, 7, 8], dtype=dtype),
|
||||||
|
expected=np.array([0.625, 1, 1.75, 4], dtype=dtype))
|
||||||
|
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
gen_nn_ops._softplus_grad,
|
gen_nn_ops._softplus_grad,
|
||||||
np.array([4, 3, 2, 1], dtype=dtype),
|
np.array([4, 3, 2, 1], dtype=dtype),
|
||||||
|
@ -2496,6 +2496,16 @@ TEST_F(OpTest, Sqrt) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OpTest, SqrtGrad) {
|
||||||
|
Repeatedly([this]() {
|
||||||
|
auto dims = RandomDims();
|
||||||
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad")
|
||||||
|
.RandomInput(DT_FLOAT, dims)
|
||||||
|
.RandomInput(DT_FLOAT, dims)
|
||||||
|
.Attr("T", DT_FLOAT));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, SquaredDifference) {
|
TEST_F(OpTest, SquaredDifference) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
auto dims = BroadcastableDims();
|
auto dims = BroadcastableDims();
|
||||||
|
@ -107,6 +107,10 @@ XLA_MAKE_BINARY(
|
|||||||
b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
|
b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
|
||||||
b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
|
b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
|
||||||
extend_dimensions));
|
extend_dimensions));
|
||||||
|
XLA_MAKE_BINARY(SqrtGrad,
|
||||||
|
b->Div(b->Mul(rhs,
|
||||||
|
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
|
||||||
|
lhs, extend_dimensions));
|
||||||
|
|
||||||
static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder,
|
static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder,
|
||||||
const xla::ComputationDataHandle& x) {
|
const xla::ComputationDataHandle& x) {
|
||||||
|
Loading…
Reference in New Issue
Block a user