Replacing GetCudaLaunchConfig and CudaLaunchKernel with their Gpu equivalent.

PiperOrigin-RevId: 256377258
This commit is contained in:
A. Unique TensorFlower 2019-07-03 08:58:41 -07:00 committed by TensorFlower Gardener
parent 4bc0fe3cf4
commit 6f0584298b
13 changed files with 79 additions and 79 deletions

View File

@ -81,14 +81,14 @@ void BiasGPU<T>::compute(const GPUDevice& d, const T* input, const T* bias,
if (total_count == 0) { if (total_count == 0) {
return; return;
} }
GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d);
if (data_format == FORMAT_NHWC) { if (data_format == FORMAT_NHWC) {
TF_CHECK_OK(CudaLaunchKernel(BiasNHWCKernel<T>, config.block_count, TF_CHECK_OK(GpuLaunchKernel(BiasNHWCKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.thread_per_block, 0, d.stream(),
config.virtual_thread_count, input, bias, config.virtual_thread_count, input, bias,
output, bias_size)); output, bias_size));
} else { } else {
TF_CHECK_OK(CudaLaunchKernel(BiasNCHWKernel<T>, config.block_count, TF_CHECK_OK(GpuLaunchKernel(BiasNCHWKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.thread_per_block, 0, d.stream(),
config.virtual_thread_count, input, bias, config.virtual_thread_count, input, bias,
output, bias_size, image_size)); output, bias_size, image_size));
@ -204,7 +204,7 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
return; return;
} }
static constexpr int32 kWarpSize = 32; static constexpr int32 kWarpSize = 32;
GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d);
const int max_shared_memory_size = d.sharedMemPerBlock() / 2; const int max_shared_memory_size = d.sharedMemPerBlock() / 2;
int32 shared_memory_size = 0; int32 shared_memory_size = 0;
@ -214,7 +214,7 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
// Check if we have enough shared memory. // Check if we have enough shared memory.
if (shared_memory_size <= max_shared_memory_size) { if (shared_memory_size <= max_shared_memory_size) {
if (data_format == FORMAT_NHWC) { if (data_format == FORMAT_NHWC) {
TF_CHECK_OK(CudaLaunchKernel(BiasGradNHWC_SharedAtomics<T>, TF_CHECK_OK(GpuLaunchKernel(BiasGradNHWC_SharedAtomics<T>,
config.block_count, config.thread_per_block, config.block_count, config.thread_per_block,
shared_memory_size, d.stream(), total_count, shared_memory_size, d.stream(), total_count,
output_backprop, bias_backprop, bias_size)); output_backprop, bias_backprop, bias_size));
@ -225,21 +225,21 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
if (config.thread_per_block < kWarpSize) { if (config.thread_per_block < kWarpSize) {
config.thread_per_block = kWarpSize; config.thread_per_block = kWarpSize;
} }
TF_CHECK_OK(CudaLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(BiasGradNCHW_SharedAtomics<T>,
BiasGradNCHW_SharedAtomics<T>, config.block_count, config.block_count, config.thread_per_block,
config.thread_per_block, 0, d.stream(), output_backprop, 0, d.stream(), output_backprop, bias_backprop,
bias_backprop, batch, bias_size, image_size, group_size)); batch, bias_size, image_size, group_size));
} }
} else { } else {
// Note that even if we don't have enough shared memory to fit the entire // Note that even if we don't have enough shared memory to fit the entire
// output block, it is possible to process one group of elements at a time. // output block, it is possible to process one group of elements at a time.
// But for now, we simply fall back to the naive implementation. // But for now, we simply fall back to the naive implementation.
if (data_format == FORMAT_NHWC) { if (data_format == FORMAT_NHWC) {
TF_CHECK_OK(CudaLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
BiasGradNHWC_Naive<T>, config.block_count, config.thread_per_block, 0, BiasGradNHWC_Naive<T>, config.block_count, config.thread_per_block, 0,
d.stream(), total_count, output_backprop, bias_backprop, bias_size)); d.stream(), total_count, output_backprop, bias_backprop, bias_size));
} else { } else {
TF_CHECK_OK(CudaLaunchKernel(BiasGradNCHW_Naive<T>, config.block_count, TF_CHECK_OK(GpuLaunchKernel(BiasGradNCHW_Naive<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.thread_per_block, 0, d.stream(),
total_count, output_backprop, bias_backprop, total_count, output_backprop, bias_backprop,
bias_size, image_size)); bias_size, image_size));

View File

@ -656,10 +656,10 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
kBlockDepth * (tile_pixels + filter_pixels) * sizeof(S); kBlockDepth * (tile_pixels + filter_pixels) * sizeof(S);
const int num_outputs = args.out_rows * args.out_cols * block_count; const int num_outputs = args.out_rows * args.out_cols * block_count;
auto device = ctx->eigen_gpu_device(); auto device = ctx->eigen_gpu_device();
GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
num_outputs, device, kernel, shared_memory_size, num_outputs, device, kernel, shared_memory_size,
block_dim.x * block_dim.y * block_dim.z); block_dim.x * block_dim.y * block_dim.z);
TF_CHECK_OK(CudaLaunchKernel(kernel, config.block_count, block_dim, TF_CHECK_OK(GpuLaunchKernel(kernel, config.block_count, block_dim,
shared_memory_size, device.stream(), args, input, shared_memory_size, device.stream(), args, input,
filter, output)); filter, output));
return Status::OK(); return Status::OK();
@ -751,10 +751,10 @@ Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args,
kKnownDepthMultiplier < 0 kKnownDepthMultiplier < 0
? std::numeric_limits<int>::max() ? std::numeric_limits<int>::max()
: device.getNumGpuMultiProcessors(); : device.getNumGpuMultiProcessors();
TF_CHECK_OK(CudaLaunchKernel(kernel, TF_CHECK_OK(GpuLaunchKernel(kernel,
std::min(max_block_count, config.block_count), std::min(max_block_count, config.block_count),
config.thread_per_block, 0, device.stream(), config.thread_per_block, 0, device.stream(), args,
args, input, filter, output, num_outputs)); input, filter, output, num_outputs));
return Status::OK(); return Status::OK();
} }
@ -969,7 +969,7 @@ Status LaunchDepthwiseConv2dBackpropInputGPU(OpKernelContext* ctx,
auto device = ctx->eigen_gpu_device(); auto device = ctx->eigen_gpu_device();
GpuLaunchConfig config = GpuLaunchConfig config =
GetGpuLaunchConfig(num_in_backprop, device, kernel, 0, 0); GetGpuLaunchConfig(num_in_backprop, device, kernel, 0, 0);
TF_CHECK_OK(CudaLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
kernel, config.block_count, config.thread_per_block, 0, device.stream(), kernel, config.block_count, config.thread_per_block, 0, device.stream(),
args, out_backprop, filter, in_backprop, num_in_backprop)); args, out_backprop, filter, in_backprop, num_in_backprop));
return Status::OK(); return Status::OK();
@ -1611,10 +1611,10 @@ Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
" is not supported"); " is not supported");
} }
const int num_out_backprop = args.out_rows * args.out_cols * block_count; const int num_out_backprop = args.out_rows * args.out_cols * block_count;
GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
num_out_backprop, device, kernel, shared_memory_size, num_out_backprop, device, kernel, shared_memory_size,
block_dim.x * block_dim.y * block_dim.z); block_dim.x * block_dim.y * block_dim.z);
TF_CHECK_OK(CudaLaunchKernel(kernel, config.block_count, block_dim, TF_CHECK_OK(GpuLaunchKernel(kernel, config.block_count, block_dim,
shared_memory_size, device.stream(), args, shared_memory_size, device.stream(), args,
out_backprop, input, filter_backprop)); out_backprop, input, filter_backprop));
return Status::OK(); return Status::OK();
@ -1717,7 +1717,7 @@ Status LaunchDepthwiseConv2dBackpropFilterGPU(
auto device = ctx->eigen_gpu_device(); auto device = ctx->eigen_gpu_device();
GpuLaunchConfig config = GpuLaunchConfig config =
GetGpuLaunchConfig(num_out_backprop, device, kernel, 0, 0); GetGpuLaunchConfig(num_out_backprop, device, kernel, 0, 0);
TF_CHECK_OK(CudaLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
kernel, config.block_count, config.thread_per_block, 0, device.stream(), kernel, config.block_count, config.thread_per_block, 0, device.stream(),
args, out_backprop, input, filter_backprop, num_out_backprop)); args, out_backprop, input, filter_backprop, num_out_backprop));
return Status::OK(); return Status::OK();

