diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc index e82f67392e8..ffe0fb68881 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.cc +++ b/tensorflow/lite/delegates/gpu/cl/api.cc @@ -513,7 +513,7 @@ class InferenceRunnerImpl : public InferenceRunner { TensorObjectDef TensorToDef(const Tensor& tensor) { TensorObjectDef def; - def.dimensions.b = 1; + def.dimensions.b = tensor.Batch(); def.dimensions.h = tensor.Height(); def.dimensions.w = tensor.Width(); def.dimensions.c = tensor.Channels(); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc index e3170f068e9..4d1b274a0aa 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc @@ -44,7 +44,7 @@ class OpenClConverterImpl : public TensorObjectConverter { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(input)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(output)); - int3 grid = int3(dims_.w, dims_.h, dims_.d()); + int3 grid = int3(dims_.w * dims_.b, dims_.h, dims_.d()); int4 size = int4(dims_.w, dims_.h, dims_.d(), dims_.b); RETURN_IF_ERROR(kernel_.SetBytesAuto(size)); RETURN_IF_ERROR(kernel_.SetBytesAuto(dims_.c)); @@ -105,7 +105,7 @@ class FromTensorConverter : public OpenClConverterImpl { "__global " + ToCLDataType(output_def.object_def.data_type) + "* dst", R"( int c = d * 4; - int index = (y * size.x + x) * channels + c; + int index = ((b * size.y + y) * size.x + x) * channels + c; dst[index] = input.x; if (c + 1 < channels) { @@ -143,12 +143,14 @@ const sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_ __kernel void from_tensor()" + src_tensor.GetDeclaration(AccessType::READ) + ", " + params_kernel.first + R"(, int4 size, int channels) { - int x = get_global_id(0); + int linear_id = get_global_id(0); + int x = (linear_id / size.w); + int b = linear_id % size.w; int y = get_global_id(1); int d = get_global_id(2); if (x >= size.x || y >= size.y || d >= size.z) return; )" + ToCLDataType(input_def.object_def.data_type, 4) + - " input = " + src_tensor.ReadWHS("x", "y", "d") + ";\n" + + " input = " + src_tensor.ReadWHSB("x", "y", "d", "b") + ";\n" + params_kernel.second + "\n}"; queue_ = environment->queue(); dims_ = input_def.dimensions; @@ -218,7 +220,7 @@ class ToTensorConverter : public OpenClConverterImpl { return std::make_pair( "__global " + ToCLDataType(input_def.object_def.data_type) + "* src", R"(int c = d * 4; - int index = (y * size.x + x) * channels + c; + int index = ((b * size.y + y) * size.x + x) * channels + c; result.x = src[index]; result.y = c + 1 < channels ? src[index + 1] : 1; result.z = c + 2 < channels ? src[index + 2] : 2; @@ -247,14 +249,16 @@ __kernel void to_tensor()" + params_kernel.first + ", " + dst_tensor.GetDeclaration(AccessType::WRITE) + R"(, int4 size, int channels) { - int x = get_global_id(0); + int linear_id = get_global_id(0); + int x = (linear_id / size.w); + int b = linear_id % size.w; int y = get_global_id(1); int d = get_global_id(2); if (x >= size.x || y >= size.y || d >= size.z) return; )" + ToCLDataType(output_def.object_def.data_type, 4) + " result;\n" + params_kernel.second + "\n " + - dst_tensor.WriteWHS("result", "x", "y", "d") + ";\n}"; + dst_tensor.WriteWHSB("result", "x", "y", "d", "b") + ";\n}"; queue_ = environment->queue(); dims_ = output_def.dimensions; return environment->program_cache()->GetOrCreateCLKernel( @@ -350,8 +354,8 @@ class TrivialCopier : public OpenClConverterImpl { } return GetOpenCLError(clEnqueueCopyBuffer( queue_->queue(), input.memobj, output.memobj, 0, 0, - SizeOf(data_type_) * dims_.w * dims_.h * dims_.d() * 4, 0, nullptr, - nullptr)); + SizeOf(data_type_) * dims_.w * dims_.h * dims_.d() * dims_.b * 4, 0, + nullptr, nullptr)); } absl::Status Copy(const OpenClTexture& input, const OpenClTexture& output) {