Adds a tf.random.stateless_binomial (stateless analogue to tf.random.Generator.binomial).

PiperOrigin-RevId: 293446442
Change-Id: I1fc995304adbf82a5f43fddd27d015b6133b5a31
This commit is contained in:
Brian Patton 2020-02-05 14:04:49 -08:00 committed by TensorFlower Gardener
parent c376fe13ae
commit 1dc9305329
11 changed files with 268 additions and 11 deletions

View File

@ -0,0 +1,48 @@
op {
graph_op_name: "StatelessRandomBinomial"
visibility: HIDDEN
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
in_arg {
name: "seed"
description: <<END
2 seeds (shape [2]).
END
}
in_arg {
name: "counts"
description: <<END
The counts of the binomial distribution. Must be broadcastable with `probs`,
and broadcastable with the rightmost dimensions of `shape`.
END
}
in_arg {
name: "probs"
description: <<END
The probability of success for the binomial distribution. Must be broadcastable
with `counts` and broadcastable with the rightmost dimensions of `shape`.
END
}
out_arg {
name: "output"
description: <<END
Random values with specified shape.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
summary: "Outputs deterministic pseudorandom random numbers from a binomial distribution."
description: <<END
Outputs random values from a binomial distribution.
The outputs are a deterministic function of `shape`, `seed`, `counts`, and `probs`.
END
}

View File

@ -5955,6 +5955,7 @@ tf_kernel_library(
":random_ops", ":random_ops",
":resource_variable_ops", ":resource_variable_ops",
":stateful_random_ops", ":stateful_random_ops",
":stateless_random_ops",
":training_op_helpers", ":training_op_helpers",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
#include "tensorflow/core/kernels/stateless_random_ops.h"
#include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/random_distributions.h"
@ -434,19 +435,114 @@ class RandomBinomialOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RandomBinomialOp); TF_DISALLOW_COPY_AND_ASSIGN(RandomBinomialOp);
}; };
// Samples from a binomial distribution, using the given parameters.
template <typename Device, typename T, typename U>
class StatelessRandomBinomialOp : public OpKernel {
// Reshape batches so each batch is this size if possible.
static const int32 kDesiredBatchSize = 100;
public:
explicit StatelessRandomBinomialOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& shape_tensor = ctx->input(0);
const Tensor& seed_tensor = ctx->input(1);
const Tensor& counts_tensor = ctx->input(2);
const Tensor& probs_tensor = ctx->input(3);
OP_REQUIRES(ctx, seed_tensor.dims() == 1 && seed_tensor.dim_size(0) == 2,
errors::InvalidArgument("seed must have shape [2], not ",
seed_tensor.shape().DebugString()));
tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(),
probs_tensor.shape().dim_sizes(),
/*fewer_dims_optimization=*/false,
/*return_flattened_batch_indices=*/true);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
"counts and probs must have compatible batch dimensions: ",
counts_tensor.shape().DebugString(), " vs. ",
probs_tensor.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
errors::InvalidArgument("Input shape should be a vector, got shape: ",
shape_tensor.shape().DebugString()));
OP_REQUIRES(ctx,
(shape_tensor.dtype() == DataType::DT_INT32 ||
shape_tensor.dtype() == DataType::DT_INT64),
errors::InvalidArgument(
"Input shape should have dtype {int32, int64}."));
// Let's check that the shape tensor dominates the broadcasted tensor.
TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
TensorShape output_shape;
if (shape_tensor.dtype() == DataType::DT_INT32) {
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
&output_shape));
} else {
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int64>(),
&output_shape));
}
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
errors::InvalidArgument(
"Shape passed in must end with broadcasted shape."));
// Now that we have a guarantee, we can get the additional dimensions added
// by sampling.
int64 samples_per_batch = 1;
const int64 num_sample_dims =
(shape_tensor.dim_size(0) - bcast.output_shape().size());
for (int64 i = 0; i < num_sample_dims; ++i) {
samples_per_batch *= shape_tensor.flat<int32>()(i);
}
int64 num_batches = 1;
for (int64 i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
num_batches *= shape_tensor.flat<int32>()(i);
}
const int64 num_elements = num_batches * samples_per_batch;
Tensor* samples_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
if (output_shape.num_elements() == 0) return;
random::PhiloxRandom::Key key;
random::PhiloxRandom::ResultType counter;
OP_REQUIRES_OK(ctx, GenerateKey(seed_tensor, &key, &counter));
auto philox = random::PhiloxRandom(counter, key);
auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
samples_per_batch, num_elements, bcast,
counts_tensor.flat<T>(), probs_tensor.flat<T>(), philox,
samples_tensor->flat<U>());
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomBinomialOp);
};
} // namespace } // namespace
#define REGISTER(RTYPE, TYPE) \ #define REGISTER(RTYPE, TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatefulRandomBinomial") \ REGISTER_KERNEL_BUILDER(Name("StatefulRandomBinomial") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \
.HostMemory("resource") \ .HostMemory("resource") \
.HostMemory("algorithm") \ .HostMemory("algorithm") \
.HostMemory("shape") \ .HostMemory("shape") \
.HostMemory("counts") \ .HostMemory("counts") \
.HostMemory("probs") \ .HostMemory("probs") \
.TypeConstraint<RTYPE>("dtype") \ .TypeConstraint<RTYPE>("dtype") \
.TypeConstraint<TYPE>("T"), \ .TypeConstraint<TYPE>("T"), \
RandomBinomialOp<CPUDevice, TYPE, RTYPE>) RandomBinomialOp<CPUDevice, TYPE, RTYPE>); \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomBinomial") \
.Device(DEVICE_CPU) \
.HostMemory("shape") \
.HostMemory("seed") \
.HostMemory("counts") \
.HostMemory("probs") \
.TypeConstraint<RTYPE>("dtype") \
.TypeConstraint<TYPE>("T"), \
StatelessRandomBinomialOp<CPUDevice, TYPE, RTYPE>)
#define REGISTER_ALL(RTYPE) \ #define REGISTER_ALL(RTYPE) \
REGISTER(RTYPE, Eigen::half); \ REGISTER(RTYPE, Eigen::half); \

View File

@ -93,4 +93,16 @@ REGISTER_OP("StatelessMultinomial")
return Status::OK(); return Status::OK();
}); });
REGISTER_OP("StatelessRandomBinomial")
.Input("shape: S")
.Input("seed: Tseed")
.Input("counts: T")
.Input("probs: T")
.Output("output: dtype")
.Attr("S: {int32, int64}")
.Attr("Tseed: {int32, int64} = DT_INT64")
.Attr("T: {half, float, double, int32, int64} = DT_DOUBLE")
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
.SetShapeFn(StatelessShape);
} // namespace tensorflow } // namespace tensorflow

View File

@ -307,6 +307,7 @@ bool OpGradientDoesntRequireInputIndices(
{"StackPop", {true, {}}}, {"StackPop", {true, {}}},
{"StackPush", {true, {}}}, {"StackPush", {true, {}}},
{"StatelessMultinomial", {true, {}}}, {"StatelessMultinomial", {true, {}}},
{"StatelessRandomBinomial", {true, {}}},
{"StatelessRandomNormal", {true, {}}}, {"StatelessRandomNormal", {true, {}}},
{"StatelessRandomUniform", {true, {}}}, {"StatelessRandomUniform", {true, {}}},
{"StatelessRandomUniformInt", {true, {}}}, {"StatelessRandomUniformInt", {true, {}}},
@ -761,6 +762,7 @@ bool OpGradientDoesntRequireOutputIndices(
{"StackPop", {true, {}}}, {"StackPop", {true, {}}},
{"StackPush", {true, {}}}, {"StackPush", {true, {}}},
{"StatelessMultinomial", {true, {}}}, {"StatelessMultinomial", {true, {}}},
{"StatelessRandomBinomial", {true, {}}},
{"StatelessRandomNormal", {true, {}}}, {"StatelessRandomNormal", {true, {}}},
{"StatelessRandomUniform", {true, {}}}, {"StatelessRandomUniform", {true, {}}},
{"StatelessRandomUniformInt", {true, {}}}, {"StatelessRandomUniformInt", {true, {}}},

View File

@ -24,6 +24,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.kernel_tests.random import util from tensorflow.python.kernel_tests.random import util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import stateful_random_ops from tensorflow.python.ops import stateful_random_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
@ -80,6 +81,18 @@ class RandomBinomialTest(test.TestCase):
sy = self._Sampler(1000, counts=10., probs=0.4, dtype=dt, seed=345) sy = self._Sampler(1000, counts=10., probs=0.4, dtype=dt, seed=345)
self.assertAllEqual(self.evaluate(sx()), self.evaluate(sy())) self.assertAllEqual(self.evaluate(sx()), self.evaluate(sy()))
def testStateless(self):
for dt in dtypes.float16, dtypes.float32, dtypes.float64:
sx = stateless_random_ops.stateless_random_binomial(
shape=[1000], seed=[12, 34], counts=10., probs=0.4, output_dtype=dt)
sy = stateless_random_ops.stateless_random_binomial(
shape=[1000], seed=[12, 34], counts=10., probs=0.4, output_dtype=dt)
sx0, sx1 = self.evaluate(sx), self.evaluate(sx)
sy0, sy1 = self.evaluate(sy), self.evaluate(sy)
self.assertAllEqual(sx0, sx1)
self.assertAllEqual(sx0, sy0)
self.assertAllEqual(sy0, sy1)
def testZeroShape(self): def testZeroShape(self):
rnd = stateful_random_ops.Generator.from_seed(12345).binomial([0], [], []) rnd = stateful_random_ops.Generator.from_seed(12345).binomial([0], [], [])
self.assertEqual([0], rnd.shape.as_list()) self.assertEqual([0], rnd.shape.as_list())

View File

@ -27,6 +27,7 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable("StatelessMultinomial") ops.NotDifferentiable("StatelessMultinomial")
ops.NotDifferentiable("StatelessRandomBinomial")
ops.NotDifferentiable("StatelessRandomNormal") ops.NotDifferentiable("StatelessRandomNormal")
ops.NotDifferentiable("StatelessRandomUniform") ops.NotDifferentiable("StatelessRandomUniform")
ops.NotDifferentiable("StatelessRandomUniformInt") ops.NotDifferentiable("StatelessRandomUniformInt")
@ -102,6 +103,74 @@ def stateless_random_uniform(shape,
return result return result
@tf_export("random.stateless_binomial")
def stateless_random_binomial(shape,
seed,
counts,
probs,
output_dtype=dtypes.int32,
name=None):
"""Outputs deterministic pseudorandom values from a binomial distribution.
The generated values follow a binomial distribution with specified count and
probability of success parameters.
This is a stateless version of `tf.random.Generator.binomial`: if run twice
with the same seeds, it will produce the same pseudorandom numbers. The
output is consistent across multiple runs on the same hardware (and between
CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
hardware.
Example:
```python
counts = [10., 20.]
# Probability of success.
probs = [0.8]
binomial_samples = tf.random.stateless_binomial(
shape=[2], seed=[123, 456], counts=counts, probs=probs)
counts = ... # Shape [3, 1, 2]
probs = ... # Shape [1, 4, 2]
shape = [3, 4, 3, 4, 2]
# Sample shape will be [3, 4, 3, 4, 2]
binomial_samples = tf.random.stateless_binomial(
shape=shape, seed=[123, 456], counts=counts, probs=probs)
```
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
seed: A shape [2] integer Tensor of seeds to the random number generator.
counts: Tensor. The counts of the binomial distribution. Must be
broadcastable with `probs`, and broadcastable with the rightmost
dimensions of `shape`.
probs: Tensor. The probability of success for the binomial distribution.
Must be broadcastable with `counts` and broadcastable with the rightmost
dimensions of `shape`.
output_dtype: The type of the output. Default: tf.int32
name: A name for the operation (optional).
Returns:
samples: A Tensor of the specified shape filled with random binomial
values. For each i, each samples[..., i] is an independent draw from
the binomial distribution on counts[i] trials with probability of
success probs[i].
"""
with ops.name_scope(name, "stateless_random_binomial",
[shape, seed, counts, probs]) as name:
shape = tensor_util.shape_tensor(shape)
probs = ops.convert_to_tensor(
probs, dtype_hint=dtypes.float32, name="probs")
counts = ops.convert_to_tensor(
counts, dtype_hint=probs.dtype, name="counts")
result = gen_stateless_random_ops.stateless_random_binomial(
shape=shape, seed=seed, counts=counts, probs=probs, dtype=output_dtype)
tensor_util.maybe_set_static_shape(result, shape)
return result
@tf_export("random.stateless_normal") @tf_export("random.stateless_normal")
def stateless_random_normal(shape, def stateless_random_normal(shape,
seed, seed,

View File

@ -72,6 +72,10 @@ tf_module {
name: "shuffle" name: "shuffle"
argspec: "args=[\'value\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'value\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
} }
member_method {
name: "stateless_binomial"
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
}
member_method { member_method {
name: "stateless_categorical" name: "stateless_categorical"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], " argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "

View File

@ -4280,6 +4280,10 @@ tf_module {
name: "StatelessMultinomial" name: "StatelessMultinomial"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], " argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
} }
member_method {
name: "StatelessRandomBinomial"
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
}
member_method { member_method {
name: "StatelessRandomNormal" name: "StatelessRandomNormal"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], " argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "

View File

@ -64,6 +64,10 @@ tf_module {
name: "shuffle" name: "shuffle"
argspec: "args=[\'value\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'value\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
} }
member_method {
name: "stateless_binomial"
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
}
member_method { member_method {
name: "stateless_categorical" name: "stateless_categorical"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], " argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "

View File

@ -4280,6 +4280,10 @@ tf_module {
name: "StatelessMultinomial" name: "StatelessMultinomial"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], " argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
} }
member_method {
name: "StatelessRandomBinomial"
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
}
member_method { member_method {
name: "StatelessRandomNormal" name: "StatelessRandomNormal"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], " argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "