Added batch support for OpenCL converters in some cases.
PiperOrigin-RevId: 315310765 Change-Id: Icec9b3b989e2c3a796882d7cb53a9c5a27bebedf
This commit is contained in:
parent
ee4f27a524
commit
4f4f5db82f
|
@ -513,7 +513,7 @@ class InferenceRunnerImpl : public InferenceRunner {
|
|||
|
||||
TensorObjectDef TensorToDef(const Tensor& tensor) {
|
||||
TensorObjectDef def;
|
||||
def.dimensions.b = 1;
|
||||
def.dimensions.b = tensor.Batch();
|
||||
def.dimensions.h = tensor.Height();
|
||||
def.dimensions.w = tensor.Width();
|
||||
def.dimensions.c = tensor.Channels();
|
||||
|
|
|
@ -44,7 +44,7 @@ class OpenClConverterImpl : public TensorObjectConverter {
|
|||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(input));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(output));
|
||||
int3 grid = int3(dims_.w, dims_.h, dims_.d());
|
||||
int3 grid = int3(dims_.w * dims_.b, dims_.h, dims_.d());
|
||||
int4 size = int4(dims_.w, dims_.h, dims_.d(), dims_.b);
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(size));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dims_.c));
|
||||
|
@ -105,7 +105,7 @@ class FromTensorConverter : public OpenClConverterImpl {
|
|||
"__global " + ToCLDataType(output_def.object_def.data_type) + "* dst",
|
||||
R"(
|
||||
int c = d * 4;
|
||||
int index = (y * size.x + x) * channels + c;
|
||||
int index = ((b * size.y + y) * size.x + x) * channels + c;
|
||||
|
||||
dst[index] = input.x;
|
||||
if (c + 1 < channels) {
|
||||
|
@ -143,12 +143,14 @@ const sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_
|
|||
__kernel void from_tensor()" +
|
||||
src_tensor.GetDeclaration(AccessType::READ) + ", " +
|
||||
params_kernel.first + R"(, int4 size, int channels) {
|
||||
int x = get_global_id(0);
|
||||
int linear_id = get_global_id(0);
|
||||
int x = (linear_id / size.w);
|
||||
int b = linear_id % size.w;
|
||||
int y = get_global_id(1);
|
||||
int d = get_global_id(2);
|
||||
if (x >= size.x || y >= size.y || d >= size.z) return;
|
||||
)" + ToCLDataType(input_def.object_def.data_type, 4) +
|
||||
" input = " + src_tensor.ReadWHS("x", "y", "d") + ";\n" +
|
||||
" input = " + src_tensor.ReadWHSB("x", "y", "d", "b") + ";\n" +
|
||||
params_kernel.second + "\n}";
|
||||
queue_ = environment->queue();
|
||||
dims_ = input_def.dimensions;
|
||||
|
@ -218,7 +220,7 @@ class ToTensorConverter : public OpenClConverterImpl {
|
|||
return std::make_pair(
|
||||
"__global " + ToCLDataType(input_def.object_def.data_type) + "* src",
|
||||
R"(int c = d * 4;
|
||||
int index = (y * size.x + x) * channels + c;
|
||||
int index = ((b * size.y + y) * size.x + x) * channels + c;
|
||||
result.x = src[index];
|
||||
result.y = c + 1 < channels ? src[index + 1] : 1;
|
||||
result.z = c + 2 < channels ? src[index + 2] : 2;
|
||||
|
@ -247,14 +249,16 @@ __kernel void to_tensor()" +
|
|||
params_kernel.first + ", " +
|
||||
dst_tensor.GetDeclaration(AccessType::WRITE) +
|
||||
R"(, int4 size, int channels) {
|
||||
int x = get_global_id(0);
|
||||
int linear_id = get_global_id(0);
|
||||
int x = (linear_id / size.w);
|
||||
int b = linear_id % size.w;
|
||||
int y = get_global_id(1);
|
||||
int d = get_global_id(2);
|
||||
|
||||
if (x >= size.x || y >= size.y || d >= size.z) return;
|
||||
)" + ToCLDataType(output_def.object_def.data_type, 4) +
|
||||
" result;\n" + params_kernel.second + "\n " +
|
||||
dst_tensor.WriteWHS("result", "x", "y", "d") + ";\n}";
|
||||
dst_tensor.WriteWHSB("result", "x", "y", "d", "b") + ";\n}";
|
||||
queue_ = environment->queue();
|
||||
dims_ = output_def.dimensions;
|
||||
return environment->program_cache()->GetOrCreateCLKernel(
|
||||
|
@ -350,8 +354,8 @@ class TrivialCopier : public OpenClConverterImpl {
|
|||
}
|
||||
return GetOpenCLError(clEnqueueCopyBuffer(
|
||||
queue_->queue(), input.memobj, output.memobj, 0, 0,
|
||||
SizeOf(data_type_) * dims_.w * dims_.h * dims_.d() * 4, 0, nullptr,
|
||||
nullptr));
|
||||
SizeOf(data_type_) * dims_.w * dims_.h * dims_.d() * dims_.b * 4, 0,
|
||||
nullptr, nullptr));
|
||||
}
|
||||
|
||||
absl::Status Copy(const OpenClTexture& input, const OpenClTexture& output) {
|
||||
|
|
Loading…
Reference in New Issue