94 lines
3.6 KiB
C++
94 lines
3.6 KiB
C++
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_RNG_H_
|
|
#define TENSORFLOW_STREAM_EXECUTOR_RNG_H_
|
|
|
|
#include <limits.h>
|
|
#include <complex>
|
|
|
|
#include "tensorflow/stream_executor/platform/logging.h"
|
|
#include "tensorflow/stream_executor/platform/port.h"
|
|
|
|
namespace stream_executor {
|
|
|
|
class Stream;
|
|
template <typename ElemT>
|
|
class DeviceMemory;
|
|
|
|
namespace rng {
|
|
|
|
// Random-number-generation support interface -- this can be derived from a GPU
|
|
// executor when the underlying platform has an RNG library implementation
|
|
// available. See StreamExecutor::AsRng().
|
|
// When a seed is not specified, the backing RNG will be initialized with the
|
|
// default seed for that implementation.
|
|
//
|
|
// Thread-hostile: see StreamExecutor class comment for details on
|
|
// thread-hostility.
|
|
class RngSupport {
|
|
public:
|
|
static constexpr int kMinSeedBytes = 16;
|
|
static constexpr int kMaxSeedBytes = INT_MAX;
|
|
|
|
// Releases any random-number-generation resources associated with this
|
|
// support object in the underlying platform implementation.
|
|
virtual ~RngSupport() {}
|
|
|
|
// Populates a GPU memory allocation with random values appropriate for the
|
|
// DeviceMemory element type; i.e. populates DeviceMemory<float> with random
|
|
// float values.
|
|
virtual bool DoPopulateRandUniform(Stream *stream,
|
|
DeviceMemory<float> *v) = 0;
|
|
virtual bool DoPopulateRandUniform(Stream *stream,
|
|
DeviceMemory<double> *v) = 0;
|
|
virtual bool DoPopulateRandUniform(Stream *stream,
|
|
DeviceMemory<std::complex<float>> *v) = 0;
|
|
virtual bool DoPopulateRandUniform(Stream *stream,
|
|
DeviceMemory<std::complex<double>> *v) = 0;
|
|
|
|
// Populates a GPU memory allocation with random values sampled from a
|
|
// Gaussian distribution with the given mean and standard deviation.
|
|
virtual bool DoPopulateRandGaussian(Stream *stream, float mean, float stddev,
|
|
DeviceMemory<float> *v) {
|
|
LOG(ERROR)
|
|
<< "platform's random number generator does not support gaussian";
|
|
return false;
|
|
}
|
|
virtual bool DoPopulateRandGaussian(Stream *stream, double mean,
|
|
double stddev, DeviceMemory<double> *v) {
|
|
LOG(ERROR)
|
|
<< "platform's random number generator does not support gaussian";
|
|
return false;
|
|
}
|
|
|
|
// Specifies the seed used to initialize the RNG.
|
|
// This call does not transfer ownership of the buffer seed; its data should
|
|
// not be altered for the lifetime of this call. At least 16 bytes of seed
|
|
// data must be provided, but not all seed data will necessarily be used.
|
|
// seed: Pointer to seed data. Must not be null.
|
|
// seed_bytes: Size of seed buffer in bytes. Must be >= 16.
|
|
virtual bool SetSeed(Stream *stream, const uint8 *seed,
|
|
uint64 seed_bytes) = 0;
|
|
|
|
protected:
|
|
static bool CheckSeed(const uint8 *seed, uint64 seed_bytes);
|
|
};
|
|
|
|
} // namespace rng
|
|
} // namespace stream_executor
|
|
|
|
#endif // TENSORFLOW_STREAM_EXECUTOR_RNG_H_
|