Modifies tf.random.stateless_uniform to support unbounded int sampling (when minval=maxval=None). Adds a StatelessRandomUniformFullInt op to support this.

PiperOrigin-RevId: 294925770
Change-Id: Ibe8fe4dc837eeba145180f00d3cd6a037c55620a
This commit is contained in:
Brian Patton 2020-02-13 08:55:18 -08:00 committed by TensorFlower Gardener
parent 41313e71ac
commit 315063582a
12 changed files with 231 additions and 16 deletions

View File

@ -1996,6 +1996,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"StatelessRandomNormal",
"StatelessRandomUniform",
"StatelessRandomUniformInt",
"StatelessRandomUniformFullInt",
"StatelessTruncatedNormal",
"StatelessWhile",
"Svd",

View File

@ -111,6 +111,35 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
}
}
xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
xla::XlaOp seeds,
const xla::Shape& shape) {
xla::XlaBuilder* builder = seeds.builder();
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
xla::PrimitiveType type = shape.element_type();
xla::RngOutput output =
GetBitGeneratorForDevice(device_type_string)(key, initial_state, shape);
switch (type) {
case xla::U32:
case xla::U64:
return output.value;
case xla::S32:
case xla::S64:
return BitcastConvertType(output.value, type);
default:
return builder->ReportError(xla::Unimplemented(
"Types other than U32, S32, U64 and S64 are not implemented by "
"StatelessRngUniformFullInt; got: %s",
xla::primitive_util::LowercasePrimitiveTypeName(type)));
}
}
namespace {
class StatelessRandomUniformOp : public XlaOpKernel {
@ -211,6 +240,47 @@ REGISTER_XLA_OP(Name("StatelessRandomUniformInt")
.TypeConstraint("Tseed", DT_INT32),
StatelessRandomUniformIntOp);
class StatelessRandomUniformFullIntOp : public XlaOpKernel {
public:
explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx),
device_type_string_(ctx->device_type().type_string()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
TensorShape seed_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
errors::InvalidArgument("seed must have shape [2], not ",
seed_shape.DebugString()));
xla::XlaOp seed = ctx->Input(1);
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
xla::XlaOp uniform =
StatelessRngUniformFullInt(device_type_string_, seed, xla_shape);
ctx->SetOutput(0, uniform);
}
private:
DataType dtype_;
string device_type_string_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
};
// TODO(phawkins): generalize to non-int32 seed types.
REGISTER_XLA_OP(Name("StatelessRandomUniformFullInt")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype", {DT_INT32, DT_INT64})
.TypeConstraint("Tseed", DT_INT32),
StatelessRandomUniformFullIntOp);
class StatelessRandomNormalOp : public XlaOpKernel {
public:
explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)

View File

@ -0,0 +1,34 @@
op {
graph_op_name: "StatelessRandomUniformFullInt"
visibility: HIDDEN
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
in_arg {
name: "seed"
description: <<END
2 seeds (shape [2]).
END
}
out_arg {
name: "output"
description: <<END
Random values with specified shape.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
description: <<END
The generated values are uniform integers covering the whole range of `dtype`.
The outputs are a deterministic function of `shape` and `seed`.
END
}

View File

@ -809,6 +809,7 @@ bool IsPinnableOp(const string& op_type) {
"RandomStandardNormal",
"StatelessRandomUniform",
"StatelessRandomUniformInt",
"StatelessRandomUniformFullInt",
"StatelessRandomNormal",
});

View File

