[XLA] Support parameterized truncated normal.

Add ParameterizedTruncatedNormal to the XLA client library, and uses it to
implement the standard version of TruncatedNormal.

Add XlaOpKernel for ParameterizedTruncatedNormal.

Add compiler test for parameterized truncated normal.

PiperOrigin-RevId: 258860922
This commit is contained in:
Bixia Zheng 2019-07-18 15:47:53 -07:00 committed by TensorFlower Gardener
parent 1bacd3b540
commit b1a6b315a6
4 changed files with 171 additions and 63 deletions
tensorflow/compiler

View File

@ -40,7 +40,7 @@ class RandomOpsTest(xla_test.XLATestCase):
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
with self.session() as sess:
with self.session():
with self.test_scope():
x = rng(dtype)
@ -103,7 +103,7 @@ class RandomOpsTest(xla_test.XLATestCase):
if (self.device in ["XLA_GPU", "XLA_CPU"
]) and (dtype in [dtypes.bfloat16, dtypes.half]):
continue
with self.session() as sess:
with self.session():
with self.test_scope():
x = random_ops.random_uniform(
shape=[1000], dtype=dtype, minval=-2, maxval=33)
@ -116,60 +116,97 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
# TODO(b/34339814): make this test work with 16 bit float types.
# TODO(b/34339814): make this test work with 16 bit float types.
for dtype in self._random_types() & {np.float32, np.float64}:
self._testRngIsNotConstant(rng, dtype)
def _checkTruncatedNormalIsInRange(self, x, a, b, mu, sigma, count,
stat_test):
def normal_cdf(x):
return .5 * math.erfc(-x / math.sqrt(2))
def normal_pdf(x):
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x):
return self.evaluate(special_math.ndtri(x))
y = self.evaluate(x)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
z = normal_cdf(beta) - normal_cdf(alpha)
self.assertEqual((y >= a).sum(), count)
self.assertEqual((y <= b).sum(), count)
# Skip statistical test for low probability regions.
if not stat_test:
return
# For more information on these calculations, see:
# Burkardt, John. "The Truncated Normal Distribution".
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
self.assertAllClose(actual_mean, expected_mean, atol=2e-3, rtol=2e-3)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
actual_median = np.median(y)
self.assertAllClose(actual_median, expected_median, atol=1e-2)
expected_variance = sigma**2 * (1 + (
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
actual_variance = np.var(y)
self.assertAllClose(
actual_variance, expected_variance, atol=2e-3, rtol=2e-3)
def testTruncatedNormalIsInRange(self):
count = 10000000
# TODO(b/34339814): make this test work with 16 bit float types.
for dtype in self._random_types() & {np.float32, np.float64}:
with self.session() as sess:
with self.session():
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
y = self.evaluate(x)
self._checkTruncatedNormalIsInRange(
x, a=-2, b=2, mu=0, sigma=1, count=count, stat_test=True)
def normal_cdf(x):
return .5 * math.erfc(-x / math.sqrt(2))
def _implParameterizedTruncatedNormalIsInRange(self, a, b, mu, sigma, count,
stat_test):
# TODO(b/34339814): make this test work with 16 bit float types.
for dtype in self._random_types() & {np.float32, np.float64}:
with self.session():
with self.test_scope():
x = random_ops.parameterized_truncated_normal(
shape=[count],
dtype=dtype,
means=mu,
stddevs=sigma,
minvals=a,
maxvals=b)
self._checkTruncatedNormalIsInRange(
x, a=a, b=b, mu=mu, sigma=sigma, count=count, stat_test=stat_test)
def normal_pdf(x):
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x, sess=sess):
return self.evaluate(special_math.ndtri(x))
a = -2.
b = 2.
mu = 0.
sigma = 1.
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
z = normal_cdf(beta) - normal_cdf(alpha)
self.assertEqual((y >= a).sum(), count)
self.assertEqual((y <= b).sum(), count)
# For more information on these calculations, see:
# Burkardt, John. "The Truncated Normal Distribution".
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
self.assertAllClose(actual_mean, expected_mean, atol=2e-3)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
actual_median = np.median(y)
self.assertAllClose(actual_median, expected_median, atol=1e-2)
expected_variance = sigma**2 * (1 + (
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
actual_variance = np.var(y)
self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
def testParameterizedTruncatedNormalIsInRange(self):
count = 10000000
self._implParameterizedTruncatedNormalIsInRange(
a=-10, b=20, mu=5, sigma=5, count=count, stat_test=True)
# the region is on the left side of the parent normal distribution
self._implParameterizedTruncatedNormalIsInRange(
a=-10, b=-4, mu=0, sigma=1, count=count, stat_test=False)
self._implParameterizedTruncatedNormalIsInRange(
a=-np.infty, b=-4, mu=0, sigma=1, count=count, stat_test=False)
# the region is on the right side of the parent normal distribution
self._implParameterizedTruncatedNormalIsInRange(
a=4, b=10, mu=0, sigma=1, count=count, stat_test=False)
self._implParameterizedTruncatedNormalIsInRange(
a=4, b=np.infty, mu=0, sigma=1, count=count, stat_test=False)
def testShuffle1d(self):
with self.session() as sess:
with self.session():
with self.test_scope():
x = math_ops.range(1 << 16)
shuffle = random_ops.random_shuffle(x)
@ -180,7 +217,7 @@ class RandomOpsTest(xla_test.XLATestCase):
self.assertAllEqual(set(result), set(expected))
def testShuffle2d(self):
with self.session() as sess:
with self.session():
with self.test_scope():
x = array_ops.diag(math_ops.range(20))
shuffle = random_ops.random_shuffle(x)

View File

@ -296,5 +296,40 @@ REGISTER_XLA_OP(Name("TruncatedNormal")
.TypeConstraint("dtype", {DT_FLOAT, DT_DOUBLE}),
TruncatedNormalOp);
class ParameterizedTruncatedNormalOp : public XlaOpKernel {
public:
explicit ParameterizedTruncatedNormalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const DataType dtype = output_type(0);
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp one = xla::One(b, xla_shape.element_type());
xla::XlaOp min_positive =
xla::MinPositiveNormalValue(b, xla_shape.element_type());
xla::XlaOp uniform = xla::RngUniform(min_positive, one, xla_shape);
xla::XlaOp means = ctx->Input(1);
xla::XlaOp stddevs = ctx->Input(2);
xla::XlaOp minvals = ctx->Input(3);
xla::XlaOp maxvals = ctx->Input(4);
ctx->SetOutput(0, ParameterizedTruncatedNormal(uniform, means, stddevs,
minvals, maxvals));
}
};
REGISTER_XLA_OP(Name("ParameterizedTruncatedNormal")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype", {DT_FLOAT, DT_DOUBLE}),
ParameterizedTruncatedNormalOp);
} // namespace
} // namespace tensorflow

View File

@ -27,29 +27,58 @@ limitations under the License.
namespace tensorflow {
xla::XlaOp TruncatedNormal(xla::XlaOp uniform) {
auto normal_cdf = [](double x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
const double kA = -2.0;
const double kB = 2.0;
const double kMu = 0.0;
const double kSigma = 1.0;
const double kAlpha = (kA - kMu) / kSigma;
const double kBeta = (kB - kMu) / kSigma;
const double kAlphaNormalCdf = normal_cdf(kAlpha);
const double kBetaNormalCdf = normal_cdf(kBeta);
const double kZ = kBetaNormalCdf - kAlphaNormalCdf;
return ParameterizedTruncatedNormal(
uniform, xla::ScalarLike(uniform, kMu), xla::ScalarLike(uniform, kSigma),
xla::ScalarLike(uniform, kA), xla::ScalarLike(uniform, kB));
}
// Implements the sampling of truncated normal distribution using the
// inversed cumulative distribution function (CDF) method as described in
// https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf.
xla::XlaOp ParameterizedTruncatedNormal(xla::XlaOp uniform, xla::XlaOp mu,
xla::XlaOp sigma, xla::XlaOp a,
xla::XlaOp b) {
xla::XlaOp one = xla::ScalarLike(uniform, 1.0);
xla::XlaOp two = xla::ScalarLike(uniform, 2.0);
xla::XlaOp sqrt_2 = xla::ScalarLike(uniform, std::sqrt(2.0));
xla::XlaOp z = xla::ScalarLike(uniform, kZ);
xla::XlaOp alpha_normal_cdf = xla::ScalarLike(uniform, kAlphaNormalCdf);
auto p = alpha_normal_cdf + z * uniform;
// probit(p) = sqrt(2) * erfinv(2*p-1)
return sqrt_2 * xla::ErfInv(two * p - one);
auto normal_cdf = [&](xla::XlaOp x) {
return (one + xla::Erf(x / sqrt_2)) / two;
};
// Calculate the cumulative probabilities for the lower and upper bound, a and
// b.
xla::XlaOp alpha = (a - mu) / sigma;
xla::XlaOp beta = (b - mu) / sigma;
xla::XlaOp alpha_normal_cdf = normal_cdf(alpha);
xla::XlaOp beta_normal_cdf = normal_cdf(beta);
// Convert the random uniform value in range (0, 1) (uniform) to a value in
// range (alpha_normal_cdf, beta_normal_cdf) that represents the cumulative
// probability (p) of a value (x) in the truncated normal distribution.
xla::XlaOp p =
alpha_normal_cdf + (beta_normal_cdf - alpha_normal_cdf) * uniform;
// Calculate x using the inversed cumulative distribution function:
// x = inversed_cdf(mu, sigma; p) = mu + sigma * sqrt(2) * erfinv(2*p-1)
// Clamp the input of erfinv to (-1, 1) because 2*p-1 may produce +/-1 due to
// computation precision.
xla::XlaOp v = two * p - one;
xla::PrimitiveType primitive_type =
uniform.builder()->GetShape(uniform).ConsumeValueOrDie().element_type();
xla::XlaOp epsilon = xla::Epsilon(uniform.builder(), primitive_type);
v = xla::Clamp(-one + epsilon, v, one - epsilon);
xla::XlaOp x = mu + sigma * sqrt_2 * xla::ErfInv(v);
// The value of x may be out of the range of (a, b), this typically happens
// when the region of the truncated normal has a very small probability.
x = xla::Clamp(a, x, b);
return x;
}
} // namespace tensorflow

View File

@ -22,12 +22,19 @@ limitations under the License.
namespace tensorflow {
// Builds an array filled with values sampled from a truncated normal
// distribution such that no values are greater than two or less than negative
// two.
// Builds an array of values sampled from a truncated normal distribution:
//
// The "uniform" parameter must be an array of random numbers distributed in
// (0,1).
// uniform: an array of random numbers in uniform distribution (0, 1).
// mu: the mean of the normal distribution.
// sigma: the standard deviation of the normal distribution.
// a: the lower bound of the generated values.
// b: the upper bound of the generated values.
xla::XlaOp ParameterizedTruncatedNormal(xla::XlaOp uniform, xla::XlaOp mu,
xla::XlaOp sigma, xla::XlaOp a,
xla::XlaOp b);
// A specialized version of ParameterizedTruncatedNormal, with mu=0, sigma=1,
// a=-2 and b=2.
xla::XlaOp TruncatedNormal(xla::XlaOp uniform);
} // namespace tensorflow