Fixes doc and error messages about non-scalar minval/maxval in integer uniform RNGs.
PiperOrigin-RevId: 299235545 Change-Id: If70d2abbb777f8add2e459fa5cb1ddc9008a6238
This commit is contained in:
parent
692e9bf05c
commit
07eff99852
@ -45,8 +45,18 @@ REGISTER_OP("RandomUniformInt")
|
||||
.Attr("T: {int32, int64}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
Status s = c->WithRank(c->input(1), 0, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"minval must be a scalar; got a tensor of shape ",
|
||||
c->DebugString(c->input(1)));
|
||||
}
|
||||
s = c->WithRank(c->input(2), 0, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"maxval must be a scalar; got a tensor of shape ",
|
||||
c->DebugString(c->input(2)));
|
||||
}
|
||||
return shape_inference::RandomShape(c);
|
||||
});
|
||||
|
||||
|
@ -60,8 +60,18 @@ REGISTER_OP("StatefulUniformInt")
|
||||
// Check inputs
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||
Status s = c->WithRank(c->input(3), 0, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"minval must be a scalar; got a tensor of shape ",
|
||||
c->DebugString(c->input(3)));
|
||||
}
|
||||
s = c->WithRank(c->input(4), 0, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"maxval must be a scalar; got a tensor of shape ",
|
||||
c->DebugString(c->input(4)));
|
||||
}
|
||||
// Set output
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));
|
||||
|
@ -63,8 +63,18 @@ REGISTER_OP("StatelessRandomUniformInt")
|
||||
.Attr("Tseed: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
Status s = c->WithRank(c->input(2), 0, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"minval must be a scalar; got a tensor of shape ",
|
||||
c->DebugString(c->input(2)));
|
||||
}
|
||||
s = c->WithRank(c->input(3), 0, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"maxval must be a scalar; got a tensor of shape ",
|
||||
c->DebugString(c->input(3)));
|
||||
}
|
||||
return StatelessShape(c);
|
||||
});
|
||||
|
||||
|
@ -304,11 +304,11 @@ class RandomUniformTest(RandomOpTestCommon):
|
||||
def testUniformIntsWithInvalidShape(self):
|
||||
for dtype in dtypes.int32, dtypes.int64:
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Shape must be rank 0 but is rank 1"):
|
||||
ValueError, "minval must be a scalar; got a tensor of shape"):
|
||||
random_ops.random_uniform(
|
||||
[1000], minval=[1, 2], maxval=3, dtype=dtype)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Shape must be rank 0 but is rank 1"):
|
||||
ValueError, "maxval must be a scalar; got a tensor of shape"):
|
||||
random_ops.random_uniform(
|
||||
[1000], minval=1, maxval=[2, 3], dtype=dtype)
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import random_seed
|
||||
@ -193,6 +194,31 @@ class StatelessOpsTest(test.TestCase):
|
||||
def testDeterminismPoisson(self):
|
||||
self._test_determinism(self._poisson_cases())
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testErrors(self):
|
||||
"""Tests that proper errors are raised.
|
||||
"""
|
||||
shape = [2, 3]
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
ValueError,
|
||||
'minval must be a scalar; got a tensor of shape '):
|
||||
@def_function.function
|
||||
def f():
|
||||
stateless.stateless_random_uniform(
|
||||
shape=shape, seed=[1, 2], minval=array_ops.zeros(shape, 'int32'),
|
||||
maxval=100, dtype='int32')
|
||||
f()
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
ValueError,
|
||||
'maxval must be a scalar; got a tensor of shape '):
|
||||
@def_function.function
|
||||
def f2():
|
||||
stateless.stateless_random_uniform(
|
||||
shape=shape, seed=[1, 2], minval=0,
|
||||
maxval=array_ops.ones(shape, 'int32') * 100,
|
||||
dtype='int32')
|
||||
f2()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -253,10 +253,12 @@ def random_uniform(shape,
|
||||
Args:
|
||||
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
|
||||
minval: A Tensor or Python value of type `dtype`, broadcastable with
|
||||
`maxval`. The lower bound on the range of random values to generate
|
||||
`shape` (for integer types, broadcasting is not supported, so it needs to
|
||||
be a scalar). The lower bound on the range of random values to generate
|
||||
(inclusive). Defaults to 0.
|
||||
maxval: A Tensor or Python value of type `dtype`, broadcastable with
|
||||
`minval`. The upper bound on the range of random values to generate
|
||||
`shape` (for integer types, broadcasting is not supported, so it needs to
|
||||
be a scalar). The upper bound on the range of random values to generate
|
||||
(exclusive). Defaults to 1 if `dtype` is floating point.
|
||||
dtype: The type of the output: `float16`, `float32`, `float64`, `int32`,
|
||||
or `int64`.
|
||||
|
@ -687,11 +687,14 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
Args:
|
||||
shape: A 1-D integer Tensor or Python array. The shape of the output
|
||||
tensor.
|
||||
minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on
|
||||
the range of random values to generate. Defaults to 0.
|
||||
maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on
|
||||
the range of random values to generate. Defaults to 1 if `dtype` is
|
||||
floating point.
|
||||
minval: A Tensor or Python value of type `dtype`, broadcastable with
|
||||
`shape` (for integer types, broadcasting is not supported, so it needs
|
||||
to be a scalar). The lower bound on the range of random values to
|
||||
generate. Defaults to 0.
|
||||
maxval: A Tensor or Python value of type `dtype`, broadcastable with
|
||||
`shape` (for integer types, broadcasting is not supported, so it needs
|
||||
to be a scalar). The upper bound on the range of random values to
|
||||
generate. Defaults to 1 if `dtype` is floating point.
|
||||
dtype: The type of the output.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
|
@ -561,6 +561,23 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"For the Philox algorithm, the size of state must be at least"):
|
||||
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
var.handle, random.RNG_ALG_PHILOX, shape)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
ValueError,
|
||||
"minval must be a scalar; got a tensor of shape "):
|
||||
@def_function.function
|
||||
def f():
|
||||
gen.uniform(shape=shape, minval=array_ops.zeros(shape, "int32"),
|
||||
maxval=100, dtype="int32")
|
||||
f()
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
ValueError,
|
||||
"maxval must be a scalar; got a tensor of shape "):
|
||||
@def_function.function
|
||||
def f2():
|
||||
gen.uniform(
|
||||
shape=shape, minval=0, maxval=array_ops.ones(shape, "int32") * 100,
|
||||
dtype="int32")
|
||||
f2()
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testGetGlobalGeneratorWithXla(self):
|
||||
|
@ -78,12 +78,15 @@ def stateless_random_uniform(shape,
|
||||
Args:
|
||||
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
|
||||
seed: A shape [2] integer Tensor of seeds to the random number generator.
|
||||
minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
|
||||
range of random values to generate. Pass `None` for full-range integers.
|
||||
Defaults to 0.
|
||||
maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on the
|
||||
range of random values to generate. Defaults to 1 if `dtype` is floating
|
||||
point. Pass `None` for full-range integers.
|
||||
minval: A Tensor or Python value of type `dtype`, broadcastable with
|
||||
`shape` (for integer types, broadcasting is not supported, so it needs to
|
||||
be a scalar). The lower bound on the range of random values to
|
||||
generate. Pass `None` for full-range integers. Defaults to 0.
|
||||
maxval: A Tensor or Python value of type `dtype`, broadcastable with
|
||||
`shape` (for integer types, broadcasting is not supported, so it needs to
|
||||
be a scalar). The upper bound on the range of random values to generate.
|
||||
Defaults to 1 if `dtype` is floating point. Pass `None` for full-range
|
||||
integers.
|
||||
dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or
|
||||
`int64`. For unbounded uniform ints (`minval`, `maxval` both `None`),
|
||||
`uint32` and `uint64` may be used.
|
||||
|
Loading…
x
Reference in New Issue
Block a user