@ -364,7 +364,13 @@ class RandomGammaOp : public OpKernel {
Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
RandomGammaOp<TYPE>)
#define REGISTER_FULL_INT(IntType) \
template struct functor::FillPhiloxRandom< \
CPUDevice, \
random::UniformFullIntDistribution<random::PhiloxRandom, IntType>>
#define REGISTER_INT(IntType) \
REGISTER_FULL_INT(IntType); \
template struct functor::FillPhiloxRandom< \
CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
@ -381,9 +387,12 @@ TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
TF_CALL_int32(REGISTER_INT);
TF_CALL_int64(REGISTER_INT);
TF_CALL_uint32(REGISTER_FULL_INT);
TF_CALL_uint64(REGISTER_FULL_INT);
#undef REGISTER
#undef REGISTER_INT
#undef REGISTER_FULL_INT
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -415,7 +424,13 @@ TF_CALL_int64(REGISTER_INT);
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
#define REGISTER_FULL_INT(IntType) \
template struct functor::FillPhiloxRandom< \
GPUDevice, \
random::UniformFullIntDistribution<random::PhiloxRandom, IntType>>
#define REGISTER_INT(IntType) \
REGISTER_FULL_INT(IntType); \
template struct functor::FillPhiloxRandom< \
GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
@ -432,9 +447,12 @@ TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
TF_CALL_int32(REGISTER_INT);
TF_CALL_int64(REGISTER_INT);
TF_CALL_uint32(REGISTER_FULL_INT);
TF_CALL_uint64(REGISTER_FULL_INT);
#undef REGISTER
#undef REGISTER_INT
#undef REGISTER_FULL_INT
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -48,6 +48,18 @@ template struct FillPhiloxRandom<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, int32> >;
template struct FillPhiloxRandom<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, int64> >;
template struct FillPhiloxRandom<
GPUDevice,
random::UniformFullIntDistribution<random::PhiloxRandom, int32> >;
template struct FillPhiloxRandom<
GPUDevice,
random::UniformFullIntDistribution<random::PhiloxRandom, int64> >;
template struct FillPhiloxRandom<
GPUDevice,
random::UniformFullIntDistribution<random::PhiloxRandom, uint32> >;
template struct FillPhiloxRandom<
GPUDevice,
random::UniformFullIntDistribution<random::PhiloxRandom, uint64> >;
template struct FillPhiloxRandom<
GPUDevice, random::NormalDistribution<random::PhiloxRandom, Eigen::half> >;
template struct FillPhiloxRandom<

View File

