Allow np dtypes as args to random_gamma. Previously, failed with AttributeError: type object 'numpy.float32' has no attribute 'as_numpy_dtype'

PiperOrigin-RevId: 263783739
This commit is contained in:
Brian Patton 2019-08-16 09:16:59 -07:00 committed by TensorFlower Gardener
parent 50703da8ab
commit c0219bdebd
2 changed files with 5 additions and 1 deletions

View File

@ -53,6 +53,10 @@ class RandomGammaTest(test.TestCase):
return func
def testNpDtypes(self):
self.evaluate(random_ops.random_gamma(
[5], alpha=np.ones([2, 1, 3]), beta=np.ones([3]), dtype=np.float32))
def testEmptySamplingNoError(self):
self.evaluate(random_ops.random_gamma(
[5], alpha=np.ones([2, 0, 3]), beta=np.ones([3]), dtype=dtypes.float32))

View File

@ -489,7 +489,7 @@ def random_gamma(shape,
alpha_broadcast = alpha + array_ops.zeros_like(beta)
seed1, seed2 = random_seed.get_seed(seed)
result = math_ops.maximum(
np.finfo(dtype.as_numpy_dtype).tiny,
np.finfo(alpha.dtype.as_numpy_dtype).tiny,
gen_random_ops.random_gamma(
shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta)
_maybe_set_static_shape_helper(result, shape, alpha_broadcast)