Allow np dtypes as args to random_gamma. Previously, failed with AttributeError: type object 'numpy.float32' has no attribute 'as_numpy_dtype'
PiperOrigin-RevId: 263783739
This commit is contained in:
parent
50703da8ab
commit
c0219bdebd
@ -53,6 +53,10 @@ class RandomGammaTest(test.TestCase):
|
||||
|
||||
return func
|
||||
|
||||
def testNpDtypes(self):
|
||||
self.evaluate(random_ops.random_gamma(
|
||||
[5], alpha=np.ones([2, 1, 3]), beta=np.ones([3]), dtype=np.float32))
|
||||
|
||||
def testEmptySamplingNoError(self):
|
||||
self.evaluate(random_ops.random_gamma(
|
||||
[5], alpha=np.ones([2, 0, 3]), beta=np.ones([3]), dtype=dtypes.float32))
|
||||
|
@ -489,7 +489,7 @@ def random_gamma(shape,
|
||||
alpha_broadcast = alpha + array_ops.zeros_like(beta)
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
result = math_ops.maximum(
|
||||
np.finfo(dtype.as_numpy_dtype).tiny,
|
||||
np.finfo(alpha.dtype.as_numpy_dtype).tiny,
|
||||
gen_random_ops.random_gamma(
|
||||
shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta)
|
||||
_maybe_set_static_shape_helper(result, shape, alpha_broadcast)
|
||||
|
Loading…
Reference in New Issue
Block a user