RandomGamma should return an empty Tensor for empty alpha.

Before this change, it returns an error "Input alpha should have non-zero element count" if empty `alpha` is presented with non-empty `shape`.

PiperOrigin-RevId: 239255074
This commit is contained in:
Brian Patton 2019-03-19 13:20:21 -07:00 committed by TensorFlower Gardener
parent 9369d376e7
commit a24cfc8d09
2 changed files with 5 additions and 1 deletions

View File

@ -173,7 +173,7 @@ class RandomGammaOp : public OpKernel {
Tensor* samples_t = nullptr; Tensor* samples_t = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
if (num_samples == 0) return; if (samples_shape.num_elements() == 0) return;
using random::PhiloxRandom; using random::PhiloxRandom;

View File

@ -53,6 +53,10 @@ class RandomGammaTest(test.TestCase):
return func return func
def testEmptySamplingNoError(self):
self.evaluate(random_ops.random_gamma(
[5], alpha=np.ones([2, 0, 3]), beta=np.ones([3]), dtype=dtypes.float32))
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testMomentsFloat32(self): def testMomentsFloat32(self):
self._testMoments(dtypes.float32) self._testMoments(dtypes.float32)