Internal change

PiperOrigin-RevId: 321431476
Change-Id: I9907a93b99cd08a05699096e9314c34cbd55601f
This commit is contained in:
A. Unique TensorFlower 2020-07-15 13:49:59 -07:00 committed by TensorFlower Gardener
parent e42e9de4b7
commit 806a053eb5
8 changed files with 60 additions and 150 deletions

View File

@ -2424,7 +2424,6 @@ cc_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/framework:numeric_types",
],
)

View File

@ -183,6 +183,10 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
// We have B1 * ... * Bk samples per batch member we need.
auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs,
&gen, &output](int start_output, int limit_output) {
// Vectorized intermediate calculations for uniform rejection sampling.
// We always generate at most 4 samples.
Eigen::array<T, 4> z;
Eigen::array<T, 4> g;
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& counts_batch_indices = bcast.x_batch_indices();
const auto& probs_batch_indices = bcast.y_batch_indices();

View File

@ -344,7 +344,7 @@ class RandomGammaOp : public OpKernel {
.HostMemory("shape") \
.TypeConstraint<TYPE>("dtype"), \
PhiloxRandomOp<CPUDevice, random::UniformDistribution< \
random::PhiloxRandom, TYPE, true>>); \
random::PhiloxRandom, TYPE>>); \
REGISTER_KERNEL_BUILDER( \
Name("RandomStandardNormal") \
.Device(DEVICE_CPU) \

View File

@ -86,13 +86,7 @@ struct FillPhiloxRandomTask<Distribution, false> {
int64 start_group, int64 limit_group, Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
// Decide skip strides according to different kResultElementCount:
// * `1 = (4 + 3) / 4` for normal Distribution.
// * `1 = (2 + 3) / 4` for double/int64 Distribution.
// * `4 = (16 + 3) / 4` for vectorized float/bfloat16 Distribution.
const int skip_strides =
(kGroupSize + gen.kResultElementCount - 1) / gen.kResultElementCount;
gen.Skip(start_group * skip_strides);
gen.Skip(start_group);
int64 offset = start_group * kGroupSize;
// First fill all the full-size groups
@ -172,8 +166,9 @@ void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
const int kGroupCost = kGroupSize * (random::PhiloxRandom::kElementCost +
Distribution::kElementCost);
const int kGroupCost =
random::PhiloxRandom::kResultElementCount *
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
kGroupCost,
[&gen, data, size, dist](int64 start_group, int64 limit_group) {

View File

@ -37,41 +37,41 @@ Tensor VecShape(int64 v) {
}
}
Graph* RandomUniform(int64 n, DataType dtype) {
Graph* RandomUniform(int64 n) {
Graph* g = new Graph(OpRegistry::Global());
test::graph::RandomUniform(g, test::graph::Constant(g, VecShape(n)), dtype);
test::graph::RandomUniform(g, test::graph::Constant(g, VecShape(n)),
DT_FLOAT);
return g;
}
Graph* RandomNormal(int64 n, DataType dtype) {
Graph* RandomNormal(int64 n) {
Graph* g = new Graph(OpRegistry::Global());
test::graph::RandomGaussian(g, test::graph::Constant(g, VecShape(n)), dtype);
test::graph::RandomGaussian(g, test::graph::Constant(g, VecShape(n)),
DT_FLOAT);
return g;
}
Graph* TruncatedNormal(int64 n, DataType dtype) {
Graph* TruncatedNormal(int64 n) {
Graph* g = new Graph(OpRegistry::Global());
test::graph::TruncatedNormal(g, test::graph::Constant(g, VecShape(n)), dtype);
test::graph::TruncatedNormal(g, test::graph::Constant(g, VecShape(n)),
DT_FLOAT);
return g;
}
#define BM_RNG(DEVICE, RNG, DTYPE) \
void BM_##DEVICE##_##RNG##_##DTYPE(int iters, int arg) { \
#define BM_RNG(DEVICE, RNG) \
void BM_##DEVICE##_##RNG(int iters, int arg) { \
testing::ItemsProcessed(static_cast<int64>(iters) * arg); \
test::Benchmark(#DEVICE, RNG(arg, DTYPE)).Run(iters); \
test::Benchmark(#DEVICE, RNG(arg)).Run(iters); \
} \
BENCHMARK(BM_##DEVICE##_##RNG##_##DTYPE)->Range(1 << 20, 8 << 20);
BENCHMARK(BM_##DEVICE##_##RNG)->Range(1 << 20, 8 << 20);
BM_RNG(cpu, RandomUniform, DT_FLOAT);
BM_RNG(cpu, RandomUniform, DT_BFLOAT16);
BM_RNG(cpu, RandomNormal, DT_FLOAT);
BM_RNG(cpu, TruncatedNormal, DT_FLOAT);
BM_RNG(cpu, RandomUniform);
BM_RNG(cpu, RandomNormal);
BM_RNG(cpu, TruncatedNormal);
#ifdef GOOGLE_CUDA
BM_RNG(gpu, RandomUniform, DT_FLOAT);
BM_RNG(gpu, RandomNormal, DT_FLOAT);
BM_RNG(gpu, TruncatedNormal, DT_FLOAT);
#endif
BM_RNG(gpu, RandomUniform);
BM_RNG(gpu, RandomNormal);
BM_RNG(gpu, TruncatedNormal);
Tensor VecAlphas(int64 n) {
Tensor alphas(DT_DOUBLE, TensorShape({n}));

View File

@ -40,7 +40,6 @@ cc_library(
deps = [
":exact_uniform_int",
":philox_random",
"//tensorflow/core/framework:numeric_types",
"//tensorflow/core/lib/bfloat16",
"//tensorflow/core/lib/gtl:array_slice",
"//tensorflow/core/platform:logging",

View File

@ -18,12 +18,12 @@ limitations under the License.
#include <string.h>
#include <algorithm>
#include <cmath>
#include <algorithm>
#include <type_traits>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/random/philox_random.h"
@ -32,56 +32,13 @@ namespace random {
// Helper function to convert a 16-bit integer to a half between [0..1).
PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16 x);
// Helper function to convert a 16-bit integer to a bfloat16 between [1..2).
PHILOX_DEVICE_INLINE bfloat16 InternalUint16ToBfloat16(uint16 x);
// Helper function to convert a 16-bit integer to a bfloat16 between [0..1).
PHILOX_DEVICE_INLINE bfloat16 Uint16ToBfloat16(uint16 x);
// Helper function to convert a 32-bit integer to a float between [1..2).
PHILOX_DEVICE_INLINE float InternalUint32ToFloat(uint32 x);
PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x);
// Helper function to convert a 32-bit integer to a float between [0..1).
PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x);
// Helper function to convert two 32-bit integers to a double between [0..1).
PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1);
// Helper function to format distribution result in vectorization path,
// it creates Eigen::Tensor and reuses packet feature with SIMD.
// This function can only work on CPU
template <class Distribution, class Generator>
PHILOX_DEVICE_INLINE typename Distribution::ResultType VectorizedFormat(
Generator* gen, typename Distribution::FormatFunc functor) {
typename Generator::ResultType sample;
typename Distribution::ResultType result;
const int kResultElementCount = Distribution::kResultElementCount;
const int inner_count = Generator::kResultElementCount;
const int outer_count = kResultElementCount / inner_count;
int offset = 0;
for (int k = 0; k < outer_count; k++) {
sample = (*gen)();
for (int i = 0; i < inner_count; i++, offset++) {
result[offset] = (*functor)(sample[i]);
}
}
// Tail processing if any.
// Put the tail condition out of above loop to improve performance:
// it will be executed only once and save time on CPU.
if (offset < kResultElementCount) {
sample = (*gen)();
for (int i = 0; offset < kResultElementCount; i++, offset++) {
result[offset] = (*functor)(sample[i]);
}
}
typedef Eigen::TensorMap<
Eigen::Tensor<typename Distribution::ResultElementType, 1,
Eigen::RowMajor, Eigen::DenseIndex>,
Eigen::Aligned>
Tensor;
auto tensor_result = Tensor(&result[0], kResultElementCount);
tensor_result = tensor_result - typename Distribution::ResultElementType(1.0);
return result;
}
// Computes a + b. Requires that the result is representable in the destination
// type and that b is not maximal (i.e. b + 1 is not 0). Notably, the addend b
// need *not* be representable in that type. (The condition on b excludes the
@ -105,15 +62,13 @@ PHILOX_DEVICE_INLINE Int SignedAdd(Int a,
// actual returned sample type.
// RealType: the data type of the real numbers that will be returned by the
// distribution. This could be either float or double for now.
// IsVec: mark this UniformDistribution can be vectorized or not by SIMD on
// CPU. Note this should always be false on GPU.
// This class is meant to be implemented through specialization. The default
// is not defined by design.
template <class Generator, typename RealType, bool IsVec = false>
template <class Generator, typename RealType>
class UniformDistribution;
template <class Generator, bool IsVec>
class UniformDistribution<Generator, Eigen::half, IsVec> {
template <class Generator>
class UniformDistribution<Generator, Eigen::half> {
public:
// The number of elements that will be returned.
static constexpr int kResultElementCount = Generator::kResultElementCount;
@ -136,17 +91,11 @@ class UniformDistribution<Generator, Eigen::half, IsVec> {
}
};
template <class Generator, bool IsVec>
class UniformDistribution<Generator, bfloat16, IsVec> {
template <class Generator>
class UniformDistribution<Generator, bfloat16> {
public:
// The number of elements that will be returned.
// Set the number to be Eigen packet size of type at least, so computations
// can be vectorized using SIMD on CPU.
static constexpr int kVectorLength = std::max(
static_cast<const int>(Eigen::internal::packet_traits<bfloat16>::size),
Generator::kResultElementCount);
static constexpr int kResultElementCount =
IsVec ? kVectorLength : Generator::kResultElementCount;
static constexpr int kResultElementCount = Generator::kResultElementCount;
// Cost of generation of a single element (in cycles).
static constexpr int kElementCost = 3;
// Indicate that this distribution may take variable number of samples
@ -154,37 +103,23 @@ class UniformDistribution<Generator, bfloat16, IsVec> {
static constexpr bool kVariableSamplesPerOutput = false;
typedef Array<bfloat16, kResultElementCount> ResultType;
typedef bfloat16 ResultElementType;
// Helper definition for the format function.
typedef bfloat16 (*FormatFunc)(uint16);
PHILOX_DEVICE_INLINE
ResultType operator()(Generator* gen) {
#ifdef __CUDA_ARCH__
static_assert(!IsVec, "Can't vectorize Distribution on GPU");
typename Generator::ResultType sample = (*gen)();
ResultType result;
for (int i = 0; i < kResultElementCount; ++i) {
result[i] = Uint16ToBfloat16(sample[i]);
result[i] = Uint16ToGfloat16(sample[i]);
}
return result;
#else
return VectorizedFormat<UniformDistribution<Generator, bfloat16, IsVec>,
Generator>(gen, InternalUint16ToBfloat16);
#endif // __CUDA_ARCH__
}
};
template <class Generator, bool IsVec>
class UniformDistribution<Generator, float, IsVec> {
template <class Generator>
class UniformDistribution<Generator, float> {
public:
// The number of elements that will be returned.
// Set the number to be Eigen packet size of type at least, so computations
// can be vectorized using SIMD on CPU.
static constexpr int kVectorLength = std::max(
static_cast<const int>(Eigen::internal::packet_traits<float>::size),
Generator::kResultElementCount);
static constexpr int kResultElementCount =
IsVec ? kVectorLength : Generator::kResultElementCount;
static constexpr int kResultElementCount = Generator::kResultElementCount;
// Cost of generation of a single element (in cycles).
static constexpr int kElementCost = 3;
// Indicate that this distribution may take variable number of samples
@ -192,28 +127,20 @@ class UniformDistribution<Generator, float, IsVec> {
static constexpr bool kVariableSamplesPerOutput = false;
typedef Array<float, kResultElementCount> ResultType;
typedef float ResultElementType;
// Helper definition for the format function.
typedef float (*FormatFunc)(uint32);
PHILOX_DEVICE_INLINE
ResultType operator()(Generator* gen) {
#ifdef __CUDA_ARCH__
static_assert(!IsVec, "Can't vectorize Distribution on GPU");
typename Generator::ResultType sample = (*gen)();
ResultType result;
for (int i = 0; i < kResultElementCount; ++i) {
result[i] = Uint32ToFloat(sample[i]);
}
return result;
#else
return VectorizedFormat<UniformDistribution<Generator, float, IsVec>,
Generator>(gen, InternalUint32ToFloat);
#endif // __CUDA_ARCH__
}
};
template <class Generator, bool IsVec>
class UniformDistribution<Generator, double, IsVec> {
template <class Generator>
class UniformDistribution<Generator, double> {
public:
// The number of elements that will be returned.
static constexpr int kResultElementCount = Generator::kResultElementCount / 2;
@ -236,8 +163,8 @@ class UniformDistribution<Generator, double, IsVec> {
}
};
template <class Generator, bool IsVec>
class UniformDistribution<Generator, int32, IsVec> {
template <class Generator>
class UniformDistribution<Generator, int32> {
public:
// The number of elements that will be returned.
static constexpr int kResultElementCount = Generator::kResultElementCount;
@ -271,8 +198,8 @@ class UniformDistribution<Generator, int32, IsVec> {
uint32 range_;
};
template <class Generator, bool IsVec>
class UniformDistribution<Generator, int64, IsVec> {
template <class Generator>
class UniformDistribution<Generator, int64> {
public:
// The number of elements that will be returned.
static constexpr int kResultElementCount = Generator::kResultElementCount / 2;
@ -837,9 +764,9 @@ PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16 x) {
return result - Eigen::half(1.0);
}
// Helper function to convert an 16-bit integer to a bfloat16 between [1..2).
// This can create a uniform distribution of values between [1..2).
PHILOX_DEVICE_INLINE bfloat16 InternalUint16ToBfloat16(uint16 x) {
// Helper function to convert an 16-bit integer to a bfloat16 between [0..1).
// This can create a uniform distribution of values between [0..1).
PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x) {
// bfloat are formatted as follows (MSB first):
// sign(1) exponent(8) mantissa(7)
// Conceptually construct the following:
@ -853,20 +780,13 @@ PHILOX_DEVICE_INLINE bfloat16 InternalUint16ToBfloat16(uint16 x) {
bfloat16 result;
memcpy(&result, &val, sizeof(val));
// The mantissa has an implicit leading 1, so the above code creates a value
// in [1, 2).
return result;
}
// Helper function to convert an 16-bit integer to a bfloat16 between [0..1).
// This can create a uniform distribution of values between [0..1).
PHILOX_DEVICE_INLINE bfloat16 Uint16ToBfloat16(uint16 x) {
// The minus will not cause a rounding that makes the result 1.
// in [1, 2). The minus will not cause a rounding that makes the result 1.
// Instead it will just be close to 1.
return InternalUint16ToBfloat16(x) - bfloat16(1.0);
return result - bfloat16(1.0);
}
// Helper function to convert an 32-bit integer to a float between [1..2).
PHILOX_DEVICE_INLINE float InternalUint32ToFloat(uint32 x) {
// Helper function to convert an 32-bit integer to a float between [0..1).
PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x) {
// IEEE754 floats are formatted as follows (MSB first):
// sign(1) exponent(8) mantissa(23)
// Conceptually construct the following:
@ -880,12 +800,7 @@ PHILOX_DEVICE_INLINE float InternalUint32ToFloat(uint32 x) {
// Assumes that endian-ness is same for float and uint32.
float result;
memcpy(&result, &val, sizeof(val));
return result;
}
// Helper function to convert an 32-bit integer to a float between [0..1).
PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x) {
return InternalUint32ToFloat(x) - 1.0f;
return result - 1.0f;
}
// Helper function to convert two 32-bit integers to a double between [0..1).

View File

@ -276,9 +276,8 @@ class RandomUniformTest(RandomOpTestCommon):
def testRange(self):
for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
dtypes.int64, dtypes.bfloat16):
use_gpu = (dt != dtypes.bfloat16)
sampler = self._Sampler(1000, minv=-2, maxv=8, dtype=dt, use_gpu=use_gpu)
dtypes.int64):
sampler = self._Sampler(1000, minv=-2, maxv=8, dtype=dt, use_gpu=True)
x = sampler()
self.assertTrue(-2 <= np.min(x))
self.assertTrue(np.max(x) < 8)
@ -364,11 +363,10 @@ class RandomUniformTest(RandomOpTestCommon):
@test_util.run_deprecated_v1
def testSeed(self):
for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
dtypes.int64, dtypes.bfloat16):
dtypes.int64):
for seed in [345, 2**100, -2**100]:
use_gpu = (dt != dtypes.bfloat16)
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=True, seed=seed)
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=True, seed=seed)
self.assertAllEqual(sx(), sy())
@test_util.run_deprecated_v1