patch cl/312773551
PR #38585: Fix invalid shape issue in random.uniform Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/38585 Note: This PR is a resubmission from #34399 This PR tries to address the issue raised in #34363 where invalid shape passed to minval/maxval (expected to be 0-D) does not raise an error. The issue was that in most of the scenarios the shape was checked inside the C++ kernel ops. However, in one condition math_ops.add was used which will implicitly do broadcast when necessarily. This results in maxval/minval's shape getting carried. This PR adds the shape check before math_ops.add, to make sure the shape is guaranteed. This PR fixes #34363. Signed-off-by: Yong Tang yong.tang.github@outlook.com Copybara import of the project: -- 1c480a2175ed7d8a86210882bfbb0ed45f0730d6 by Yong Tang <yong.tang.github@outlook.com>: Fix invalid shape issue in random.uniform This PR tries to address the issue raised in 34363 where invalid shape passed to minval/maxval (expected to be 0-D) does not raise an error. The issue was that in most of the scenarios the shape was checked inside the C++ kernel ops. However, in one condition math_ops.add was used which will implicitly do broadcast when necessarily. This results in maxval/minval's shape getting carried. This PR adds the shape check before math_ops.add, to make sure the shape is guaranteed. This PR fixes 34363. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> -- 81dca0016c1efbfc99d7f22e2ac6d26e0c5099b5 by Yong Tang <yong.tang.github@outlook.com>: Add test case for invalid shape issue in random.uniform Signed-off-by: Yong Tang <yong.tang.github@outlook.com> -- be3dee4337f45883326536bc2fad7539cd1a2244 by Yong Tang <yong.tang.github@outlook.com>: Use explicit broadcast_to to prevent shape overflow Signed-off-by: Yong Tang <yong.tang.github@outlook.com> RELNOTES=n/a PiperOrigin-RevId: 313446121 Change-Id: I34b076d79c13a7db040bf46aa5b2f2b43075c55f
This commit is contained in:
parent
23971655a4
commit
426869b50f
@ -23,6 +23,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -415,6 +416,13 @@ class RandomUniformTest(RandomOpTestCommon):
|
||||
use_gpu=use_gpu,
|
||||
graph_seed=965)
|
||||
|
||||
def testUniformWithInvalidMaxMindShape(self):
|
||||
# Test case for GitHub issue 34363.
|
||||
with self.assertRaises(
|
||||
(errors.InvalidArgumentError, errors.UnknownError, ValueError)):
|
||||
array = array_ops.zeros(shape=(1,))
|
||||
random_ops.random_uniform(shape=(), minval=array)
|
||||
|
||||
|
||||
class RandomShapeTest(test.TestCase):
|
||||
|
||||
|
@ -304,6 +304,12 @@ def random_uniform(shape,
|
||||
if not maxval_is_one:
|
||||
result = math_ops.multiply(result, maxval)
|
||||
else:
|
||||
# Use explicit "broadcast_to" so that any shape incompatibility
|
||||
# are returned with InvalidArgument error.
|
||||
# This prevent "slient broadcast" that may cause the shape of
|
||||
# result "overflow" when minval or maxval is larger than expected shape
|
||||
maxval = array_ops.broadcast_to(maxval, shape)
|
||||
minval = array_ops.broadcast_to(minval, shape)
|
||||
result = math_ops.add(result * (maxval - minval), minval, name=name)
|
||||
# TODO(b/132092188): C++ shape inference inside functional ops does not
|
||||
# cross FuncGraph boundaries since that information is only available in
|
||||
|
Loading…
Reference in New Issue
Block a user