diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 1cc49af52b9..8b798a1df5f 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -2385,6 +2385,51 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { private: }; +class MeanOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + /*outputs=*/1); + } + + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + auto* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MEAN); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + MeanAttributes attr; + Tensor channel; + RETURN_IF_ERROR(reader->ReadTensor(1, &channel)); + for (int i = 0; i < channel.data.size(); i++) { + std::string unsupported; + switch (channel.data[i]) { + case 1: + attr.dims.insert(Axis::HEIGHT); + break; + case 2: + attr.dims.insert(Axis::WIDTH); + break; + case 0: + unsupported = unsupported.empty() ? "batch" : unsupported; + ABSL_FALLTHROUGH_INTENDED; + case 3: + unsupported = unsupported.empty() ? "channels" : unsupported; + ABSL_FALLTHROUGH_INTENDED; + default: + return UnimplementedError( + absl::StrCat("Unsupported mean dimension: ", unsupported)); + } + } + node->operation.attributes = attr; + return OkStatus(); + } +}; + class UnsupportedOperationParser : public TFLiteOperationParser { public: Status IsSupported(const TfLiteContext* context, @@ -2433,6 +2478,8 @@ std::unique_ptr NewOperationParser( return absl::make_unique(); case kTfLiteBuiltinMaxPool2d: return absl::make_unique(PoolingType::MAX); + case kTfLiteBuiltinMean: + return absl::make_unique(); case kTfLiteBuiltinMirrorPad: return absl::make_unique(/*mirror_pad=*/true); case kTfLiteBuiltinMul: diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc index 0ccfad4014b..38412aeec08 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.cc +++ b/tensorflow/lite/delegates/gpu/common/operations.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -102,6 +102,8 @@ std::string ToString(enum OperationType op) { return "lstm"; case OperationType::MAX_UNPOOLING_2D: return "max_unpooling"; + case OperationType::MEAN: + return "mean"; case OperationType::MUL: return "mul"; case OperationType::MULTIPLY_SCALAR: @@ -171,6 +173,7 @@ OperationType OperationTypeFromString(const std::string& name) { {"log", OperationType::LOG}, {"lstm", OperationType::LSTM}, {"max_unpooling", OperationType::MAX_UNPOOLING_2D}, + {"mean", OperationType::MEAN}, {"mul", OperationType::MUL}, {"multiply_scalar", OperationType::MULTIPLY_SCALAR}, {"pad", OperationType::PAD}, diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h index 5698fe5c57b..f87a64382cb 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.h +++ b/tensorflow/lite/delegates/gpu/common/operations.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,6 +50,7 @@ enum class OperationType { LOG, LSTM, MAX_UNPOOLING_2D, + MEAN, MUL, MULTIPLY_SCALAR, PAD, @@ -166,6 +167,11 @@ struct MaxUnpooling3DAttributes { Padding3D padding; }; +struct MeanAttributes { + // The vector of dimensions to calculate mean along. + std::set dims; +}; + struct ConcatAttributes { // Defines axis by which to concat on. Axis axis = Axis::UNKNOWN; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD index afa31bfa7ce..45d27923aff 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD @@ -303,6 +303,36 @@ cc_test( ], ) +cc_library( + name = "mean", + srcs = ["mean.cc"], + hdrs = ["mean.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "mean_test", + srcs = ["mean_test.cc"], + tags = [ + "notap", + "tflite_not_portable_ios", + ], + deps = [ + ":mean", + ":test_util", + "//tensorflow/lite/delegates/gpu/common:operations", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "mul", srcs = ["mul.cc"], @@ -641,6 +671,7 @@ TFLITE_GPU_BINARY_RELEASE_OPERATORS = [ "pooling", "prelu", "relu", + "mean", "reshape", "slice", "softmax", diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc new file mode 100644 index 00000000000..aaceb61b4e1 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mean.cc @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/mean.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class Mean : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto attr = absl::any_cast(ctx.node->operation.attributes); + if (attr.dims != std::set({Axis::HEIGHT, Axis::WIDTH})) { + return InvalidArgumentError( + "Mean calculation is supported only for height and width."); + } + + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + + std::vector parameters = { + {"input_data_0_h", input->tensor.shape.h}, + {"input_data_0_w", input->tensor.shape.w}}; + + std::string source = R"( + vec4 sum = vec4(0.0); + float size = float($input_data_0_w$ * $input_data_0_h$); + for (int w = 0; w < $input_data_0_w$; w++) { + for (int h = 0; h < $input_data_0_h$; h++) { + sum += $input_data_0[w, h, gid.z]$; + } + } + value_0 = sum / size; + )"; + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/{}, + /*shared_variables=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(1, 1, 4), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewMeanNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mean.h b/tensorflow/lite/delegates/gpu/gl/kernels/mean.h new file mode 100644 index 00000000000..af2628fbb25 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mean.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEAN_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEAN_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewMeanNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEAN_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mean_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mean_test.cc new file mode 100644 index 00000000000..63569ff8b68 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mean_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/mean.h" + +#include +#include +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/test_util.h" + +using ::testing::FloatNear; +using ::testing::Pointwise; + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +TEST(MeanTest, Smoke) { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 2, 2, 1); + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 2; + output.shape = BHWC(1, 1, 1, 1); + + MeanAttributes attr; + attr.dims = {Axis::HEIGHT, Axis::WIDTH}; + + SingleOpModel model({ToString(OperationType::MEAN), attr}, {input}, {output}); + ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); + ASSERT_OK(model.Invoke(*NewMeanNodeShader())); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2.5})); +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc index 3744a772530..4743342b36c 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h" #include "tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h" #include "tensorflow/lite/delegates/gpu/gl/kernels/lstm.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/mean.h" #include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h" #include "tensorflow/lite/delegates/gpu/gl/kernels/pad.h" #include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h" @@ -80,6 +81,7 @@ class Registry : public NodeShader { insert_op(Type::DEPTHWISE_CONVOLUTION, NewDepthwiseConvolutionNodeShader); insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader); insert_op(Type::LSTM, NewLstmNodeShader); + insert_op(Type::MEAN, NewMeanNodeShader); insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader); insert_op(Type::PAD, NewPadNodeShader); insert_op(Type::POOLING_2D, NewPoolingNodeShader); diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index c6977b21f57..4885a695de2 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h" @@ -194,6 +195,10 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, node_id, inputs[0], inputs[1], outputs[0], absl::any_cast(node->operation.attributes)); break; + case OperationType::MEAN: + *tasks = Mean(node_id, inputs[0], outputs[0], + absl::any_cast(node->operation.attributes)); + break; case OperationType::MULTIPLY_SCALAR: *tasks = Multiply( node_id, inputs[0], outputs[0], diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD index 8cf7dba27e4..0be2017012c 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -20,6 +20,7 @@ cc_library( ":elementwise", ":fully_connected", ":max_unpooling", + ":mean", ":mul", ":padding", ":pooling", @@ -313,6 +314,44 @@ ios_unit_test( deps = [":max_unpooling_test_lib"], ) +cc_library( + name = "mean", + srcs = ["mean.cc"], + hdrs = ["mean.h"], + deps = [ + ":util", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", + "@com_google_absl//absl/types:variant", + ], +) + +objc_library( + name = "mean_test_lib", + testonly = 1, + srcs = ["mean_test.mm"], + sdk_frameworks = ["XCTest"], + deps = [ + ":mean", + ":test_util", + ], +) + +ios_unit_test( + name = "mean_test", + testonly = 1, + minimum_os_version = "10.0", + tags = [ + "notap", + "tflite_not_portable_android", + ], + deps = [":mean_test_lib"], +) + cc_library( name = "mul", srcs = ["mul.cc"], diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc b/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc new file mode 100644 index 00000000000..8c888d0bca1 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc @@ -0,0 +1,137 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::string GetMeanCode() { + std::string shader_source = R"( + #include + using namespace metal; + struct uniforms { + int4 src_size; + int4 dst_size; + }; + + $0 + kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + if (static_cast(gid.x) >= params.dst_size.x || + static_cast(gid.y) >= params.dst_size.y || + static_cast(gid.z) >= params.dst_size.z) { + return; + } + + float4 sum = float4(0.0); + float size = float( params.src_size.x * params.src_size.y); + for (int w = 0; w < params.src_size.x; w++) { + for (int h = 0; h < params.src_size.y; h++) { + const int buffer_index = + (gid.z * params.src_size.y + h) * params.src_size.x + w; + sum += src_buffer[buffer_index]; + } + } + sum /= size; + const int linear_index = + (gid.z * params.dst_size.y + int(gid.y)) * params.dst_size.x + int(gid.x); + + FLT4 value = FLT4(sum); + $2 + output_buffer[linear_index] = value; + } + )"; + return shader_source; +} + +std::vector Mean(int id, ValueId input_id, + ValueId output_id, + const MeanAttributes& attr) { + if (attr.dims != std::set({Axis::HEIGHT, Axis::WIDTH})) { + // Mean calculation is supported only for height and width + return {}; + } + + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + std::string code = GetMeanCode(); + desc->shader_source = code; + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = {output_id, "device FLT4* output_buffer", + [input_id](const std::map& buffers) { + const auto& input_dimension = + buffers.find(input_id)->second; + return BHWC(1, 1, 1, input_dimension.c); + }}; + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id](const std::map& buffers) { + const auto& dimension = buffers.find(input_id)->second; + const auto& output_dimension = buffers.find(output_id)->second; + std::vector uniform_params = { + dimension.w, + dimension.h, + IntegralDivideRoundUp(dimension.c, 4), + 0, + output_dimension.w, + output_dimension.h, + IntegralDivideRoundUp(dimension.c, 4), + 0}; + return GetByteBuffer(uniform_params); + }}, + }; + + desc->resize_function = [output_id](const std::map& buffers) { + BHWC dst_shape = buffers.find(output_id)->second; + const uint3 grid = + uint3(dst_shape.w, dst_shape.h, IntegralDivideRoundUp(dst_shape.c, 4)); + const uint3 groups_size = GetWorkGroupSizeForGrid(grid); + int groups_x = IntegralDivideRoundUp(grid.x, groups_size.x); + int groups_y = IntegralDivideRoundUp(grid.y, groups_size.y); + int groups_z = IntegralDivideRoundUp(grid.z, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mean.h b/tensorflow/lite/delegates/gpu/metal/kernels/mean.h new file mode 100644 index 00000000000..5f6a0493181 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mean.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_ + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector Mean(int id, ValueId input_id, + ValueId output_id, + const MeanAttributes& attr); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MEAN_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm new file mode 100644 index 00000000000..69eed7d86b0 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h" + +#import + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +using ::tflite::gpu::Axis; +using ::tflite::gpu::BHWC; +using ::tflite::gpu::DataType; +using ::tflite::gpu::OperationType; +using ::tflite::gpu::MeanAttributes; +using ::tflite::gpu::TensorRef; +using ::tflite::gpu::metal::CompareVectors; +using ::tflite::gpu::metal::SingleOpModel; + +@interface MeanTest : XCTestCase +@end + +@implementation MeanTest +- (void)setUp { + [super setUp]; +} + +- (void)testMeanSmoke { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 2, 2, 1); + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 1; + output.shape = BHWC(1, 1, 1, 1); + + MeanAttributes attr; + attr.dims = {Axis::HEIGHT, Axis::WIDTH}; + + SingleOpModel model({ToString(OperationType::MEAN), attr}, {input}, {output}); + XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); + auto status = model.Invoke(); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + status = CompareVectors({2.5}, model.GetOutput(0), 1e-6f); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); +} + +@end