Metal Pooling kernels replaced with pooling tasks from common/tasks.

PiperOrigin-RevId: 353686250
Change-Id: Ib65111414fcef072bd71383c069e222183c24e5b
This commit is contained in:
Raman Sarokin 2021-01-25 11:08:29 -08:00 committed by TensorFlower Gardener
parent 0bb358ee70
commit 5a251d2975
6 changed files with 9 additions and 352 deletions

View File

@ -22,7 +22,6 @@ cc_library(
":conv",
":depthwise_conv",
":fully_connected",
":pooling",
":transpose_conv",
":winograd",
],
@ -296,28 +295,12 @@ ios_unit_test(
deps = [":padding_test_lib"],
)
cc_library(
name = "pooling",
srcs = ["pooling.cc"],
hdrs = ["pooling.h"],
deps = [
":util",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
"@com_google_absl//absl/strings",
],
)
objc_library(
name = "pooling_test_lib",
testonly = 1,
srcs = ["pooling_test.mm"],
sdk_frameworks = ["XCTest"],
deps = [
":pooling",
":test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:pooling_test_util",
],

View File

@ -1,182 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::string GetMaxPoolingCode() {
std::string shader_source = R"(
kernel void ComputeFunction($0
uint3 gid[[thread_position_in_grid]]) {
if (static_cast<int>(gid.x) >= args.dst_tensor.Width() ||
static_cast<int>(gid.y) >= args.dst_tensor.Height() ||
static_cast<int>(gid.z) >= args.dst_tensor.Slices()) {
return;
}
FLT4 maximum = FLT4(-10000.0);
for (int ky = 0; ky < args.kernel_size_y; ++ky) {
for (int kx = 0; kx < args.kernel_size_x; ++kx) {
int c_x = int(gid.x) * args.stride_x - args.offset_x + kx;
int c_y = int(gid.y) * args.stride_y - args.offset_y + ky;
bool outside = c_x < 0 || c_y < 0 || c_x >= args.src_tensor.Width() ||
c_y >= args.src_tensor.Height();
FLT4 src_color = outside ? FLT4(-10000.0) : args.src_tensor.Read(c_x, c_y, gid.z);
maximum = max(maximum, src_color);
}
}
args.dst_tensor.Write(maximum, gid.x, gid.y, gid.z);
}
)";
return shader_source;
}
std::string GetMaxPoolingIndicesCode() {
std::string shader_source = R"(
kernel void ComputeFunction($0
uint3 gid[[thread_position_in_grid]]) {
if (static_cast<int>(gid.x) >= args.dst_tensor.Width() ||
static_cast<int>(gid.y) >= args.dst_tensor.Height() ||
static_cast<int>(gid.z) >= args.dst_tensor.Slices()) {
return;
}
FLT4 maximum = FLT4(-10000.0);
ushort4 indexes = ushort4(0);
ushort index_counter = 0;
for (int ky = 0; ky < args.kernel_size_y; ++ky) {
for (int kx = 0; kx < args.kernel_size_x; ++kx) {
int c_x = int(gid.x) * args.stride_x - args.offset_x + kx;
int c_y = int(gid.y) * args.stride_y - args.offset_y + ky;
bool outside = c_x < 0 || c_y < 0 || c_x >= args.src_tensor.Width() ||
c_y >= args.src_tensor.Height();
FLT4 src_color = outside ? FLT4(-10000.0) : args.src_tensor.Read(c_x, c_y, gid.z);
if (src_color.x > maximum.x) {
indexes.x = index_counter;
maximum.x = src_color.x;
}
if (src_color.y > maximum.y) {
indexes.y = index_counter;
maximum.y = src_color.y;
}
if (src_color.z > maximum.z) {
indexes.z = index_counter;
maximum.z = src_color.z;
}
if (src_color.w > maximum.w) {
indexes.w = index_counter;
maximum.w = src_color.w;
}
index_counter++;
}
}
FLT4 value = static_cast<FLT4>(indexes);
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
}
)";
return shader_source;
}
std::string GetAveragePoolingCode() {
std::string shader_source = R"(
kernel void ComputeFunction($0
uint tid[[thread_index_in_threadgroup]],
uint3 gid[[thread_position_in_grid]]) {
if (static_cast<int>(gid.x) >= args.dst_tensor.Width() ||
static_cast<int>(gid.y) >= args.dst_tensor.Height() ||
static_cast<int>(gid.z) >= args.dst_tensor.Slices()) {
return;
}
float4 sum = float4(0.0f);
float window_size = 0.0f;
for (int ky = 0; ky < args.kernel_size_y; ++ky) {
for (int kx = 0; kx < args.kernel_size_x; ++kx) {
int c_x = int(gid.x) * args.stride_x - args.offset_x + kx;
int c_y = int(gid.y) * args.stride_y - args.offset_y + ky;
bool outside = c_x < 0 || c_y < 0 || c_x >= args.src_tensor.Width() ||
c_y >= args.src_tensor.Height();
float4 src_color = outside ? float4(0.0f) : float4(args.src_tensor.Read(c_x, c_y, gid.z));
window_size += outside ? 0.0f : 1.0f;
sum += src_color;
}
}
// If window_size==0, window covered nothing. This situation is a sign of
// incorrectly constructed operation. NaNs are expected as output.
FLT4 value = FLT4(sum / window_size);
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
}
)";
return shader_source;
}
} // namespace
ComputeTaskDescriptor Pooling(const OperationDef& definition,
const Pooling2DAttributes& attr,
bool generate_indices) {
ComputeTaskDescriptor desc(definition);
if (attr.type == PoolingType::MAX) {
desc.shader_source =
generate_indices ? GetMaxPoolingIndicesCode() : GetMaxPoolingCode();
} else if (attr.type == PoolingType::AVERAGE) {
desc.shader_source = GetAveragePoolingCode();
}
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc.args.AddInt("kernel_size_x", attr.kernel.w);
desc.args.AddInt("kernel_size_y", attr.kernel.h);
desc.args.AddInt("stride_x", attr.strides.w);
desc.args.AddInt("stride_y", attr.strides.h);
desc.args.AddInt("offset_x", attr.padding.prepended.w);
desc.args.AddInt("offset_y", attr.padding.prepended.h);
desc.resize_function = [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
const uint3 grid = uint3(dst_shapes[0].w, dst_shapes[0].h,
DivideRoundUp(dst_shapes[0].c, 4));
const uint3 groups_size = GetWorkGroupSizeForGrid(grid);
int groups_x = DivideRoundUp(grid.x, groups_size.x);
int groups_y = DivideRoundUp(grid.y, groups_size.y);
int groups_z = DivideRoundUp(grid.z, groups_size.z);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
};
return desc;
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -1,37 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_POOLING_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_POOLING_H_
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
namespace tflite {
namespace gpu {
namespace metal {
ComputeTaskDescriptor Pooling(const OperationDef& definition,
const Pooling2DAttributes& attr,
bool generate_indices);
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_POOLING_H_

View File

@ -27,16 +27,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
using ::tflite::gpu::BHWC;
using ::tflite::gpu::DataType;
using ::tflite::gpu::HW;
using ::tflite::gpu::OperationType;
using ::tflite::gpu::Pooling2DAttributes;
using ::tflite::gpu::PoolingType;
using ::tflite::gpu::TensorRef;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
@interface PoolingTest : XCTestCase
@end
@ -44,92 +34,6 @@ using ::tflite::gpu::metal::SingleOpModel;
tflite::gpu::metal::MetalExecutionEnvironment exec_env_;
}
- (void)testPoolingMaxKernel2x2Stride2x2WithIndices {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 4, 4, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
TensorRef<BHWC> indices;
indices.type = DataType::INT32;
indices.ref = 2;
indices.shape = BHWC(1, 2, 2, 1);
Pooling2DAttributes attr;
attr.kernel = HW(2, 2);
attr.padding.prepended = HW(0, 0);
attr.padding.appended = HW(0, 0);
attr.strides = HW(2, 2);
attr.type = PoolingType::MAX;
attr.output_indices = true;
SingleOpModel model({ToString(OperationType::POOLING_2D), attr}, {input}, {output, indices});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 1, 2, 3, 4, 3, 4, 7, 8, 7, 8, 5, 6, 5, 6}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({4, 4, 8, 8}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({3, 3, 1, 1}, model.GetOutput(1), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testPoolingMaxKernel2x2Stride2x2WithoutIndices {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 4, 4, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
Pooling2DAttributes attr;
attr.kernel = HW(2, 2);
attr.padding.prepended = HW(0, 0);
attr.padding.appended = HW(0, 0);
attr.strides = HW(2, 2);
attr.type = PoolingType::MAX;
SingleOpModel model({ToString(OperationType::POOLING_2D), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 1, 2, 3, 4, 3, 4, 7, 8, 7, 8, 5, 6, 5, 6}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({4, 4, 8, 8}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testPoolingAverageKernel2x2Stride2x2 {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 4, 4, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
Pooling2DAttributes attr;
attr.kernel = HW(2, 2);
attr.padding.prepended = HW(0, 0);
attr.padding.appended = HW(0, 0);
attr.strides = HW(2, 2);
attr.type = PoolingType::AVERAGE;
SingleOpModel model({ToString(OperationType::POOLING_2D), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({1, 2, 3, 4}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testAveragePooling {
auto status = AveragePoolingTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());

View File

@ -43,6 +43,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common/tasks:lstm",
"//tensorflow/lite/delegates/gpu/common/tasks:max_unpooling",
"//tensorflow/lite/delegates/gpu/common/tasks:padding",
"//tensorflow/lite/delegates/gpu/common/tasks:pooling",
"//tensorflow/lite/delegates/gpu/common/tasks:prelu",
"//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize",
"//tensorflow/lite/delegates/gpu/common/tasks:reduce",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/tasks/lstm.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/padding.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/pooling.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/prelu.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/quantize_and_dequantize.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/reduce.h"
@ -49,7 +50,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h"
#include "tensorflow/lite/delegates/gpu/metal/selectors/default_selector.h"
@ -125,6 +125,11 @@ void SelectPadding(const PadAttributes& attr, const OperationDef& op_def,
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
}
std::unique_ptr<GPUOperation> SelectPooling(const Pooling2DAttributes& attr,
const OperationDef& op_def) {
return absl::make_unique<GPUOperation>(CreatePooling(op_def, attr));
}
std::unique_ptr<GPUOperation> SelectReduce(const std::set<Axis>& axis_to_reduce,
const BHWC& src_shape,
OperationType op_type,
@ -419,25 +424,8 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
case OperationType::POOLING_2D: {
auto attr =
absl::any_cast<Pooling2DAttributes>(node.operation.attributes);
auto pooling_op_def = op_def;
pooling_op_def.dst_tensors = {op_def.dst_tensors[0]};
auto gpu_op = Pooling(op_def, attr, false);
gpu_subgraph->operations[0].task_desc =
absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
gpu_subgraph->operations[0].input_ids = {static_cast<int>(inputs[0]->id)};
gpu_subgraph->operations[0].output_ids = {
static_cast<int>(outputs[0]->id)};
if (attr.type == PoolingType::MAX && attr.output_indices) {
gpu_subgraph->operations.push_back({});
auto gpu_ind_op = Pooling(op_def, attr, true);
gpu_subgraph->operations[1].task_desc =
absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_ind_op));
gpu_subgraph->operations[1].input_ids = {
static_cast<int>(inputs[0]->id)};
gpu_subgraph->operations[1].output_ids = {
static_cast<int>(outputs[1]->id)};
}
break;
gpu_operation->operation = SelectPooling(attr, op_def);
return absl::OkStatus();
}
case OperationType::PRELU: {
auto attr = absl::any_cast<PReLUAttributes>(node.operation.attributes);