Fixes doc and error messages about non-scalar minval/maxval in integer uniform RNGs.

PiperOrigin-RevId: 299235545
Change-Id: If70d2abbb777f8add2e459fa5cb1ddc9008a6238
This commit is contained in:
Peng Wang 2020-03-05 17:25:17 -08:00 committed by TensorFlower Gardener
parent 692e9bf05c
commit 07eff99852
9 changed files with 102 additions and 21 deletions

View File

@ -45,8 +45,18 @@ REGISTER_OP("RandomUniformInt")
.Attr("T: {int32, int64}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
Status s = c->WithRank(c->input(1), 0, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"minval must be a scalar; got a tensor of shape ",
c->DebugString(c->input(1)));
}
s = c->WithRank(c->input(2), 0, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"maxval must be a scalar; got a tensor of shape ",
c->DebugString(c->input(2)));
}
return shape_inference::RandomShape(c);
});

View File

@ -60,8 +60,18 @@ REGISTER_OP("StatefulUniformInt")
// Check inputs
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
Status s = c->WithRank(c->input(3), 0, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"minval must be a scalar; got a tensor of shape ",
c->DebugString(c->input(3)));
}
s = c->WithRank(c->input(4), 0, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"maxval must be a scalar; got a tensor of shape ",
c->DebugString(c->input(4)));
}
// Set output
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));

View File

@ -63,8 +63,18 @@ REGISTER_OP("StatelessRandomUniformInt")
.Attr("Tseed: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
Status s = c->WithRank(c->input(2), 0, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"minval must be a scalar; got a tensor of shape ",
c->DebugString(c->input(2)));
}
s = c->WithRank(c->input(3), 0, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"maxval must be a scalar; got a tensor of shape ",
c->DebugString(c->input(3)));
}
return StatelessShape(c);
});

View File

@ -304,11 +304,11 @@ class RandomUniformTest(RandomOpTestCommon):
def testUniformIntsWithInvalidShape(self):
for dtype in dtypes.int32, dtypes.int64:
with self.assertRaisesRegexp(
ValueError, "Shape must be rank 0 but is rank 1"):
ValueError, "minval must be a scalar; got a tensor of shape"):
random_ops.random_uniform(
[1000], minval=[1, 2], maxval=3, dtype=dtype)
with self.assertRaisesRegexp(
ValueError, "Shape must be rank 0 but is rank 1"):
ValueError, "maxval must be a scalar; got a tensor of shape"):
random_ops.random_uniform(
[1000], minval=1, maxval=[2, 3], dtype=dtype)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import functools
import numpy as np
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
@ -193,6 +194,31 @@ class StatelessOpsTest(test.TestCase):
def testDeterminismPoisson(self):
self._test_determinism(self._poisson_cases())
@test_util.run_v2_only
def testErrors(self):
"""Tests that proper errors are raised.
"""
shape = [2, 3]
with self.assertRaisesWithPredicateMatch(
ValueError,
'minval must be a scalar; got a tensor of shape '):
@def_function.function
def f():
stateless.stateless_random_uniform(
shape=shape, seed=[1, 2], minval=array_ops.zeros(shape, 'int32'),
maxval=100, dtype='int32')
f()
with self.assertRaisesWithPredicateMatch(
ValueError,
'maxval must be a scalar; got a tensor of shape '):
@def_function.function
def f2():
stateless.stateless_random_uniform(
shape=shape, seed=[1, 2], minval=0,
maxval=array_ops.ones(shape, 'int32') * 100,
dtype='int32')
f2()
if __name__ == '__main__':
test.main()

View File

@ -253,10 +253,12 @@ def random_uniform(shape,
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
minval: A Tensor or Python value of type `dtype`, broadcastable with
`maxval`. The lower bound on the range of random values to generate
`shape` (for integer types, broadcasting is not supported, so it needs to
be a scalar). The lower bound on the range of random values to generate
(inclusive). Defaults to 0.
maxval: A Tensor or Python value of type `dtype`, broadcastable with
`minval`. The upper bound on the range of random values to generate
`shape` (for integer types, broadcasting is not supported, so it needs to
be a scalar). The upper bound on the range of random values to generate
(exclusive). Defaults to 1 if `dtype` is floating point.
dtype: The type of the output: `float16`, `float32`, `float64`, `int32`,
or `int64`.

View File

@ -687,11 +687,14 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output
tensor.
minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on
the range of random values to generate. Defaults to 0.
maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on
the range of random values to generate. Defaults to 1 if `dtype` is
floating point.
minval: A Tensor or Python value of type `dtype`, broadcastable with
`shape` (for integer types, broadcasting is not supported, so it needs
to be a scalar). The lower bound on the range of random values to
generate. Defaults to 0.
maxval: A Tensor or Python value of type `dtype`, broadcastable with
`shape` (for integer types, broadcasting is not supported, so it needs
to be a scalar). The upper bound on the range of random values to
generate. Defaults to 1 if `dtype` is floating point.
dtype: The type of the output.
name: A name for the operation (optional).

View File

@ -561,6 +561,23 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
"For the Philox algorithm, the size of state must be at least"):
gen_stateful_random_ops.stateful_standard_normal_v2(
var.handle, random.RNG_ALG_PHILOX, shape)
with self.assertRaisesWithPredicateMatch(
ValueError,
"minval must be a scalar; got a tensor of shape "):
@def_function.function
def f():
gen.uniform(shape=shape, minval=array_ops.zeros(shape, "int32"),
maxval=100, dtype="int32")
f()
with self.assertRaisesWithPredicateMatch(
ValueError,
"maxval must be a scalar; got a tensor of shape "):
@def_function.function
def f2():
gen.uniform(
shape=shape, minval=0, maxval=array_ops.ones(shape, "int32") * 100,
dtype="int32")
f2()
@test_util.run_v2_only
def testGetGlobalGeneratorWithXla(self):

View File

@ -78,12 +78,15 @@ def stateless_random_uniform(shape,
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
seed: A shape [2] integer Tensor of seeds to the random number generator.
minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
range of random values to generate. Pass `None` for full-range integers.
Defaults to 0.
maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on the
range of random values to generate. Defaults to 1 if `dtype` is floating
point. Pass `None` for full-range integers.
minval: A Tensor or Python value of type `dtype`, broadcastable with
`shape` (for integer types, broadcasting is not supported, so it needs to
be a scalar). The lower bound on the range of random values to
generate. Pass `None` for full-range integers. Defaults to 0.
maxval: A Tensor or Python value of type `dtype`, broadcastable with
`shape` (for integer types, broadcasting is not supported, so it needs to
be a scalar). The upper bound on the range of random values to generate.
Defaults to 1 if `dtype` is floating point. Pass `None` for full-range
integers.
dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or
`int64`. For unbounded uniform ints (`minval`, `maxval` both `None`),
`uint32` and `uint64` may be used.