Merge pull request #38094 from houtoms:softmax_vector_pr

PiperOrigin-RevId: 306272653
Change-Id: I5fe5cad82ec3b62361872fd41b4300000b6dd086
This commit is contained in:
TensorFlower Gardener 2020-04-13 11:31:31 -07:00
commit 94f88814d1

View File

@ -72,28 +72,94 @@ struct softmax_traits<Eigen::half> {
using accumulator_type = float;
};
template <typename T, typename U>
template <typename T, typename U, int kUnroll>
__global__ void GenerateNormalizedProb(const T* logits, const U* sum_probs,
const T* max_logits, T* output,
const int num_rows, const int num_cols,
const bool in_log_space) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int row = tid / num_cols;
const int col = tid % num_cols;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int row, col;
// TODO(jamesqin): change to half2 load when inputs are Eigen::half.
U input = strict_cast<U>(logits[tid]);
U max_val = strict_cast<U>(ldg(max_logits + row));
U result;
if (row < num_rows && col < num_cols) {
if (in_log_space) {
result = input - max_val - log(ldg(sum_probs + row));
} else {
result = exp(input - max_val) / ldg(sum_probs + row);
U input[kUnroll];
U max_val[kUnroll];
U result[kUnroll];
for (int i = 0; i < kUnroll; i++) {
row = tid / num_cols;
col = tid % num_cols;
if (row < num_rows && col < num_cols) {
input[i] = strict_cast<U>(logits[tid]);
max_val[i] = strict_cast<U>(ldg(max_logits + row));
}
tid += gridDim.x * blockDim.x;
}
tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = 0; i < kUnroll; i++) {
row = tid / num_cols;
col = tid % num_cols;
if (row < num_rows && col < num_cols) {
if (in_log_space) {
result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row));
} else {
result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row);
}
output[tid] = strict_cast<T>(result[i]);
}
tid += gridDim.x * blockDim.x;
}
}
template <>
__global__ void GenerateNormalizedProb<Eigen::half, float, 8>(
const Eigen::half* logits, const float* sum_probs,
const Eigen::half* max_logits, Eigen::half* output, const int num_rows,
const int num_cols, const bool in_log_space) {
const int kUnroll = 8;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int idx[kUnroll];
int row[kUnroll];
float input[kUnroll];
float max_val[kUnroll];
float result[kUnroll];
if (tid * kUnroll + kUnroll - 1 < num_rows * num_cols) {
ulonglong2 logits_d =
*reinterpret_cast<const ulonglong2*>(logits + tid * kUnroll);
Eigen::half* logits_h = reinterpret_cast<Eigen::half*>(&logits_d);
ulonglong2 output_d;
Eigen::half* output_h = reinterpret_cast<Eigen::half*>(&output_d);
for (int i = 0; i < kUnroll; i++) {
idx[i] = tid * kUnroll + i;
row[i] = idx[i] / num_cols;
input[i] = strict_cast<float>(logits_h[i]);
max_val[i] = strict_cast<float>(ldg(max_logits + row[i]));
if (in_log_space) {
result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row[i]));
} else {
result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row[i]);
}
output_h[i] = strict_cast<Eigen::half>(result[i]);
}
*reinterpret_cast<ulonglong2*>(output + tid * kUnroll) = output_d;
} else {
for (int i = 0; i < kUnroll; i++) {
if (tid * kUnroll + i < num_rows * num_cols) {
idx[i] = tid * kUnroll + i;
row[i] = idx[i] / num_cols;
input[i] = strict_cast<float>(logits[idx[i]]);
max_val[i] = strict_cast<float>(ldg(max_logits + row[i]));
if (in_log_space) {
result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row[i]));
} else {
result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row[i]);
}
output[idx[i]] = strict_cast<Eigen::half>(result[i]);
}
}
output[tid] = strict_cast<T>(result);
}
}
@ -165,8 +231,6 @@ class SoftmaxOpGPU : public OpKernel {
context, const_cast<T*>(max_logits.flat<T>().data()),
reinterpret_cast<const T*>(logits_in_.flat<T>().data()), rows, cols);
const int numThreadsPerBlock = 128;
const int numBlocks = Eigen::divup(rows * cols, numThreadsPerBlock);
gpuprim::CountingInputIterator<int> counting_iterator(0);
using InputIterType =
@ -184,12 +248,36 @@ class SoftmaxOpGPU : public OpKernel {
context, const_cast<acc_type*>(sum_probs.flat<acc_type>().data()),
input_itr, rows, cols);
TF_CHECK_OK(GpuLaunchKernel(
GenerateNormalizedProb<T, acc_type>, numBlocks, numThreadsPerBlock, 0,
cu_stream, reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
reinterpret_cast<const acc_type*>(sum_probs.flat<acc_type>().data()),
reinterpret_cast<const T*>(max_logits.flat<T>().data()),
const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
auto in_ptr = reinterpret_cast<uintptr_t>(logits_in_.flat<T>().data());
auto out_ptr = reinterpret_cast<uintptr_t>(softmax_out->flat<T>().data());
bool aligned = in_ptr % 16 == 0 && out_ptr % 16 == 0;
const int numThreadsPerBlock = 128;
if (DataTypeToEnum<T>::value == DT_HALF && aligned) {
const int kUnroll = 8;
const int numBlocks =
Eigen::divup(rows * cols, numThreadsPerBlock * kUnroll);
TF_CHECK_OK(GpuLaunchKernel(
GenerateNormalizedProb<T, acc_type, kUnroll>, numBlocks,
numThreadsPerBlock, 0, cu_stream,
reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
reinterpret_cast<const acc_type*>(
sum_probs.flat<acc_type>().data()),
reinterpret_cast<const T*>(max_logits.flat<T>().data()),
const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
} else {
const int kUnroll = 4;
const int numBlocks =
Eigen::divup(rows * cols, numThreadsPerBlock * kUnroll);
TF_CHECK_OK(GpuLaunchKernel(
GenerateNormalizedProb<T, acc_type, kUnroll>, numBlocks,
numThreadsPerBlock, 0, cu_stream,
reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
reinterpret_cast<const acc_type*>(
sum_probs.flat<acc_type>().data()),
reinterpret_cast<const T*>(max_logits.flat<T>().data()),
const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
}
}
}