diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index 5649c068780..d5b6e7d5e17 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -330,8 +330,26 @@ struct LaunchBatchMatMul { auto* a_base_ptr = in_x.template flat().data(); auto* b_base_ptr = in_y.template flat().data(); auto* c_base_ptr = out->template flat().data(); + uint64 a_stride; + uint64 b_stride; + uint64 c_stride; - if (!bcast.IsBroadcastingRequired()) { + bool is_full_broadcast = + std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; + bool use_strided_batched = + (!bcast.IsBroadcastingRequired() || is_full_broadcast) && + batch_size > 1; + if (use_strided_batched) { + a_stride = bcast.x_batch_size() != 1 ? m * k : 0; + b_stride = bcast.y_batch_size() != 1 ? k * n : 0; + c_stride = m * n; + a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + } else if (!bcast.IsBroadcastingRequired()) { for (int64 i = 0; i < batch_size; ++i) { a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); @@ -407,6 +425,23 @@ struct LaunchBatchMatMul { ", k=", k)); } } + } else if (use_strided_batched) { + bool blas_launch_status = + stream + ->ThenBlasGemmStridedBatched( + blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), *b_ptrs[0], adj_y ? k : n, + b_stride, *a_ptrs[0], adj_x ? m : k, a_stride, + static_cast(0.0), c_ptrs[0], n, c_stride, + batch_size) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMMStridedBatched launch failed : a.shape=", + in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } } else { BlasScratchAllocator scratch_allocator(context); bool blas_launch_status = @@ -467,7 +502,26 @@ struct LaunchBatchMatMul { auto* b_base_ptr = in_y.template flat().data(); auto* c_base_ptr = out->template flat().data(); - if (!bcast.IsBroadcastingRequired()) { + uint64 a_stride; + uint64 b_stride; + uint64 c_stride; + + bool is_full_broadcast = + std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; + bool use_strided_batched = + (!bcast.IsBroadcastingRequired() || is_full_broadcast) && + batch_size > 1; + if (use_strided_batched) { + a_stride = bcast.x_batch_size() != 1 ? m * k : 0; + b_stride = bcast.y_batch_size() != 1 ? k * n : 0; + c_stride = m * n; + a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + } else if (!bcast.IsBroadcastingRequired()) { for (int64 i = 0; i < batch_size; ++i) { a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); @@ -518,6 +572,23 @@ struct LaunchBatchMatMul { ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, ", k=", k)); } + } else if (use_strided_batched) { + bool blas_launch_status = + stream + ->ThenBlasGemmStridedBatched( + blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), *b_ptrs[0], adj_y ? k : n, + b_stride, *a_ptrs[0], adj_x ? m : k, a_stride, + static_cast(0.0), c_ptrs[0], n, c_stride, + batch_size) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMMStridedBatched launch failed : a.shape=", + in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } } else { BlasScratchAllocator scratch_allocator(context); bool blas_launch_status =