Fix some corner cases in the binomial sampler, and spell out some details in its docstring.

PiperOrigin-RevId: 255658335
This commit is contained in:
Alexey Radul 2019-06-28 13:31:13 -07:00 committed by TensorFlower Gardener
parent ab5ab73f4d
commit 84a3afbb02
3 changed files with 41 additions and 8 deletions
tensorflow

View File

@ -207,7 +207,17 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
// Calculate normalized samples, then convert them.
// Determine the method to use.
double dcount = static_cast<double>(count);
if (prob <= T(0.5)) {
if (dcount <= 0.0 || prob <= T(0.0)) {
while (sample < limit_sample) {
output(sample) = static_cast<U>(0.0);
sample++;
}
} else if (prob >= T(1.0)) {
while (sample < limit_sample) {
output(sample) = static_cast<U>(dcount);
sample++;
}
} else if (prob <= T(0.5)) {
double dp = static_cast<double>(prob);
if (count * prob >= T(10)) {
while (sample < limit_sample) {
@ -221,7 +231,7 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
sample++;
}
}
} else {
} else if (prob > T(0.5)) {
T q = T(1) - prob;
double dcount = static_cast<double>(count);
double dq = static_cast<double>(q);
@ -238,6 +248,14 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
sample++;
}
}
} else { // prob is NaN
// TODO(srvasude): What should happen if prob is NaN but the output
// type is an integer (which doesn't have a sentinel for NaN)? Fail
// the whole batch sample? Return a specialized sentinel like -1?
while (sample < limit_sample) {
output(sample) = static_cast<U>(NAN);
sample++;
}
}
}
};

View File

@ -115,6 +115,16 @@ class RandomBinomialTest(test.TestCase):
probs=np.float32(0.9))
self.assertEqual([10], rnd.shape.as_list())
@test_util.run_v2_only
def testCornerCases(self):
rng = stateful_random_ops.Generator.from_seed(12345)
counts = np.array([5, 5, 5, 0, 0, 0], dtype=np.float32)
probs = np.array([0, 1, float("nan"), -10, 10, float("nan")],
dtype=np.float32)
expected = np.array([0, 5, float("nan"), 0, 0, 0], dtype=np.float32)
result = rng.binomial(
shape=[6], counts=counts, probs=probs, dtype=np.float32)
self.assertAllEqual(expected, self.evaluate(result))
if __name__ == "__main__":
test.main()

View File

@ -543,7 +543,7 @@ class Generator(tracking.AutoTrackable):
# Probability of success.
probs = [0.8, 0.9]
rng = tf.random.experimental.Generator(seed=234)
rng = tf.random.experimental.Generator.from_seed(seed=234)
binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs)
```
@ -551,15 +551,20 @@ class Generator(tracking.AutoTrackable):
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output
tensor.
counts: A 0/1-D Tensor or Python value`. The counts of the binomial
distribution.
probs: A 0/1-D Tensor or Python value`. The probability of success for the
binomial distribution.
counts: A 0/1-D Tensor or Python value. The counts of the binomial
distribution. Must be broadcastable with the leftmost dimension
defined by `shape`.
probs: A 0/1-D Tensor or Python value. The probability of success for the
binomial distribution. Must be broadcastable with the leftmost
dimension defined by `shape`.
dtype: The type of the output. Default: tf.int32
name: A name for the operation (optional).
Returns:
A tensor of the specified shape filled with random binomial values.
samples: A Tensor of the specified shape filled with random binomial
values. For each i, each samples[i, ...] is an independent draw from
the binomial distribution on counts[i] trials with probability of
success probs[i].
"""
dtype = dtypes.as_dtype(dtype)
with ops.name_scope(name, "binomial", [shape, counts, probs]) as name: