Add tf.math.ndtri and tf.math.erfinv.

PiperOrigin-RevId: 272571083
This commit is contained in:
Srinivas Vasudevan 2019-10-02 19:14:14 -07:00 committed by TensorFlower Gardener
parent 7206485040
commit bbe62e94dd
19 changed files with 327 additions and 2 deletions

View File

@ -767,6 +767,40 @@ Status ErfGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Erf", ErfGrad);
Status ErfinvGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto grad = grad_inputs[0];
auto root_pi_over_two =
Cast(scope, Const(scope, std::sqrt(M_PI) / 2), grad.type());
Scope grad_scope = scope.WithControlDependencies(grad);
auto x = ConjugateHelper(grad_scope, op.input(0));
// grad * sqrt(pi) / 2 * exp(erfinv(x) ** 2)
auto dx = Mul(grad_scope, Mul(grad_scope, grad, root_pi_over_two),
Exp(grad_scope, Square(grad_scope, op.output(0))));
grad_outputs->push_back(dx);
return grad_scope.status();
}
REGISTER_GRADIENT_OP("Erfinv", ErfinvGrad);
Status NdtriGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto grad = grad_inputs[0];
auto root_two_pi =
Cast(scope, Const(scope, std::sqrt(2 * M_PI)), grad.type());
auto two = Cast(scope, Const(scope, 2), grad.type());
Scope grad_scope = scope.WithControlDependencies(grad);
auto x = ConjugateHelper(grad_scope, op.input(0));
// grad * sqrt(2 * pi) * exp(ndtri(x) ** 2 / 2)
auto dx = Mul(
grad_scope, Mul(grad_scope, grad, root_two_pi),
Exp(grad_scope, Div(grad_scope, Square(grad_scope, op.output(0)), two)));
grad_outputs->push_back(dx);
return grad_scope.status();
}
REGISTER_GRADIENT_OP("Ndtri", NdtriGrad);
Status LgammaGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {

View File

@ -89,7 +89,9 @@ class CWiseUnaryGradTest : public ::testing::Test {
COMPLEX,
ANGLE,
LGAMMA,
ERF
ERF,
ERFINV,
NDTRI
};
template <typename X_T, typename Y_T>
@ -200,6 +202,12 @@ class CWiseUnaryGradTest : public ::testing::Test {
case ERF:
y = Erf(scope_, x);
break;
case ERFINV:
y = Erfinv(scope_, x);
break;
case NDTRI:
y = Ndtri(scope_, x);
break;
}
float max_error;
@ -567,6 +575,20 @@ TEST_F(CWiseUnaryGradTest, Erf_Complex) {
}
}
TEST_F(CWiseUnaryGradTest, Ndtri) {
auto x_fn = [this](const int i) {
return RV({0.1, 0.2, 0.3, 0.5, 0.7, 0.9});
};
TestCWiseGrad<float, float>(NDTRI, x_fn);
}
TEST_F(CWiseUnaryGradTest, Erfinv) {
auto x_fn = [this](const int i) {
return RV({-0.9, -0.3, -0.1, 0.2, 0.6, 0.8});
};
TestCWiseGrad<float, float>(ERFINV, x_fn);
}
class MathGradTest : public ::testing::Test {
protected:
MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}

View File

@ -0,0 +1,3 @@
op {
graph_op_name: "Erfinv"
}

View File

@ -0,0 +1,3 @@
op {
graph_op_name: "Ndtri"
}

View File

