Metal Mean kernels replaced with reduce tasks from common/tasks.

Now Metal Mean supports any combination of axis reduction.
Metal supports reduction ops with any combination of axis.

PiperOrigin-RevId: 353507802
Change-Id: Ib032b83e6075c1004e5249b6b8c8f66a13b2cf79
This commit is contained in:
Raman Sarokin 2021-01-24 06:36:01 -08:00 committed by TensorFlower Gardener
parent 32e64f9b89
commit 2cff21749d
6 changed files with 23 additions and 291 deletions

View File

@ -23,7 +23,6 @@ cc_library(
":depthwise_conv",
":fully_connected",
":max_unpooling",
":mean",
":pooling",
":resize",
":transpose_conv",
@ -291,45 +290,6 @@ ios_unit_test(
deps = [":max_unpooling_test_lib"],
)
cc_library(
name = "mean",
srcs = ["mean.cc"],
hdrs = ["mean.h"],
deps = [
":util",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:tensor",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
"@com_google_absl//absl/types:variant",
],
)
objc_library(
name = "mean_test_lib",
testonly = 1,
srcs = ["mean_test.mm"],
sdk_frameworks = ["XCTest"],
deps = [
":mean",
":test_util",
],
)
ios_unit_test(
name = "mean_test",
testonly = 1,
minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = [
"no_mac", # TODO(b/171882133)
"notap",
"tflite_not_portable_android",
],
deps = [":mean_test_lib"],
)
objc_library(
name = "padding_test_lib",
testonly = 1,
@ -776,7 +736,6 @@ objc_library(
"fully_connected_test.mm",
"lstm_test.mm",
"max_unpooling_test.mm",
"mean_test.mm",
"padding_test.mm",
"pooling_test.mm",
"prelu_test.mm",

View File

@ -1,134 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
#include <cstdint>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/substitute.h"
#include "absl/types/variant.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h"
namespace tflite {
namespace gpu {
namespace metal {
std::string GetMeanCode(const int3& work_group_size) {
const std::string wg_x = std::to_string(work_group_size.x);
const std::string wg_y = std::to_string(work_group_size.y);
std::string c = R"(
kernel void ComputeFunction($0
uint tid[[thread_index_in_threadgroup]],
uint3 tid3d[[thread_position_in_threadgroup]],
uint3 gid[[thread_position_in_grid]]) {
int local_x = static_cast<int>(tid3d.x);
int local_y = static_cast<int>(tid3d.y);
int local_id = static_cast<int>(tid);
int S = static_cast<int>(gid.z);
if (S >= args.dst_tensor.Slices()) return;
)";
c += " threadgroup float4 accum[" +
std::to_string(work_group_size.x * work_group_size.y) + "];\n";
c += " accum[local_id] = float4(0.0f);\n";
c += " for (int s_y = local_y; s_y < args.src_tensor.Height(); s_y += " +
wg_y + ") {\n";
c += " for (int s_x = local_x; s_x < args.src_tensor.Width(); s_x += " +
wg_x + ") {\n";
c += " accum[local_id] += float4(args.src_tensor.Read(s_x, s_y, S));\n";
c += " }\n";
c += " }\n";
c += " accum[local_id] *= args.inv_multiplier_x;\n";
c += " threadgroup_barrier(mem_flags::mem_threadgroup);\n";
const int total_size = work_group_size.x * work_group_size.y;
int offset = 1;
int reminder = total_size / 4;
for (; reminder >= 8; reminder /= 4, offset *= 4) {
c += " if (local_id < " + std::to_string(reminder) + ") {\n";
c += " int t = local_id * " + std::to_string(offset * 4) + ";\n";
c += " float4 sum = accum[t + " + std::to_string(offset) + "];\n";
c += " sum += accum[t + " + std::to_string(offset * 2) + "];\n";
c += " sum += accum[t + " + std::to_string(offset * 3) + "];\n";
c += " accum[t] += sum;\n";
c += " }\n";
c += " threadgroup_barrier(mem_flags::mem_threadgroup);\n";
}
c += " float4 sum = accum[0];\n";
reminder *= 4;
for (int i = 1; i < reminder; ++i) {
c += " sum += accum[" + std::to_string(offset * i) + "];\n";
}
c += " FLT4 value = FLT4(sum * args.inv_multiplier_y);\n";
c += R"(
args.dst_tensor.Write(value, 0, 0, gid.z);
}
)";
return c;
}
ComputeTaskDescriptor Mean(const OperationDef& definition,
const MeanAttributes& attr) {
if (attr.dims != std::set<Axis>({Axis::HEIGHT, Axis::WIDTH})) {
// Mean calculation is supported only for height and width
return {};
}
const int3 work_group_size = int3(16, 16, 1);
ComputeTaskDescriptor desc(definition);
std::string code = GetMeanCode(work_group_size);
desc.shader_source = code;
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc.args.AddFloat("inv_multiplier_x");
desc.args.AddFloat("inv_multiplier_y");
desc.update_function = {
[work_group_size](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes,
ArgumentsBinder* args) -> absl::Status {
const double total_size = src_shapes[0].w * src_shapes[0].h;
const double size_0 = work_group_size.x * work_group_size.y;
const double size_1 = total_size / size_0;
RETURN_IF_ERROR(args->SetFloat("inv_multiplier_x", 1.0 / size_1));
RETURN_IF_ERROR(args->SetFloat("inv_multiplier_y", 1.0 / size_0));
return absl::OkStatus();
}};
desc.resize_function = [work_group_size](
const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
const int dst_slices = DivideRoundUp(dst_shapes[0].c, 4);
const int groups_z = DivideRoundUp(dst_slices, work_group_size.z);
return std::make_pair(work_group_size, uint3{1, 1, groups_z});
};
return desc;
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -1,34 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
namespace tflite {
namespace gpu {
namespace metal {
ComputeTaskDescriptor Mean(const OperationDef& definition,
const MeanAttributes& attr);
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_

View File

@ -1,70 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
#import <XCTest/XCTest.h>
#include <string>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
using ::tflite::gpu::Axis;
using ::tflite::gpu::BHWC;
using ::tflite::gpu::DataType;
using ::tflite::gpu::OperationType;
using ::tflite::gpu::MeanAttributes;
using ::tflite::gpu::TensorRef;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
@interface MeanTest : XCTestCase
@end
@implementation MeanTest
- (void)setUp {
[super setUp];
}
- (void)testMeanSmoke {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 1, 1);
MeanAttributes attr;
attr.dims = {Axis::HEIGHT, Axis::WIDTH};
SingleOpModel model({ToString(OperationType::MEAN), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({2.5}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
@end

View File

@ -44,6 +44,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common/tasks:padding",
"//tensorflow/lite/delegates/gpu/common/tasks:prelu",
"//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize",
"//tensorflow/lite/delegates/gpu/common/tasks:reduce",
"//tensorflow/lite/delegates/gpu/common/tasks:relu",
"//tensorflow/lite/delegates/gpu/common/tasks:reshape",
"//tensorflow/lite/delegates/gpu/common/tasks:reshapex4",

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/tasks/padding.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/prelu.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/quantize_and_dequantize.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/reduce.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/relu.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/reshape.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/reshapex4.h"
@ -47,7 +48,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/resize.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
@ -120,6 +120,15 @@ void SelectPadding(const PadAttributes& attr, const OperationDef& op_def,
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
}
std::unique_ptr<GPUOperation> SelectReduce(const std::set<Axis>& axis_to_reduce,
const BHWC& src_shape,
OperationType op_type,
const OperationDef& op_def,
const GpuInfo& gpu_info) {
return absl::make_unique<Reduce>(
CreateReduce(axis_to_reduce, src_shape, op_type, op_def, gpu_info));
}
void SelectReshape(int src_channels, int dst_channels,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
@ -387,13 +396,9 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
}
case OperationType::MEAN: {
auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
if (attr.dims != std::set<Axis>({Axis::HEIGHT, Axis::WIDTH})) {
return absl::UnimplementedError("Mean supports HW axis only in Metal");
}
auto gpu_op = Mean(op_def, attr);
gpu_operation->task_desc =
absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
break;
gpu_operation->operation = SelectReduce(
attr.dims, inputs[0]->tensor.shape, op_type, op_def, gpu_info);
return absl::OkStatus();
}
case OperationType::PAD: {
auto attr = absl::any_cast<PadAttributes>(node.operation.attributes);
@ -429,6 +434,15 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
absl::make_unique<GPUOperation>(CreatePReLU(gpu_info, op_def, attr));
return absl::OkStatus();
}
case OperationType::REDUCE_MAXIMUM:
case OperationType::REDUCE_MINIMUM:
case OperationType::REDUCE_PRODUCT:
case OperationType::REDUCE_SUM: {
auto attr = absl::any_cast<ReduceAttributes>(node.operation.attributes);
gpu_operation->operation = SelectReduce(
attr.dims, inputs[0]->tensor.shape, op_type, op_def, gpu_info);
return absl::OkStatus();
}
case OperationType::RELU: {
auto attr = absl::any_cast<ReLUAttributes>(node.operation.attributes);
gpu_operation->operation =
@ -541,10 +555,6 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
case OperationType::CONSTANT:
// TODO(b/162763635): implement MeanStddevNormalization for Metal.
case OperationType::MEAN_STDDEV_NORMALIZATION:
case OperationType::REDUCE_MAXIMUM:
case OperationType::REDUCE_MINIMUM:
case OperationType::REDUCE_PRODUCT:
case OperationType::REDUCE_SUM:
case OperationType::SPACE_TO_BATCH:
return absl::UnimplementedError("Unsupported op: " + node.operation.type);
default: