diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index b06a6f9a988..0b0d8ea33c4 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1996,6 +1996,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() { "StatelessRandomNormal", "StatelessRandomUniform", "StatelessRandomUniformInt", + "StatelessRandomUniformFullInt", "StatelessTruncatedNormal", "StatelessWhile", "Svd", diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 38f7f1c89b2..13c3dbe489e 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -111,6 +111,35 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, } } +xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string, + xla::XlaOp seeds, + const xla::Shape& shape) { + xla::XlaBuilder* builder = seeds.builder(); + + xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {}); + xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {}); + xla::XlaOp key = ConvertElementType(seed0, xla::U64) | + ShiftLeft(ConvertElementType(seed1, xla::U64), + ConstantR0WithType(builder, xla::U64, 32)); + xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0); + xla::PrimitiveType type = shape.element_type(); + xla::RngOutput output = + GetBitGeneratorForDevice(device_type_string)(key, initial_state, shape); + switch (type) { + case xla::U32: + case xla::U64: + return output.value; + case xla::S32: + case xla::S64: + return BitcastConvertType(output.value, type); + default: + return builder->ReportError(xla::Unimplemented( + "Types other than U32, S32, U64 and S64 are not implemented by " + "StatelessRngUniformFullInt; got: %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))); + } +} + namespace { class StatelessRandomUniformOp : public XlaOpKernel { @@ -211,6 +240,47 @@ REGISTER_XLA_OP(Name("StatelessRandomUniformInt") .TypeConstraint("Tseed", DT_INT32), StatelessRandomUniformIntOp); +class StatelessRandomUniformFullIntOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx), + device_type_string_(ctx->device_type().type_string()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + + xla::XlaOp seed = ctx->Input(1); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + xla::XlaOp uniform = + StatelessRngUniformFullInt(device_type_string_, seed, xla_shape); + + ctx->SetOutput(0, uniform); + } + + private: + DataType dtype_; + string device_type_string_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp); +}; + +// TODO(phawkins): generalize to non-int32 seed types. +REGISTER_XLA_OP(Name("StatelessRandomUniformFullInt") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_INT32, DT_INT64}) + .TypeConstraint("Tseed", DT_INT32), + StatelessRandomUniformFullIntOp); + class StatelessRandomNormalOp : public XlaOpKernel { public: explicit StatelessRandomNormalOp(OpKernelConstruction* ctx) diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformFullInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformFullInt.pbtxt new file mode 100644 index 00000000000..26bb6781add --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformFullInt.pbtxt @@ -0,0 +1,34 @@ +op { + graph_op_name: "StatelessRandomUniformFullInt" + visibility: HIDDEN + in_arg { + name: "shape" + description: <<END +The shape of the output tensor. +END + } + in_arg { + name: "seed" + description: <<END +2 seeds (shape [2]). +END + } + out_arg { + name: "output" + description: <<END +Random values with specified shape. +END + } + attr { + name: "dtype" + description: <<END +The type of the output. +END + } + summary: "Outputs deterministic pseudorandom random integers from a uniform distribution." + description: <<END +The generated values are uniform integers covering the whole range of `dtype`. + +The outputs are a deterministic function of `shape` and `seed`. +END +} diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index b79f0fc4de8..348f7774d58 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -809,6 +809,7 @@ bool IsPinnableOp(const string& op_type) { "RandomStandardNormal", "StatelessRandomUniform", "StatelessRandomUniformInt", + "StatelessRandomUniformFullInt", "StatelessRandomNormal", }); diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 945f520506e..7b7f5153436 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -364,7 +364,13 @@ class RandomGammaOp : public OpKernel { Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ RandomGammaOp<TYPE>) +#define REGISTER_FULL_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, \ + random::UniformFullIntDistribution<random::PhiloxRandom, IntType>> + #define REGISTER_INT(IntType) \ + REGISTER_FULL_INT(IntType); \ template struct functor::FillPhiloxRandom< \ CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ @@ -381,9 +387,12 @@ TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); TF_CALL_int32(REGISTER_INT); TF_CALL_int64(REGISTER_INT); +TF_CALL_uint32(REGISTER_FULL_INT); +TF_CALL_uint64(REGISTER_FULL_INT); #undef REGISTER #undef REGISTER_INT +#undef REGISTER_FULL_INT #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -415,7 +424,13 @@ TF_CALL_int64(REGISTER_INT); random::TruncatedNormalDistribution< \ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>); +#define REGISTER_FULL_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + GPUDevice, \ + random::UniformFullIntDistribution<random::PhiloxRandom, IntType>> + #define REGISTER_INT(IntType) \ + REGISTER_FULL_INT(IntType); \ template struct functor::FillPhiloxRandom< \ GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ @@ -432,9 +447,12 @@ TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); TF_CALL_int32(REGISTER_INT); TF_CALL_int64(REGISTER_INT); +TF_CALL_uint32(REGISTER_FULL_INT); +TF_CALL_uint64(REGISTER_FULL_INT); #undef REGISTER #undef REGISTER_INT +#undef REGISTER_FULL_INT #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/random_op_gpu.cu.cc b/tensorflow/core/kernels/random_op_gpu.cu.cc index 3e8413da8fd..9d7c56e3310 100644 --- a/tensorflow/core/kernels/random_op_gpu.cu.cc +++ b/tensorflow/core/kernels/random_op_gpu.cu.cc @@ -48,6 +48,18 @@ template struct FillPhiloxRandom< GPUDevice, random::UniformDistribution<random::PhiloxRandom, int32> >; template struct FillPhiloxRandom< GPUDevice, random::UniformDistribution<random::PhiloxRandom, int64> >; +template struct FillPhiloxRandom< + GPUDevice, + random::UniformFullIntDistribution<random::PhiloxRandom, int32> >; +template struct FillPhiloxRandom< + GPUDevice, + random::UniformFullIntDistribution<random::PhiloxRandom, int64> >; +template struct FillPhiloxRandom< + GPUDevice, + random::UniformFullIntDistribution<random::PhiloxRandom, uint32> >; +template struct FillPhiloxRandom< + GPUDevice, + random::UniformFullIntDistribution<random::PhiloxRandom, uint64> >; template struct FillPhiloxRandom< GPUDevice, random::NormalDistribution<random::PhiloxRandom, Eigen::half> >; template struct FillPhiloxRandom< diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index dd451dbc2d5..167daf2ff9e 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -163,6 +163,26 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase { } }; +template <typename Device, typename IntType> +class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase { + public: + using StatelessRandomOpBase::StatelessRandomOpBase; + + void Fill(OpKernelContext* context, random::PhiloxRandom random, + Tensor* output) override { + // Build distribution + typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType> + Distribution; + Distribution dist; + + auto flat = output->flat<IntType>(); + // Reuse the compute kernels from the stateful random ops + functor::FillPhiloxRandom<Device, Distribution>()( + context, context->eigen_device<Device>(), random, flat.data(), + flat.size(), dist); + } +}; + // Samples from one or more Poisson distributions. template <typename T, typename U> class StatelessRandomPoissonOp : public StatelessRandomOpBase { @@ -376,7 +396,17 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase { random::TruncatedNormalDistribution< \ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >) +#define REGISTER_FULL_INT(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomUniformFullInt") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>) + #define REGISTER_INT(DEVICE, TYPE) \ + REGISTER_FULL_INT(DEVICE, TYPE); \ REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \ .Device(DEVICE_##DEVICE) \ .HostMemory("shape") \ @@ -390,6 +420,8 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase { #define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE) #define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE) #define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE) +#define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE) +#define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE) TF_CALL_half(REGISTER_CPU); TF_CALL_bfloat16(REGISTER_CPU); @@ -397,6 +429,8 @@ TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); TF_CALL_int32(REGISTER_INT_CPU); TF_CALL_int64(REGISTER_INT_CPU); +TF_CALL_uint32(REGISTER_FULL_INT_CPU); +TF_CALL_uint64(REGISTER_FULL_INT_CPU); #define REGISTER_POISSON(RATE_TYPE, OUT_TYPE) \ REGISTER_KERNEL_BUILDER(Name("StatelessRandomPoisson") \ @@ -447,6 +481,8 @@ TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); TF_CALL_int32(REGISTER_INT_GPU); TF_CALL_int64(REGISTER_INT_GPU); +TF_CALL_uint32(REGISTER_FULL_INT_GPU); +TF_CALL_uint64(REGISTER_FULL_INT_GPU); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -456,6 +492,8 @@ TF_CALL_int64(REGISTER_INT_GPU); #undef REGISTER_GPU #undef REGISTER_INT_CPU #undef REGISTER_INT_GPU +#undef REGISTER_FULL_INT_CPU +#undef REGISTER_FULL_INT_GPU } // namespace diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc index 27d0b71cf44..83342019ab6 100644 --- a/tensorflow/core/ops/stateless_random_ops.cc +++ b/tensorflow/core/ops/stateless_random_ops.cc @@ -68,6 +68,15 @@ REGISTER_OP("StatelessRandomUniformInt") return StatelessShape(c); }); +REGISTER_OP("StatelessRandomUniformFullInt") + .Input("shape: T") + .Input("seed: Tseed") + .Output("output: dtype") + .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64") + .Attr("T: {int32, int64} = DT_INT32") + .Attr("Tseed: {int32, int64, uint32, uint64} = DT_INT64") + .SetShapeFn(StatelessShape); + REGISTER_OP("StatelessMultinomial") .Input("logits: T") .Input("num_samples: int32") diff --git a/tensorflow/python/eager/pywrap_gradient_exclusions.cc b/tensorflow/python/eager/pywrap_gradient_exclusions.cc index 883b9425e7d..6647728828c 100644 --- a/tensorflow/python/eager/pywrap_gradient_exclusions.cc +++ b/tensorflow/python/eager/pywrap_gradient_exclusions.cc @@ -312,6 +312,7 @@ bool OpGradientDoesntRequireInputIndices( {"StatelessRandomNormal", {true, {}}}, {"StatelessRandomPoisson", {true, {}}}, {"StatelessRandomUniform", {true, {}}}, + {"StatelessRandomUniformFullInt", {true, {}}}, {"StatelessRandomUniformInt", {true, {}}}, {"StatelessTruncatedNormal", {true, {}}}, {"StopGradient", {true, {}}}, @@ -768,6 +769,7 @@ bool OpGradientDoesntRequireOutputIndices( {"StatelessRandomNormal", {true, {}}}, {"StatelessRandomPoisson", {true, {}}}, {"StatelessRandomUniform", {true, {}}}, + {"StatelessRandomUniformFullInt", {true, {}}}, {"StatelessRandomUniformInt", {true, {}}}, {"StatelessTruncatedNormal", {true, {}}}, {"StopGradient", {true, {}}}, diff --git a/tensorflow/python/ops/stateless_random_ops.py b/tensorflow/python/ops/stateless_random_ops.py index 94c80ea002b..eb08383ba6c 100644 --- a/tensorflow/python/ops/stateless_random_ops.py +++ b/tensorflow/python/ops/stateless_random_ops.py @@ -35,6 +35,7 @@ ops.NotDifferentiable("StatelessRandomNormal") ops.NotDifferentiable("StatelessRandomPoisson") ops.NotDifferentiable("StatelessRandomUniform") ops.NotDifferentiable("StatelessRandomUniformInt") +ops.NotDifferentiable("StatelessRandomUniformFullInt") ops.NotDifferentiable("StatelessTruncatedNormal") @@ -65,44 +66,65 @@ def stateless_random_uniform(shape, `maxval - minval` significantly smaller than the range of the output (either `2**32` or `2**64`). + For full full-range (i.e. inclusive of both max and min) random integers, pass + `minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype + either both `minval` and `maxval` must be `None` or neither may be `None`. For + example: + ```python + ints = tf.random.stateless_uniform( + [10], seed=(2, 3), minval=None, maxval=None, dtype=tf.int32) + ``` + Args: shape: A 1-D integer Tensor or Python array. The shape of the output tensor. seed: A shape [2] integer Tensor of seeds to the random number generator. minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the - range of random values to generate. Defaults to 0. + range of random values to generate. Pass `None` for full-range integers. + Defaults to 0. maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on the range of random values to generate. Defaults to 1 if `dtype` is floating - point. + point. Pass `None` for full-range integers. dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or - `int64`. + `int64`. For unbounded uniform ints (`minval`, `maxval` both `None`), + `uint32` and `uint64` may be used. name: A name for the operation (optional). Returns: A tensor of the specified shape filled with random uniform values. Raises: - ValueError: If `dtype` is integral and `maxval` is not specified. + ValueError: If `dtype` is integral and only one of `minval` or `maxval` is + specified. """ dtype = dtypes.as_dtype(dtype) if dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32, - dtypes.float64, dtypes.int32, dtypes.int64): + dtypes.float64, dtypes.int32, dtypes.int64, dtypes.uint32, + dtypes.uint64): raise ValueError("Invalid dtype %r" % dtype) - if maxval is None: - if dtype.is_integer: - raise ValueError("Must specify maxval for integer dtype %r" % dtype) + if dtype.is_integer: + if (minval is None) != (maxval is None): + raise ValueError("For integer dtype {}, minval and maxval must be both " + "`None` or both non-`None`.".format(dtype)) + if minval is not None and dtype in (dtypes.uint32, dtypes.uint64): + raise ValueError("Invalid dtype for bounded uniform integers: %r" % dtype) + elif maxval is None: maxval = 1 with ops.name_scope(name, "stateless_random_uniform", [shape, seed, minval, maxval]) as name: shape = tensor_util.shape_tensor(shape) - minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") - maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") - if dtype.is_integer: - result = gen_stateless_random_ops.stateless_random_uniform_int( - shape, seed=seed, minval=minval, maxval=maxval, name=name) + if dtype.is_integer and minval is None: + result = gen_stateless_random_ops.stateless_random_uniform_full_int( + shape, seed=seed, dtype=dtype, name=name) else: - rnd = gen_stateless_random_ops.stateless_random_uniform( - shape, seed=seed, dtype=dtype) - result = math_ops.add(rnd * (maxval - minval), minval, name=name) + minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") + maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") + if dtype.is_integer: + result = gen_stateless_random_ops.stateless_random_uniform_int( + shape, seed=seed, minval=minval, maxval=maxval, name=name) + else: + rnd = gen_stateless_random_ops.stateless_random_uniform( + shape, seed=seed, dtype=dtype) + result = math_ops.add(rnd * (maxval - minval), minval, name=name) tensor_util.maybe_set_static_shape(result, shape) return result diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index dded95f465b..2229fe317c3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -4308,6 +4308,10 @@ tf_module { name: "StatelessRandomUniform" argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], " } + member_method { + name: "StatelessRandomUniformFullInt" + argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], " + } member_method { name: "StatelessRandomUniformInt" argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index dded95f465b..2229fe317c3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -4308,6 +4308,10 @@ tf_module { name: "StatelessRandomUniform" argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], " } + member_method { + name: "StatelessRandomUniformFullInt" + argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], " + } member_method { name: "StatelessRandomUniformInt" argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "