Add Bessel functions to the public api:

- tf.math.special.bessel_i0
    - tf.math.special.bessel_i0e
    - tf.math.special.bessel_i1
    - tf.math.special.bessel_i1e
    - tf.math.special.bessel_k0
    - tf.math.special.bessel_k0e
    - tf.math.special.bessel_k1
    - tf.math.special.bessel_k1e
    - tf.math.special.bessel_j0
    - tf.math.special.bessel_j1
    - tf.math.special.bessel_y0
    - tf.math.special.bessel_y1

PiperOrigin-RevId: 317025879
Change-Id: I5c4407eda6bef0d1659b7a566979c7dbbad4ad83
This commit is contained in:
Srinivas Vasudevan 2020-06-17 21:01:24 -07:00 committed by TensorFlower Gardener
parent 56c01dc970
commit 4f341bb742
36 changed files with 1028 additions and 145 deletions

View File

@ -37,6 +37,7 @@ from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
# ops include:
@ -103,8 +104,8 @@ sign = _unary_op(math_ops.sign)
tanh = _unary_op(math_ops.tanh)
# Bessel
bessel_i0e = _unary_op(math_ops.bessel_i0e)
bessel_i1e = _unary_op(math_ops.bessel_i1e)
bessel_i0e = _unary_op(special_math_ops.bessel_i0e)
bessel_i1e = _unary_op(special_math_ops.bessel_i1e)
# Binary operators

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BesselI0"
visibility: HIDDEN
}

View File

@ -1,10 +1,4 @@
op {
graph_op_name: "BesselI0e"
summary: "Computes the Bessel i0e function of `x` element-wise."
description: <<END
Exponentially scaled modified Bessel function of order 0 defined as
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
This function is faster and numerically stabler than `bessel_i0(x)`.
END
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BesselI1"
visibility: HIDDEN
}

View File

@ -1,10 +1,4 @@
op {
graph_op_name: "BesselI1e"
summary: "Computes the Bessel i1e function of `x` element-wise."
description: <<END
Exponentially scaled modified Bessel function of order 0 defined as
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
This function is faster and numerically stabler than `bessel_i1(x)`.
END
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BesselJ0"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BesselJ1"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BesselK0"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BesselK0e"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BesselK1"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BesselK1e"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BesselY0"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BesselY1"
visibility: HIDDEN
}

View File

@ -1,6 +0,0 @@
op {
graph_op_name: "BesselI0e"
endpoint {
name: "math.bessel_i0e"
}
}

View File

@ -1,6 +0,0 @@
op {
graph_op_name: "BesselI1e"
endpoint {
name: "math.bessel_i1e"
}
}

View File

@ -1,29 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
REGISTER3(UnaryOp, CPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER3(UnaryOp, GPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
double);
#endif
} // namespace tensorflow

View File

