Metal Mean kernels replaced with reduce tasks from common/tasks.
Now Metal Mean supports any combination of axis reduction. Metal supports reduction ops with any combination of axis. PiperOrigin-RevId: 353507802 Change-Id: Ib032b83e6075c1004e5249b6b8c8f66a13b2cf79
This commit is contained in:
parent
32e64f9b89
commit
2cff21749d
@ -23,7 +23,6 @@ cc_library(
|
||||
":depthwise_conv",
|
||||
":fully_connected",
|
||||
":max_unpooling",
|
||||
":mean",
|
||||
":pooling",
|
||||
":resize",
|
||||
":transpose_conv",
|
||||
@ -291,45 +290,6 @@ ios_unit_test(
|
||||
deps = [":max_unpooling_test_lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mean",
|
||||
srcs = ["mean.cc"],
|
||||
hdrs = ["mean.h"],
|
||||
deps = [
|
||||
":util",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "mean_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["mean_test.mm"],
|
||||
sdk_frameworks = ["XCTest"],
|
||||
deps = [
|
||||
":mean",
|
||||
":test_util",
|
||||
],
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "mean_test",
|
||||
testonly = 1,
|
||||
minimum_os_version = "11.0",
|
||||
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||
tags = [
|
||||
"no_mac", # TODO(b/171882133)
|
||||
"notap",
|
||||
"tflite_not_portable_android",
|
||||
],
|
||||
deps = [":mean_test_lib"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "padding_test_lib",
|
||||
testonly = 1,
|
||||
@ -776,7 +736,6 @@ objc_library(
|
||||
"fully_connected_test.mm",
|
||||
"lstm_test.mm",
|
||||
"max_unpooling_test.mm",
|
||||
"mean_test.mm",
|
||||
"padding_test.mm",
|
||||
"pooling_test.mm",
|
||||
"prelu_test.mm",
|
||||
|
@ -1,134 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
std::string GetMeanCode(const int3& work_group_size) {
|
||||
const std::string wg_x = std::to_string(work_group_size.x);
|
||||
const std::string wg_y = std::to_string(work_group_size.y);
|
||||
std::string c = R"(
|
||||
kernel void ComputeFunction($0
|
||||
uint tid[[thread_index_in_threadgroup]],
|
||||
uint3 tid3d[[thread_position_in_threadgroup]],
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
int local_x = static_cast<int>(tid3d.x);
|
||||
int local_y = static_cast<int>(tid3d.y);
|
||||
int local_id = static_cast<int>(tid);
|
||||
int S = static_cast<int>(gid.z);
|
||||
if (S >= args.dst_tensor.Slices()) return;
|
||||
)";
|
||||
c += " threadgroup float4 accum[" +
|
||||
std::to_string(work_group_size.x * work_group_size.y) + "];\n";
|
||||
c += " accum[local_id] = float4(0.0f);\n";
|
||||
c += " for (int s_y = local_y; s_y < args.src_tensor.Height(); s_y += " +
|
||||
wg_y + ") {\n";
|
||||
c += " for (int s_x = local_x; s_x < args.src_tensor.Width(); s_x += " +
|
||||
wg_x + ") {\n";
|
||||
c += " accum[local_id] += float4(args.src_tensor.Read(s_x, s_y, S));\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
c += " accum[local_id] *= args.inv_multiplier_x;\n";
|
||||
c += " threadgroup_barrier(mem_flags::mem_threadgroup);\n";
|
||||
const int total_size = work_group_size.x * work_group_size.y;
|
||||
int offset = 1;
|
||||
int reminder = total_size / 4;
|
||||
for (; reminder >= 8; reminder /= 4, offset *= 4) {
|
||||
c += " if (local_id < " + std::to_string(reminder) + ") {\n";
|
||||
c += " int t = local_id * " + std::to_string(offset * 4) + ";\n";
|
||||
c += " float4 sum = accum[t + " + std::to_string(offset) + "];\n";
|
||||
c += " sum += accum[t + " + std::to_string(offset * 2) + "];\n";
|
||||
c += " sum += accum[t + " + std::to_string(offset * 3) + "];\n";
|
||||
c += " accum[t] += sum;\n";
|
||||
c += " }\n";
|
||||
c += " threadgroup_barrier(mem_flags::mem_threadgroup);\n";
|
||||
}
|
||||
c += " float4 sum = accum[0];\n";
|
||||
reminder *= 4;
|
||||
for (int i = 1; i < reminder; ++i) {
|
||||
c += " sum += accum[" + std::to_string(offset * i) + "];\n";
|
||||
}
|
||||
c += " FLT4 value = FLT4(sum * args.inv_multiplier_y);\n";
|
||||
c += R"(
|
||||
args.dst_tensor.Write(value, 0, 0, gid.z);
|
||||
}
|
||||
)";
|
||||
return c;
|
||||
}
|
||||
|
||||
ComputeTaskDescriptor Mean(const OperationDef& definition,
|
||||
const MeanAttributes& attr) {
|
||||
if (attr.dims != std::set<Axis>({Axis::HEIGHT, Axis::WIDTH})) {
|
||||
// Mean calculation is supported only for height and width
|
||||
return {};
|
||||
}
|
||||
|
||||
const int3 work_group_size = int3(16, 16, 1);
|
||||
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
std::string code = GetMeanCode(work_group_size);
|
||||
desc.shader_source = code;
|
||||
|
||||
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||
|
||||
desc.args.AddFloat("inv_multiplier_x");
|
||||
desc.args.AddFloat("inv_multiplier_y");
|
||||
|
||||
desc.update_function = {
|
||||
[work_group_size](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes,
|
||||
ArgumentsBinder* args) -> absl::Status {
|
||||
const double total_size = src_shapes[0].w * src_shapes[0].h;
|
||||
const double size_0 = work_group_size.x * work_group_size.y;
|
||||
const double size_1 = total_size / size_0;
|
||||
RETURN_IF_ERROR(args->SetFloat("inv_multiplier_x", 1.0 / size_1));
|
||||
RETURN_IF_ERROR(args->SetFloat("inv_multiplier_y", 1.0 / size_0));
|
||||
return absl::OkStatus();
|
||||
}};
|
||||
|
||||
desc.resize_function = [work_group_size](
|
||||
const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
const int dst_slices = DivideRoundUp(dst_shapes[0].c, 4);
|
||||
const int groups_z = DivideRoundUp(dst_slices, work_group_size.z);
|
||||
return std::make_pair(work_group_size, uint3{1, 1, groups_z});
|
||||
};
|
||||
return desc;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -1,34 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
ComputeTaskDescriptor Mean(const OperationDef& definition,
|
||||
const MeanAttributes& attr);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_
|
@ -1,70 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
|
||||
|
||||
using ::tflite::gpu::Axis;
|
||||
using ::tflite::gpu::BHWC;
|
||||
using ::tflite::gpu::DataType;
|
||||
using ::tflite::gpu::OperationType;
|
||||
using ::tflite::gpu::MeanAttributes;
|
||||
using ::tflite::gpu::TensorRef;
|
||||
using ::tflite::gpu::metal::CompareVectors;
|
||||
using ::tflite::gpu::metal::SingleOpModel;
|
||||
|
||||
@interface MeanTest : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation MeanTest
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
}
|
||||
|
||||
- (void)testMeanSmoke {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
|
||||
MeanAttributes attr;
|
||||
attr.dims = {Axis::HEIGHT, Axis::WIDTH};
|
||||
|
||||
SingleOpModel model({ToString(OperationType::MEAN), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({2.5}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
@end
|
@ -44,6 +44,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:padding",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:prelu",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:reduce",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:relu",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:reshape",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:reshapex4",
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/padding.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/prelu.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/quantize_and_dequantize.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/reduce.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/relu.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/reshape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/reshapex4.h"
|
||||
@ -47,7 +48,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/resize.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
|
||||
@ -120,6 +120,15 @@ void SelectPadding(const PadAttributes& attr, const OperationDef& op_def,
|
||||
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
|
||||
}
|
||||
|
||||
std::unique_ptr<GPUOperation> SelectReduce(const std::set<Axis>& axis_to_reduce,
|
||||
const BHWC& src_shape,
|
||||
OperationType op_type,
|
||||
const OperationDef& op_def,
|
||||
const GpuInfo& gpu_info) {
|
||||
return absl::make_unique<Reduce>(
|
||||
CreateReduce(axis_to_reduce, src_shape, op_type, op_def, gpu_info));
|
||||
}
|
||||
|
||||
void SelectReshape(int src_channels, int dst_channels,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
@ -387,13 +396,9 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
|
||||
}
|
||||
case OperationType::MEAN: {
|
||||
auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
|
||||
if (attr.dims != std::set<Axis>({Axis::HEIGHT, Axis::WIDTH})) {
|
||||
return absl::UnimplementedError("Mean supports HW axis only in Metal");
|
||||
}
|
||||
auto gpu_op = Mean(op_def, attr);
|
||||
gpu_operation->task_desc =
|
||||
absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
break;
|
||||
gpu_operation->operation = SelectReduce(
|
||||
attr.dims, inputs[0]->tensor.shape, op_type, op_def, gpu_info);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::PAD: {
|
||||
auto attr = absl::any_cast<PadAttributes>(node.operation.attributes);
|
||||
@ -429,6 +434,15 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
|
||||
absl::make_unique<GPUOperation>(CreatePReLU(gpu_info, op_def, attr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::REDUCE_MAXIMUM:
|
||||
case OperationType::REDUCE_MINIMUM:
|
||||
case OperationType::REDUCE_PRODUCT:
|
||||
case OperationType::REDUCE_SUM: {
|
||||
auto attr = absl::any_cast<ReduceAttributes>(node.operation.attributes);
|
||||
gpu_operation->operation = SelectReduce(
|
||||
attr.dims, inputs[0]->tensor.shape, op_type, op_def, gpu_info);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::RELU: {
|
||||
auto attr = absl::any_cast<ReLUAttributes>(node.operation.attributes);
|
||||
gpu_operation->operation =
|
||||
@ -541,10 +555,6 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
|
||||
case OperationType::CONSTANT:
|
||||
// TODO(b/162763635): implement MeanStddevNormalization for Metal.
|
||||
case OperationType::MEAN_STDDEV_NORMALIZATION:
|
||||
case OperationType::REDUCE_MAXIMUM:
|
||||
case OperationType::REDUCE_MINIMUM:
|
||||
case OperationType::REDUCE_PRODUCT:
|
||||
case OperationType::REDUCE_SUM:
|
||||
case OperationType::SPACE_TO_BATCH:
|
||||
return absl::UnimplementedError("Unsupported op: " + node.operation.type);
|
||||
default:
|
||||
|
Loading…
x
Reference in New Issue
Block a user