@ -163,6 +163,26 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
}
};
template <typename Device, typename IntType>
class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
public:
using StatelessRandomOpBase::StatelessRandomOpBase;
void Fill(OpKernelContext* context, random::PhiloxRandom random,
Tensor* output) override {
// Build distribution
typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType>
Distribution;
Distribution dist;
auto flat = output->flat<IntType>();
// Reuse the compute kernels from the stateful random ops
functor::FillPhiloxRandom<Device, Distribution>()(
context, context->eigen_device<Device>(), random, flat.data(),
flat.size(), dist);
}
};
// Samples from one or more Poisson distributions.
template <typename T, typename U>
class StatelessRandomPoissonOp : public StatelessRandomOpBase {
@ -376,7 +396,17 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase {
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
#define REGISTER_FULL_INT(DEVICE, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomUniformFullInt") \
.Device(DEVICE_##DEVICE) \
.HostMemory("shape") \
.HostMemory("seed") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>)
#define REGISTER_INT(DEVICE, TYPE) \
REGISTER_FULL_INT(DEVICE, TYPE); \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \
.Device(DEVICE_##DEVICE) \
.HostMemory("shape") \
@ -390,6 +420,8 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase {
#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
#define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE)
#define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE)
TF_CALL_half(REGISTER_CPU);
TF_CALL_bfloat16(REGISTER_CPU);
@ -397,6 +429,8 @@ TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_int32(REGISTER_INT_CPU);
TF_CALL_int64(REGISTER_INT_CPU);
TF_CALL_uint32(REGISTER_FULL_INT_CPU);
TF_CALL_uint64(REGISTER_FULL_INT_CPU);
#define REGISTER_POISSON(RATE_TYPE, OUT_TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomPoisson") \
@ -447,6 +481,8 @@ TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
TF_CALL_int32(REGISTER_INT_GPU);
TF_CALL_int64(REGISTER_INT_GPU);
TF_CALL_uint32(REGISTER_FULL_INT_GPU);
TF_CALL_uint64(REGISTER_FULL_INT_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -456,6 +492,8 @@ TF_CALL_int64(REGISTER_INT_GPU);
#undef REGISTER_GPU
#undef REGISTER_INT_CPU
#undef REGISTER_INT_GPU
#undef REGISTER_FULL_INT_CPU
#undef REGISTER_FULL_INT_GPU
} // namespace

View File

@ -68,6 +68,15 @@ REGISTER_OP("StatelessRandomUniformInt")
return StatelessShape(c);
});
REGISTER_OP("StatelessRandomUniformFullInt")
.Input("shape: T")
.Input("seed: Tseed")
.Output("output: dtype")
.Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
.Attr("T: {int32, int64} = DT_INT32")
.Attr("Tseed: {int32, int64, uint32, uint64} = DT_INT64")
.SetShapeFn(StatelessShape);
REGISTER_OP("StatelessMultinomial")
.Input("logits: T")
.Input("num_samples: int32")

View File

@ -312,6 +312,7 @@ bool OpGradientDoesntRequireInputIndices(
{"StatelessRandomNormal", {true, {}}},
{"StatelessRandomPoisson", {true, {}}},
{"StatelessRandomUniform", {true, {}}},
{"StatelessRandomUniformFullInt", {true, {}}},
{"StatelessRandomUniformInt", {true, {}}},
{"StatelessTruncatedNormal", {true, {}}},
{"StopGradient", {true, {}}},
@ -768,6 +769,7 @@ bool OpGradientDoesntRequireOutputIndices(
{"StatelessRandomNormal", {true, {}}},
{"StatelessRandomPoisson", {true, {}}},
{"StatelessRandomUniform", {true, {}}},
{"StatelessRandomUniformFullInt", {true, {}}},
{"StatelessRandomUniformInt", {true, {}}},
{"StatelessTruncatedNormal", {true, {}}},
{"StopGradient", {true, {}}},

View File

@ -35,6 +35,7 @@ ops.NotDifferentiable("StatelessRandomNormal")
ops.NotDifferentiable("StatelessRandomPoisson")
ops.NotDifferentiable("StatelessRandomUniform")
ops.NotDifferentiable("StatelessRandomUniformInt")
ops.NotDifferentiable("StatelessRandomUniformFullInt")
ops.NotDifferentiable("StatelessTruncatedNormal")
@ -65,44 +66,65 @@ def stateless_random_uniform(shape,
`maxval - minval` significantly smaller than the range of the output (either
`2**32` or `2**64`).
For full full-range (i.e. inclusive of both max and min) random integers, pass
`minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype
either both `minval` and `maxval` must be `None` or neither may be `None`. For
example:
```python
ints = tf.random.stateless_uniform(
[10], seed=(2, 3), minval=None, maxval=None, dtype=tf.int32)
```
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
seed: A shape [2] integer Tensor of seeds to the random number generator.
minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
range of random values to generate. Defaults to 0.
range of random values to generate. Pass `None` for full-range integers.
Defaults to 0.
maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on the
range of random values to generate. Defaults to 1 if `dtype` is floating
point.
point. Pass `None` for full-range integers.
dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or
`int64`.
`int64`. For unbounded uniform ints (`minval`, `maxval` both `None`),
`uint32` and `uint64` may be used.
name: A name for the operation (optional).
Returns:
A tensor of the specified shape filled with random uniform values.
Raises:
ValueError: If `dtype` is integral and `maxval` is not specified.
ValueError: If `dtype` is integral and only one of `minval` or `maxval` is
specified.
"""
dtype = dtypes.as_dtype(dtype)
if dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32,
dtypes.float64, dtypes.int32, dtypes.int64):
dtypes.float64, dtypes.int32, dtypes.int64, dtypes.uint32,
dtypes.uint64):
raise ValueError("Invalid dtype %r" % dtype)
if maxval is None:
if dtype.is_integer:
raise ValueError("Must specify maxval for integer dtype %r" % dtype)
if dtype.is_integer:
if (minval is None) != (maxval is None):
raise ValueError("For integer dtype {}, minval and maxval must be both "
"`None` or both non-`None`.".format(dtype))
if minval is not None and dtype in (dtypes.uint32, dtypes.uint64):
raise ValueError("Invalid dtype for bounded uniform integers: %r" % dtype)
elif maxval is None:
maxval = 1
with ops.name_scope(name, "stateless_random_uniform",
[shape, seed, minval, maxval]) as name:
shape = tensor_util.shape_tensor(shape)
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
if dtype.is_integer:
result = gen_stateless_random_ops.stateless_random_uniform_int(
shape, seed=seed, minval=minval, maxval=maxval, name=name)
if dtype.is_integer and minval is None:
result = gen_stateless_random_ops.stateless_random_uniform_full_int(
shape, seed=seed, dtype=dtype, name=name)
else:
rnd = gen_stateless_random_ops.stateless_random_uniform(
shape, seed=seed, dtype=dtype)
result = math_ops.add(rnd * (maxval - minval), minval, name=name)
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
if dtype.is_integer:
result = gen_stateless_random_ops.stateless_random_uniform_int(
shape, seed=seed, minval=minval, maxval=maxval, name=name)
else:
rnd = gen_stateless_random_ops.stateless_random_uniform(
shape, seed=seed, dtype=dtype)
result = math_ops.add(rnd * (maxval - minval), minval, name=name)
tensor_util.maybe_set_static_shape(result, shape)
return result

View File

@ -4308,6 +4308,10 @@ tf_module {
name: "StatelessRandomUniform"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomUniformFullInt"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomUniformInt"
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -4308,6 +4308,10 @@ tf_module {
name: "StatelessRandomUniform"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomUniformFullInt"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomUniformInt"
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "