diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 996950b65f3..e39e5f2eb3b 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -173,7 +173,7 @@ class RandomGammaOp : public OpKernel { Tensor* samples_t = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t)); - if (num_samples == 0) return; + if (samples_shape.num_elements() == 0) return; using random::PhiloxRandom; diff --git a/tensorflow/python/kernel_tests/random/random_gamma_test.py b/tensorflow/python/kernel_tests/random/random_gamma_test.py index 5cc13f67777..2fbfdc0a963 100644 --- a/tensorflow/python/kernel_tests/random/random_gamma_test.py +++ b/tensorflow/python/kernel_tests/random/random_gamma_test.py @@ -53,6 +53,10 @@ class RandomGammaTest(test.TestCase): return func + def testEmptySamplingNoError(self): + self.evaluate(random_ops.random_gamma( + [5], alpha=np.ones([2, 0, 3]), beta=np.ones([3]), dtype=dtypes.float32)) + @test_util.run_deprecated_v1 def testMomentsFloat32(self): self._testMoments(dtypes.float32)