Merge pull request #38290 from benbarsdell:strided-batch-matmul

PiperOrigin-RevId: 305918738
Change-Id: I84550609b1bb5cc723be2ea3d23a3f8d9fda3fd2
This commit is contained in:
TensorFlower Gardener 2020-04-10 12:08:03 -07:00
commit 2947e86977

View File

@ -330,8 +330,26 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
auto* a_base_ptr = in_x.template flat<Scalar>().data();
auto* b_base_ptr = in_y.template flat<Scalar>().data();
auto* c_base_ptr = out->template flat<Scalar>().data();
uint64 a_stride;
uint64 b_stride;
uint64 c_stride;
if (!bcast.IsBroadcastingRequired()) {
bool is_full_broadcast =
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
bool use_strided_batched =
(!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
batch_size > 1;
if (use_strided_batched) {
a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
c_stride = m * n;
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
a_ptrs.push_back(&a_device_memory.back());
b_ptrs.push_back(&b_device_memory.back());
c_ptrs.push_back(&c_device_memory.back());
} else if (!bcast.IsBroadcastingRequired()) {
for (int64 i = 0; i < batch_size; ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
@ -407,6 +425,23 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
", k=", k));
}
}
} else if (use_strided_batched) {
bool blas_launch_status =
stream
->ThenBlasGemmStridedBatched(
blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), *b_ptrs[0], adj_y ? k : n,
b_stride, *a_ptrs[0], adj_x ? m : k, a_stride,
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
batch_size)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
"Blas xGEMMStridedBatched launch failed : a.shape=",
in_x.shape().DebugString(),
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
", k=", k, ", batch_size=", batch_size));
}
} else {
BlasScratchAllocator scratch_allocator(context);
bool blas_launch_status =
@ -467,7 +502,26 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
auto* b_base_ptr = in_y.template flat<Scalar>().data();
auto* c_base_ptr = out->template flat<Scalar>().data();
if (!bcast.IsBroadcastingRequired()) {
uint64 a_stride;
uint64 b_stride;
uint64 c_stride;
bool is_full_broadcast =
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
bool use_strided_batched =
(!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
batch_size > 1;
if (use_strided_batched) {
a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
c_stride = m * n;
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
a_ptrs.push_back(&a_device_memory.back());
b_ptrs.push_back(&b_device_memory.back());
c_ptrs.push_back(&c_device_memory.back());
} else if (!bcast.IsBroadcastingRequired()) {
for (int64 i = 0; i < batch_size; ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
@ -518,6 +572,23 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
", k=", k));
}
} else if (use_strided_batched) {
bool blas_launch_status =
stream
->ThenBlasGemmStridedBatched(
blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), *b_ptrs[0], adj_y ? k : n,
b_stride, *a_ptrs[0], adj_x ? m : k, a_stride,
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
batch_size)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
"Blas xGEMMStridedBatched launch failed : a.shape=",
in_x.shape().DebugString(),
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
", k=", k, ", batch_size=", batch_size));
}
} else {
BlasScratchAllocator scratch_allocator(context);
bool blas_launch_status =