View File

@ -128,9 +128,9 @@ struct DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
int* info) { int* info) {
const int64 num_matrices = output.size(); const int64 num_matrices = output.size();
const int64 n = lu_factor.dimension(2); const int64 n = lu_factor.dimension(2);
GpuLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); GpuLaunchConfig config = GetGpuLaunchConfig(num_matrices, device);
TF_CHECK_OK(CudaLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/false>, DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/false>,
config.block_count, config.thread_per_block, 0, device.stream(), config.block_count, config.thread_per_block, 0, device.stream(),
config.virtual_thread_count, n, lu_factor.data(), pivots, nullptr, config.virtual_thread_count, n, lu_factor.data(), pivots, nullptr,
@ -151,8 +151,8 @@ struct LogDeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
typename TTypes<Scalar, 1>::Tensor log_abs_det) { typename TTypes<Scalar, 1>::Tensor log_abs_det) {
const int64 num_matrices = sign.size(); const int64 num_matrices = sign.size();
const int64 n = lu_factor.dimension(2); const int64 n = lu_factor.dimension(2);
GpuLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); GpuLaunchConfig config = GetGpuLaunchConfig(num_matrices, device);
TF_CHECK_OK(CudaLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/true>, DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/true>,
config.block_count, config.thread_per_block, 0, device.stream(), config.block_count, config.thread_per_block, 0, device.stream(),
config.virtual_thread_count, n, lu_factor.data(), pivots, sign.data(), config.virtual_thread_count, n, lu_factor.data(), pivots, sign.data(),

View File

@ -118,9 +118,9 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
const auto& d = context->eigen_device<GPUDevice>(); const auto& d = context->eigen_device<GPUDevice>();
// Compute a mask for all predictions. // Compute a mask for all predictions.
CudaLaunchConfig config = GetCudaLaunchConfig(num_targets * num_classes, d); CudaLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
OP_REQUIRES_OK(context, CudaLaunchKernel( OP_REQUIRES_OK(
ComputePredictionMaskKernel<T, TargetT>, context, GpuLaunchKernel(ComputePredictionMaskKernel<T, TargetT>,
config.block_count, config.thread_per_block, 0, config.block_count, config.thread_per_block, 0,
d.stream(), predictions.data(), targets.data(), d.stream(), predictions.data(), targets.data(),
predictions_mask.flat<int64>().data(), predictions_mask.flat<int64>().data(),

View File

@ -222,8 +222,8 @@ class LuOpGpu : public AsyncOpKernel {
int* pivots_ptr = pivots.flat<int>().data(); int* pivots_ptr = pivots.flat<int>().data();
Tidx* permutation_indices_ptr = Tidx* permutation_indices_ptr =
permutation_indices->template flat<Tidx>().data(); permutation_indices->template flat<Tidx>().data();
GpuLaunchConfig cfgPivots = GetCudaLaunchConfig(batch_size, device); GpuLaunchConfig cfgPivots = GetGpuLaunchConfig(batch_size, device);
TF_CHECK_OK(CudaLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
ComputePermutationFromTranspositionsKernel<Tidx>, cfgPivots.block_count, ComputePermutationFromTranspositionsKernel<Tidx>, cfgPivots.block_count,
cfgPivots.thread_per_block, 0, device.stream(), cfgPivots, num_rows, cfgPivots.thread_per_block, 0, device.stream(), cfgPivots, num_rows,
pivots_ptr, permutation_indices_ptr)); pivots_ptr, permutation_indices_ptr));

View File

@ -241,7 +241,7 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
DataType::DT_INT32, TensorShape({max_nms_mask_size}), &d_nms_mask)); DataType::DT_INT32, TensorShape({max_nms_mask_size}), &d_nms_mask));
// reset data sensitive tensors // reset data sensitive tensors
auto device = context->eigen_gpu_device(); auto device = context->eigen_gpu_device();
auto config = GetCudaLaunchConfig(d_nms_mask.NumElements(), device); auto config = GetGpuLaunchConfig(d_nms_mask.NumElements(), device);
TF_CHECK_OK(GpuLaunchKernel(SetZero<int>, config.block_count, TF_CHECK_OK(GpuLaunchKernel(SetZero<int>, config.block_count,
config.thread_per_block, 0, device.stream(), config.thread_per_block, 0, device.stream(),
config.virtual_thread_count, config.virtual_thread_count,
@ -425,7 +425,7 @@ class NonMaxSuppressionV2GPUOp : public OpKernel {
&d_sorted_boxes)); &d_sorted_boxes));
// this will return sorted scores and their indices // this will return sorted scores and their indices
auto config = GetCudaLaunchConfig(num_boxes, device); auto config = GetGpuLaunchConfig(num_boxes, device);
// initialize box and score indices // initialize box and score indices
TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count, TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
config.thread_per_block, 0, device.stream(), config.thread_per_block, 0, device.stream(),
@ -472,7 +472,7 @@ class NonMaxSuppressionV2GPUOp : public OpKernel {
context->allocate_output(0, TensorShape({num_outputs}), context->allocate_output(0, TensorShape({num_outputs}),
&output_indices)); &output_indices));
if (num_outputs == 0) return; if (num_outputs == 0) return;
config = GetCudaLaunchConfig(num_outputs, device); config = GetGpuLaunchConfig(num_outputs, device);
TF_CHECK_OK(GpuLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
IndexMultiSelect<int, int>, config.block_count, config.thread_per_block, IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
0, device.stream(), config.virtual_thread_count, 0, device.stream(), config.virtual_thread_count,

View File

@ -104,9 +104,9 @@ struct ReluGrad<Device, Eigen::half> {
if (count == 0) return; if (count == 0) return;
int32 half2_count = Eigen::divup(count, 2); int32 half2_count = Eigen::divup(count, 2);
constexpr int32 kThreadInBlock = 512; constexpr int32 kThreadInBlock = 512;
GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock); half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
TF_CHECK_OK(CudaLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0, ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
d.stream(), gradient.data(), feature.data(), backprop.data(), count)); d.stream(), gradient.data(), feature.data(), backprop.data(), count));
} }
@ -133,9 +133,9 @@ struct Relu<Device, qint8> {
int32 vect_count = Eigen::divup(count, 4); int32 vect_count = Eigen::divup(count, 4);
constexpr int32 kThreadInBlock = 512; constexpr int32 kThreadInBlock = 512;
GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
vect_count, d, Relu_int8x4_kernel, 0, kThreadInBlock); vect_count, d, Relu_int8x4_kernel, 0, kThreadInBlock);
TF_CHECK_OK(CudaLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
Relu_int8x4_kernel, config.block_count, config.thread_per_block, 0, Relu_int8x4_kernel, config.block_count, config.thread_per_block, 0,
d.stream(), vect_count, reinterpret_cast<const int32*>(input.data()), d.stream(), vect_count, reinterpret_cast<const int32*>(input.data()),
reinterpret_cast<int32*>(output.data()))); reinterpret_cast<int32*>(output.data())));

