Add tf.math.xlog1py, a safe way to compute x * log1p(y)
PiperOrigin-RevId: 288971952 Change-Id: I3850da3b37f006b11198d203a1b73f3cb336b833
This commit is contained in:
parent
093e476572
commit
19986377f2
@ -1776,9 +1776,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
"Lgamma", "Digamma",
|
||||
// Binary
|
||||
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
|
||||
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd",
|
||||
"BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd",
|
||||
"LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
||||
"MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
|
||||
"BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
|
||||
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
||||
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
|
||||
"TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
|
||||
"Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
|
||||
|
@ -241,6 +241,15 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
rtol=1e-4,
|
||||
atol=1e-6)
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops.xlog1py,
|
||||
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
|
||||
np.array([-1, 5, 6, 7, 8, float("NaN")], dtype=dtype),
|
||||
expected=np.array([0, 7.167038, 5.837730, 4.158883, 2.197225, 0],
|
||||
dtype=dtype),
|
||||
rtol=1e-4,
|
||||
atol=1e-6)
|
||||
|
||||
def testIntOps(self):
|
||||
for dtype in self.signed_int_types:
|
||||
self._testBinary(
|
||||
|
@ -151,6 +151,15 @@ xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y,
|
||||
}
|
||||
XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
xla::XlaOp Xlog1pyImpl(xla::XlaOp x, xla::XlaOp y,
|
||||
const BCast& broadcast_helper) {
|
||||
auto non_zero = xla::Mul(x, xla::Log1p(y));
|
||||
auto zero = xla::ZerosLike(x);
|
||||
auto x_is_zero = xla::Eq(x, zero);
|
||||
return xla::Select(x_is_zero, zero, non_zero);
|
||||
}
|
||||
XLA_MAKE_BINARY(Xlog1py, Xlog1pyImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y,
|
||||
const BCast& broadcast_helper) {
|
||||
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
|
||||
|
4
tensorflow/core/api_def/base_api/api_def_Xlog1py.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_Xlog1py.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Xlog1py"
|
||||
summary: "Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise."
|
||||
}
|
4
tensorflow/core/api_def/python_api/api_def_Xlog1py.pbtxt
Normal file
4
tensorflow/core/api_def/python_api/api_def_Xlog1py.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Xlog1py"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -6472,6 +6472,7 @@ filegroup(
|
||||
"cwise_op_tan.cc",
|
||||
"cwise_op_tanh.cc",
|
||||
"cwise_op_xlogy.cc",
|
||||
"cwise_op_xlog1py.cc",
|
||||
"cwise_op_xdivy.cc",
|
||||
"data_format_ops.cc",
|
||||
"decode_wav_op.cc",
|
||||
|
31
tensorflow/core/kernels/cwise_op_gpu_xlog1py.cu.cc
Normal file
31
tensorflow/core/kernels/cwise_op_gpu_xlog1py.cu.cc
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if GOOGLE_CUDA
|
||||
DEFINE_BINARY5(xlog1py, Eigen::half, float, double, complex64, complex128);
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
// TODO(ROCm): enable complex64 / complex128 after compiler fix.
|
||||
DEFINE_BINARY3(xlog1py, Eigen::half, float, double);
|
||||
#endif
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
41
tensorflow/core/kernels/cwise_op_xlog1py.cc
Normal file
41
tensorflow/core/kernels/cwise_op_xlog1py.cc
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER5(BinaryOp, CPU, "Xlog1py", functor::xlog1py, float, Eigen::half,
|
||||
double, complex64, complex128);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNEL(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Xlog1py").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
|
||||
BinaryOp<SYCLDevice, functor::xlog1py<TYPE>>);
|
||||
REGISTER_SYCL_KERNEL(Eigen::half);
|
||||
REGISTER_SYCL_KERNEL(float);
|
||||
REGISTER_SYCL_KERNEL(double);
|
||||
REGISTER_SYCL_KERNEL(complex64);
|
||||
REGISTER_SYCL_KERNEL(complex128);
|
||||
#undef REGISTER_SYCL_KERNEL
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER5(BinaryOp, GPU, "Xlog1py", functor::xlog1py, float, Eigen::half,
|
||||
double, complex64, complex128);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
@ -703,6 +703,41 @@ struct functor_traits<xlogy_op<Scalar>> {
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct xlog1py_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(xlog1py_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
|
||||
operator()(const Scalar& x, const Scalar& y) const {
|
||||
if (x == Scalar(0.)) {
|
||||
return Scalar(0.);
|
||||
}
|
||||
return x * numext::log1p(y);
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
|
||||
const Packet& y) const {
|
||||
Packet zeros = pzero(x);
|
||||
Packet mask = pcmp_eq(x, zeros);
|
||||
scalar_log1p_op<Scalar> log1p_op;
|
||||
Packet log1p_y = log1p_op.packetOp(y);
|
||||
Packet x_log1p_y = pmul(x, log1p_y);
|
||||
return pselect(mask, x, x_log1p_y);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct functor_traits<xlog1py_op<Scalar>> {
|
||||
enum {
|
||||
Cost = functor_traits<scalar_log1p_op<Scalar>>::Cost +
|
||||
Eigen::NumTraits<Scalar>::MulCost,
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
PacketAccess = false,
|
||||
#else
|
||||
PacketAccess = functor_traits<scalar_log1p_op<Scalar>>::PacketAccess
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct xdivy_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(xdivy_op)
|
||||
@ -1141,6 +1176,9 @@ struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {};
|
||||
template <typename T>
|
||||
struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {};
|
||||
|
||||
template <typename T>
|
||||
struct xlog1py : base<T, Eigen::internal::xlog1py_op<T>> {};
|
||||
|
||||
template <typename T>
|
||||
struct less : base<T, Eigen::internal::less<T>, bool> {};
|
||||
|
||||
|
@ -579,6 +579,25 @@ Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
}
|
||||
REGISTER_OP_GRADIENT("Xlogy", XlogyGrad);
|
||||
|
||||
Status Xlog1pyGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// clang-format off
|
||||
return GradForBinaryCwise(g, {
|
||||
FDH::Const("const", 1.0f),
|
||||
{{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
|
||||
{{"zeros"}, "ZerosLike", {"x"}},
|
||||
{{"yp1"}, "Add", {"y", "one"}},
|
||||
{{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
|
||||
{{"is_zero_cast"}, "Cast", {"is_x_zero"},
|
||||
{{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
|
||||
{{"safe_log1py"}, "Xlog1py", {"is_zero_cast", "y"}},
|
||||
{{"xlog1pygrad"}, "Xdivy", {"x", "yp1"}},
|
||||
{{"gx"}, "Mul", {"safe_log1py", "dz"}},
|
||||
{{"gy"}, "Mul", {"xlog1pygrad", "dz"}},
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
REGISTER_OP_GRADIENT("Xlog1py", Xlog1pyGrad);
|
||||
|
||||
Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// clang-format off
|
||||
return GradForBinaryCwise(g, {
|
||||
|
@ -962,6 +962,29 @@ TEST_F(MathGradTest, Xlogy) {
|
||||
TensorShape({2, 1})));
|
||||
}
|
||||
|
||||
TEST_F(MathGradTest, Xlog1py) {
|
||||
auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f},
|
||||
TensorShape({2, 3}));
|
||||
auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1}));
|
||||
Tensor dx;
|
||||
Tensor dy;
|
||||
auto g = [](float x, float y) -> float {
|
||||
return x == 0. ? 0. : std::log1p(y);
|
||||
};
|
||||
auto h = [](float x, float y) -> float {
|
||||
return x == 0. ? 0. : x / (y + 1.);
|
||||
};
|
||||
SymGrad("Xlog1py", x, y, &dx, &dy);
|
||||
test::ExpectClose(
|
||||
dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f),
|
||||
g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)},
|
||||
TensorShape({2, 3})));
|
||||
test::ExpectClose(
|
||||
dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f),
|
||||
h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)},
|
||||
TensorShape({2, 1})));
|
||||
}
|
||||
|
||||
TEST_F(MathGradTest, Xdivy) {
|
||||
auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f},
|
||||
TensorShape({2, 3}));
|
||||
|
@ -522,6 +522,13 @@ REGISTER_OP("Xlogy")
|
||||
.Attr("T: {half, float, double, complex64, complex128}")
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
|
||||
|
||||
REGISTER_OP("Xlog1py")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {half, float, double, complex64, complex128}")
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
|
||||
|
||||
REGISTER_OP("Xdivy")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
|
@ -691,6 +691,23 @@ def _XLogyGrad(op, grad):
|
||||
array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
|
||||
|
||||
|
||||
@ops.RegisterGradient("Xlog1py")
|
||||
def _XLog1pyGrad(op, grad):
|
||||
"""Returns gradient of xlog1py(x, y) with respect to x and y."""
|
||||
x = op.inputs[0]
|
||||
y = op.inputs[1]
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
with ops.control_dependencies([grad]):
|
||||
not_zero_x = math_ops.cast(
|
||||
math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
|
||||
partial_x = gen_math_ops.xlog1py(not_zero_x, y)
|
||||
partial_y = gen_math_ops.xdivy(x, y + 1.)
|
||||
return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
|
||||
array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
|
||||
|
||||
|
||||
@ops.RegisterGradient("Xdivy")
|
||||
def _XDivyGrad(op, grad):
|
||||
"""Returns gradient of xdivy(x, y) with respect to x and y."""
|
||||
|
@ -489,6 +489,56 @@ class XlogyTest(test.TestCase):
|
||||
self.assertAllClose(zero, xlogy_ygrad)
|
||||
|
||||
|
||||
class Xlog1pyTest(test.TestCase):
|
||||
|
||||
def _xlog1py_gradients(self, x, y):
|
||||
xlog1py_xgrad = self.evaluate(
|
||||
gradients.gradients(math_ops.xlog1py(x, y), x)[0])
|
||||
xlog1py_ygrad = self.evaluate(
|
||||
gradients.gradients(math_ops.xlog1py(x, y), y)[0])
|
||||
return xlog1py_xgrad, xlog1py_ygrad
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroValuesGrad(self):
|
||||
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
|
||||
x = constant_op.constant(0.1, dtype=dtype)
|
||||
y = constant_op.constant(3.1, dtype=dtype)
|
||||
xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y)
|
||||
xlog1py_expected_xgrad = self.evaluate(math_ops.log1p(y))
|
||||
xlog1py_expected_ygrad = self.evaluate(x / (1. + y))
|
||||
self.assertAllClose(xlog1py_expected_xgrad, xlog1py_xgrad)
|
||||
self.assertAllClose(xlog1py_expected_ygrad, xlog1py_ygrad)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZeroXGrad(self):
|
||||
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
|
||||
x = constant_op.constant(0., dtype=dtype)
|
||||
y = constant_op.constant(3.1, dtype=dtype)
|
||||
xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y)
|
||||
zero = self.evaluate(x)
|
||||
self.assertAllClose(zero, xlog1py_xgrad)
|
||||
self.assertAllClose(zero, xlog1py_ygrad)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNegOneYGrad(self):
|
||||
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
|
||||
x = constant_op.constant(0.1, dtype=dtype)
|
||||
y = constant_op.constant(-1., dtype=dtype)
|
||||
xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y)
|
||||
self.assertAllClose(-np.inf, xlog1py_xgrad)
|
||||
self.assertAllClose(np.inf, xlog1py_ygrad)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZeroXNegOneYGrad(self):
|
||||
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
|
||||
x = constant_op.constant(0., dtype=dtype)
|
||||
y = constant_op.constant(-1., dtype=dtype)
|
||||
xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y)
|
||||
zero = self.evaluate(x)
|
||||
self.assertAllClose(zero, xlog1py_xgrad)
|
||||
self.assertAllClose(zero, xlog1py_ygrad)
|
||||
|
||||
|
||||
class XdivyTest(test.TestCase):
|
||||
|
||||
def _xdivy_gradients(self, x, y):
|
||||
|
@ -4290,6 +4290,43 @@ def reciprocal_no_nan(x, name=None):
|
||||
return gen_math_ops.div_no_nan(one, x, name=scope)
|
||||
|
||||
|
||||
@tf_export("math.xlog1py")
|
||||
@dispatch.add_dispatch_support
|
||||
def xlog1py(x, y, name=None):
|
||||
r"""Compute x * log1p(y).
|
||||
|
||||
Given `x` and `y`, compute `x * log1p(y)`. This function safely returns
|
||||
zero when `x = 0`, no matter what the value of `y` is.
|
||||
|
||||
Example:
|
||||
|
||||
>>> tf.math.xlog1py(0., 1.)
|
||||
<tf.Tensor: shape=(), dtype=float32, numpy=0.>
|
||||
>>> tf.math.xlog1py(1., 1.)
|
||||
<tf.Tensor: shape=(), dtype=float32, numpy=0.6931472>
|
||||
>>> tf.math.xlog1py(2., 2.)
|
||||
<tf.Tensor: shape=(), dtype=float32, numpy=2.1972246>
|
||||
>>> tf.math.xlog1py(0., -1.)
|
||||
<tf.Tensor: shape=(), dtype=float32, numpy=0.>
|
||||
|
||||
Args:
|
||||
x: A `tf.Tensor` of type `bfloat16`, `half`, `float32`, `float64`,
|
||||
`complex64`, `complex128`
|
||||
y: A `tf.Tensor` of type `bfloat16`, `half`, `float32`, `float64`,
|
||||
`complex64`, `complex128`
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
`x * log1p(y)`.
|
||||
|
||||
@compatibility(scipy)
|
||||
Equivalent to scipy.special.xlog1py
|
||||
@end_compatibility
|
||||
"""
|
||||
with ops.name_scope(name, "xlog1py", [x]):
|
||||
return gen_math_ops.xlog1py(x, y)
|
||||
|
||||
|
||||
@tf_export("math.erfinv")
|
||||
@dispatch.add_dispatch_support
|
||||
def erfinv(x, name=None):
|
||||
|
@ -560,6 +560,40 @@ class XlogyTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(xtimes_logy, xlogy_tf_np[1])
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class Xlog1pyTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testXlog1pyNoNeg1(self):
|
||||
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
|
||||
x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
|
||||
y = constant_op.constant([[-0.1, -0.2, 3.5], [3.1, -0.9, 2.]],
|
||||
dtype=dtype)
|
||||
with test_util.use_gpu():
|
||||
xlog1py = self.evaluate(math_ops.xlog1py(x, y))
|
||||
xtimeslog1py = self.evaluate(x * math_ops.log1p(y))
|
||||
self.assertAllClose(xlog1py, xtimeslog1py)
|
||||
|
||||
def testXlog1pyWithNegOne(self):
|
||||
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
|
||||
x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
|
||||
y = constant_op.constant([[0.1, 0.2, 3.5], [-1., 1., 2.]], dtype=dtype)
|
||||
with test_util.use_gpu():
|
||||
xlog1py_tf_np = self.evaluate(math_ops.xlog1py(x, y))
|
||||
zeros_np = self.evaluate(array_ops.zeros_like(y))
|
||||
self.assertAllClose(xlog1py_tf_np, zeros_np)
|
||||
|
||||
def testXlog1pyWithZeroBroadcast(self):
|
||||
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
|
||||
x = constant_op.constant([[0.], [1.]], dtype=dtype)
|
||||
y = constant_op.constant([[-0.1, -0.2, -1.], [0., 1., 2.]], dtype=dtype)
|
||||
with test_util.use_gpu():
|
||||
xlog1py_tf_np = self.evaluate(math_ops.xlog1py(x, y))
|
||||
zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
|
||||
xtimes_log1py = self.evaluate(math_ops.log1p(y[1]))
|
||||
self.assertAllClose(zeros_np, xlog1py_tf_np[0])
|
||||
self.assertAllClose(xtimes_log1py, xlog1py_tf_np[1])
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class XdivyTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -2564,6 +2564,7 @@ def _convert_cast(pfor_input):
|
||||
@RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod)
|
||||
@RegisterPForWithArgs("Xdivy", math_ops.xdivy)
|
||||
@RegisterPForWithArgs("Xlogy", math_ops.xlogy)
|
||||
@RegisterPForWithArgs("Xlog1py", math_ops.xlog1py)
|
||||
@RegisterPForWithArgs("Zeta", math_ops.zeta)
|
||||
def _convert_cwise(pfor_input, op_type, op_func):
|
||||
# Note that ops handled here do not have attributes except those listed below
|
||||
|
@ -496,6 +496,10 @@ tf_module {
|
||||
name: "xdivy"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "xlog1py"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "xlogy"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -4944,6 +4944,10 @@ tf_module {
|
||||
name: "Xdivy"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Xlog1py"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Xlogy"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -496,6 +496,10 @@ tf_module {
|
||||
name: "xdivy"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "xlog1py"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "xlogy"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -4944,6 +4944,10 @@ tf_module {
|
||||
name: "Xdivy"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Xlog1py"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Xlogy"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user