@ -0,0 +1,32 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
namespace tensorflow {
REGISTER2(UnaryOp, CPU, "Ndtri", functor::ndtri, float, double);
REGISTER2(UnaryOp, CPU, "Erfinv", functor::erfinv, float, double);
REGISTER2(SimpleBinaryOp, CPU, "NdtriGrad", functor::ndtri_grad, float, double);
REGISTER2(SimpleBinaryOp, CPU, "ErfinvGrad", functor::erfinv_grad, float,
double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER2(UnaryOp, GPU, "Ndtri", functor::ndtri, float, double);
REGISTER2(UnaryOp, GPU, "Erfinv", functor::erfinv, float, double);
REGISTER2(SimpleBinaryOp, GPU, "NdtriGrad", functor::ndtri_grad, float, double);
REGISTER2(SimpleBinaryOp, GPU, "ErfinvGrad", functor::erfinv_grad, float,
double);
#endif
} // namespace tensorflow

View File

@ -0,0 +1,30 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_UNARY1(ndtri, double);
DEFINE_UNARY1(erfinv, double);
DEFINE_SIMPLE_BINARY1(ndtri_grad, double);
DEFINE_SIMPLE_BINARY1(erfinv_grad, double);
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -0,0 +1,30 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_UNARY1(ndtri, float);
DEFINE_UNARY1(erfinv, float);
DEFINE_SIMPLE_BINARY1(ndtri_grad, float);
DEFINE_SIMPLE_BINARY1(erfinv_grad, float);
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -728,6 +728,33 @@ struct functor_traits<xdivy_op<Scalar>> {
};
};
template <typename T>
struct scalar_erfinv_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_erfinv_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
T y = numext::ndtri((x + static_cast<T>(1.)) / static_cast<T>(2.));
return y / static_cast<T>(numext::sqrt(2.));
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
packetOp(const Packet& x) const {
Packet y = pndtri<Packet>(pmadd(pset1<Packet>(0.5), x, pset1<Packet>(0.5)));
return pdiv(y, psqrt(pset1<Packet>(2.)));
}
};
template <typename T>
struct functor_traits<scalar_erfinv_op<T>> {
enum {
Cost = functor_traits<scalar_ndtri_op<T>>::Cost + NumTraits<T>::AddCost,
#if TENSORFLOW_USE_ROCM
PacketAccess = false,
#else
PacketAccess = packet_traits<T>::HasNdtri,
#endif
};
};
} // end namespace internal
} // end namespace Eigen
@ -873,6 +900,12 @@ struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {};
template <typename T>
struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {};
template <typename T>
struct ndtri : base<T, Eigen::internal::scalar_ndtri_op<T>> {};
template <typename T>
struct erfinv : base<T, Eigen::internal::scalar_erfinv_op<T>> {};
template <typename T>
struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {};

View File

@ -174,6 +174,79 @@ struct functor_traits<scalar_rsqrt_gradient_op<T>> {
};
};
#define SQRT_PI 1.772453850905516027298167483341
#define SQRT_2PI 2.506628274631000502415765284811
// Gradient for the erfinv function
template <typename T>
struct scalar_erfinv_gradient_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_erfinv_gradient_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
operator()(const T& output, const T& output_gradient) const {
// grad * sqrt(pi) / 2 * exp(erfinv(x) ** 2)
if (output_gradient == T(0)) {
return T(0);
} else {
return static_cast<T>(0.5 * SQRT_PI) * numext::exp(output * output) *
output_gradient;
}
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
packetOp(const Packet& output, const Packet& output_gradient) const {
const Packet const_half_sqrt_pi =
pset1<Packet>(static_cast<T>(0.5 * SQRT_PI));
return pmul(pmul(const_half_sqrt_pi, pexp(pmul(output, output))),
output_gradient);
}
};
template <typename T>
struct functor_traits<scalar_erfinv_gradient_op<T>> {
enum {
#if TENSORFLOW_USE_ROCM
PacketAccess = false,
#else
PacketAccess = packet_traits<T>::HasNdtri,
#endif
Cost = 10 * NumTraits<T>::MulCost + scalar_div_cost<T, PacketAccess>::value,
};
};
// Gradient for the ndtri function
template <typename T>
struct scalar_ndtri_gradient_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_ndtri_gradient_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
// grad * sqrt(2 * pi) * exp(ndtri(x) ** 2 / 2)
operator()(const T& output, const T& output_gradient) const {
if (output_gradient == T(0)) {
return T(0);
} else {
return static_cast<T>(SQRT_2PI) * numext::exp(output * output / 2.) *
output_gradient;
}
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
packetOp(const Packet& output, const Packet& output_gradient) const {
const Packet const_sqrt_two_pi = pset1<Packet>(static_cast<T>(SQRT_2PI));
return pmul(pmul(const_sqrt_two_pi,
pexp(pdiv(pmul(output, output), pset1<Packet>(2.)))),
output_gradient);
}
};
template <typename T>
struct functor_traits<scalar_ndtri_gradient_op<T>> {
enum {
#if TENSORFLOW_USE_ROCM
PacketAccess = false,
#else
PacketAccess = packet_traits<T>::HasNdtri,
#endif
Cost = 10 * NumTraits<T>::MulCost + scalar_div_cost<T, PacketAccess>::value,
};
};
} // end namespace internal
} // end namespace Eigen
@ -234,6 +307,12 @@ struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {};
template <typename T>
struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {};
template <typename T>
struct erfinv_grad : base<T, Eigen::internal::scalar_erfinv_gradient_op<T>> {};
template <typename T>
struct ndtri_grad : base<T, Eigen::internal::scalar_ndtri_gradient_op<T>> {};
} // end namespace functor
} // end namespace tensorflow

View File

@ -281,7 +281,8 @@ REGISTER_OP("Lgamma").UNARY_REAL();
REGISTER_OP("Digamma").UNARY_REAL();
REGISTER_OP("Erf").UNARY_REAL();
REGISTER_OP("Erfinv").UNARY_REAL();
REGISTER_OP("Ndtri").UNARY_REAL();
REGISTER_OP("Erfc").UNARY_REAL();
REGISTER_OP("Sigmoid").UNARY_COMPLEX();

