Added batch support for OpenCL converters in some cases.

PiperOrigin-RevId: 315310765
Change-Id: Icec9b3b989e2c3a796882d7cb53a9c5a27bebedf
This commit is contained in:
Raman Sarokin 2020-06-08 11:02:58 -07:00 committed by TensorFlower Gardener
parent ee4f27a524
commit 4f4f5db82f
2 changed files with 14 additions and 10 deletions

View File

@ -513,7 +513,7 @@ class InferenceRunnerImpl : public InferenceRunner {
TensorObjectDef TensorToDef(const Tensor& tensor) { TensorObjectDef TensorToDef(const Tensor& tensor) {
TensorObjectDef def; TensorObjectDef def;
def.dimensions.b = 1; def.dimensions.b = tensor.Batch();
def.dimensions.h = tensor.Height(); def.dimensions.h = tensor.Height();
def.dimensions.w = tensor.Width(); def.dimensions.w = tensor.Width();
def.dimensions.c = tensor.Channels(); def.dimensions.c = tensor.Channels();

View File

@ -44,7 +44,7 @@ class OpenClConverterImpl : public TensorObjectConverter {
kernel_.ResetBindingCounter(); kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(input)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(input));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(output)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(output));
int3 grid = int3(dims_.w, dims_.h, dims_.d()); int3 grid = int3(dims_.w * dims_.b, dims_.h, dims_.d());
int4 size = int4(dims_.w, dims_.h, dims_.d(), dims_.b); int4 size = int4(dims_.w, dims_.h, dims_.d(), dims_.b);
RETURN_IF_ERROR(kernel_.SetBytesAuto(size)); RETURN_IF_ERROR(kernel_.SetBytesAuto(size));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dims_.c)); RETURN_IF_ERROR(kernel_.SetBytesAuto(dims_.c));
@ -105,7 +105,7 @@ class FromTensorConverter : public OpenClConverterImpl {
"__global " + ToCLDataType(output_def.object_def.data_type) + "* dst", "__global " + ToCLDataType(output_def.object_def.data_type) + "* dst",
R"( R"(
int c = d * 4; int c = d * 4;
int index = (y * size.x + x) * channels + c; int index = ((b * size.y + y) * size.x + x) * channels + c;
dst[index] = input.x; dst[index] = input.x;
if (c + 1 < channels) { if (c + 1 < channels) {
@ -143,12 +143,14 @@ const sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_
__kernel void from_tensor()" + __kernel void from_tensor()" +
src_tensor.GetDeclaration(AccessType::READ) + ", " + src_tensor.GetDeclaration(AccessType::READ) + ", " +
params_kernel.first + R"(, int4 size, int channels) { params_kernel.first + R"(, int4 size, int channels) {
int x = get_global_id(0); int linear_id = get_global_id(0);
int x = (linear_id / size.w);
int b = linear_id % size.w;
int y = get_global_id(1); int y = get_global_id(1);
int d = get_global_id(2); int d = get_global_id(2);
if (x >= size.x || y >= size.y || d >= size.z) return; if (x >= size.x || y >= size.y || d >= size.z) return;
)" + ToCLDataType(input_def.object_def.data_type, 4) + )" + ToCLDataType(input_def.object_def.data_type, 4) +
" input = " + src_tensor.ReadWHS("x", "y", "d") + ";\n" + " input = " + src_tensor.ReadWHSB("x", "y", "d", "b") + ";\n" +
params_kernel.second + "\n}"; params_kernel.second + "\n}";
queue_ = environment->queue(); queue_ = environment->queue();
dims_ = input_def.dimensions; dims_ = input_def.dimensions;
@ -218,7 +220,7 @@ class ToTensorConverter : public OpenClConverterImpl {
return std::make_pair( return std::make_pair(
"__global " + ToCLDataType(input_def.object_def.data_type) + "* src", "__global " + ToCLDataType(input_def.object_def.data_type) + "* src",
R"(int c = d * 4; R"(int c = d * 4;
int index = (y * size.x + x) * channels + c; int index = ((b * size.y + y) * size.x + x) * channels + c;
result.x = src[index]; result.x = src[index];
result.y = c + 1 < channels ? src[index + 1] : 1; result.y = c + 1 < channels ? src[index + 1] : 1;
result.z = c + 2 < channels ? src[index + 2] : 2; result.z = c + 2 < channels ? src[index + 2] : 2;
@ -247,14 +249,16 @@ __kernel void to_tensor()" +
params_kernel.first + ", " + params_kernel.first + ", " +
dst_tensor.GetDeclaration(AccessType::WRITE) + dst_tensor.GetDeclaration(AccessType::WRITE) +
R"(, int4 size, int channels) { R"(, int4 size, int channels) {
int x = get_global_id(0); int linear_id = get_global_id(0);
int x = (linear_id / size.w);
int b = linear_id % size.w;
int y = get_global_id(1); int y = get_global_id(1);
int d = get_global_id(2); int d = get_global_id(2);
if (x >= size.x || y >= size.y || d >= size.z) return; if (x >= size.x || y >= size.y || d >= size.z) return;
)" + ToCLDataType(output_def.object_def.data_type, 4) + )" + ToCLDataType(output_def.object_def.data_type, 4) +
" result;\n" + params_kernel.second + "\n " + " result;\n" + params_kernel.second + "\n " +
dst_tensor.WriteWHS("result", "x", "y", "d") + ";\n}"; dst_tensor.WriteWHSB("result", "x", "y", "d", "b") + ";\n}";
queue_ = environment->queue(); queue_ = environment->queue();
dims_ = output_def.dimensions; dims_ = output_def.dimensions;
return environment->program_cache()->GetOrCreateCLKernel( return environment->program_cache()->GetOrCreateCLKernel(
@ -350,8 +354,8 @@ class TrivialCopier : public OpenClConverterImpl {
} }
return GetOpenCLError(clEnqueueCopyBuffer( return GetOpenCLError(clEnqueueCopyBuffer(
queue_->queue(), input.memobj, output.memobj, 0, 0, queue_->queue(), input.memobj, output.memobj, 0, 0,
SizeOf(data_type_) * dims_.w * dims_.h * dims_.d() * 4, 0, nullptr, SizeOf(data_type_) * dims_.w * dims_.h * dims_.d() * dims_.b * 4, 0,
nullptr)); nullptr, nullptr));
} }
absl::Status Copy(const OpenClTexture& input, const OpenClTexture& output) { absl::Status Copy(const OpenClTexture& input, const OpenClTexture& output) {