View File

@ -71,7 +71,7 @@ struct Roll<GPUDevice, T> {
d.memcpyHostToDevice(thres_buf, threshold.data(), thres_bytes); d.memcpyHostToDevice(thres_buf, threshold.data(), thres_bytes);
d.memcpyHostToDevice(range_buf, dim_range.data(), range_bytes); d.memcpyHostToDevice(range_buf, dim_range.data(), range_bytes);
CudaLaunchConfig cfg = GetCudaLaunchConfig(num_elements, d); CudaLaunchConfig cfg = GetGpuLaunchConfig(num_elements, d);
TF_CHECK_OK(GpuLaunchKernel(RollKernel<T>, cfg.block_count, TF_CHECK_OK(GpuLaunchKernel(RollKernel<T>, cfg.block_count,
cfg.thread_per_block, 0, d.stream(), cfg.thread_per_block, 0, d.stream(),

View File

@ -70,12 +70,12 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
// maximize occupancy // maximize occupancy
const int kGroupSize = Distribution::kResultElementCount; const int kGroupSize = Distribution::kResultElementCount;
int work_element_count = (output_size + kGroupSize - 1) / kGroupSize; int work_element_count = (output_size + kGroupSize - 1) / kGroupSize;
GpuLaunchConfig cfg = GetCudaLaunchConfig(work_element_count, d, GpuLaunchConfig cfg =
FillKernel<Distribution>, 0, 0); GetGpuLaunchConfig(work_element_count, d, FillKernel<Distribution>, 0, 0);
int zero = 0; int zero = 0;
cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int)); cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int));
TF_CHECK_OK(CudaLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
FillKernel<Distribution>, cfg.block_count, cfg.thread_per_block, 0, FillKernel<Distribution>, cfg.block_count, cfg.thread_per_block, 0,
d.stream(), dist, state_size, output_size, state_data, output_data)); d.stream(), dist, state_size, output_size, state_data, output_data));
} }

