diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index edcec281802..ae95f89e3eb 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1776,9 +1776,9 @@ absl::flat_hash_map>* GetWhitelistTable() { "Lgamma", "Digamma", // Binary "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan", - "MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd", - "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd", - "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod", + "BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 444948c4078..6276bddba82 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -241,6 +241,15 @@ class BinaryOpsTest(xla_test.XLATestCase): rtol=1e-4, atol=1e-6) + self._testBinary( + gen_math_ops.xlog1py, + np.array([0, 4, 3, 2, 1, 0], dtype=dtype), + np.array([-1, 5, 6, 7, 8, float("NaN")], dtype=dtype), + expected=np.array([0, 7.167038, 5.837730, 4.158883, 2.197225, 0], + dtype=dtype), + rtol=1e-4, + atol=1e-6) + def testIntOps(self): for dtype in self.signed_int_types: self._testBinary( diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 19c09b07959..f4a85b8da8a 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -151,6 +151,15 @@ xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y, } XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper)); +xla::XlaOp Xlog1pyImpl(xla::XlaOp x, xla::XlaOp y, + const BCast& broadcast_helper) { + auto non_zero = xla::Mul(x, xla::Log1p(y)); + auto zero = xla::ZerosLike(x); + auto x_is_zero = xla::Eq(x, zero); + return xla::Select(x_is_zero, zero, non_zero); +} +XLA_MAKE_BINARY(Xlog1py, Xlog1pyImpl(lhs, rhs, broadcast_helper)); + xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); diff --git a/tensorflow/core/api_def/base_api/api_def_Xlog1py.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xlog1py.pbtxt new file mode 100644 index 00000000000..773ab38bfdb --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Xlog1py.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "Xlog1py" + summary: "Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise." +} diff --git a/tensorflow/core/api_def/python_api/api_def_Xlog1py.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xlog1py.pbtxt new file mode 100644 index 00000000000..8d33cb940ee --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_Xlog1py.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "Xlog1py" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 80db46e3ec6..306b0c0540a 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6472,6 +6472,7 @@ filegroup( "cwise_op_tan.cc", "cwise_op_tanh.cc", "cwise_op_xlogy.cc", + "cwise_op_xlog1py.cc", "cwise_op_xdivy.cc", "data_format_ops.cc", "decode_wav_op.cc", diff --git a/tensorflow/core/kernels/cwise_op_gpu_xlog1py.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xlog1py.cu.cc new file mode 100644 index 00000000000..0838336867d --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_xlog1py.cu.cc @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +#if GOOGLE_CUDA +DEFINE_BINARY5(xlog1py, Eigen::half, float, double, complex64, complex128); +#elif TENSORFLOW_USE_ROCM +// TODO(ROCm): enable complex64 / complex128 after compiler fix. +DEFINE_BINARY3(xlog1py, Eigen::half, float, double); +#endif +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/cwise_op_xlog1py.cc b/tensorflow/core/kernels/cwise_op_xlog1py.cc new file mode 100644 index 00000000000..f00d73e3038 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_xlog1py.cc @@ -0,0 +1,41 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(BinaryOp, CPU, "Xlog1py", functor::xlog1py, float, Eigen::half, + double, complex64, complex128); + +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Xlog1py").Device(DEVICE_SYCL).TypeConstraint("T"), \ + BinaryOp>); +REGISTER_SYCL_KERNEL(Eigen::half); +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); +REGISTER_SYCL_KERNEL(complex64); +REGISTER_SYCL_KERNEL(complex128); +#undef REGISTER_SYCL_KERNEL + +#endif // TENSORFLOW_USE_SYCL + +#if GOOGLE_CUDA +REGISTER5(BinaryOp, GPU, "Xlog1py", functor::xlog1py, float, Eigen::half, + double, complex64, complex128); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 446187c4e9b..73217c01d18 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -703,6 +703,41 @@ struct functor_traits> { }; }; +template +struct xlog1py_op { + EIGEN_EMPTY_STRUCT_CTOR(xlog1py_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x, const Scalar& y) const { + if (x == Scalar(0.)) { + return Scalar(0.); + } + return x * numext::log1p(y); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, + const Packet& y) const { + Packet zeros = pzero(x); + Packet mask = pcmp_eq(x, zeros); + scalar_log1p_op log1p_op; + Packet log1p_y = log1p_op.packetOp(y); + Packet x_log1p_y = pmul(x, log1p_y); + return pselect(mask, x, x_log1p_y); + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits>::Cost + + Eigen::NumTraits::MulCost, +#if TENSORFLOW_USE_ROCM + PacketAccess = false, +#else + PacketAccess = functor_traits>::PacketAccess +#endif + }; +}; + template struct xdivy_op { EIGEN_EMPTY_STRUCT_CTOR(xdivy_op) @@ -1141,6 +1176,9 @@ struct xdivy : base> {}; template struct xlogy : base> {}; +template +struct xlog1py : base> {}; + template struct less : base, bool> {}; diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc index 4194f13261c..18f884da3c9 100644 --- a/tensorflow/core/ops/math_grad.cc +++ b/tensorflow/core/ops/math_grad.cc @@ -579,6 +579,25 @@ Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Xlogy", XlogyGrad); +Status Xlog1pyGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + FDH::Const("const", 1.0f), + {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"zeros"}, "ZerosLike", {"x"}}, + {{"yp1"}, "Add", {"y", "one"}}, + {{"is_x_zero"}, "NotEqual", {"x", "zeros"}}, + {{"is_zero_cast"}, "Cast", {"is_x_zero"}, + {{"SrcT", DT_BOOL}, {"DstT", "$T"}}}, + {{"safe_log1py"}, "Xlog1py", {"is_zero_cast", "y"}}, + {{"xlog1pygrad"}, "Xdivy", {"x", "yp1"}}, + {{"gx"}, "Mul", {"safe_log1py", "dz"}}, + {{"gy"}, "Mul", {"xlog1pygrad", "dz"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Xlog1py", Xlog1pyGrad); + Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc index a4ecdcb78b7..ef839de92c9 100644 --- a/tensorflow/core/ops/math_grad_test.cc +++ b/tensorflow/core/ops/math_grad_test.cc @@ -962,6 +962,29 @@ TEST_F(MathGradTest, Xlogy) { TensorShape({2, 1}))); } +TEST_F(MathGradTest, Xlog1py) { + auto x = test::AsTensor({0.f, 0.f, 2.f, 3.f, 4.f, 5.f}, + TensorShape({2, 3})); + auto y = test::AsTensor({.5f, 2.f}, TensorShape({2, 1})); + Tensor dx; + Tensor dy; + auto g = [](float x, float y) -> float { + return x == 0. ? 0. : std::log1p(y); + }; + auto h = [](float x, float y) -> float { + return x == 0. ? 0. : x / (y + 1.); + }; + SymGrad("Xlog1py", x, y, &dx, &dy); + test::ExpectClose( + dx, test::AsTensor({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f), + g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)}, + TensorShape({2, 3}))); + test::ExpectClose( + dy, test::AsTensor({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f), + h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)}, + TensorShape({2, 1}))); +} + TEST_F(MathGradTest, Xdivy) { auto x = test::AsTensor({0.f, 0.f, 2.f, 3.f, 4.f, 5.f}, TensorShape({2, 3})); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index d8be0b265c4..00bd2026f6a 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -522,6 +522,13 @@ REGISTER_OP("Xlogy") .Attr("T: {half, float, double, complex64, complex128}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); +REGISTER_OP("Xlog1py") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr("T: {half, float, double, complex64, complex128}") + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); + REGISTER_OP("Xdivy") .Input("x: T") .Input("y: T") diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index e6b565b75d0..61d0cb64ba4 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -691,6 +691,23 @@ def _XLogyGrad(op, grad): array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) +@ops.RegisterGradient("Xlog1py") +def _XLog1pyGrad(op, grad): + """Returns gradient of xlog1py(x, y) with respect to x and y.""" + x = op.inputs[0] + y = op.inputs[1] + sx = array_ops.shape(x) + sy = array_ops.shape(y) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + with ops.control_dependencies([grad]): + not_zero_x = math_ops.cast( + math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) + partial_x = gen_math_ops.xlog1py(not_zero_x, y) + partial_y = gen_math_ops.xdivy(x, y + 1.) + return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), + array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) + + @ops.RegisterGradient("Xdivy") def _XDivyGrad(op, grad): """Returns gradient of xdivy(x, y) with respect to x and y.""" diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py index 9079f4b9b19..4a07d2949a8 100644 --- a/tensorflow/python/ops/math_grad_test.py +++ b/tensorflow/python/ops/math_grad_test.py @@ -489,6 +489,56 @@ class XlogyTest(test.TestCase): self.assertAllClose(zero, xlogy_ygrad) +class Xlog1pyTest(test.TestCase): + + def _xlog1py_gradients(self, x, y): + xlog1py_xgrad = self.evaluate( + gradients.gradients(math_ops.xlog1py(x, y), x)[0]) + xlog1py_ygrad = self.evaluate( + gradients.gradients(math_ops.xlog1py(x, y), y)[0]) + return xlog1py_xgrad, xlog1py_ygrad + + @test_util.run_deprecated_v1 + def testNonZeroValuesGrad(self): + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + x = constant_op.constant(0.1, dtype=dtype) + y = constant_op.constant(3.1, dtype=dtype) + xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y) + xlog1py_expected_xgrad = self.evaluate(math_ops.log1p(y)) + xlog1py_expected_ygrad = self.evaluate(x / (1. + y)) + self.assertAllClose(xlog1py_expected_xgrad, xlog1py_xgrad) + self.assertAllClose(xlog1py_expected_ygrad, xlog1py_ygrad) + + @test_util.run_deprecated_v1 + def testZeroXGrad(self): + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + x = constant_op.constant(0., dtype=dtype) + y = constant_op.constant(3.1, dtype=dtype) + xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y) + zero = self.evaluate(x) + self.assertAllClose(zero, xlog1py_xgrad) + self.assertAllClose(zero, xlog1py_ygrad) + + @test_util.run_deprecated_v1 + def testNegOneYGrad(self): + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + x = constant_op.constant(0.1, dtype=dtype) + y = constant_op.constant(-1., dtype=dtype) + xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y) + self.assertAllClose(-np.inf, xlog1py_xgrad) + self.assertAllClose(np.inf, xlog1py_ygrad) + + @test_util.run_deprecated_v1 + def testZeroXNegOneYGrad(self): + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + x = constant_op.constant(0., dtype=dtype) + y = constant_op.constant(-1., dtype=dtype) + xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y) + zero = self.evaluate(x) + self.assertAllClose(zero, xlog1py_xgrad) + self.assertAllClose(zero, xlog1py_ygrad) + + class XdivyTest(test.TestCase): def _xdivy_gradients(self, x, y): diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 4b6d3300212..360bf2b91dd 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -4290,6 +4290,43 @@ def reciprocal_no_nan(x, name=None): return gen_math_ops.div_no_nan(one, x, name=scope) +@tf_export("math.xlog1py") +@dispatch.add_dispatch_support +def xlog1py(x, y, name=None): + r"""Compute x * log1p(y). + + Given `x` and `y`, compute `x * log1p(y)`. This function safely returns + zero when `x = 0`, no matter what the value of `y` is. + + Example: + + >>> tf.math.xlog1py(0., 1.) + + >>> tf.math.xlog1py(1., 1.) + + >>> tf.math.xlog1py(2., 2.) + + >>> tf.math.xlog1py(0., -1.) + + + Args: + x: A `tf.Tensor` of type `bfloat16`, `half`, `float32`, `float64`, + `complex64`, `complex128` + y: A `tf.Tensor` of type `bfloat16`, `half`, `float32`, `float64`, + `complex64`, `complex128` + name: A name for the operation (optional). + + Returns: + `x * log1p(y)`. + + @compatibility(scipy) + Equivalent to scipy.special.xlog1py + @end_compatibility + """ + with ops.name_scope(name, "xlog1py", [x]): + return gen_math_ops.xlog1py(x, y) + + @tf_export("math.erfinv") @dispatch.add_dispatch_support def erfinv(x, name=None): diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 37669bfab8f..f5289e59459 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -560,6 +560,40 @@ class XlogyTest(test_util.TensorFlowTestCase): self.assertAllClose(xtimes_logy, xlogy_tf_np[1]) +@test_util.run_all_in_graph_and_eager_modes +class Xlog1pyTest(test_util.TensorFlowTestCase): + + def testXlog1pyNoNeg1(self): + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype) + y = constant_op.constant([[-0.1, -0.2, 3.5], [3.1, -0.9, 2.]], + dtype=dtype) + with test_util.use_gpu(): + xlog1py = self.evaluate(math_ops.xlog1py(x, y)) + xtimeslog1py = self.evaluate(x * math_ops.log1p(y)) + self.assertAllClose(xlog1py, xtimeslog1py) + + def testXlog1pyWithNegOne(self): + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + x = constant_op.constant(np.zeros((2, 3)), dtype=dtype) + y = constant_op.constant([[0.1, 0.2, 3.5], [-1., 1., 2.]], dtype=dtype) + with test_util.use_gpu(): + xlog1py_tf_np = self.evaluate(math_ops.xlog1py(x, y)) + zeros_np = self.evaluate(array_ops.zeros_like(y)) + self.assertAllClose(xlog1py_tf_np, zeros_np) + + def testXlog1pyWithZeroBroadcast(self): + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + x = constant_op.constant([[0.], [1.]], dtype=dtype) + y = constant_op.constant([[-0.1, -0.2, -1.], [0., 1., 2.]], dtype=dtype) + with test_util.use_gpu(): + xlog1py_tf_np = self.evaluate(math_ops.xlog1py(x, y)) + zeros_np = self.evaluate(array_ops.zeros_like(y[0])) + xtimes_log1py = self.evaluate(math_ops.log1p(y[1])) + self.assertAllClose(zeros_np, xlog1py_tf_np[0]) + self.assertAllClose(xtimes_log1py, xlog1py_tf_np[1]) + + @test_util.run_all_in_graph_and_eager_modes class XdivyTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index c1965a8a0fd..c6caf2b7f17 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -2564,6 +2564,7 @@ def _convert_cast(pfor_input): @RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod) @RegisterPForWithArgs("Xdivy", math_ops.xdivy) @RegisterPForWithArgs("Xlogy", math_ops.xlogy) +@RegisterPForWithArgs("Xlog1py", math_ops.xlog1py) @RegisterPForWithArgs("Zeta", math_ops.zeta) def _convert_cwise(pfor_input, op_type, op_func): # Note that ops handled here do not have attributes except those listed below diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt index c24b1c38179..e4ab4e8f88a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt @@ -496,6 +496,10 @@ tf_module { name: "xdivy" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "xlog1py" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "xlogy" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 386848c1e2f..dc4552d62aa 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -4944,6 +4944,10 @@ tf_module { name: "Xdivy" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "Xlog1py" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Xlogy" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt index 33828112832..d68ca9759d4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt @@ -496,6 +496,10 @@ tf_module { name: "xdivy" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "xlog1py" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "xlogy" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 386848c1e2f..dc4552d62aa 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -4944,6 +4944,10 @@ tf_module { name: "Xdivy" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "Xlog1py" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Xlogy" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "