Support Mean operation in the most naive way.
PiperOrigin-RevId: 291009742 Change-Id: I13d0afa5287af5418f76058b0f4706e5f68e7a53
This commit is contained in:
parent
49055157ec
commit
de90b76101
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -2385,6 +2385,51 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser {
|
||||
private:
|
||||
};
|
||||
|
||||
class MeanOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
return CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
|
||||
/*outputs=*/1);
|
||||
}
|
||||
|
||||
Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration, GraphFloat32* graph,
|
||||
ObjectReader* reader) final {
|
||||
auto* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MEAN);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
MeanAttributes attr;
|
||||
Tensor<Linear, DataType::INT32> channel;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &channel));
|
||||
for (int i = 0; i < channel.data.size(); i++) {
|
||||
std::string unsupported;
|
||||
switch (channel.data[i]) {
|
||||
case 1:
|
||||
attr.dims.insert(Axis::HEIGHT);
|
||||
break;
|
||||
case 2:
|
||||
attr.dims.insert(Axis::WIDTH);
|
||||
break;
|
||||
case 0:
|
||||
unsupported = unsupported.empty() ? "batch" : unsupported;
|
||||
ABSL_FALLTHROUGH_INTENDED;
|
||||
case 3:
|
||||
unsupported = unsupported.empty() ? "channels" : unsupported;
|
||||
ABSL_FALLTHROUGH_INTENDED;
|
||||
default:
|
||||
return UnimplementedError(
|
||||
absl::StrCat("Unsupported mean dimension: ", unsupported));
|
||||
}
|
||||
}
|
||||
node->operation.attributes = attr;
|
||||
return OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
class UnsupportedOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
Status IsSupported(const TfLiteContext* context,
|
||||
@ -2433,6 +2478,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
return absl::make_unique<LSTMOperationParser>();
|
||||
case kTfLiteBuiltinMaxPool2d:
|
||||
return absl::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
|
||||
case kTfLiteBuiltinMean:
|
||||
return absl::make_unique<MeanOperationParser>();
|
||||
case kTfLiteBuiltinMirrorPad:
|
||||
return absl::make_unique<PadOperationParser>(/*mirror_pad=*/true);
|
||||
case kTfLiteBuiltinMul:
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -102,6 +102,8 @@ std::string ToString(enum OperationType op) {
|
||||
return "lstm";
|
||||
case OperationType::MAX_UNPOOLING_2D:
|
||||
return "max_unpooling";
|
||||
case OperationType::MEAN:
|
||||
return "mean";
|
||||
case OperationType::MUL:
|
||||
return "mul";
|
||||
case OperationType::MULTIPLY_SCALAR:
|
||||
@ -171,6 +173,7 @@ OperationType OperationTypeFromString(const std::string& name) {
|
||||
{"log", OperationType::LOG},
|
||||
{"lstm", OperationType::LSTM},
|
||||
{"max_unpooling", OperationType::MAX_UNPOOLING_2D},
|
||||
{"mean", OperationType::MEAN},
|
||||
{"mul", OperationType::MUL},
|
||||
{"multiply_scalar", OperationType::MULTIPLY_SCALAR},
|
||||
{"pad", OperationType::PAD},
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -50,6 +50,7 @@ enum class OperationType {
|
||||
LOG,
|
||||
LSTM,
|
||||
MAX_UNPOOLING_2D,
|
||||
MEAN,
|
||||
MUL,
|
||||
MULTIPLY_SCALAR,
|
||||
PAD,
|
||||
@ -166,6 +167,11 @@ struct MaxUnpooling3DAttributes {
|
||||
Padding3D padding;
|
||||
};
|
||||
|
||||
struct MeanAttributes {
|
||||
// The vector of dimensions to calculate mean along.
|
||||
std::set<Axis> dims;
|
||||
};
|
||||
|
||||
struct ConcatAttributes {
|
||||
// Defines axis by which to concat on.
|
||||
Axis axis = Axis::UNKNOWN;
|
||||
|
@ -303,6 +303,36 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mean",
|
||||
srcs = ["mean.cc"],
|
||||
hdrs = ["mean.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/gl:node_shader",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "mean_test",
|
||||
srcs = ["mean_test.cc"],
|
||||
tags = [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
],
|
||||
deps = [
|
||||
":mean",
|
||||
":test_util",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mul",
|
||||
srcs = ["mul.cc"],
|
||||
@ -641,6 +671,7 @@ TFLITE_GPU_BINARY_RELEASE_OPERATORS = [
|
||||
"pooling",
|
||||
"prelu",
|
||||
"relu",
|
||||
"mean",
|
||||
"reshape",
|
||||
"slice",
|
||||
"softmax",
|
||||
|
81
tensorflow/lite/delegates/gpu/gl/kernels/mean.cc
Normal file
81
tensorflow/lite/delegates/gpu/gl/kernels/mean.cc
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/gl/kernels/mean.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace gl {
|
||||
namespace {
|
||||
|
||||
class Mean : public NodeShader {
|
||||
public:
|
||||
Status GenerateCode(const GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) const final {
|
||||
auto attr = absl::any_cast<MeanAttributes>(ctx.node->operation.attributes);
|
||||
if (attr.dims != std::set<Axis>({Axis::HEIGHT, Axis::WIDTH})) {
|
||||
return InvalidArgumentError(
|
||||
"Mean calculation is supported only for height and width.");
|
||||
}
|
||||
|
||||
auto input = ctx.graph->FindInputs(ctx.node->id)[0];
|
||||
|
||||
std::vector<Variable> parameters = {
|
||||
{"input_data_0_h", input->tensor.shape.h},
|
||||
{"input_data_0_w", input->tensor.shape.w}};
|
||||
|
||||
std::string source = R"(
|
||||
vec4 sum = vec4(0.0);
|
||||
float size = float($input_data_0_w$ * $input_data_0_h$);
|
||||
for (int w = 0; w < $input_data_0_w$; w++) {
|
||||
for (int h = 0; h < $input_data_0_h$; h++) {
|
||||
sum += $input_data_0[w, h, gid.z]$;
|
||||
}
|
||||
}
|
||||
value_0 = sum / size;
|
||||
)";
|
||||
*generated_code = {
|
||||
/*parameters=*/std::move(parameters),
|
||||
/*objects=*/{},
|
||||
/*shared_variables=*/{},
|
||||
/*workload=*/uint3(),
|
||||
/*workgroup=*/uint3(1, 1, 4),
|
||||
/*source_code=*/std::move(source),
|
||||
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
return OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<NodeShader> NewMeanNodeShader() {
|
||||
return absl::make_unique<Mean>();
|
||||
}
|
||||
|
||||
} // namespace gl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
34
tensorflow/lite/delegates/gpu/gl/kernels/mean.h
Normal file
34
tensorflow/lite/delegates/gpu/gl/kernels/mean.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEAN_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEAN_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace gl {
|
||||
|
||||
std::unique_ptr<NodeShader> NewMeanNodeShader();
|
||||
|
||||
} // namespace gl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEAN_H_
|
54
tensorflow/lite/delegates/gpu/gl/kernels/mean_test.cc
Normal file
54
tensorflow/lite/delegates/gpu/gl/kernels/mean_test.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/gl/kernels/mean.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/kernels/test_util.h"
|
||||
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Pointwise;
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(MeanTest, Smoke) {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
|
||||
MeanAttributes attr;
|
||||
attr.dims = {Axis::HEIGHT, Axis::WIDTH};
|
||||
|
||||
SingleOpModel model({ToString(OperationType::MEAN), attr}, {input}, {output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0}));
|
||||
ASSERT_OK(model.Invoke(*NewMeanNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2.5}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/kernels/lstm.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/kernels/mean.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/kernels/pad.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h"
|
||||
@ -80,6 +81,7 @@ class Registry : public NodeShader {
|
||||
insert_op(Type::DEPTHWISE_CONVOLUTION, NewDepthwiseConvolutionNodeShader);
|
||||
insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader);
|
||||
insert_op(Type::LSTM, NewLstmNodeShader);
|
||||
insert_op(Type::MEAN, NewMeanNodeShader);
|
||||
insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader);
|
||||
insert_op(Type::PAD, NewPadNodeShader);
|
||||
insert_op(Type::POOLING_2D, NewPoolingNodeShader);
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
|
||||
@ -194,6 +195,10 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
node_id, inputs[0], inputs[1], outputs[0],
|
||||
absl::any_cast<MaxUnpooling2DAttributes>(node->operation.attributes));
|
||||
break;
|
||||
case OperationType::MEAN:
|
||||
*tasks = Mean(node_id, inputs[0], outputs[0],
|
||||
absl::any_cast<MeanAttributes>(node->operation.attributes));
|
||||
break;
|
||||
case OperationType::MULTIPLY_SCALAR:
|
||||
*tasks = Multiply(
|
||||
node_id, inputs[0], outputs[0],
|
||||
|
@ -20,6 +20,7 @@ cc_library(
|
||||
":elementwise",
|
||||
":fully_connected",
|
||||
":max_unpooling",
|
||||
":mean",
|
||||
":mul",
|
||||
":padding",
|
||||
":pooling",
|
||||
@ -313,6 +314,44 @@ ios_unit_test(
|
||||
deps = [":max_unpooling_test_lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mean",
|
||||
srcs = ["mean.cc"],
|
||||
hdrs = ["mean.h"],
|
||||
deps = [
|
||||
":util",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
||||
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "mean_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["mean_test.mm"],
|
||||
sdk_frameworks = ["XCTest"],
|
||||
deps = [
|
||||
":mean",
|
||||
":test_util",
|
||||
],
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "mean_test",
|
||||
testonly = 1,
|
||||
minimum_os_version = "10.0",
|
||||
tags = [
|
||||
"notap",
|
||||
"tflite_not_portable_android",
|
||||
],
|
||||
deps = [":mean_test_lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mul",
|
||||
srcs = ["mul.cc"],
|
||||
|
137
tensorflow/lite/delegates/gpu/metal/kernels/mean.cc
Normal file
137
tensorflow/lite/delegates/gpu/metal/kernels/mean.cc
Normal file
@ -0,0 +1,137 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
std::string GetMeanCode() {
|
||||
std::string shader_source = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
struct uniforms {
|
||||
int4 src_size;
|
||||
int4 dst_size;
|
||||
};
|
||||
|
||||
$0
|
||||
kernel void ComputeFunction(
|
||||
$1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
if (static_cast<int>(gid.x) >= params.dst_size.x ||
|
||||
static_cast<int>(gid.y) >= params.dst_size.y ||
|
||||
static_cast<int>(gid.z) >= params.dst_size.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float4 sum = float4(0.0);
|
||||
float size = float( params.src_size.x * params.src_size.y);
|
||||
for (int w = 0; w < params.src_size.x; w++) {
|
||||
for (int h = 0; h < params.src_size.y; h++) {
|
||||
const int buffer_index =
|
||||
(gid.z * params.src_size.y + h) * params.src_size.x + w;
|
||||
sum += src_buffer[buffer_index];
|
||||
}
|
||||
}
|
||||
sum /= size;
|
||||
const int linear_index =
|
||||
(gid.z * params.dst_size.y + int(gid.y)) * params.dst_size.x + int(gid.x);
|
||||
|
||||
FLT4 value = FLT4(sum);
|
||||
$2
|
||||
output_buffer[linear_index] = value;
|
||||
}
|
||||
)";
|
||||
return shader_source;
|
||||
}
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> Mean(int id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const MeanAttributes& attr) {
|
||||
if (attr.dims != std::set<Axis>({Axis::HEIGHT, Axis::WIDTH})) {
|
||||
// Mean calculation is supported only for height and width
|
||||
return {};
|
||||
}
|
||||
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
desc->is_linkable = false;
|
||||
std::string code = GetMeanCode();
|
||||
desc->shader_source = code;
|
||||
|
||||
desc->input_buffers = {
|
||||
{input_id, "device FLT4* const src_buffer"},
|
||||
};
|
||||
|
||||
desc->output_buffer = {output_id, "device FLT4* output_buffer",
|
||||
[input_id](const std::map<ValueId, BHWC>& buffers) {
|
||||
const auto& input_dimension =
|
||||
buffers.find(input_id)->second;
|
||||
return BHWC(1, 1, 1, input_dimension.c);
|
||||
}};
|
||||
desc->uniform_buffers = {
|
||||
{"constant uniforms& params",
|
||||
[input_id, output_id](const std::map<ValueId, BHWC>& buffers) {
|
||||
const auto& dimension = buffers.find(input_id)->second;
|
||||
const auto& output_dimension = buffers.find(output_id)->second;
|
||||
std::vector<int> uniform_params = {
|
||||
dimension.w,
|
||||
dimension.h,
|
||||
IntegralDivideRoundUp(dimension.c, 4),
|
||||
0,
|
||||
output_dimension.w,
|
||||
output_dimension.h,
|
||||
IntegralDivideRoundUp(dimension.c, 4),
|
||||
0};
|
||||
return GetByteBuffer(uniform_params);
|
||||
}},
|
||||
};
|
||||
|
||||
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
|
||||
BHWC dst_shape = buffers.find(output_id)->second;
|
||||
const uint3 grid =
|
||||
uint3(dst_shape.w, dst_shape.h, IntegralDivideRoundUp(dst_shape.c, 4));
|
||||
const uint3 groups_size = GetWorkGroupSizeForGrid(grid);
|
||||
int groups_x = IntegralDivideRoundUp(grid.x, groups_size.x);
|
||||
int groups_y = IntegralDivideRoundUp(grid.y, groups_size.y);
|
||||
int groups_z = IntegralDivideRoundUp(grid.z, groups_size.z);
|
||||
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
|
||||
};
|
||||
return {desc};
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
36
tensorflow/lite/delegates/gpu/metal/kernels/mean.h
Normal file
36
tensorflow/lite/delegates/gpu/metal/kernels/mean.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> Mean(int id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const MeanAttributes& attr);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_
|
70
tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm
Normal file
70
tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
using ::tflite::gpu::Axis;
|
||||
using ::tflite::gpu::BHWC;
|
||||
using ::tflite::gpu::DataType;
|
||||
using ::tflite::gpu::OperationType;
|
||||
using ::tflite::gpu::MeanAttributes;
|
||||
using ::tflite::gpu::TensorRef;
|
||||
using ::tflite::gpu::metal::CompareVectors;
|
||||
using ::tflite::gpu::metal::SingleOpModel;
|
||||
|
||||
@interface MeanTest : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation MeanTest
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
}
|
||||
|
||||
- (void)testMeanSmoke {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
|
||||
MeanAttributes attr;
|
||||
attr.dims = {Axis::HEIGHT, Axis::WIDTH};
|
||||
|
||||
SingleOpModel model({ToString(OperationType::MEAN), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
status = CompareVectors({2.5}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
@end
|
Loading…
x
Reference in New Issue
Block a user