Merge pull request #38094 from houtoms:softmax_vector_pr
PiperOrigin-RevId: 306272653 Change-Id: I5fe5cad82ec3b62361872fd41b4300000b6dd086
This commit is contained in:
commit
94f88814d1
@ -72,28 +72,94 @@ struct softmax_traits<Eigen::half> {
|
||||
using accumulator_type = float;
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, int kUnroll>
|
||||
__global__ void GenerateNormalizedProb(const T* logits, const U* sum_probs,
|
||||
const T* max_logits, T* output,
|
||||
const int num_rows, const int num_cols,
|
||||
const bool in_log_space) {
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
const int row = tid / num_cols;
|
||||
const int col = tid % num_cols;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row, col;
|
||||
|
||||
// TODO(jamesqin): change to half2 load when inputs are Eigen::half.
|
||||
U input = strict_cast<U>(logits[tid]);
|
||||
U max_val = strict_cast<U>(ldg(max_logits + row));
|
||||
U result;
|
||||
|
||||
if (row < num_rows && col < num_cols) {
|
||||
if (in_log_space) {
|
||||
result = input - max_val - log(ldg(sum_probs + row));
|
||||
} else {
|
||||
result = exp(input - max_val) / ldg(sum_probs + row);
|
||||
U input[kUnroll];
|
||||
U max_val[kUnroll];
|
||||
U result[kUnroll];
|
||||
for (int i = 0; i < kUnroll; i++) {
|
||||
row = tid / num_cols;
|
||||
col = tid % num_cols;
|
||||
if (row < num_rows && col < num_cols) {
|
||||
input[i] = strict_cast<U>(logits[tid]);
|
||||
max_val[i] = strict_cast<U>(ldg(max_logits + row));
|
||||
}
|
||||
tid += gridDim.x * blockDim.x;
|
||||
}
|
||||
|
||||
tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = 0; i < kUnroll; i++) {
|
||||
row = tid / num_cols;
|
||||
col = tid % num_cols;
|
||||
if (row < num_rows && col < num_cols) {
|
||||
if (in_log_space) {
|
||||
result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row));
|
||||
} else {
|
||||
result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row);
|
||||
}
|
||||
output[tid] = strict_cast<T>(result[i]);
|
||||
}
|
||||
tid += gridDim.x * blockDim.x;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void GenerateNormalizedProb<Eigen::half, float, 8>(
|
||||
const Eigen::half* logits, const float* sum_probs,
|
||||
const Eigen::half* max_logits, Eigen::half* output, const int num_rows,
|
||||
const int num_cols, const bool in_log_space) {
|
||||
const int kUnroll = 8;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int idx[kUnroll];
|
||||
int row[kUnroll];
|
||||
|
||||
float input[kUnroll];
|
||||
float max_val[kUnroll];
|
||||
float result[kUnroll];
|
||||
|
||||
if (tid * kUnroll + kUnroll - 1 < num_rows * num_cols) {
|
||||
ulonglong2 logits_d =
|
||||
*reinterpret_cast<const ulonglong2*>(logits + tid * kUnroll);
|
||||
Eigen::half* logits_h = reinterpret_cast<Eigen::half*>(&logits_d);
|
||||
ulonglong2 output_d;
|
||||
Eigen::half* output_h = reinterpret_cast<Eigen::half*>(&output_d);
|
||||
|
||||
for (int i = 0; i < kUnroll; i++) {
|
||||
idx[i] = tid * kUnroll + i;
|
||||
row[i] = idx[i] / num_cols;
|
||||
input[i] = strict_cast<float>(logits_h[i]);
|
||||
max_val[i] = strict_cast<float>(ldg(max_logits + row[i]));
|
||||
if (in_log_space) {
|
||||
result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row[i]));
|
||||
} else {
|
||||
result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row[i]);
|
||||
}
|
||||
output_h[i] = strict_cast<Eigen::half>(result[i]);
|
||||
}
|
||||
|
||||
*reinterpret_cast<ulonglong2*>(output + tid * kUnroll) = output_d;
|
||||
} else {
|
||||
for (int i = 0; i < kUnroll; i++) {
|
||||
if (tid * kUnroll + i < num_rows * num_cols) {
|
||||
idx[i] = tid * kUnroll + i;
|
||||
row[i] = idx[i] / num_cols;
|
||||
input[i] = strict_cast<float>(logits[idx[i]]);
|
||||
max_val[i] = strict_cast<float>(ldg(max_logits + row[i]));
|
||||
if (in_log_space) {
|
||||
result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row[i]));
|
||||
} else {
|
||||
result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row[i]);
|
||||
}
|
||||
output[idx[i]] = strict_cast<Eigen::half>(result[i]);
|
||||
}
|
||||
}
|
||||
output[tid] = strict_cast<T>(result);
|
||||
}
|
||||
}
|
||||
|
||||
@ -165,8 +231,6 @@ class SoftmaxOpGPU : public OpKernel {
|
||||
context, const_cast<T*>(max_logits.flat<T>().data()),
|
||||
reinterpret_cast<const T*>(logits_in_.flat<T>().data()), rows, cols);
|
||||
|
||||
const int numThreadsPerBlock = 128;
|
||||
const int numBlocks = Eigen::divup(rows * cols, numThreadsPerBlock);
|
||||
|
||||
gpuprim::CountingInputIterator<int> counting_iterator(0);
|
||||
using InputIterType =
|
||||
@ -184,12 +248,36 @@ class SoftmaxOpGPU : public OpKernel {
|
||||
context, const_cast<acc_type*>(sum_probs.flat<acc_type>().data()),
|
||||
input_itr, rows, cols);
|
||||
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
GenerateNormalizedProb<T, acc_type>, numBlocks, numThreadsPerBlock, 0,
|
||||
cu_stream, reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
|
||||
reinterpret_cast<const acc_type*>(sum_probs.flat<acc_type>().data()),
|
||||
reinterpret_cast<const T*>(max_logits.flat<T>().data()),
|
||||
const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
|
||||
auto in_ptr = reinterpret_cast<uintptr_t>(logits_in_.flat<T>().data());
|
||||
auto out_ptr = reinterpret_cast<uintptr_t>(softmax_out->flat<T>().data());
|
||||
bool aligned = in_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||
|
||||
const int numThreadsPerBlock = 128;
|
||||
if (DataTypeToEnum<T>::value == DT_HALF && aligned) {
|
||||
const int kUnroll = 8;
|
||||
const int numBlocks =
|
||||
Eigen::divup(rows * cols, numThreadsPerBlock * kUnroll);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
GenerateNormalizedProb<T, acc_type, kUnroll>, numBlocks,
|
||||
numThreadsPerBlock, 0, cu_stream,
|
||||
reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
|
||||
reinterpret_cast<const acc_type*>(
|
||||
sum_probs.flat<acc_type>().data()),
|
||||
reinterpret_cast<const T*>(max_logits.flat<T>().data()),
|
||||
const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
|
||||
} else {
|
||||
const int kUnroll = 4;
|
||||
const int numBlocks =
|
||||
Eigen::divup(rows * cols, numThreadsPerBlock * kUnroll);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
GenerateNormalizedProb<T, acc_type, kUnroll>, numBlocks,
|
||||
numThreadsPerBlock, 0, cu_stream,
|
||||
reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
|
||||
reinterpret_cast<const acc_type*>(
|
||||
sum_probs.flat<acc_type>().data()),
|
||||
reinterpret_cast<const T*>(max_logits.flat<T>().data()),
|
||||
const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user