Add GPU kernel for tf.random.stateless_gamma.
PiperOrigin-RevId: 351972885 Change-Id: Iabef759abd14aabc6e4c7f9783ba5899d824b40c
This commit is contained in:
		
							parent
							
								
									1059e04b92
								
							
						
					
					
						commit
						489133d42a
					
				@ -47,53 +47,7 @@ namespace {
 | 
				
			|||||||
static constexpr int kReservedSamplesPerOutput = 256;
 | 
					static constexpr int kReservedSamplesPerOutput = 256;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
 | 
					typedef Eigen::ThreadPoolDevice CPUDevice;
 | 
				
			||||||
 | 
					typedef Eigen::GpuDevice GPUDevice;
 | 
				
			||||||
// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
 | 
					 | 
				
			||||||
// a single sample from this buffer. If the buffer is empty, it first generates
 | 
					 | 
				
			||||||
// new samples using the provided distribution.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// If the call to Distribution::operator() returns samples[0...N-1], then this
 | 
					 | 
				
			||||||
// class returns samples in the following order:
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
//   samples[N-1], samples[N-2],..., samples[1], samples[0]
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// For comparison, random::SingleSampleAdapter returns samples in
 | 
					 | 
				
			||||||
// the following order:
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
//   samples[0], samples[1],...,samples[N-2], samples[N-1].
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
template <class Distribution>
 | 
					 | 
				
			||||||
class SampleBuffer {
 | 
					 | 
				
			||||||
 public:
 | 
					 | 
				
			||||||
  typedef typename Distribution::ResultElementType ResultElementType;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  PHILOX_DEVICE_INLINE
 | 
					 | 
				
			||||||
  explicit SampleBuffer(Distribution* distribution)
 | 
					 | 
				
			||||||
      : distribution_(distribution), remaining_numbers_(0) {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  PHILOX_DEVICE_INLINE
 | 
					 | 
				
			||||||
  ResultElementType operator()(random::PhiloxRandom* random) {
 | 
					 | 
				
			||||||
    if (remaining_numbers_ == 0) {
 | 
					 | 
				
			||||||
      results_ = (*distribution_)(random);
 | 
					 | 
				
			||||||
      remaining_numbers_ = Distribution::kResultElementCount;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    remaining_numbers_--;
 | 
					 | 
				
			||||||
    return results_[remaining_numbers_];
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Mark this buffer as empty. The next call to operator() will fill it
 | 
					 | 
				
			||||||
  // with new random numbers.
 | 
					 | 
				
			||||||
  PHILOX_DEVICE_INLINE
 | 
					 | 
				
			||||||
  void Clear() { remaining_numbers_ = 0; }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 private:
 | 
					 | 
				
			||||||
  typedef typename Distribution::ResultType ResultType;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Distribution* distribution_;
 | 
					 | 
				
			||||||
  ResultType results_;
 | 
					 | 
				
			||||||
  int remaining_numbers_;
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
};  // namespace
 | 
					};  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -102,7 +56,8 @@ namespace functor {
 | 
				
			|||||||
template <typename T>
 | 
					template <typename T>
 | 
				
			||||||
struct StatelessRandomGammaFunctor<CPUDevice, T> {
 | 
					struct StatelessRandomGammaFunctor<CPUDevice, T> {
 | 
				
			||||||
  static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
 | 
					  static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
 | 
				
			||||||
                     int64 num_alphas, int64 samples_per_alpha,
 | 
					                     int64 num_samples, int64 num_alphas,
 | 
				
			||||||
 | 
					                     int64 samples_per_alpha,
 | 
				
			||||||
                     const random::PhiloxRandom& random, T* samples_flat) {
 | 
					                     const random::PhiloxRandom& random, T* samples_flat) {
 | 
				
			||||||
    typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
 | 
					    typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
 | 
				
			||||||
    typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
 | 
					    typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
 | 
				
			||||||
@ -124,8 +79,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
 | 
				
			|||||||
      Normal normal;
 | 
					      Normal normal;
 | 
				
			||||||
      Uniform uniform;
 | 
					      Uniform uniform;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      SampleBuffer<Normal> normal_buffer(&normal);
 | 
					      RandomSampleBuffer<Normal> normal_buffer(&normal);
 | 
				
			||||||
      SampleBuffer<Uniform> uniform_buffer(&uniform);
 | 
					      RandomSampleBuffer<Uniform> uniform_buffer(&uniform);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      for (int64 output_idx = start_output; output_idx < limit_output;
 | 
					      for (int64 output_idx = start_output; output_idx < limit_output;
 | 
				
			||||||
           /* output_idx incremented within inner loop below */) {
 | 
					           /* output_idx incremented within inner loop below */) {
 | 
				
			||||||
@ -138,14 +93,14 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
 | 
				
			|||||||
        const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
 | 
					        const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DISABLE_FLOAT_EQUALITY_WARNING
 | 
					        DISABLE_FLOAT_EQUALITY_WARNING
 | 
				
			||||||
        if (alpha == static_cast<double>(1.0)) {
 | 
					        if (alpha == 1.0) {
 | 
				
			||||||
          ENABLE_FLOAT_EQUALITY_WARNING
 | 
					          ENABLE_FLOAT_EQUALITY_WARNING
 | 
				
			||||||
          // Sample from an exponential distribution.
 | 
					          // Sample from an exponential distribution.
 | 
				
			||||||
          for (int64 sample_idx = output_idx % samples_per_alpha;
 | 
					          for (int64 sample_idx = output_idx % samples_per_alpha;
 | 
				
			||||||
               sample_idx < samples_per_alpha && output_idx < limit_output;
 | 
					               sample_idx < samples_per_alpha && output_idx < limit_output;
 | 
				
			||||||
               sample_idx++, output_idx++) {
 | 
					               sample_idx++, output_idx++) {
 | 
				
			||||||
            // As we want data stable regardless of sharding
 | 
					            // As we want data stable regardless of sharding, we skip on a
 | 
				
			||||||
            // (including eventually on GPU), we skip on a per-sample basis.
 | 
					            // per-sample basis.
 | 
				
			||||||
            random::PhiloxRandom gen = random;
 | 
					            random::PhiloxRandom gen = random;
 | 
				
			||||||
            gen.Skip(kReservedSamplesPerOutput * output_idx);
 | 
					            gen.Skip(kReservedSamplesPerOutput * output_idx);
 | 
				
			||||||
            double u = uniform(&gen)[Uniform::kResultElementCount - 1];
 | 
					            double u = uniform(&gen)[Uniform::kResultElementCount - 1];
 | 
				
			||||||
@ -162,7 +117,7 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
 | 
				
			|||||||
          //
 | 
					          //
 | 
				
			||||||
          // For alpha<1, we add one to d=alpha-1/3, and multiply the final
 | 
					          // For alpha<1, we add one to d=alpha-1/3, and multiply the final
 | 
				
			||||||
          // result by uniform()^(1/alpha)
 | 
					          // result by uniform()^(1/alpha)
 | 
				
			||||||
          const bool alpha_less_than_one = alpha < 1;
 | 
					          const bool alpha_less_than_one = alpha < 1.0;
 | 
				
			||||||
          const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
 | 
					          const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
 | 
				
			||||||
          const double c = 1.0 / 3 / sqrt(d);
 | 
					          const double c = 1.0 / 3 / sqrt(d);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -171,8 +126,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
 | 
				
			|||||||
               sample_idx < samples_per_alpha && output_idx < limit_output;
 | 
					               sample_idx < samples_per_alpha && output_idx < limit_output;
 | 
				
			||||||
               sample_idx++, output_idx++) {
 | 
					               sample_idx++, output_idx++) {
 | 
				
			||||||
            // Since each sample may use a variable number of normal/uniform
 | 
					            // Since each sample may use a variable number of normal/uniform
 | 
				
			||||||
            // samples, and we want data stable regardless of sharding
 | 
					            // samples, and we want data stable regardless of sharding, we skip
 | 
				
			||||||
            // (including eventually on GPU), we skip on a per-sample basis.
 | 
					            // on a per-sample basis.
 | 
				
			||||||
            random::PhiloxRandom gen = random;
 | 
					            random::PhiloxRandom gen = random;
 | 
				
			||||||
            gen.Skip(kReservedSamplesPerOutput * output_idx);
 | 
					            gen.Skip(kReservedSamplesPerOutput * output_idx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -226,8 +181,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
 | 
				
			|||||||
                                    Uniform::kElementCost +
 | 
					                                    Uniform::kElementCost +
 | 
				
			||||||
                                    3 * random::PhiloxRandom::kElementCost;
 | 
					                                    3 * random::PhiloxRandom::kElementCost;
 | 
				
			||||||
    auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
 | 
					    auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
 | 
				
			||||||
    Shard(worker_threads.num_threads, worker_threads.workers,
 | 
					    Shard(worker_threads.num_threads, worker_threads.workers, num_samples,
 | 
				
			||||||
          num_alphas * samples_per_alpha, kElementCost, DoWork);
 | 
					          kElementCost, DoWork);
 | 
				
			||||||
    return Status::OK();
 | 
					    return Status::OK();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
@ -280,19 +235,21 @@ class StatelessRandomGammaOp : public OpKernel {
 | 
				
			|||||||
                    "Input alpha should have non-zero element count, got: ",
 | 
					                    "Input alpha should have non-zero element count, got: ",
 | 
				
			||||||
                    num_alphas));
 | 
					                    num_alphas));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64 samples_per_alpha = samples_shape.num_elements() / num_alphas;
 | 
					    const int64 num_samples = samples_shape.num_elements();
 | 
				
			||||||
 | 
					    const int64 samples_per_alpha = num_samples / num_alphas;
 | 
				
			||||||
    const auto alpha_flat = alpha_t.flat<T>().data();
 | 
					    const auto alpha_flat = alpha_t.flat<T>().data();
 | 
				
			||||||
    auto samples_flat = output->flat<T>().data();
 | 
					    auto samples_flat = output->flat<T>().data();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    OP_REQUIRES_OK(ctx, functor::StatelessRandomGammaFunctor<Device, T>::Fill(
 | 
					    OP_REQUIRES_OK(ctx, functor::StatelessRandomGammaFunctor<Device, T>::Fill(
 | 
				
			||||||
                            ctx, alpha_flat, num_alphas, samples_per_alpha,
 | 
					                            ctx, alpha_flat, num_samples, num_alphas,
 | 
				
			||||||
                            random, samples_flat));
 | 
					                            samples_per_alpha, random, samples_flat));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomGammaOp);
 | 
					  TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomGammaOp);
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#define REGISTER_GAMMA(TYPE)                                  \
 | 
					// Register CPU kernels for stateless gamma op.
 | 
				
			||||||
 | 
					#define REGISTER_GAMMA_CPU(TYPE)                              \
 | 
				
			||||||
  REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2")      \
 | 
					  REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2")      \
 | 
				
			||||||
                              .Device(DEVICE_CPU)             \
 | 
					                              .Device(DEVICE_CPU)             \
 | 
				
			||||||
                              .HostMemory("shape")            \
 | 
					                              .HostMemory("shape")            \
 | 
				
			||||||
@ -301,12 +258,32 @@ class StatelessRandomGammaOp : public OpKernel {
 | 
				
			|||||||
                              .TypeConstraint<TYPE>("dtype"), \
 | 
					                              .TypeConstraint<TYPE>("dtype"), \
 | 
				
			||||||
                          StatelessRandomGammaOp<CPUDevice, TYPE>)
 | 
					                          StatelessRandomGammaOp<CPUDevice, TYPE>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TF_CALL_half(REGISTER_GAMMA);
 | 
					TF_CALL_half(REGISTER_GAMMA_CPU);
 | 
				
			||||||
TF_CALL_bfloat16(REGISTER_GAMMA);
 | 
					TF_CALL_bfloat16(REGISTER_GAMMA_CPU);
 | 
				
			||||||
TF_CALL_float(REGISTER_GAMMA);
 | 
					TF_CALL_float(REGISTER_GAMMA_CPU);
 | 
				
			||||||
TF_CALL_double(REGISTER_GAMMA);
 | 
					TF_CALL_double(REGISTER_GAMMA_CPU);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#undef REGISTER_GAMMA
 | 
					#undef REGISTER_GAMMA_CPU
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Register GPU kernels for stateless gamma op.
 | 
				
			||||||
 | 
					#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define REGISTER_GAMMA_GPU(TYPE)                              \
 | 
				
			||||||
 | 
					  REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2")      \
 | 
				
			||||||
 | 
					                              .Device(DEVICE_GPU)             \
 | 
				
			||||||
 | 
					                              .HostMemory("shape")            \
 | 
				
			||||||
 | 
					                              .HostMemory("seed")             \
 | 
				
			||||||
 | 
					                              .TypeConstraint<TYPE>("dtype"), \
 | 
				
			||||||
 | 
					                          StatelessRandomGammaOp<GPUDevice, TYPE>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TF_CALL_half(REGISTER_GAMMA_GPU);
 | 
				
			||||||
 | 
					TF_CALL_bfloat16(REGISTER_GAMMA_GPU);
 | 
				
			||||||
 | 
					TF_CALL_float(REGISTER_GAMMA_GPU);
 | 
				
			||||||
 | 
					TF_CALL_double(REGISTER_GAMMA_GPU);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#undef REGISTER_GAMMA_GPU
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
}  // namespace tensorflow
 | 
					}  // namespace tensorflow
 | 
				
			||||||
 | 
				
			|||||||
@ -27,12 +27,60 @@ namespace functor {
 | 
				
			|||||||
template <typename Device, typename T>
 | 
					template <typename Device, typename T>
 | 
				
			||||||
struct StatelessRandomGammaFunctor {
 | 
					struct StatelessRandomGammaFunctor {
 | 
				
			||||||
  static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
 | 
					  static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
 | 
				
			||||||
                     int64 num_alphas, int64 samples_per_alpha,
 | 
					                     int64 num_samples, int64 num_alphas,
 | 
				
			||||||
 | 
					                     int64 samples_per_alpha,
 | 
				
			||||||
                     const random::PhiloxRandom& random, T* samples_flat);
 | 
					                     const random::PhiloxRandom& random, T* samples_flat);
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace functor
 | 
					}  // namespace functor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
 | 
				
			||||||
 | 
					// a single sample from this buffer. If the buffer is empty, it first generates
 | 
				
			||||||
 | 
					// new samples using the provided distribution.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// If the call to Distribution::operator() returns samples[0...N-1], then this
 | 
				
			||||||
 | 
					// class returns samples in the following order:
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//   samples[N-1], samples[N-2],..., samples[1], samples[0]
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// For comparison, random::SingleSampleAdapter returns samples in
 | 
				
			||||||
 | 
					// the following order:
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//   samples[0], samples[1],...,samples[N-2], samples[N-1].
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					template <class Distribution>
 | 
				
			||||||
 | 
					class RandomSampleBuffer {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  typedef typename Distribution::ResultElementType ResultElementType;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PHILOX_DEVICE_INLINE
 | 
				
			||||||
 | 
					  explicit RandomSampleBuffer(Distribution* distribution)
 | 
				
			||||||
 | 
					      : distribution_(distribution), remaining_numbers_(0) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PHILOX_DEVICE_INLINE
 | 
				
			||||||
 | 
					  ResultElementType operator()(random::PhiloxRandom* random) {
 | 
				
			||||||
 | 
					    if (remaining_numbers_ == 0) {
 | 
				
			||||||
 | 
					      results_ = (*distribution_)(random);
 | 
				
			||||||
 | 
					      remaining_numbers_ = Distribution::kResultElementCount;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    remaining_numbers_--;
 | 
				
			||||||
 | 
					    return results_[remaining_numbers_];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Mark this buffer as empty. The next call to operator() will fill it
 | 
				
			||||||
 | 
					  // with new random numbers.
 | 
				
			||||||
 | 
					  PHILOX_DEVICE_INLINE
 | 
				
			||||||
 | 
					  void Clear() { remaining_numbers_ = 0; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  typedef typename Distribution::ResultType ResultType;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Distribution* distribution_;
 | 
				
			||||||
 | 
					  ResultType results_;
 | 
				
			||||||
 | 
					  int remaining_numbers_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace tensorflow
 | 
					}  // namespace tensorflow
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif  // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
 | 
					#endif  // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										179
									
								
								tensorflow/core/kernels/stateless_random_gamma_op_gpu.cu.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								tensorflow/core/kernels/stateless_random_gamma_op_gpu.cu.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,179 @@
 | 
				
			|||||||
 | 
					/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define EIGEN_USE_GPU
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "tensorflow/core/framework/register_types.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/kernels/stateless_random_gamma_op.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/lib/random/random_distributions.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/util/gpu_kernel_helper.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace tensorflow {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					typedef Eigen::GpuDevice GPUDevice;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Each attempt to generate a new draw from the Gamma distribution is 95+%
 | 
				
			||||||
 | 
					// successful, and requires 1-2 normal + 1 uniform sample.
 | 
				
			||||||
 | 
					static constexpr int kReservedSamplesPerOutput = 256;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#if EIGEN_COMP_GNUC && __cplusplus > 199711L
 | 
				
			||||||
 | 
					#define DISABLE_FLOAT_EQUALITY_WARNING \
 | 
				
			||||||
 | 
					  _Pragma("GCC diagnostic push")       \
 | 
				
			||||||
 | 
					      _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
 | 
				
			||||||
 | 
					#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					#define DISABLE_FLOAT_EQUALITY_WARNING
 | 
				
			||||||
 | 
					#define ENABLE_FLOAT_EQUALITY_WARNING
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					__global__ void __launch_bounds__(1024)
 | 
				
			||||||
 | 
					    FillKernel(int64 num_samples, int64 num_alphas, int64 samples_per_alpha,
 | 
				
			||||||
 | 
					               random::PhiloxRandom random, T* samples_flat,
 | 
				
			||||||
 | 
					               const T* alpha_flat) {
 | 
				
			||||||
 | 
					  using Eigen::numext::exp;
 | 
				
			||||||
 | 
					  using Eigen::numext::log;
 | 
				
			||||||
 | 
					  using Eigen::numext::log1p;
 | 
				
			||||||
 | 
					  using Eigen::numext::pow;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
 | 
				
			||||||
 | 
					  typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Normal normal;
 | 
				
			||||||
 | 
					  Uniform uniform;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  RandomSampleBuffer<Normal> normal_buffer(&normal);
 | 
				
			||||||
 | 
					  RandomSampleBuffer<Uniform> uniform_buffer(&uniform);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int64 output_idx : GpuGridRangeX(num_samples)) {
 | 
				
			||||||
 | 
					    int64 alpha_idx = output_idx / samples_per_alpha;
 | 
				
			||||||
 | 
					    int64 sample_idx = output_idx % samples_per_alpha;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    DISABLE_FLOAT_EQUALITY_WARNING
 | 
				
			||||||
 | 
					    if (alpha == 1.0) {
 | 
				
			||||||
 | 
					      ENABLE_FLOAT_EQUALITY_WARNING
 | 
				
			||||||
 | 
					      // Sample from an exponential distribution.
 | 
				
			||||||
 | 
					      // As we want data stable regardless of sharding, we skip on a per-sample
 | 
				
			||||||
 | 
					      // basis.
 | 
				
			||||||
 | 
					      random::PhiloxRandom gen = random;
 | 
				
			||||||
 | 
					      gen.Skip(kReservedSamplesPerOutput * output_idx);
 | 
				
			||||||
 | 
					      double u = uniform(&gen)[Uniform::kResultElementCount - 1];
 | 
				
			||||||
 | 
					      const double res = -log1p(-u);
 | 
				
			||||||
 | 
					      // We use alpha_idx + sample_idx * num_alphas instead of output_idx
 | 
				
			||||||
 | 
					      // to generate numbers in the right order (CPU and GPU kernels
 | 
				
			||||||
 | 
					      // must generate numbers in the same order).
 | 
				
			||||||
 | 
					      samples_flat[alpha_idx + sample_idx * num_alphas] = static_cast<T>(res);
 | 
				
			||||||
 | 
					    } else {  // if alpha != 1.0
 | 
				
			||||||
 | 
					      // Transformation-rejection from pairs of uniform and normal random
 | 
				
			||||||
 | 
					      // variables. http://dl.acm.org/citation.cfm?id=358414
 | 
				
			||||||
 | 
					      //
 | 
				
			||||||
 | 
					      // The algorithm has an acceptance rate of ~95% for small alpha (~1),
 | 
				
			||||||
 | 
					      // and higher accept rates for higher alpha, so runtime is
 | 
				
			||||||
 | 
					      // O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
 | 
				
			||||||
 | 
					      //
 | 
				
			||||||
 | 
					      // For alpha<1, we add one to d=alpha-1/3, and multiply the final
 | 
				
			||||||
 | 
					      // result by uniform()^(1/alpha)
 | 
				
			||||||
 | 
					      const bool alpha_less_than_one = alpha < 1.0;
 | 
				
			||||||
 | 
					      const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
 | 
				
			||||||
 | 
					      const double c = 1.0 / 3 / sqrt(d);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Since each sample may use a variable number of normal/uniform
 | 
				
			||||||
 | 
					      // samples, and we want data stable regardless of sharding, we skip on a
 | 
				
			||||||
 | 
					      // per-sample basis.
 | 
				
			||||||
 | 
					      random::PhiloxRandom gen = random;
 | 
				
			||||||
 | 
					      gen.Skip(kReservedSamplesPerOutput * output_idx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // To prevent overwriting SampleBuffer's underlying array with
 | 
				
			||||||
 | 
					      // zeros (in tensorflow::random::Array constructor), we just mark
 | 
				
			||||||
 | 
					      // the buffer as empty instead of initializing a new SampleBuffer
 | 
				
			||||||
 | 
					      // object here. The next call to operator() will fill the buffer
 | 
				
			||||||
 | 
					      // with new numbers.
 | 
				
			||||||
 | 
					      normal_buffer.Clear();
 | 
				
			||||||
 | 
					      uniform_buffer.Clear();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Keep trying until we don't reject a sample. In practice, we will
 | 
				
			||||||
 | 
					      // only reject ~5% at worst, for low alpha near 1.
 | 
				
			||||||
 | 
					      while (true) {
 | 
				
			||||||
 | 
					        const double x = normal_buffer(&gen);
 | 
				
			||||||
 | 
					        double v = 1 + c * x;
 | 
				
			||||||
 | 
					        if (v <= 0) {
 | 
				
			||||||
 | 
					          continue;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        v = v * v * v;
 | 
				
			||||||
 | 
					        double u = uniform_buffer(&gen);
 | 
				
			||||||
 | 
					        // The first option in the if is a "squeeze" short-circuit to
 | 
				
			||||||
 | 
					        // dodge the two logs. Magic constant sourced from the paper
 | 
				
			||||||
 | 
					        // linked above. Upward of .91 of the area covered by the log
 | 
				
			||||||
 | 
					        // inequality is covered by the squeeze as well (larger coverage
 | 
				
			||||||
 | 
					        // for smaller values of alpha).
 | 
				
			||||||
 | 
					        if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
 | 
				
			||||||
 | 
					            (log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
 | 
				
			||||||
 | 
					          double res = d * v;
 | 
				
			||||||
 | 
					          if (alpha_less_than_one) {
 | 
				
			||||||
 | 
					            double b = uniform_buffer(&gen);
 | 
				
			||||||
 | 
					            res *= pow(b, 1 / alpha);
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          // We use alpha_idx + sample_idx * num_alphas instead of output_idx
 | 
				
			||||||
 | 
					          // to generate numbers in the right order (CPU and GPU kernels
 | 
				
			||||||
 | 
					          // must generate numbers in the same order).
 | 
				
			||||||
 | 
					          samples_flat[alpha_idx + sample_idx * num_alphas] =
 | 
				
			||||||
 | 
					              static_cast<T>(res);
 | 
				
			||||||
 | 
					          break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }  // while: true
 | 
				
			||||||
 | 
					    }    // if (alpha == 1.0)
 | 
				
			||||||
 | 
					  }      // for: output_idx
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace functor {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					struct StatelessRandomGammaFunctor<GPUDevice, T> {
 | 
				
			||||||
 | 
					  static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
 | 
				
			||||||
 | 
					                     int64 num_samples, int64 num_alphas,
 | 
				
			||||||
 | 
					                     int64 samples_per_alpha,
 | 
				
			||||||
 | 
					                     const random::PhiloxRandom& random, T* samples_flat) {
 | 
				
			||||||
 | 
					    const GPUDevice& d = ctx->eigen_device<GPUDevice>();
 | 
				
			||||||
 | 
					    GpuLaunchConfig cfg = GetGpuLaunchConfig(num_samples, d);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    TF_CHECK_OK(GpuLaunchKernel(FillKernel<T>, cfg.block_count,
 | 
				
			||||||
 | 
					                                cfg.thread_per_block, 0, d.stream(),
 | 
				
			||||||
 | 
					                                num_samples, num_alphas, samples_per_alpha,
 | 
				
			||||||
 | 
					                                random, samples_flat, alpha_flat));
 | 
				
			||||||
 | 
					    return Status::OK();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace functor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define REGISTER_GPU_SPEC(type) \
 | 
				
			||||||
 | 
					  template struct functor::StatelessRandomGammaFunctor<GPUDevice, type>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TF_CALL_half(REGISTER_GPU_SPEC);
 | 
				
			||||||
 | 
					TF_CALL_bfloat16(REGISTER_GPU_SPEC);
 | 
				
			||||||
 | 
					TF_CALL_float(REGISTER_GPU_SPEC);
 | 
				
			||||||
 | 
					TF_CALL_double(REGISTER_GPU_SPEC);
 | 
				
			||||||
 | 
					#undef REGISTER_GPU_SPEC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace tensorflow
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#endif  // GOOGLE_CUDA
 | 
				
			||||||
@ -216,6 +216,17 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
 | 
				
			|||||||
      pure = stateless_op(seed=preseed)
 | 
					      pure = stateless_op(seed=preseed)
 | 
				
			||||||
      self.assertAllEqual(stateful, pure)
 | 
					      self.assertAllEqual(stateful, pure)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _test_match_stateless_cpu_gpu(self, case, seed):
 | 
				
			||||||
 | 
					    # Stateless ops should produce the same result on CPUs and GPUs.
 | 
				
			||||||
 | 
					    _, stateless_op, _ = case
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with ops.device('CPU'):
 | 
				
			||||||
 | 
					      result_cpu = stateless_op(seed=seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with ops.device(get_device().name):
 | 
				
			||||||
 | 
					      result_gpu = stateless_op(seed=seed)
 | 
				
			||||||
 | 
					      self.assertAllClose(result_cpu, result_gpu)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _test_old_and_new_stateless_match(self, case, seed):
 | 
					  def _test_old_and_new_stateless_match(self, case, seed):
 | 
				
			||||||
    """Tests that the new stateless ops match the old stateless ones."""
 | 
					    """Tests that the new stateless ops match the old stateless ones."""
 | 
				
			||||||
    with ops.device(get_device().name):
 | 
					    with ops.device(get_device().name):
 | 
				
			||||||
@ -306,6 +317,18 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
 | 
				
			|||||||
      self.skipTest('Lacking XLA kernel')
 | 
					      self.skipTest('Lacking XLA kernel')
 | 
				
			||||||
    self._test_match(case, seed)
 | 
					    self._test_match(case, seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
 | 
					      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
 | 
				
			||||||
 | 
					      for seed_id, seed in enumerate(SEEDS)
 | 
				
			||||||
 | 
					      for case_id, case in enumerate(gamma_cases()))
 | 
				
			||||||
 | 
					  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
 | 
				
			||||||
 | 
					  def testStatelessGammaCpuGpuMatch(self, case, seed):
 | 
				
			||||||
 | 
					    if get_device().device_type != 'GPU':
 | 
				
			||||||
 | 
					      # This test compares the numbers produced by the CPU and GPU kernel for
 | 
				
			||||||
 | 
					      # stateless_random_gamma.
 | 
				
			||||||
 | 
					      self.skipTest('This test requires GPU')
 | 
				
			||||||
 | 
					    self._test_match_stateless_cpu_gpu(case, seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @parameterized.named_parameters(
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
 | 
					      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
 | 
				
			||||||
      for seed_id, seed in enumerate(SEEDS)
 | 
					      for seed_id, seed in enumerate(SEEDS)
 | 
				
			||||||
@ -387,10 +410,6 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
 | 
				
			|||||||
      for case_id, case in enumerate(gamma_cases()))
 | 
					      for case_id, case in enumerate(gamma_cases()))
 | 
				
			||||||
  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
 | 
					  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
 | 
				
			||||||
  def testDeterminismGamma(self, case, seed_type):
 | 
					  def testDeterminismGamma(self, case, seed_type):
 | 
				
			||||||
    if get_device().device_type == 'GPU':
 | 
					 | 
				
			||||||
      # This test was passing before because soft placement silently picked the
 | 
					 | 
				
			||||||
      # CPU kernels.
 | 
					 | 
				
			||||||
      self.skipTest('Lacking GPU kernel')
 | 
					 | 
				
			||||||
    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
 | 
					    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
 | 
				
			||||||
      # This test was passing before because soft placement silently picked the
 | 
					      # This test was passing before because soft placement silently picked the
 | 
				
			||||||
      # CPU kernels.
 | 
					      # CPU kernels.
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user