Add GPU kernel for tf.random.stateless_gamma.
PiperOrigin-RevId: 351972885 Change-Id: Iabef759abd14aabc6e4c7f9783ba5899d824b40c
This commit is contained in:
parent
1059e04b92
commit
489133d42a
@ -47,53 +47,7 @@ namespace {
|
||||
static constexpr int kReservedSamplesPerOutput = 256;
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
|
||||
// a single sample from this buffer. If the buffer is empty, it first generates
|
||||
// new samples using the provided distribution.
|
||||
//
|
||||
// If the call to Distribution::operator() returns samples[0...N-1], then this
|
||||
// class returns samples in the following order:
|
||||
//
|
||||
// samples[N-1], samples[N-2],..., samples[1], samples[0]
|
||||
//
|
||||
// For comparison, random::SingleSampleAdapter returns samples in
|
||||
// the following order:
|
||||
//
|
||||
// samples[0], samples[1],...,samples[N-2], samples[N-1].
|
||||
//
|
||||
template <class Distribution>
|
||||
class SampleBuffer {
|
||||
public:
|
||||
typedef typename Distribution::ResultElementType ResultElementType;
|
||||
|
||||
PHILOX_DEVICE_INLINE
|
||||
explicit SampleBuffer(Distribution* distribution)
|
||||
: distribution_(distribution), remaining_numbers_(0) {}
|
||||
|
||||
PHILOX_DEVICE_INLINE
|
||||
ResultElementType operator()(random::PhiloxRandom* random) {
|
||||
if (remaining_numbers_ == 0) {
|
||||
results_ = (*distribution_)(random);
|
||||
remaining_numbers_ = Distribution::kResultElementCount;
|
||||
}
|
||||
|
||||
remaining_numbers_--;
|
||||
return results_[remaining_numbers_];
|
||||
}
|
||||
|
||||
// Mark this buffer as empty. The next call to operator() will fill it
|
||||
// with new random numbers.
|
||||
PHILOX_DEVICE_INLINE
|
||||
void Clear() { remaining_numbers_ = 0; }
|
||||
|
||||
private:
|
||||
typedef typename Distribution::ResultType ResultType;
|
||||
|
||||
Distribution* distribution_;
|
||||
ResultType results_;
|
||||
int remaining_numbers_;
|
||||
};
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
}; // namespace
|
||||
|
||||
@ -102,7 +56,8 @@ namespace functor {
|
||||
template <typename T>
|
||||
struct StatelessRandomGammaFunctor<CPUDevice, T> {
|
||||
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
|
||||
int64 num_alphas, int64 samples_per_alpha,
|
||||
int64 num_samples, int64 num_alphas,
|
||||
int64 samples_per_alpha,
|
||||
const random::PhiloxRandom& random, T* samples_flat) {
|
||||
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
|
||||
@ -124,8 +79,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
|
||||
Normal normal;
|
||||
Uniform uniform;
|
||||
|
||||
SampleBuffer<Normal> normal_buffer(&normal);
|
||||
SampleBuffer<Uniform> uniform_buffer(&uniform);
|
||||
RandomSampleBuffer<Normal> normal_buffer(&normal);
|
||||
RandomSampleBuffer<Uniform> uniform_buffer(&uniform);
|
||||
|
||||
for (int64 output_idx = start_output; output_idx < limit_output;
|
||||
/* output_idx incremented within inner loop below */) {
|
||||
@ -138,14 +93,14 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
|
||||
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
|
||||
|
||||
DISABLE_FLOAT_EQUALITY_WARNING
|
||||
if (alpha == static_cast<double>(1.0)) {
|
||||
if (alpha == 1.0) {
|
||||
ENABLE_FLOAT_EQUALITY_WARNING
|
||||
// Sample from an exponential distribution.
|
||||
for (int64 sample_idx = output_idx % samples_per_alpha;
|
||||
sample_idx < samples_per_alpha && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// As we want data stable regardless of sharding
|
||||
// (including eventually on GPU), we skip on a per-sample basis.
|
||||
// As we want data stable regardless of sharding, we skip on a
|
||||
// per-sample basis.
|
||||
random::PhiloxRandom gen = random;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
double u = uniform(&gen)[Uniform::kResultElementCount - 1];
|
||||
@ -162,7 +117,7 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
|
||||
//
|
||||
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
|
||||
// result by uniform()^(1/alpha)
|
||||
const bool alpha_less_than_one = alpha < 1;
|
||||
const bool alpha_less_than_one = alpha < 1.0;
|
||||
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
|
||||
const double c = 1.0 / 3 / sqrt(d);
|
||||
|
||||
@ -171,8 +126,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
|
||||
sample_idx < samples_per_alpha && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// Since each sample may use a variable number of normal/uniform
|
||||
// samples, and we want data stable regardless of sharding
|
||||
// (including eventually on GPU), we skip on a per-sample basis.
|
||||
// samples, and we want data stable regardless of sharding, we skip
|
||||
// on a per-sample basis.
|
||||
random::PhiloxRandom gen = random;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
|
||||
@ -226,8 +181,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
|
||||
Uniform::kElementCost +
|
||||
3 * random::PhiloxRandom::kElementCost;
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||
num_alphas * samples_per_alpha, kElementCost, DoWork);
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, num_samples,
|
||||
kElementCost, DoWork);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
@ -280,19 +235,21 @@ class StatelessRandomGammaOp : public OpKernel {
|
||||
"Input alpha should have non-zero element count, got: ",
|
||||
num_alphas));
|
||||
|
||||
const int64 samples_per_alpha = samples_shape.num_elements() / num_alphas;
|
||||
const int64 num_samples = samples_shape.num_elements();
|
||||
const int64 samples_per_alpha = num_samples / num_alphas;
|
||||
const auto alpha_flat = alpha_t.flat<T>().data();
|
||||
auto samples_flat = output->flat<T>().data();
|
||||
|
||||
OP_REQUIRES_OK(ctx, functor::StatelessRandomGammaFunctor<Device, T>::Fill(
|
||||
ctx, alpha_flat, num_alphas, samples_per_alpha,
|
||||
random, samples_flat));
|
||||
ctx, alpha_flat, num_samples, num_alphas,
|
||||
samples_per_alpha, random, samples_flat));
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomGammaOp);
|
||||
};
|
||||
|
||||
#define REGISTER_GAMMA(TYPE) \
|
||||
// Register CPU kernels for stateless gamma op.
|
||||
#define REGISTER_GAMMA_CPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.HostMemory("shape") \
|
||||
@ -301,12 +258,32 @@ class StatelessRandomGammaOp : public OpKernel {
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomGammaOp<CPUDevice, TYPE>)
|
||||
|
||||
TF_CALL_half(REGISTER_GAMMA);
|
||||
TF_CALL_bfloat16(REGISTER_GAMMA);
|
||||
TF_CALL_float(REGISTER_GAMMA);
|
||||
TF_CALL_double(REGISTER_GAMMA);
|
||||
TF_CALL_half(REGISTER_GAMMA_CPU);
|
||||
TF_CALL_bfloat16(REGISTER_GAMMA_CPU);
|
||||
TF_CALL_float(REGISTER_GAMMA_CPU);
|
||||
TF_CALL_double(REGISTER_GAMMA_CPU);
|
||||
|
||||
#undef REGISTER_GAMMA
|
||||
#undef REGISTER_GAMMA_CPU
|
||||
|
||||
// Register GPU kernels for stateless gamma op.
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_GAMMA_GPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("seed") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomGammaOp<GPUDevice, TYPE>)
|
||||
|
||||
TF_CALL_half(REGISTER_GAMMA_GPU);
|
||||
TF_CALL_bfloat16(REGISTER_GAMMA_GPU);
|
||||
TF_CALL_float(REGISTER_GAMMA_GPU);
|
||||
TF_CALL_double(REGISTER_GAMMA_GPU);
|
||||
|
||||
#undef REGISTER_GAMMA_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -27,12 +27,60 @@ namespace functor {
|
||||
template <typename Device, typename T>
|
||||
struct StatelessRandomGammaFunctor {
|
||||
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
|
||||
int64 num_alphas, int64 samples_per_alpha,
|
||||
int64 num_samples, int64 num_alphas,
|
||||
int64 samples_per_alpha,
|
||||
const random::PhiloxRandom& random, T* samples_flat);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
|
||||
// a single sample from this buffer. If the buffer is empty, it first generates
|
||||
// new samples using the provided distribution.
|
||||
//
|
||||
// If the call to Distribution::operator() returns samples[0...N-1], then this
|
||||
// class returns samples in the following order:
|
||||
//
|
||||
// samples[N-1], samples[N-2],..., samples[1], samples[0]
|
||||
//
|
||||
// For comparison, random::SingleSampleAdapter returns samples in
|
||||
// the following order:
|
||||
//
|
||||
// samples[0], samples[1],...,samples[N-2], samples[N-1].
|
||||
//
|
||||
template <class Distribution>
|
||||
class RandomSampleBuffer {
|
||||
public:
|
||||
typedef typename Distribution::ResultElementType ResultElementType;
|
||||
|
||||
PHILOX_DEVICE_INLINE
|
||||
explicit RandomSampleBuffer(Distribution* distribution)
|
||||
: distribution_(distribution), remaining_numbers_(0) {}
|
||||
|
||||
PHILOX_DEVICE_INLINE
|
||||
ResultElementType operator()(random::PhiloxRandom* random) {
|
||||
if (remaining_numbers_ == 0) {
|
||||
results_ = (*distribution_)(random);
|
||||
remaining_numbers_ = Distribution::kResultElementCount;
|
||||
}
|
||||
|
||||
remaining_numbers_--;
|
||||
return results_[remaining_numbers_];
|
||||
}
|
||||
|
||||
// Mark this buffer as empty. The next call to operator() will fill it
|
||||
// with new random numbers.
|
||||
PHILOX_DEVICE_INLINE
|
||||
void Clear() { remaining_numbers_ = 0; }
|
||||
|
||||
private:
|
||||
typedef typename Distribution::ResultType ResultType;
|
||||
|
||||
Distribution* distribution_;
|
||||
ResultType results_;
|
||||
int remaining_numbers_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
|
||||
|
179
tensorflow/core/kernels/stateless_random_gamma_op_gpu.cu.cc
Normal file
179
tensorflow/core/kernels/stateless_random_gamma_op_gpu.cu.cc
Normal file
@ -0,0 +1,179 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/stateless_random_gamma_op.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// Each attempt to generate a new draw from the Gamma distribution is 95+%
|
||||
// successful, and requires 1-2 normal + 1 uniform sample.
|
||||
static constexpr int kReservedSamplesPerOutput = 256;
|
||||
|
||||
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||
#else
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__global__ void __launch_bounds__(1024)
|
||||
FillKernel(int64 num_samples, int64 num_alphas, int64 samples_per_alpha,
|
||||
random::PhiloxRandom random, T* samples_flat,
|
||||
const T* alpha_flat) {
|
||||
using Eigen::numext::exp;
|
||||
using Eigen::numext::log;
|
||||
using Eigen::numext::log1p;
|
||||
using Eigen::numext::pow;
|
||||
|
||||
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
|
||||
|
||||
Normal normal;
|
||||
Uniform uniform;
|
||||
|
||||
RandomSampleBuffer<Normal> normal_buffer(&normal);
|
||||
RandomSampleBuffer<Uniform> uniform_buffer(&uniform);
|
||||
|
||||
for (int64 output_idx : GpuGridRangeX(num_samples)) {
|
||||
int64 alpha_idx = output_idx / samples_per_alpha;
|
||||
int64 sample_idx = output_idx % samples_per_alpha;
|
||||
|
||||
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
|
||||
|
||||
DISABLE_FLOAT_EQUALITY_WARNING
|
||||
if (alpha == 1.0) {
|
||||
ENABLE_FLOAT_EQUALITY_WARNING
|
||||
// Sample from an exponential distribution.
|
||||
// As we want data stable regardless of sharding, we skip on a per-sample
|
||||
// basis.
|
||||
random::PhiloxRandom gen = random;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
double u = uniform(&gen)[Uniform::kResultElementCount - 1];
|
||||
const double res = -log1p(-u);
|
||||
// We use alpha_idx + sample_idx * num_alphas instead of output_idx
|
||||
// to generate numbers in the right order (CPU and GPU kernels
|
||||
// must generate numbers in the same order).
|
||||
samples_flat[alpha_idx + sample_idx * num_alphas] = static_cast<T>(res);
|
||||
} else { // if alpha != 1.0
|
||||
// Transformation-rejection from pairs of uniform and normal random
|
||||
// variables. http://dl.acm.org/citation.cfm?id=358414
|
||||
//
|
||||
// The algorithm has an acceptance rate of ~95% for small alpha (~1),
|
||||
// and higher accept rates for higher alpha, so runtime is
|
||||
// O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
|
||||
//
|
||||
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
|
||||
// result by uniform()^(1/alpha)
|
||||
const bool alpha_less_than_one = alpha < 1.0;
|
||||
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
|
||||
const double c = 1.0 / 3 / sqrt(d);
|
||||
|
||||
// Since each sample may use a variable number of normal/uniform
|
||||
// samples, and we want data stable regardless of sharding, we skip on a
|
||||
// per-sample basis.
|
||||
random::PhiloxRandom gen = random;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
|
||||
// To prevent overwriting SampleBuffer's underlying array with
|
||||
// zeros (in tensorflow::random::Array constructor), we just mark
|
||||
// the buffer as empty instead of initializing a new SampleBuffer
|
||||
// object here. The next call to operator() will fill the buffer
|
||||
// with new numbers.
|
||||
normal_buffer.Clear();
|
||||
uniform_buffer.Clear();
|
||||
|
||||
// Keep trying until we don't reject a sample. In practice, we will
|
||||
// only reject ~5% at worst, for low alpha near 1.
|
||||
while (true) {
|
||||
const double x = normal_buffer(&gen);
|
||||
double v = 1 + c * x;
|
||||
if (v <= 0) {
|
||||
continue;
|
||||
}
|
||||
v = v * v * v;
|
||||
double u = uniform_buffer(&gen);
|
||||
// The first option in the if is a "squeeze" short-circuit to
|
||||
// dodge the two logs. Magic constant sourced from the paper
|
||||
// linked above. Upward of .91 of the area covered by the log
|
||||
// inequality is covered by the squeeze as well (larger coverage
|
||||
// for smaller values of alpha).
|
||||
if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
|
||||
(log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
|
||||
double res = d * v;
|
||||
if (alpha_less_than_one) {
|
||||
double b = uniform_buffer(&gen);
|
||||
res *= pow(b, 1 / alpha);
|
||||
}
|
||||
// We use alpha_idx + sample_idx * num_alphas instead of output_idx
|
||||
// to generate numbers in the right order (CPU and GPU kernels
|
||||
// must generate numbers in the same order).
|
||||
samples_flat[alpha_idx + sample_idx * num_alphas] =
|
||||
static_cast<T>(res);
|
||||
break;
|
||||
}
|
||||
} // while: true
|
||||
} // if (alpha == 1.0)
|
||||
} // for: output_idx
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T>
|
||||
struct StatelessRandomGammaFunctor<GPUDevice, T> {
|
||||
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
|
||||
int64 num_samples, int64 num_alphas,
|
||||
int64 samples_per_alpha,
|
||||
const random::PhiloxRandom& random, T* samples_flat) {
|
||||
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||
GpuLaunchConfig cfg = GetGpuLaunchConfig(num_samples, d);
|
||||
|
||||
TF_CHECK_OK(GpuLaunchKernel(FillKernel<T>, cfg.block_count,
|
||||
cfg.thread_per_block, 0, d.stream(),
|
||||
num_samples, num_alphas, samples_per_alpha,
|
||||
random, samples_flat, alpha_flat));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER_GPU_SPEC(type) \
|
||||
template struct functor::StatelessRandomGammaFunctor<GPUDevice, type>;
|
||||
|
||||
TF_CALL_half(REGISTER_GPU_SPEC);
|
||||
TF_CALL_bfloat16(REGISTER_GPU_SPEC);
|
||||
TF_CALL_float(REGISTER_GPU_SPEC);
|
||||
TF_CALL_double(REGISTER_GPU_SPEC);
|
||||
#undef REGISTER_GPU_SPEC
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
@ -216,6 +216,17 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
||||
pure = stateless_op(seed=preseed)
|
||||
self.assertAllEqual(stateful, pure)
|
||||
|
||||
def _test_match_stateless_cpu_gpu(self, case, seed):
|
||||
# Stateless ops should produce the same result on CPUs and GPUs.
|
||||
_, stateless_op, _ = case
|
||||
|
||||
with ops.device('CPU'):
|
||||
result_cpu = stateless_op(seed=seed)
|
||||
|
||||
with ops.device(get_device().name):
|
||||
result_gpu = stateless_op(seed=seed)
|
||||
self.assertAllClose(result_cpu, result_gpu)
|
||||
|
||||
def _test_old_and_new_stateless_match(self, case, seed):
|
||||
"""Tests that the new stateless ops match the old stateless ones."""
|
||||
with ops.device(get_device().name):
|
||||
@ -306,6 +317,18 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(gamma_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testStatelessGammaCpuGpuMatch(self, case, seed):
|
||||
if get_device().device_type != 'GPU':
|
||||
# This test compares the numbers produced by the CPU and GPU kernel for
|
||||
# stateless_random_gamma.
|
||||
self.skipTest('This test requires GPU')
|
||||
self._test_match_stateless_cpu_gpu(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
@ -387,10 +410,6 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
||||
for case_id, case in enumerate(gamma_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismGamma(self, case, seed_type):
|
||||
if get_device().device_type == 'GPU':
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking GPU kernel')
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
|
Loading…
Reference in New Issue
Block a user