RandomGamma should return an empty Tensor for empty alpha.
Before this change, it returns an error "Input alpha should have non-zero element count" if empty `alpha` is presented with non-empty `shape`. PiperOrigin-RevId: 239255074
This commit is contained in:
parent
9369d376e7
commit
a24cfc8d09
@ -173,7 +173,7 @@ class RandomGammaOp : public OpKernel {
|
||||
Tensor* samples_t = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
|
||||
|
||||
if (num_samples == 0) return;
|
||||
if (samples_shape.num_elements() == 0) return;
|
||||
|
||||
using random::PhiloxRandom;
|
||||
|
||||
|
@ -53,6 +53,10 @@ class RandomGammaTest(test.TestCase):
|
||||
|
||||
return func
|
||||
|
||||
def testEmptySamplingNoError(self):
|
||||
self.evaluate(random_ops.random_gamma(
|
||||
[5], alpha=np.ones([2, 0, 3]), beta=np.ones([3]), dtype=dtypes.float32))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMomentsFloat32(self):
|
||||
self._testMoments(dtypes.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user