Allow np dtypes as args to random_gamma. Previously, failed with AttributeError: type object 'numpy.float32' has no attribute 'as_numpy_dtype'
PiperOrigin-RevId: 263783739
This commit is contained in:
parent
50703da8ab
commit
c0219bdebd
@ -53,6 +53,10 @@ class RandomGammaTest(test.TestCase):
|
|||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
def testNpDtypes(self):
|
||||||
|
self.evaluate(random_ops.random_gamma(
|
||||||
|
[5], alpha=np.ones([2, 1, 3]), beta=np.ones([3]), dtype=np.float32))
|
||||||
|
|
||||||
def testEmptySamplingNoError(self):
|
def testEmptySamplingNoError(self):
|
||||||
self.evaluate(random_ops.random_gamma(
|
self.evaluate(random_ops.random_gamma(
|
||||||
[5], alpha=np.ones([2, 0, 3]), beta=np.ones([3]), dtype=dtypes.float32))
|
[5], alpha=np.ones([2, 0, 3]), beta=np.ones([3]), dtype=dtypes.float32))
|
||||||
|
@ -489,7 +489,7 @@ def random_gamma(shape,
|
|||||||
alpha_broadcast = alpha + array_ops.zeros_like(beta)
|
alpha_broadcast = alpha + array_ops.zeros_like(beta)
|
||||||
seed1, seed2 = random_seed.get_seed(seed)
|
seed1, seed2 = random_seed.get_seed(seed)
|
||||||
result = math_ops.maximum(
|
result = math_ops.maximum(
|
||||||
np.finfo(dtype.as_numpy_dtype).tiny,
|
np.finfo(alpha.dtype.as_numpy_dtype).tiny,
|
||||||
gen_random_ops.random_gamma(
|
gen_random_ops.random_gamma(
|
||||||
shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta)
|
shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta)
|
||||||
_maybe_set_static_shape_helper(result, shape, alpha_broadcast)
|
_maybe_set_static_shape_helper(result, shape, alpha_broadcast)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user