Added batch support for OpenCL converters in some cases.
PiperOrigin-RevId: 315310765 Change-Id: Icec9b3b989e2c3a796882d7cb53a9c5a27bebedf
This commit is contained in:
parent
ee4f27a524
commit
4f4f5db82f
|
@ -513,7 +513,7 @@ class InferenceRunnerImpl : public InferenceRunner {
|
||||||
|
|
||||||
TensorObjectDef TensorToDef(const Tensor& tensor) {
|
TensorObjectDef TensorToDef(const Tensor& tensor) {
|
||||||
TensorObjectDef def;
|
TensorObjectDef def;
|
||||||
def.dimensions.b = 1;
|
def.dimensions.b = tensor.Batch();
|
||||||
def.dimensions.h = tensor.Height();
|
def.dimensions.h = tensor.Height();
|
||||||
def.dimensions.w = tensor.Width();
|
def.dimensions.w = tensor.Width();
|
||||||
def.dimensions.c = tensor.Channels();
|
def.dimensions.c = tensor.Channels();
|
||||||
|
|
|
@ -44,7 +44,7 @@ class OpenClConverterImpl : public TensorObjectConverter {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(input));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(input));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(output));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(output));
|
||||||
int3 grid = int3(dims_.w, dims_.h, dims_.d());
|
int3 grid = int3(dims_.w * dims_.b, dims_.h, dims_.d());
|
||||||
int4 size = int4(dims_.w, dims_.h, dims_.d(), dims_.b);
|
int4 size = int4(dims_.w, dims_.h, dims_.d(), dims_.b);
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(size));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(size));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dims_.c));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dims_.c));
|
||||||
|
@ -105,7 +105,7 @@ class FromTensorConverter : public OpenClConverterImpl {
|
||||||
"__global " + ToCLDataType(output_def.object_def.data_type) + "* dst",
|
"__global " + ToCLDataType(output_def.object_def.data_type) + "* dst",
|
||||||
R"(
|
R"(
|
||||||
int c = d * 4;
|
int c = d * 4;
|
||||||
int index = (y * size.x + x) * channels + c;
|
int index = ((b * size.y + y) * size.x + x) * channels + c;
|
||||||
|
|
||||||
dst[index] = input.x;
|
dst[index] = input.x;
|
||||||
if (c + 1 < channels) {
|
if (c + 1 < channels) {
|
||||||
|
@ -143,12 +143,14 @@ const sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_
|
||||||
__kernel void from_tensor()" +
|
__kernel void from_tensor()" +
|
||||||
src_tensor.GetDeclaration(AccessType::READ) + ", " +
|
src_tensor.GetDeclaration(AccessType::READ) + ", " +
|
||||||
params_kernel.first + R"(, int4 size, int channels) {
|
params_kernel.first + R"(, int4 size, int channels) {
|
||||||
int x = get_global_id(0);
|
int linear_id = get_global_id(0);
|
||||||
|
int x = (linear_id / size.w);
|
||||||
|
int b = linear_id % size.w;
|
||||||
int y = get_global_id(1);
|
int y = get_global_id(1);
|
||||||
int d = get_global_id(2);
|
int d = get_global_id(2);
|
||||||
if (x >= size.x || y >= size.y || d >= size.z) return;
|
if (x >= size.x || y >= size.y || d >= size.z) return;
|
||||||
)" + ToCLDataType(input_def.object_def.data_type, 4) +
|
)" + ToCLDataType(input_def.object_def.data_type, 4) +
|
||||||
" input = " + src_tensor.ReadWHS("x", "y", "d") + ";\n" +
|
" input = " + src_tensor.ReadWHSB("x", "y", "d", "b") + ";\n" +
|
||||||
params_kernel.second + "\n}";
|
params_kernel.second + "\n}";
|
||||||
queue_ = environment->queue();
|
queue_ = environment->queue();
|
||||||
dims_ = input_def.dimensions;
|
dims_ = input_def.dimensions;
|
||||||
|
@ -218,7 +220,7 @@ class ToTensorConverter : public OpenClConverterImpl {
|
||||||
return std::make_pair(
|
return std::make_pair(
|
||||||
"__global " + ToCLDataType(input_def.object_def.data_type) + "* src",
|
"__global " + ToCLDataType(input_def.object_def.data_type) + "* src",
|
||||||
R"(int c = d * 4;
|
R"(int c = d * 4;
|
||||||
int index = (y * size.x + x) * channels + c;
|
int index = ((b * size.y + y) * size.x + x) * channels + c;
|
||||||
result.x = src[index];
|
result.x = src[index];
|
||||||
result.y = c + 1 < channels ? src[index + 1] : 1;
|
result.y = c + 1 < channels ? src[index + 1] : 1;
|
||||||
result.z = c + 2 < channels ? src[index + 2] : 2;
|
result.z = c + 2 < channels ? src[index + 2] : 2;
|
||||||
|
@ -247,14 +249,16 @@ __kernel void to_tensor()" +
|
||||||
params_kernel.first + ", " +
|
params_kernel.first + ", " +
|
||||||
dst_tensor.GetDeclaration(AccessType::WRITE) +
|
dst_tensor.GetDeclaration(AccessType::WRITE) +
|
||||||
R"(, int4 size, int channels) {
|
R"(, int4 size, int channels) {
|
||||||
int x = get_global_id(0);
|
int linear_id = get_global_id(0);
|
||||||
|
int x = (linear_id / size.w);
|
||||||
|
int b = linear_id % size.w;
|
||||||
int y = get_global_id(1);
|
int y = get_global_id(1);
|
||||||
int d = get_global_id(2);
|
int d = get_global_id(2);
|
||||||
|
|
||||||
if (x >= size.x || y >= size.y || d >= size.z) return;
|
if (x >= size.x || y >= size.y || d >= size.z) return;
|
||||||
)" + ToCLDataType(output_def.object_def.data_type, 4) +
|
)" + ToCLDataType(output_def.object_def.data_type, 4) +
|
||||||
" result;\n" + params_kernel.second + "\n " +
|
" result;\n" + params_kernel.second + "\n " +
|
||||||
dst_tensor.WriteWHS("result", "x", "y", "d") + ";\n}";
|
dst_tensor.WriteWHSB("result", "x", "y", "d", "b") + ";\n}";
|
||||||
queue_ = environment->queue();
|
queue_ = environment->queue();
|
||||||
dims_ = output_def.dimensions;
|
dims_ = output_def.dimensions;
|
||||||
return environment->program_cache()->GetOrCreateCLKernel(
|
return environment->program_cache()->GetOrCreateCLKernel(
|
||||||
|
@ -350,8 +354,8 @@ class TrivialCopier : public OpenClConverterImpl {
|
||||||
}
|
}
|
||||||
return GetOpenCLError(clEnqueueCopyBuffer(
|
return GetOpenCLError(clEnqueueCopyBuffer(
|
||||||
queue_->queue(), input.memobj, output.memobj, 0, 0,
|
queue_->queue(), input.memobj, output.memobj, 0, 0,
|
||||||
SizeOf(data_type_) * dims_.w * dims_.h * dims_.d() * 4, 0, nullptr,
|
SizeOf(data_type_) * dims_.w * dims_.h * dims_.d() * dims_.b * 4, 0,
|
||||||
nullptr));
|
nullptr, nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Copy(const OpenClTexture& input, const OpenClTexture& output) {
|
absl::Status Copy(const OpenClTexture& input, const OpenClTexture& output) {
|
||||||
|
|
Loading…
Reference in New Issue