Metal MaxUnpooling kernels replaced with max unpooling tasks from common/tasks.
PiperOrigin-RevId: 353682068 Change-Id: I891d6012a135540b4d3087f4103bc5222c598549
This commit is contained in:
parent
5bb0f5174d
commit
5cd9da0f25
@ -22,7 +22,6 @@ cc_library(
|
||||
":conv",
|
||||
":depthwise_conv",
|
||||
":fully_connected",
|
||||
":max_unpooling",
|
||||
":pooling",
|
||||
":transpose_conv",
|
||||
":winograd",
|
||||
@ -251,28 +250,12 @@ ios_unit_test(
|
||||
deps = [":lstm_test_lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "max_unpooling",
|
||||
srcs = ["max_unpooling.cc"],
|
||||
hdrs = ["max_unpooling.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "max_unpooling_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["max_unpooling_test.mm"],
|
||||
sdk_frameworks = ["XCTest"],
|
||||
deps = [
|
||||
":max_unpooling",
|
||||
":test_util",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:max_unpooling_test_util",
|
||||
],
|
||||
|
@ -1,102 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
|
||||
std::string GetMaxUnpoolingCode() {
|
||||
std::string shader_source = R"(
|
||||
kernel void ComputeFunction($0
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
int X = static_cast<int>(gid.x);
|
||||
int Y = static_cast<int>(gid.y);
|
||||
if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) {
|
||||
return;
|
||||
}
|
||||
|
||||
int src_x = (X + args.offset_x) / args.stride_x;
|
||||
int src_y = (Y + args.offset_y) / args.stride_y;
|
||||
|
||||
bool outside = src_x < 0 || src_y < 0 ||
|
||||
src_x >= args.src_tensor.Width() || src_y >= args.src_tensor.Height();
|
||||
|
||||
int4 indexes = outside ? int4(0) : int4(args.src_indices.Read(src_x, src_y, gid.z));
|
||||
FLT4 src_color = outside ? FLT4(0.0f) : args.src_tensor.Read(src_x, src_y, gid.z);
|
||||
|
||||
int t_x = X - (src_x * args.stride_x - args.offset_x);
|
||||
int t_y = Y - (src_y * args.stride_y - args.offset_y);
|
||||
int t_index = t_y * args.kernel_size_x + t_x;
|
||||
|
||||
FLT4 value;
|
||||
value.x = t_index == indexes.x ? src_color.x : 0.0;
|
||||
value.y = t_index == indexes.y ? src_color.y : 0.0;
|
||||
value.z = t_index == indexes.z ? src_color.z : 0.0;
|
||||
value.w = t_index == indexes.w ? src_color.w : 0.0;
|
||||
|
||||
args.dst_tensor.Write(value, X, Y, gid.z);
|
||||
}
|
||||
)";
|
||||
return shader_source;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ComputeTaskDescriptor MaxUnpooling(const OperationDef& definition,
|
||||
const MaxUnpooling2DAttributes& attr) {
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.shader_source = GetMaxUnpoolingCode();
|
||||
|
||||
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||
desc.AddSrcTensor("src_indices", definition.src_tensors[1]);
|
||||
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||
|
||||
desc.args.AddInt("kernel_size_x", attr.kernel.w);
|
||||
desc.args.AddInt("stride_x", attr.strides.w);
|
||||
desc.args.AddInt("stride_y", attr.strides.h);
|
||||
desc.args.AddInt("offset_x", attr.padding.prepended.w);
|
||||
desc.args.AddInt("offset_y", attr.padding.prepended.h);
|
||||
|
||||
desc.resize_function = [](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
const uint3 groups_size{8, 4, 1};
|
||||
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
|
||||
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
|
||||
int groups_z = DivideRoundUp(dst_shapes[0].c, 4);
|
||||
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
|
||||
};
|
||||
|
||||
return desc;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -1,36 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MAX_UNPOOLING_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MAX_UNPOOLING_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
ComputeTaskDescriptor MaxUnpooling(const OperationDef& definition,
|
||||
const MaxUnpooling2DAttributes& attr);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MAX_UNPOOLING_H_
|
@ -27,15 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
|
||||
|
||||
using ::tflite::gpu::BHWC;
|
||||
using ::tflite::gpu::DataType;
|
||||
using ::tflite::gpu::HW;
|
||||
using ::tflite::gpu::MaxUnpooling2DAttributes;
|
||||
using ::tflite::gpu::OperationType;
|
||||
using ::tflite::gpu::TensorRef;
|
||||
using ::tflite::gpu::metal::CompareVectors;
|
||||
using ::tflite::gpu::metal::SingleOpModel;
|
||||
|
||||
@interface MaxUnpoolingMetalTest : XCTestCase
|
||||
@end
|
||||
|
||||
@ -43,39 +34,6 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
tflite::gpu::metal::MetalExecutionEnvironment exec_env_;
|
||||
}
|
||||
|
||||
- (void)testKernel2x2Stride2x2 {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRef<BHWC> indices;
|
||||
indices.type = DataType::INT32;
|
||||
indices.ref = 1;
|
||||
indices.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 4, 4, 1);
|
||||
|
||||
MaxUnpooling2DAttributes attr;
|
||||
attr.kernel = HW(2, 2);
|
||||
attr.padding.prepended = HW(0, 0);
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(2, 2);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::MAX_UNPOOLING_2D), attr}, {input, indices},
|
||||
{output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
XCTAssertTrue(model.PopulateTensor(1, {0, 0, 0, 0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status =
|
||||
CompareVectors({1, 0, 2, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMaxUnpooling {
|
||||
auto status = MaxUnpoolingTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
|
@ -41,6 +41,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:concat_z",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:elementwise",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:lstm",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:max_unpooling",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:padding",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:prelu",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize",
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/concat_z.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/elementwise.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/lstm.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/padding.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/prelu.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/quantize_and_dequantize.h"
|
||||
@ -48,7 +49,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h"
|
||||
@ -114,6 +114,11 @@ std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
|
||||
return absl::make_unique<GPUOperation>(CreateLSTM(op_def, gpu_info));
|
||||
}
|
||||
|
||||
std::unique_ptr<GPUOperation> SelectMaxUnpooling(
|
||||
const MaxUnpooling2DAttributes& attr, const OperationDef& op_def) {
|
||||
return absl::make_unique<GPUOperation>(CreateMaxUnpooling(op_def, attr));
|
||||
}
|
||||
|
||||
void SelectPadding(const PadAttributes& attr, const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
GPUOperation operation = CreatePadding(op_def, attr);
|
||||
@ -395,12 +400,10 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::MAX_UNPOOLING_2D: {
|
||||
auto gpu_op = MaxUnpooling(
|
||||
op_def,
|
||||
absl::any_cast<MaxUnpooling2DAttributes>(node.operation.attributes));
|
||||
gpu_operation->task_desc =
|
||||
absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
break;
|
||||
auto attr =
|
||||
absl::any_cast<MaxUnpooling2DAttributes>(node.operation.attributes);
|
||||
gpu_operation->operation = SelectMaxUnpooling(attr, op_def);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::MEAN: {
|
||||
auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
|
||||
|
Loading…
Reference in New Issue
Block a user