Fixed shape inference in xlog1py. Modified binary ops tests to run on 1d & 2d inputs so that broadcasting is also tested implicitly. Verified that test failed on xlog1py prior to change in binary_ops.cc.

PiperOrigin-RevId: 317939721
Change-Id: I6f1f8e501028b84933c152d7315fe67fee5b9b46
This commit is contained in:
A. Unique TensorFlower 2020-06-23 14:28:35 -07:00 committed by TensorFlower Gardener
parent b88bebf1ed
commit 4f6e48e9fd
2 changed files with 7 additions and 6 deletions
tensorflow/compiler

View File

@ -229,16 +229,16 @@ class BinaryOpsTest(xla_test.XLATestCase):
self._testBinary(
gen_math_ops.xdivy,
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype),
expected=np.array([0, 0.8, 0.5, 0.285714, 0.125, 0], dtype=dtype),
np.array([[0, 5, 6, 7, 8, float("NaN")]], dtype=dtype),
expected=np.array([[0, 0.8, 0.5, 0.285714, 0.125, 0]], dtype=dtype),
rtol=1e-6,
atol=1e-6)
self._testBinary(
gen_math_ops.xlogy,
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype),
expected=np.array([0, 6.437752, 5.375278, 3.89182, 2.079442, 0],
np.array([[0, 5, 6, 7, 8, float("NaN")]], dtype=dtype),
expected=np.array([[0, 6.437752, 5.375278, 3.89182, 2.079442, 0]],
dtype=dtype),
rtol=1e-4,
atol=1e-6)
@ -246,8 +246,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
self._testBinary(
gen_math_ops.xlog1py,
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
np.array([-1, 5, 6, 7, 8, float("NaN")], dtype=dtype),
expected=np.array([0, 7.167038, 5.837730, 4.158883, 2.197225, 0],
np.array([[-1, 5, 6, 7, 8, float("NaN")]], dtype=dtype),
expected=np.array([[0, 7.167038, 5.837730, 4.158883, 2.197225, 0]],
dtype=dtype),
rtol=1e-4,
atol=1e-6)

View File

@ -153,6 +153,7 @@ XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper));
xla::XlaOp Xlog1pyImpl(xla::XlaOp x, xla::XlaOp y,
const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto non_zero = xla::Mul(x, xla::Log1p(y));
auto zero = xla::ZerosLike(non_zero);
auto x_is_zero = xla::Eq(x, zero);