View File

@ -804,6 +804,24 @@ def _ErfcGrad(op, grad):
return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x))
@ops.RegisterGradient("Erfinv")
def _ErfinvGrad(op, grad):
"""Returns grad * sqrt(pi) / 2 * exp(erfinv(x)**2)."""
root_pi_over_two = constant_op.constant(np.sqrt(np.pi) / 2, dtype=grad.dtype)
with ops.control_dependencies([grad]):
return grad * root_pi_over_two * math_ops.exp(
math_ops.square(op.outputs[0]))
@ops.RegisterGradient("Ndtri")
def _NdtriGrad(op, grad):
"""Returns grad * sqrt(2 * pi) * exp(ndtri(x)**2 / 2)."""
root_two_pi = constant_op.constant(np.sqrt(2 * np.pi), dtype=grad.dtype)
with ops.control_dependencies([grad]):
return grad * root_two_pi * math_ops.exp(
math_ops.square(op.outputs[0]) / 2.)
@ops.RegisterGradient("Lgamma")
def _LgammaGrad(op, grad):
"""Returns grad * digamma(x)."""

View File

@ -99,6 +99,7 @@ class MathTest(PForTestCase, parameterized.TestCase):
math_ops.digamma,
math_ops.erf,
math_ops.erfc,
math_ops.erfinv,
math_ops.exp,
math_ops.expm1,
math_ops.inv,
@ -107,6 +108,7 @@ class MathTest(PForTestCase, parameterized.TestCase):
math_ops.lgamma,
math_ops.log,
math_ops.log1p,
math_ops.ndtri,
]
self._test_unary_cwise_ops(real_ops, False)

View File

@ -2434,6 +2434,7 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Elu", nn_ops.elu)
@RegisterPForWithArgs("Erf", math_ops.erf)
@RegisterPForWithArgs("Erfc", math_ops.erfc)
@RegisterPForWithArgs("Erfinv", math_ops.erfinv)
@RegisterPForWithArgs("Exp", math_ops.exp)
@RegisterPForWithArgs("Expm1", math_ops.expm1)
@RegisterPForWithArgs("Floor", math_ops.floor)
@ -2465,6 +2466,7 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Mod", math_ops.mod)
@RegisterPForWithArgs("Mul", math_ops.multiply)
@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan)
@RegisterPForWithArgs("Ndtri", math_ops.ndtri)
@RegisterPForWithArgs("Neg", math_ops.negative)
@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
@RegisterPForWithArgs("Pow", math_ops.pow)

View File

@ -303,6 +303,7 @@ _UNARY_ELEMENTWISE_OPS = [
math_ops.digamma,
math_ops.erf,
math_ops.erfc,
math_ops.erfinv,
math_ops.exp,
math_ops.expm1,
math_ops.floor,
@ -315,6 +316,7 @@ _UNARY_ELEMENTWISE_OPS = [
math_ops.log1p,
math_ops.log_sigmoid,
math_ops.logical_not,
math_ops.ndtri,
math_ops.negative,
math_ops.real,
math_ops.reciprocal,

View File

@ -58,6 +58,7 @@ UNARY_FLOAT_OPS = [
math_ops.digamma,
math_ops.erf,
math_ops.erfc,
math_ops.erfinv,
math_ops.exp,
math_ops.expm1,
math_ops.floor,
@ -69,6 +70,7 @@ UNARY_FLOAT_OPS = [
math_ops.log,
math_ops.log1p,
math_ops.log_sigmoid,
math_ops.ndtri,
math_ops.negative,
math_ops.real,
math_ops.reciprocal,

View File

@ -1236,6 +1236,10 @@ tf_module {
name: "erfc"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "erfinv"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "executing_eagerly"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
@ -1712,6 +1716,10 @@ tf_module {
name: "multiply"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "ndtri"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "negative"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1184,6 +1184,10 @@ tf_module {
name: "Erfc"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Erfinv"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "EuclideanNorm"
argspec: "args=[\'input\', \'axis\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
@ -2348,6 +2352,10 @@ tf_module {
name: "NcclReduce"
argspec: "args=[\'input\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Ndtri"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Neg"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -612,6 +612,10 @@ tf_module {
name: "equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "erfinv"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "executing_eagerly"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
@ -796,6 +800,10 @@ tf_module {
name: "multiply"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "ndtri"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "negative"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1184,6 +1184,10 @@ tf_module {
name: "Erfc"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Erfinv"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "EuclideanNorm"
argspec: "args=[\'input\', \'axis\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
@ -2348,6 +2352,10 @@ tf_module {
name: "NcclReduce"
argspec: "args=[\'input\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Ndtri"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Neg"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "