Merge pull request #38290 from benbarsdell:strided-batch-matmul
PiperOrigin-RevId: 305918738 Change-Id: I84550609b1bb5cc723be2ea3d23a3f8d9fda3fd2
This commit is contained in:
commit
2947e86977
@ -330,8 +330,26 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
auto* a_base_ptr = in_x.template flat<Scalar>().data();
|
||||
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
||||
auto* c_base_ptr = out->template flat<Scalar>().data();
|
||||
uint64 a_stride;
|
||||
uint64 b_stride;
|
||||
uint64 c_stride;
|
||||
|
||||
if (!bcast.IsBroadcastingRequired()) {
|
||||
bool is_full_broadcast =
|
||||
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
||||
bool use_strided_batched =
|
||||
(!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
|
||||
batch_size > 1;
|
||||
if (use_strided_batched) {
|
||||
a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
|
||||
b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
|
||||
c_stride = m * n;
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
} else if (!bcast.IsBroadcastingRequired()) {
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
@ -407,6 +425,23 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
", k=", k));
|
||||
}
|
||||
}
|
||||
} else if (use_strided_batched) {
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmStridedBatched(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), *b_ptrs[0], adj_y ? k : n,
|
||||
b_stride, *a_ptrs[0], adj_x ? m : k, a_stride,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
|
||||
batch_size)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas xGEMMStridedBatched launch failed : a.shape=",
|
||||
in_x.shape().DebugString(),
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
} else {
|
||||
BlasScratchAllocator scratch_allocator(context);
|
||||
bool blas_launch_status =
|
||||
@ -467,7 +502,26 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
||||
auto* c_base_ptr = out->template flat<Scalar>().data();
|
||||
|
||||
if (!bcast.IsBroadcastingRequired()) {
|
||||
uint64 a_stride;
|
||||
uint64 b_stride;
|
||||
uint64 c_stride;
|
||||
|
||||
bool is_full_broadcast =
|
||||
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
||||
bool use_strided_batched =
|
||||
(!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
|
||||
batch_size > 1;
|
||||
if (use_strided_batched) {
|
||||
a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
|
||||
b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
|
||||
c_stride = m * n;
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
} else if (!bcast.IsBroadcastingRequired()) {
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
@ -518,6 +572,23 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k));
|
||||
}
|
||||
} else if (use_strided_batched) {
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmStridedBatched(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), *b_ptrs[0], adj_y ? k : n,
|
||||
b_stride, *a_ptrs[0], adj_x ? m : k, a_stride,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
|
||||
batch_size)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas xGEMMStridedBatched launch failed : a.shape=",
|
||||
in_x.shape().DebugString(),
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
} else {
|
||||
BlasScratchAllocator scratch_allocator(context);
|
||||
bool blas_launch_status =
|
||||
|
Loading…
x
Reference in New Issue
Block a user