View File

@ -196,14 +196,14 @@ class SvdOpGpu : public AsyncOpKernel {
const GPUDevice& d = context->eigen_device<GPUDevice>(); const GPUDevice& d = context->eigen_device<GPUDevice>();
d.memset(outputV_ptr, 0, batch_size * sizeof(Scalar)); d.memset(outputV_ptr, 0, batch_size * sizeof(Scalar));
Gpu2DLaunchConfig cfg2D = GetCuda2DLaunchConfig(batch_size, m, d); Gpu2DLaunchConfig cfg2D = GetCuda2DLaunchConfig(batch_size, m, d);
TF_CHECK_OK(CudaLaunchKernel(ComputeValueOfVKernel<Scalar>, TF_CHECK_OK(GpuLaunchKernel(ComputeValueOfVKernel<Scalar>,
cfg2D.block_count, cfg2D.thread_per_block, 0, cfg2D.block_count, cfg2D.thread_per_block, 0,
d.stream(), cfg2D, m, full_matrices_ ? m : p, d.stream(), cfg2D, m, full_matrices_ ? m : p,
input_copy.flat<Scalar>().data(), input_copy.flat<Scalar>().data(), outputU_ptr,
outputU_ptr, outputS_ptr, outputV_ptr)); outputS_ptr, outputV_ptr));
// 2. clamp V to -1 or +1 // 2. clamp V to -1 or +1
GpuLaunchConfig cfg1D = GetCudaLaunchConfig(batch_size, d); GpuLaunchConfig cfg1D = GetGpuLaunchConfig(batch_size, d);
TF_CHECK_OK(CudaLaunchKernel(ExtractSignOfVKernel<Scalar>, TF_CHECK_OK(GpuLaunchKernel(ExtractSignOfVKernel<Scalar>,
cfg1D.block_count, cfg1D.thread_per_block, 0, cfg1D.block_count, cfg1D.thread_per_block, 0,
d.stream(), cfg1D, outputV_ptr)); d.stream(), cfg1D, outputV_ptr));
} }