@ -943,12 +943,6 @@ struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
template <typename T>
struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
template <typename T>
struct bessel_i0e : base<T, Eigen::internal::scalar_bessel_i0e_op<T>> {};
template <typename T>
struct bessel_i1e : base<T, Eigen::internal::scalar_bessel_i1e_op<T>> {};
struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
};

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
#include "tensorflow/core/kernels/special_math/special_math_op_misc_impl.h"
namespace tensorflow {
REGISTER3(UnaryOp, CPU, "BesselI0", functor::bessel_i0, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselI1", functor::bessel_i1, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselK0", functor::bessel_k0, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselK1", functor::bessel_k1, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselK0e", functor::bessel_k0e, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselK1e", functor::bessel_k1e, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselJ0", functor::bessel_j0, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselJ1", functor::bessel_j1, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselY0", functor::bessel_y0, Eigen::half, float,
double);
REGISTER3(UnaryOp, CPU, "BesselY1", functor::bessel_y1, Eigen::half, float,
double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER3(UnaryOp, GPU, "BesselI0", functor::bessel_i0, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselI1", functor::bessel_i1, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselK0", functor::bessel_k0, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselK1", functor::bessel_k1, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselK0e", functor::bessel_k0e, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselK1e", functor::bessel_k1e, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselJ0", functor::bessel_j0, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselJ1", functor::bessel_j1, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselY0", functor::bessel_y0, Eigen::half, float,
double);
REGISTER3(UnaryOp, GPU, "BesselY1", functor::bessel_y1, Eigen::half, float,
double);
#endif
} // namespace tensorflow

View File

@ -1,4 +1,4 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,11 +16,25 @@ limitations under the License.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
#include "tensorflow/core/kernels/special_math/special_math_op_misc_impl.h"
namespace tensorflow {
namespace functor {
DEFINE_UNARY3(bessel_i0, Eigen::half, float, double);
DEFINE_UNARY3(bessel_i1, Eigen::half, float, double);
DEFINE_UNARY3(bessel_i0e, Eigen::half, float, double);
DEFINE_UNARY3(bessel_i1e, Eigen::half, float, double);
DEFINE_UNARY3(bessel_k0, Eigen::half, float, double);
DEFINE_UNARY3(bessel_k1, Eigen::half, float, double);
DEFINE_UNARY3(bessel_k0e, Eigen::half, float, double);
DEFINE_UNARY3(bessel_k1e, Eigen::half, float, double);
DEFINE_UNARY3(bessel_j0, Eigen::half, float, double);
DEFINE_UNARY3(bessel_j1, Eigen::half, float, double);
DEFINE_UNARY3(bessel_y0, Eigen::half, float, double);
DEFINE_UNARY3(bessel_y1, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow

View File

@ -685,6 +685,44 @@ struct fresnel_sin : base<T, Eigen::internal::fresnel_sin_op<T>> {};
template <typename T>
struct spence : base<T, Eigen::internal::spence_op<T>> {};
// Bessel Functions
template <typename T>
struct bessel_i0 : base<T, Eigen::internal::scalar_bessel_i0_op<T>> {};
template <typename T>
struct bessel_i0e : base<T, Eigen::internal::scalar_bessel_i0e_op<T>> {};
template <typename T>
struct bessel_i1 : base<T, Eigen::internal::scalar_bessel_i1_op<T>> {};
template <typename T>
struct bessel_i1e : base<T, Eigen::internal::scalar_bessel_i1e_op<T>> {};
template <typename T>
struct bessel_k0 : base<T, Eigen::internal::scalar_bessel_k0_op<T>> {};
template <typename T>
struct bessel_k0e : base<T, Eigen::internal::scalar_bessel_k0e_op<T>> {};
template <typename T>
struct bessel_k1 : base<T, Eigen::internal::scalar_bessel_k1_op<T>> {};
template <typename T>
struct bessel_k1e : base<T, Eigen::internal::scalar_bessel_k1e_op<T>> {};
template <typename T>
struct bessel_j0 : base<T, Eigen::internal::scalar_bessel_j0_op<T>> {};
template <typename T>
struct bessel_j1 : base<T, Eigen::internal::scalar_bessel_j1_op<T>> {};
template <typename T>
struct bessel_y0 : base<T, Eigen::internal::scalar_bessel_y0_op<T>> {};
template <typename T>
struct bessel_y1 : base<T, Eigen::internal::scalar_bessel_y1_op<T>> {};
} // end namespace functor
} // end namespace tensorflow

View File

@ -297,10 +297,6 @@ REGISTER_OP("Acos").UNARY();
REGISTER_OP("Atan").UNARY();
REGISTER_OP("BesselI0e").UNARY_REAL();
REGISTER_OP("BesselI1e").UNARY_REAL();
REGISTER_OP("_UnaryOpsComposition")
.Input("x: T")
.Output("y: T")

View File

@ -20,34 +20,33 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("Dawsn")
.Input("x: T")
.Output("y: T")
.Attr("T: {float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
#define UNARY_REAL() \
Input("x: T") \
.Output("y: T") \
.Attr("T: {bfloat16, half, float, double}") \
.SetShapeFn(shape_inference::UnchangedShape)
REGISTER_OP("Expint")
.Input("x: T")
.Output("y: T")
.Attr("T: {float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("Dawsn").UNARY_REAL();
REGISTER_OP("Expint").UNARY_REAL();
REGISTER_OP("FresnelCos").UNARY_REAL();
REGISTER_OP("FresnelSin").UNARY_REAL();
REGISTER_OP("Spence").UNARY_REAL();
REGISTER_OP("FresnelCos")
.Input("x: T")
.Output("y: T")
.Attr("T: {float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
// Bessel functions
REGISTER_OP("FresnelSin")
.Input("x: T")
.Output("y: T")
.Attr("T: {float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("BesselI0").UNARY_REAL();
REGISTER_OP("BesselI1").UNARY_REAL();
REGISTER_OP("BesselI0e").UNARY_REAL();
REGISTER_OP("BesselI1e").UNARY_REAL();
REGISTER_OP("Spence")
.Input("x: T")
.Output("y: T")
.Attr("T: {float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("BesselK0").UNARY_REAL();
REGISTER_OP("BesselK1").UNARY_REAL();
REGISTER_OP("BesselK0e").UNARY_REAL();
REGISTER_OP("BesselK1e").UNARY_REAL();
REGISTER_OP("BesselJ0").UNARY_REAL();
REGISTER_OP("BesselJ1").UNARY_REAL();
REGISTER_OP("BesselY0").UNARY_REAL();
REGISTER_OP("BesselY1").UNARY_REAL();
} // namespace tensorflow

View File

@ -44,6 +44,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.platform import test
@ -90,8 +91,9 @@ def _unary_real_test_combinations():
("Asinh", math_ops.asinh),
("Atan", math_ops.atan),
("Atanh", math_ops.atanh),
("BesselI0e", math_ops.bessel_i0e),
("BesselI1e", math_ops.bessel_i1e),
# TODO(b/157272291): Add testing for more special functions.
("BesselI0e", special_math_ops.bessel_i0e),
("BesselI1e", special_math_ops.bessel_i1e),
("Ceil", math_ops.ceil),
("Cos", math_ops.cos),
("Cosh", math_ops.cosh),

View File

@ -411,7 +411,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 461> a = {{
static std::array<OpIndexInfo, 465> a = {{
{"Abs"},
{"AccumulateNV2"},
{"Acos"},
@ -443,6 +443,10 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
{"BatchNormWithGlobalNormalization"},
{"BatchToSpace"},
{"BatchToSpaceND"},
{"BesselI0"},
{"BesselJ0"},
{"BesselK0"},
{"BesselY0"},
{"Betainc"},
{"BiasAdd"},
{"BiasAddGrad"},

View File

@ -33,6 +33,7 @@ from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import special_math_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@ -228,8 +229,8 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
self._compareBoth(x, special.i0e, special_math_ops.bessel_i0e)
self._compareBoth(x, special.i1e, special_math_ops.bessel_i1e)
except ImportError as e:
tf_logging.warn("Cannot test special functions: %s" % str(e))
@ -281,8 +282,8 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(x, np.arctan, math_ops.atan)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
self._compareBoth(x, special.i0e, special_math_ops.bessel_i0e)
self._compareBoth(x, special.i1e, special_math_ops.bessel_i1e)
except ImportError as e:
tf_logging.warn("Cannot test special functions: %s" % str(e))
@ -335,8 +336,8 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(k, np.tan, math_ops.tan)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
self._compareBoth(x, special.i0e, special_math_ops.bessel_i0e)
self._compareBoth(x, special.i1e, special_math_ops.bessel_i1e)
except ImportError as e:
tf_logging.warn("Cannot test special functions: %s" % str(e))
@ -375,13 +376,6 @@ class UnaryOpTest(test.TestCase):
math_ops.lgamma)
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
except ImportError as e:
tf_logging.warn("Cannot test special functions: %s" % str(e))
self._compareBothSparse(x, np.abs, math_ops.abs)
self._compareBothSparse(x, np.negative, math_ops.negative)
self._compareBothSparse(x, np.square, math_ops.square)

View File

@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
def _safe_shape_div(x, y):
@ -875,16 +876,42 @@ def _SpenceGrad(op, grad):
return grad * partial_x
@ops.RegisterGradient("BesselI0")
def _BesselI0Grad(op, grad):
"""Compute gradient of bessel_i0(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
partial_x = special_math_ops.bessel_i1(x)
return grad * partial_x
@ops.RegisterGradient("BesselI0e")
def _BesselI0eGrad(op, grad):
"""Compute gradient of bessel_i0e(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
partial_x = (math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
partial_x = (special_math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
return grad * partial_x
@ops.RegisterGradient("BesselI1")
def _BesselI1Grad(op, grad):
"""Compute gradient of bessel_i1(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# For x = 0, the correct gradient is 1.0.
# However, the main branch gives NaN because of the division by x, so
# we impute the gradient manually.
# An alternative solution is to express the gradient via bessel_i0 and
# bessel_i2, but the latter is not yet implemented in Eigen.
dy_dx = array_ops.where_v2(
math_ops.equal(x, 0.), math_ops.cast(1., x.dtype),
special_math_ops.bessel_i0(x) - math_ops.div(y, x))
return grad * dy_dx
@ops.RegisterGradient("BesselI1e")
def _BesselI1eGrad(op, grad):
"""Compute gradient of bessel_i1e(x) with respect to its argument."""
@ -896,16 +923,104 @@ def _BesselI1eGrad(op, grad):
# we impute the gradient manually.
# An alternative solution is to express the gradient via bessel_i0e and
# bessel_i2e, but the latter is not yet implemented in Eigen.
eps = np.finfo(x.dtype.as_numpy_dtype).eps
zeros = array_ops.zeros_like(x)
x_is_not_tiny = math_ops.abs(x) > eps
safe_x = array_ops.where_v2(x_is_not_tiny, x, eps + zeros)
dy_dx = math_ops.bessel_i0e(safe_x) - y * (
math_ops.sign(safe_x) + math_ops.reciprocal(safe_x))
dy_dx = array_ops.where_v2(x_is_not_tiny, dy_dx, 0.5 + zeros)
dy_dx = array_ops.where_v2(
math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype),
special_math_ops.bessel_i0e(x) - y *
(math_ops.sign(x) + math_ops.reciprocal(x)))
return grad * dy_dx
@ops.RegisterGradient("BesselK0")
def _BesselK0Grad(op, grad):
"""Compute gradient of bessel_k0(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
partial_x = -special_math_ops.bessel_k1(x)
return grad * partial_x
@ops.RegisterGradient("BesselK0e")
def _BesselK0eGrad(op, grad):
"""Compute gradient of bessel_k0e(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
partial_x = (y - special_math_ops.bessel_k1e(x))
return grad * partial_x
@ops.RegisterGradient("BesselK1")
def _BesselK1Grad(op, grad):
"""Compute gradient of bessel_k1(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# At 0., this is NaN which is fine since the derivative is undefined
# at 0.
partial_x = -special_math_ops.bessel_k0(x) - math_ops.div(y, x)
return grad * partial_x
@ops.RegisterGradient("BesselK1e")
def _BesselK1eGrad(op, grad):
"""Compute gradient of bessel_k1e(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# At 0., this is NaN which is fine since the derivative is undefined
# at 0.
partial_x = (
y * (1. - math_ops.reciprocal(x)) - special_math_ops.bessel_k0e(x))
return grad * partial_x
@ops.RegisterGradient("BesselJ0")
def _BesselJ0Grad(op, grad):
"""Compute gradient of bessel_j0(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
partial_x = -special_math_ops.bessel_j1(x)
return grad * partial_x
@ops.RegisterGradient("BesselJ1")
def _BesselJ1Grad(op, grad):
"""Compute gradient of bessel_j1(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# For x = 0, the correct gradient is 0.5.
# However, the main branch gives NaN because of the division by x, so
# we impute the gradient manually.
# An alternative solution is to express the gradient via bessel_i0e and
# bessel_i2e, but the latter is not yet implemented in Eigen.
dy_dx = array_ops.where_v2(
math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype),
special_math_ops.bessel_j0(x) - math_ops.div(y, x))
return grad * dy_dx
@ops.RegisterGradient("BesselY0")
def _BesselY0Grad(op, grad):
"""Compute gradient of bessel_y0(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
partial_x = -special_math_ops.bessel_y1(x)
return grad * partial_x
@ops.RegisterGradient("BesselY1")
def _BesselY1Grad(op, grad):
"""Compute gradient of bessel_y1(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# At 0., this is NaN which is fine since the derivative is undefined
# at 0.
partial_x = special_math_ops.bessel_y0(x) - math_ops.div(y, x)
return grad * partial_x
@ops.RegisterGradient("Igamma")
def _IgammaGrad(op, grad):
"""Returns gradient of igamma(a, x) with respect to a and x."""

View File

@ -90,8 +90,6 @@ class MathTest(PForTestCase, parameterized.TestCase):
math_ops.asinh,
math_ops.atan,
math_ops.atanh,
math_ops.bessel_i0e,
math_ops.bessel_i1e,
math_ops.cos,
math_ops.cosh,
math_ops.digamma,
@ -107,6 +105,8 @@ class MathTest(PForTestCase, parameterized.TestCase):
math_ops.log,
math_ops.log1p,
math_ops.ndtri,
special_math_ops.bessel_i0e,
special_math_ops.bessel_i1e,
]
self._test_unary_cwise_ops(real_ops, False)

View File

@ -2703,8 +2703,18 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Atan", math_ops.atan)
@RegisterPForWithArgs("Atan2", math_ops.atan2)
@RegisterPForWithArgs("Atanh", math_ops.atanh)
@RegisterPForWithArgs("BesselI0e", math_ops.bessel_i0e)
@RegisterPForWithArgs("BesselI1e", math_ops.bessel_i1e)
@RegisterPForWithArgs("BesselI0", special_math_ops.bessel_i0)
@RegisterPForWithArgs("BesselI1", special_math_ops.bessel_i1)
@RegisterPForWithArgs("BesselI0e", special_math_ops.bessel_i0e)
@RegisterPForWithArgs("BesselI1e", special_math_ops.bessel_i1e)
@RegisterPForWithArgs("BesselK0", special_math_ops.bessel_k0)
@RegisterPForWithArgs("BesselK1", special_math_ops.bessel_k1)
@RegisterPForWithArgs("BesselK0e", special_math_ops.bessel_k0e)
@RegisterPForWithArgs("BesselK1e", special_math_ops.bessel_k1e)
@RegisterPForWithArgs("BesselJ0", special_math_ops.bessel_j0)
@RegisterPForWithArgs("BesselJ1", special_math_ops.bessel_j1)
@RegisterPForWithArgs("BesselY0", special_math_ops.bessel_y0)
@RegisterPForWithArgs("BesselY1", special_math_ops.bessel_y1)
@RegisterPForWithArgs("BitwiseAnd", bitwise_ops.bitwise_and)
@RegisterPForWithArgs("BitwiseOr", bitwise_ops.bitwise_or)
@RegisterPForWithArgs("BitwiseXor", bitwise_ops.bitwise_xor)

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -87,8 +88,8 @@ def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None):
halflen_float = math_ops.cast(halflen_float, dtype=dtype)
num = beta * math_ops.sqrt(
one - math_ops.pow(arg, two) / math_ops.pow(halflen_float, two))
window = math_ops.exp(num - beta) * (math_ops.bessel_i0e(num) /
math_ops.bessel_i0e(beta))
window = math_ops.exp(num - beta) * (
special_math_ops.bessel_i0e(num) / special_math_ops.bessel_i0e(beta))
return window

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_sparse_ops import *
@ -2928,8 +2929,9 @@ _UNARY_OPS = [
math_ops.sqrt,
math_ops.erf,
math_ops.tanh,
math_ops.bessel_i0e,
math_ops.bessel_i1e,
# TODO(b/157272291) Add dispatchers for rest of special functions.
special_math_ops.bessel_i0e,
special_math_ops.bessel_i1e,
]
for unary_op in _UNARY_OPS:
_UnaryMapValueDispatcher(unary_op).register(unary_op)

View File

@ -250,7 +250,7 @@ def spence(x, name=None):
return gen_special_math_ops.spence(x)
@tf_export('math.bessel_i0')
@tf_export('math.bessel_i0', 'math.special.bessel_i0')
@dispatch.add_dispatch_support
def bessel_i0(x, name=None):
"""Computes the Bessel i0 function of `x` element-wise.
@ -259,6 +259,9 @@ def bessel_i0(x, name=None):
It is preferable to use the numerically stabler function `i0e(x)` instead.
>>> tf.math.special.bessel_i0([-1., -0.5, 0.5, 1.]).numpy()
array([1.26606588, 1.06348337, 1.06348337, 1.26606588], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
@ -272,10 +275,36 @@ def bessel_i0(x, name=None):
@end_compatibility
"""
with ops.name_scope(name, 'bessel_i0', [x]):
return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i0e(x)
return gen_special_math_ops.bessel_i0(x)
@tf_export('math.bessel_i1')
@tf_export('math.bessel_i0e', 'math.special.bessel_i0e')
@dispatch.add_dispatch_support
def bessel_i0e(x, name=None):
"""Computes the Bessel i0e function of `x` element-wise.
Modified Bessel function of order 0.
>>> tf.math.special.bessel_i0e([-1., -0.5, 0.5, 1.]).numpy()
array([0.46575961, 0.64503527, 0.64503527, 0.46575961], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.i0e
@end_compatibility
"""
with ops.name_scope(name, 'bessel_i0e', [x]):
return gen_special_math_ops.bessel_i0e(x)
@tf_export('math.bessel_i1', 'math.special.bessel_i1')
@dispatch.add_dispatch_support
def bessel_i1(x, name=None):
"""Computes the Bessel i1 function of `x` element-wise.
@ -284,6 +313,9 @@ def bessel_i1(x, name=None):
It is preferable to use the numerically stabler function `i1e(x)` instead.
>>> tf.math.special.bessel_i1([-1., -0.5, 0.5, 1.]).numpy()
array([-0.5651591 , -0.25789431, 0.25789431, 0.5651591 ], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
@ -297,7 +329,245 @@ def bessel_i1(x, name=None):
@end_compatibility
"""
with ops.name_scope(name, 'bessel_i1', [x]):
return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i1e(x)
return gen_special_math_ops.bessel_i1(x)
@tf_export('math.bessel_i1e', 'math.special.bessel_i1e')
@dispatch.add_dispatch_support
def bessel_i1e(x, name=None):
"""Computes the Bessel i1e function of `x` element-wise.
Modified Bessel function of order 1.
>>> tf.math.special.bessel_i1e([-1., -0.5, 0.5, 1.]).numpy()
array([-0.20791042, -0.15642083, 0.15642083, 0.20791042], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.i1e
@end_compatibility
"""
with ops.name_scope(name, 'bessel_i1e', [x]):
return gen_special_math_ops.bessel_i1e(x)
@tf_export('math.special.bessel_k0')
@dispatch.add_dispatch_support
def bessel_k0(x, name=None):
"""Computes the Bessel k0 function of `x` element-wise.
Modified Bessel function of order 0.
It is preferable to use the numerically stabler function `k0e(x)` instead.
>>> tf.math.special.bessel_k0([0.5, 1., 2., 4.]).numpy()
array([0.92441907, 0.42102444, 0.11389387, 0.01115968], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.k0
@end_compatibility
"""
with ops.name_scope(name, 'bessel_k0', [x]):
return gen_special_math_ops.bessel_k0(x)
@tf_export('math.special.bessel_k0e')
@dispatch.add_dispatch_support
def bessel_k0e(x, name=None):
"""Computes the Bessel k0e function of `x` element-wise.
Modified Bessel function of order 0.
>>> tf.math.special.bessel_k0e([0.5, 1., 2., 4.]).numpy()
array([1.52410939, 1.14446308, 0.84156822, 0.60929767], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.k0e
@end_compatibility
"""
with ops.name_scope(name, 'bessel_k0e', [x]):
return gen_special_math_ops.bessel_k0e(x)
@tf_export('math.special.bessel_k1')
@dispatch.add_dispatch_support
def bessel_k1(x, name=None):
"""Computes the Bessel k1 function of `x` element-wise.
Modified Bessel function of order 1.
It is preferable to use the numerically stabler function `k1e(x)` instead.
>>> tf.math.special.bessel_k1([0.5, 1., 2., 4.]).numpy()
array([1.65644112, 0.60190723, 0.13986588, 0.0124835 ], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.k1
@end_compatibility
"""
with ops.name_scope(name, 'bessel_k1', [x]):
return gen_special_math_ops.bessel_k1(x)
@tf_export('math.special.bessel_k1e')
@dispatch.add_dispatch_support
def bessel_k1e(x, name=None):
"""Computes the Bessel k1e function of `x` element-wise.
Modified Bessel function of order 1.
>>> tf.math.special.bessel_k1e([0.5, 1., 2., 4.]).numpy()
array([2.73100971, 1.63615349, 1.03347685, 0.68157595], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.k1e
@end_compatibility
"""
with ops.name_scope(name, 'bessel_k1e', [x]):
return gen_special_math_ops.bessel_k1e(x)
@tf_export('math.special.bessel_j0')
@dispatch.add_dispatch_support
def bessel_j0(x, name=None):
"""Computes the Bessel j0 function of `x` element-wise.
Modified Bessel function of order 0.
>>> tf.math.special.bessel_j0([0.5, 1., 2., 4.]).numpy()
array([ 0.93846981, 0.76519769, 0.22389078, -0.39714981], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.j0
@end_compatibility
"""
with ops.name_scope(name, 'bessel_j0', [x]):
return gen_special_math_ops.bessel_j0(x)
@tf_export('math.special.bessel_j1')
@dispatch.add_dispatch_support
def bessel_j1(x, name=None):
"""Computes the Bessel j1 function of `x` element-wise.
Modified Bessel function of order 1.
>>> tf.math.special.bessel_j1([0.5, 1., 2., 4.]).numpy()
array([ 0.24226846, 0.44005059, 0.57672481, -0.06604333], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.j1
@end_compatibility
"""
with ops.name_scope(name, 'bessel_j1', [x]):
return gen_special_math_ops.bessel_j1(x)
@tf_export('math.special.bessel_y0')
@dispatch.add_dispatch_support
def bessel_y0(x, name=None):
"""Computes the Bessel y0 function of `x` element-wise.
Modified Bessel function of order 0.
>>> tf.math.special.bessel_y0([0.5, 1., 2., 4.]).numpy()
array([-0.44451873, 0.08825696, 0.51037567, -0.01694074], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.y0
@end_compatibility
"""
with ops.name_scope(name, 'bessel_y0', [x]):
return gen_special_math_ops.bessel_y0(x)
@tf_export('math.special.bessel_y1')
@dispatch.add_dispatch_support
def bessel_y1(x, name=None):
"""Computes the Bessel y1 function of `x` element-wise.
Modified Bessel function of order 1.
>>> tf.math.special.bessel_y1([0.5, 1., 2., 4.]).numpy()
array([-1.47147239, -0.78121282, -0.10703243, 0.39792571], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.y1
@end_compatibility
"""
with ops.name_scope(name, 'bessel_y1', [x]):
return gen_special_math_ops.bessel_y1(x)
@ops.RegisterGradient('XlaEinsum')

View File

@ -403,34 +403,236 @@ class SpenceTest(test.TestCase, parameterized.TestCase):
self.assertAllClose([[[-1.]]], analytical)
class BesselTest(test.TestCase):
@test_util.run_all_in_graph_and_eager_modes
class BesselTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_bessel_i0(self):
x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
def test_besseli_boundary(self):
self.assertAllClose(1., special_math_ops.bessel_i0(0.))
self.assertAllClose(1., special_math_ops.bessel_i0e(0.))
self.assertAllClose(0., special_math_ops.bessel_i1(0.))
self.assertAllClose(0., special_math_ops.bessel_i1e(0.))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_i0(np.nan))))
self.assertTrue(
np.isnan(self.evaluate(special_math_ops.bessel_i0e(np.nan))))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_i1(np.nan))))
self.assertTrue(
np.isnan(self.evaluate(special_math_ops.bessel_i1e(np.nan))))
@test_util.run_in_graph_and_eager_modes
def test_besselj_boundary(self):
self.assertAllClose(1., special_math_ops.bessel_j0(0.))
self.assertAllClose(0., special_math_ops.bessel_j1(0.))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_j0(np.nan))))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_j1(np.nan))))
@test_util.run_in_graph_and_eager_modes
def test_besselk_boundary(self):
self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k0(0.))))
self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k0e(0.))))
self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k1(0.))))
self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k1e(0.))))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_k0(np.nan))))
self.assertTrue(
np.isnan(self.evaluate(special_math_ops.bessel_k0e(np.nan))))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_k1(np.nan))))
self.assertTrue(
np.isnan(self.evaluate(special_math_ops.bessel_k1e(np.nan))))
@parameterized.parameters(np.float32, np.float64)
def test_i0j0_even(self, dtype):
x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
self.assertAllClose(
self.evaluate(special_math_ops.bessel_i0(x)),
self.evaluate(special_math_ops.bessel_i0(-x)))
self.assertAllClose(
self.evaluate(special_math_ops.bessel_i0e(x)),
self.evaluate(special_math_ops.bessel_i0e(-x)))
self.assertAllClose(
self.evaluate(special_math_ops.bessel_j0(x)),
self.evaluate(special_math_ops.bessel_j0(-x)))
@parameterized.parameters(np.float32, np.float64)
def test_i1j1_odd(self, dtype):
x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
self.assertAllClose(
self.evaluate(special_math_ops.bessel_i1(x)),
self.evaluate(-special_math_ops.bessel_i1(-x)))
self.assertAllClose(
self.evaluate(special_math_ops.bessel_i1e(x)),
self.evaluate(-special_math_ops.bessel_i1e(-x)))
self.assertAllClose(
self.evaluate(special_math_ops.bessel_j1(x)),
self.evaluate(-special_math_ops.bessel_j1(-x)))
@parameterized.parameters(np.float32, np.float64)
def test_besseli_small(self, dtype):
x = np.random.uniform(-1., 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(special.i0(x_single),
self.evaluate(special_math_ops.bessel_i0(x_single)))
self.assertAllClose(special.i0(x_double),
self.evaluate(special_math_ops.bessel_i0(x_double)))
self.assertAllClose(
special.i0(x), self.evaluate(special_math_ops.bessel_i0(x)))
self.assertAllClose(
special.i1(x), self.evaluate(special_math_ops.bessel_i1(x)))
self.assertAllClose(
special.i0e(x), self.evaluate(special_math_ops.bessel_i0e(x)))
self.assertAllClose(
special.i1e(x), self.evaluate(special_math_ops.bessel_i1e(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@test_util.run_in_graph_and_eager_modes
def test_bessel_i1(self):
x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
@parameterized.parameters(np.float32, np.float64)
def test_besselj_small(self, dtype):
x = np.random.uniform(-1., 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(special.i1(x_single),
self.evaluate(special_math_ops.bessel_i1(x_single)))
self.assertAllClose(special.i1(x_double),
self.evaluate(special_math_ops.bessel_i1(x_double)))
self.assertAllClose(
special.j0(x), self.evaluate(special_math_ops.bessel_j0(x)))
self.assertAllClose(
special.j1(x), self.evaluate(special_math_ops.bessel_j1(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_besselk_small(self, dtype):
x = np.random.uniform(np.finfo(dtype).eps, 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.k0(x), self.evaluate(special_math_ops.bessel_k0(x)))
self.assertAllClose(
special.k0e(x), self.evaluate(special_math_ops.bessel_k0e(x)))
self.assertAllClose(
special.k1(x), self.evaluate(special_math_ops.bessel_k1(x)))
self.assertAllClose(
special.k1e(x), self.evaluate(special_math_ops.bessel_k1e(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_bessely_small(self, dtype):
x = np.random.uniform(np.finfo(dtype).eps, 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.y0(x), self.evaluate(special_math_ops.bessel_y0(x)))
self.assertAllClose(
special.y1(x), self.evaluate(special_math_ops.bessel_y1(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_besseli_larger(self, dtype):
x = np.random.uniform(1., 20., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.i0e(x), self.evaluate(special_math_ops.bessel_i0e(x)))
self.assertAllClose(
special.i1e(x), self.evaluate(special_math_ops.bessel_i1e(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_besselj_larger(self, dtype):
x = np.random.uniform(1., 30., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.j0(x), self.evaluate(special_math_ops.bessel_j0(x)))
self.assertAllClose(
special.j1(x), self.evaluate(special_math_ops.bessel_j1(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_besselk_larger(self, dtype):
x = np.random.uniform(1., 30., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.k0(x), self.evaluate(special_math_ops.bessel_k0(x)))
self.assertAllClose(
special.k0e(x), self.evaluate(special_math_ops.bessel_k0e(x)))
self.assertAllClose(
special.k1(x), self.evaluate(special_math_ops.bessel_k1(x)))
self.assertAllClose(
special.k1e(x), self.evaluate(special_math_ops.bessel_k1e(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_bessely_larger(self, dtype):
x = np.random.uniform(1., 30., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.y0(x), self.evaluate(special_math_ops.bessel_y0(x)))
self.assertAllClose(
special.y1(x), self.evaluate(special_math_ops.bessel_y1(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
def test_besseli_gradient(self):
inputs = [np.random.uniform(-10., 10., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_i0, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-3)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_i0e, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_i1, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-3)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_i1e, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
def test_besselj_gradient(self):
inputs = [np.random.uniform(-50., 50., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_j0, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_j1, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
def test_besselk_gradient(self):
inputs = [np.random.uniform(1., 50., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_k0, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_k0e, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_k1, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_k1e, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
def test_bessely_gradient(self):
inputs = [np.random.uniform(1., 50., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_y0, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_y1, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
@test_util.run_all_in_graph_and_eager_modes
class EinsumTest(test.TestCase):

View File

@ -1,5 +1,53 @@
path: "tensorflow.math.special"
tf_module {
member_method {
name: "bessel_i0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_i0e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_i1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_i1e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_j0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_j1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_k0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_k0e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_k1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_k1e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_y0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_y1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "dawsn"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -428,14 +428,54 @@ tf_module {
name: "BatchToSpaceND"
argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselI0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselI0e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselI1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselI1e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselJ0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselJ1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselK0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselK0e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselK1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselK1e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselY0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselY1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Betainc"
argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1,5 +1,53 @@
path: "tensorflow.math.special"
tf_module {
member_method {
name: "bessel_i0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_i0e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_i1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_i1e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_j0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_j1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_k0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_k0e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_k1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_k1e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_y0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "bessel_y1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "dawsn"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -428,14 +428,54 @@ tf_module {
name: "BatchToSpaceND"
argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselI0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselI0e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselI1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselI1e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselJ0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselJ1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselK0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselK0e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselK1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselK1e"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselY0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BesselY1"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Betainc"
argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "