Fix some corner cases in the binomial sampler, and spell out some details in its docstring.
PiperOrigin-RevId: 255658335
This commit is contained in:
parent
ab5ab73f4d
commit
84a3afbb02
tensorflow
core/kernels
python
@ -207,7 +207,17 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
|
||||
// Calculate normalized samples, then convert them.
|
||||
// Determine the method to use.
|
||||
double dcount = static_cast<double>(count);
|
||||
if (prob <= T(0.5)) {
|
||||
if (dcount <= 0.0 || prob <= T(0.0)) {
|
||||
while (sample < limit_sample) {
|
||||
output(sample) = static_cast<U>(0.0);
|
||||
sample++;
|
||||
}
|
||||
} else if (prob >= T(1.0)) {
|
||||
while (sample < limit_sample) {
|
||||
output(sample) = static_cast<U>(dcount);
|
||||
sample++;
|
||||
}
|
||||
} else if (prob <= T(0.5)) {
|
||||
double dp = static_cast<double>(prob);
|
||||
if (count * prob >= T(10)) {
|
||||
while (sample < limit_sample) {
|
||||
@ -221,7 +231,7 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
|
||||
sample++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
} else if (prob > T(0.5)) {
|
||||
T q = T(1) - prob;
|
||||
double dcount = static_cast<double>(count);
|
||||
double dq = static_cast<double>(q);
|
||||
@ -238,6 +248,14 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
|
||||
sample++;
|
||||
}
|
||||
}
|
||||
} else { // prob is NaN
|
||||
// TODO(srvasude): What should happen if prob is NaN but the output
|
||||
// type is an integer (which doesn't have a sentinel for NaN)? Fail
|
||||
// the whole batch sample? Return a specialized sentinel like -1?
|
||||
while (sample < limit_sample) {
|
||||
output(sample) = static_cast<U>(NAN);
|
||||
sample++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -115,6 +115,16 @@ class RandomBinomialTest(test.TestCase):
|
||||
probs=np.float32(0.9))
|
||||
self.assertEqual([10], rnd.shape.as_list())
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testCornerCases(self):
|
||||
rng = stateful_random_ops.Generator.from_seed(12345)
|
||||
counts = np.array([5, 5, 5, 0, 0, 0], dtype=np.float32)
|
||||
probs = np.array([0, 1, float("nan"), -10, 10, float("nan")],
|
||||
dtype=np.float32)
|
||||
expected = np.array([0, 5, float("nan"), 0, 0, 0], dtype=np.float32)
|
||||
result = rng.binomial(
|
||||
shape=[6], counts=counts, probs=probs, dtype=np.float32)
|
||||
self.assertAllEqual(expected, self.evaluate(result))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -543,7 +543,7 @@ class Generator(tracking.AutoTrackable):
|
||||
# Probability of success.
|
||||
probs = [0.8, 0.9]
|
||||
|
||||
rng = tf.random.experimental.Generator(seed=234)
|
||||
rng = tf.random.experimental.Generator.from_seed(seed=234)
|
||||
binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs)
|
||||
```
|
||||
|
||||
@ -551,15 +551,20 @@ class Generator(tracking.AutoTrackable):
|
||||
Args:
|
||||
shape: A 1-D integer Tensor or Python array. The shape of the output
|
||||
tensor.
|
||||
counts: A 0/1-D Tensor or Python value`. The counts of the binomial
|
||||
distribution.
|
||||
probs: A 0/1-D Tensor or Python value`. The probability of success for the
|
||||
binomial distribution.
|
||||
counts: A 0/1-D Tensor or Python value. The counts of the binomial
|
||||
distribution. Must be broadcastable with the leftmost dimension
|
||||
defined by `shape`.
|
||||
probs: A 0/1-D Tensor or Python value. The probability of success for the
|
||||
binomial distribution. Must be broadcastable with the leftmost
|
||||
dimension defined by `shape`.
|
||||
dtype: The type of the output. Default: tf.int32
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A tensor of the specified shape filled with random binomial values.
|
||||
samples: A Tensor of the specified shape filled with random binomial
|
||||
values. For each i, each samples[i, ...] is an independent draw from
|
||||
the binomial distribution on counts[i] trials with probability of
|
||||
success probs[i].
|
||||
"""
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
with ops.name_scope(name, "binomial", [shape, counts, probs]) as name:
|
||||
|
Loading…
Reference in New Issue
Block a user