View File

@ -402,7 +402,7 @@ cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards,
// We are limited by the amount of shared memory we have per block. // We are limited by the amount of shared memory we have per block.
auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>); auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
TF_CHECK_OK(CudaLaunchKernel(TopKKernel<T>, batch_size, num_shards, TF_CHECK_OK(GpuLaunchKernel(TopKKernel<T>, batch_size, num_shards,
shared_memory_size, stream, input, length, k, shared_memory_size, stream, input, length, k,
sorted, output, indices)); sorted, output, indices));
return cudaGetLastError(); return cudaGetLastError();

View File

@ -77,7 +77,7 @@ class TridiagonalMatMulOpGpu : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output)); OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output));
const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>(); const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
CudaLaunchConfig cfg = GetCudaLaunchConfig(1, device); CudaLaunchConfig cfg = GetGpuLaunchConfig(1, device);
TF_CHECK_OK(GpuLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
TridiagonalMatMulKernel<Scalar>, cfg.block_count, cfg.thread_per_block, TridiagonalMatMulKernel<Scalar>, cfg.block_count, cfg.thread_per_block,
0, device.stream(), batch_size, m, n, superdiag.flat<Scalar>().data(), 0, device.stream(), batch_size, m, n, superdiag.flat<Scalar>().data(),

View File

@ -228,10 +228,10 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals, void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
const Scalar* rhs, Scalar* output, int m, int k) { const Scalar* rhs, Scalar* output, int m, int k) {
const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>(); const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
GpuLaunchConfig cfg = GetCudaLaunchConfig(1, device); GpuLaunchConfig cfg = GetGpuLaunchConfig(1, device);
bool* not_invertible_dev; bool* not_invertible_dev;
cudaMalloc(&not_invertible_dev, sizeof(bool)); cudaMalloc(&not_invertible_dev, sizeof(bool));
TF_CHECK_OK(CudaLaunchKernel(SolveForSizeOneOrTwoKernel<Scalar>, TF_CHECK_OK(GpuLaunchKernel(SolveForSizeOneOrTwoKernel<Scalar>,
cfg.block_count, cfg.thread_per_block, 0, cfg.block_count, cfg.thread_per_block, 0,
device.stream(), m, diagonals, rhs, k, output, device.stream(), m, diagonals, rhs, k, output,
not_invertible_dev)); not_invertible_dev));