Added support of HWC constant tensor in ElementwiseWithOneInputAndConstantArguent.

Remove Mul (replaced by elementwise).
Changed Add (some cases handled in elementwise).

PiperOrigin-RevId: 315000693
Change-Id: Icd981170d11f17df418361a554dc08822f3ee273
This commit is contained in:
Raman Sarokin 2020-06-05 14:49:54 -07:00 committed by TensorFlower Gardener
parent ffc8ddf6bf
commit 9c9e961174
10 changed files with 122 additions and 388 deletions

View File

@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/prelu.h"
@ -55,22 +54,6 @@ namespace gpu {
namespace metal {
namespace {
bool IsWidthBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w &&
inputs[1]->tensor.shape.w == 1;
}
bool IsHeightBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h &&
inputs[1]->tensor.shape.h == 1;
}
bool IsChannelsBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
inputs[1]->tensor.shape.c == 1;
}
std::vector<ComputeTaskDescriptorPtr> SelectDepthWiseConv(
int id, ValueId input_id, ValueId output_id,
const DepthwiseConvolution2DAttributes& attr,
@ -205,26 +188,22 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
auto op_type = OperationTypeFromString(node->operation.type);
switch (op_type) {
case OperationType::ADD: {
const auto srcs = graph.FindInputs(node_id);
ElementwiseBroadcastSettings broadcast;
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
if (broadcast.width || broadcast.height || broadcast.channels) {
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type,
broadcast);
} else {
const AddAttributes& attr =
absl::any_cast<AddAttributes>(node->operation.attributes);
const auto* hwc_tensor =
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
&attr.param);
if (hwc_tensor) {
if (inputs.size() == 1) {
if (node->operation.attributes.has_value()) {
auto attr = absl::any_cast<AddAttributes>(node->operation.attributes);
*tasks = ElementwiseWithOneInputAndConstantArguent(
node_id, inputs[0], outputs[0], options, op_type, attr.param);
} else {
return absl::UnimplementedError(
"Unsupported op: " + node->operation.type +
", no support of HWC constant tensor");
"Missing attributes for single input op: " +
node->operation.type);
}
*tasks = Add(node_id, inputs, outputs[0], attr, options);
} else if (inputs.size() == 2) {
const auto srcs = graph.FindInputs(node_id);
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
srcs[1]->tensor.shape, op_type);
} else { // more than 2 inputs
*tasks = Add(node_id, inputs, outputs[0], options);
}
break;
}
@ -309,31 +288,21 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
absl::any_cast<MeanAttributes>(node->operation.attributes));
break;
case OperationType::MUL:
if (node->operation.attributes.has_value()) {
const MultiplyAttributes& attr =
absl::any_cast<MultiplyAttributes>(node->operation.attributes);
const auto* hwc_tensor =
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
&attr.param);
if (hwc_tensor) {
return absl::UnimplementedError(
"Unsupported op: " + node->operation.type +
", no support of HWC constant tensor");
}
*tasks = Multiply(node_id, inputs[0], outputs[0], attr, options);
} else {
if (inputs.size() == 2) {
const auto srcs = graph.FindInputs(node_id);
ElementwiseBroadcastSettings broadcast;
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
op_type, broadcast);
if (inputs.size() == 1) {
if (node->operation.attributes.has_value()) {
auto attr =
absl::any_cast<MultiplyAttributes>(node->operation.attributes);
*tasks = ElementwiseWithOneInputAndConstantArguent(
node_id, inputs[0], outputs[0], options, op_type, attr.param);
} else {
return absl::UnimplementedError(
"No support of multiply with more than 2 inputs");
"Missing attributes for single input op: " +
node->operation.type);
}
} else if (inputs.size() == 2) {
const auto srcs = graph.FindInputs(node_id);
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
srcs[1]->tensor.shape, op_type);
}
break;
case OperationType::PAD: {
@ -413,27 +382,21 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB: {
const ElementwiseAttributes* attr =
absl::any_cast<ElementwiseAttributes>(&node->operation.attributes);
if (attr) {
const auto* hwc_tensor =
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
&attr->param);
if (hwc_tensor) {
if (inputs.size() == 1) {
if (node->operation.attributes.has_value()) {
auto attr =
absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
*tasks = ElementwiseWithOneInputAndConstantArguent(
node_id, inputs[0], outputs[0], options, op_type, attr.param);
} else {
return absl::UnimplementedError(
"Unsupported op: " + node->operation.type +
", no support of HWC constant tensor");
"Missing attributes for single input op: " +
node->operation.type);
}
*tasks = ElementwiseWithOneInputAndConstantArguent(
node_id, inputs[0], outputs[0], options, op_type, *attr);
} else {
} else if (inputs.size() == 2) {
const auto srcs = graph.FindInputs(node_id);
ElementwiseBroadcastSettings broadcast;
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type,
broadcast);
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
srcs[1]->tensor.shape, op_type);
}
} break;
case OperationType::BATCH_NORMALIZATION:

View File

@ -27,7 +27,6 @@ cc_library(
":fully_connected",
":max_unpooling",
":mean",
":mul",
":padding",
":pooling",
":prelu",
@ -228,6 +227,7 @@ cc_library(
srcs = ["elementwise.cc"],
hdrs = ["elementwise.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:convert",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
@ -381,46 +381,6 @@ ios_unit_test(
deps = [":mean_test_lib"],
)
cc_library(
name = "mul",
srcs = ["mul.cc"],
hdrs = ["mul.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
],
)
objc_library(
name = "mul_test_lib",
testonly = 1,
srcs = ["mul_test.mm"],
sdk_frameworks = ["XCTest"],
deps = [
":mul",
":test_util",
],
)
ios_unit_test(
name = "mul_test",
testonly = 1,
minimum_os_version = "9.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
"tflite_not_portable_android",
],
deps = [":mul_test_lib"],
)
cc_library(
name = "padding",
srcs = ["padding.cc"],
@ -953,7 +913,6 @@ objc_library(
"elementwise_test.mm",
"fully_connected_test.mm",
"max_unpooling_test.mm",
"mul_test.mm",
"padding_test.mm",
"pooling_test.mm",
"prelu_test.mm",

View File

@ -52,39 +52,9 @@ std::string GetAddTableCodeFused(int src_count) {
std::vector<ComputeTaskDescriptorPtr> Add(int id,
const std::vector<ValueId> input_ids,
ValueId output_id,
const AddAttributes& attr,
const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
// Add scalar
const float* add_value = absl::get_if<float>(&attr.param);
if (add_value) {
desc->is_linkable = true;
desc->shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {
return value + )" +
std::to_string(*add_value) + ";}";
desc->input_buffers = {{input_ids[0]}};
desc->output_buffer = {output_id};
return {desc};
}
// Add vector
auto broadcast = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
if (broadcast) {
desc->is_linkable = true;
desc->shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
device FLT4* const broadcast) { return value + broadcast[gid.z]; })";
desc->input_buffers = {{input_ids[0]}};
desc->output_buffer = {output_id};
desc->immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(broadcast->data, options.storage_precision)},
};
return {desc};
}
desc->is_linkable = true;
desc->is_associative_op = true;
desc->shader_source = GetAddTableCodeFused(input_ids.size() - 1);

View File

@ -27,11 +27,9 @@ namespace tflite {
namespace gpu {
namespace metal {
// Add with broadcast.
std::vector<ComputeTaskDescriptorPtr> Add(int id,
const std::vector<ValueId> input_ids,
ValueId output_id,
const AddAttributes& attr,
const RuntimeOptions& options);
} // namespace metal

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/convert.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@ -77,20 +78,20 @@ std::string TwoInputFunctor(OperationType op_type, const std::string& value0,
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
int id, std::vector<ValueId> input_ids, ValueId output_id,
OperationType op_type, const ElementwiseBroadcastSettings& settings) {
const BHWC& second_shape, OperationType op_type) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = true;
const std::string x_coord = settings.width ? "0" : "int(gid.x)";
const std::string y_coord = settings.height ? "0" : "int(gid.y)";
const std::string s_coord = settings.channels ? "0" : "int(gid.z)";
const std::string x_coord = second_shape.w == 1 ? "0" : "int(gid.x)";
const std::string y_coord = second_shape.h == 1 ? "0" : "int(gid.y)";
const std::string s_coord = second_shape.c == 1 ? "0" : "int(gid.z)";
std::string code =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* "
"const second_tensor, int2 second_size) {\n";
code += " int second_index = (" + s_coord + " * second_size.y + " + y_coord +
") * second_size.x + " + x_coord + ";\n";
code += " FLT4 src_1 = second_tensor[second_index];\n";
if (settings.channels) {
if (second_shape.c == 1) {
code += " src_1.y = src_1.x;\n";
code += " src_1.z = src_1.x;\n";
code += " src_1.w = src_1.x;\n";
@ -138,13 +139,13 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
OperationType op_type, const ElementwiseAttributes& attr) {
OperationType op_type, const TensorOrScalar& attr) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = true;
auto scalar = absl::get_if<float>(&attr.param);
auto linear_buf =
absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
auto scalar = absl::get_if<float>(&attr);
auto linear_buf = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr);
auto hwc_buf = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr);
std::string param_desc;
if (scalar) {
param_desc += ", float scalar_val";
@ -152,6 +153,9 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
if (linear_buf) {
param_desc += ", device FLT4* const linear_buf";
}
if (hwc_buf) {
param_desc += ", device FLT4* const hwc_buf, int2 hwc_size";
}
desc->shader_source =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid" + param_desc +
") {\n";
@ -159,6 +163,18 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
desc->shader_source += " FLT4 second_arg = FLT4(scalar_val);\n";
} else if (linear_buf) {
desc->shader_source += " FLT4 second_arg = linear_buf[gid.z];\n";
} else if (hwc_buf) {
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "int(gid.x)";
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "int(gid.y)";
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "int(gid.z)";
std::string index = "(" + s_coord + " * hwc_size.y + " + y_coord +
") * hwc_size.x + " + x_coord;
desc->shader_source += " FLT4 second_arg = hwc_buf[" + index + "];\n";
if (hwc_buf->shape.c == 1) {
desc->shader_source += " second_arg.y = second_arg.x;\n";
desc->shader_source += " second_arg.z = second_arg.x;\n";
desc->shader_source += " second_arg.w = second_arg.x;\n";
}
}
desc->shader_source +=
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
@ -180,6 +196,20 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
{"device FLT4* const",
GetByteBufferConverted(linear_buf->data, options.storage_precision)},
};
} else if (hwc_buf) {
std::vector<uint8_t> size_bits =
GetByteBuffer(std::vector<int>{hwc_buf->shape.w, hwc_buf->shape.h});
desc->uniform_buffers = {
{"constant int2&",
[size_bits](const std::map<ValueId, BHWC>& buffers) {
return size_bits;
}},
};
desc->immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(ConvertToPHWC4(*hwc_buf),
options.storage_precision)},
};
}
return {desc};
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@ -25,25 +26,19 @@ namespace tflite {
namespace gpu {
namespace metal {
struct ElementwiseBroadcastSettings {
bool width = false;
bool height = false;
bool channels = false;
};
// Two inputs are two runtime tensors
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
int id, std::vector<ValueId> input_ids, ValueId output_id,
OperationType op_type, const ElementwiseBroadcastSettings& settings);
// One input is one runtime tensor
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
int id, ValueId input_id, ValueId output_id, OperationType op_type);
// Two inputs are two runtime tensors
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
int id, std::vector<ValueId> input_ids, ValueId output_id,
const BHWC& second_shape, OperationType op_type);
// First input is one runtime tensor and second input is constant argument
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
OperationType op_type, const ElementwiseAttributes& attr);
OperationType op_type, const TensorOrScalar& attr);
} // namespace metal
} // namespace gpu

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
using ::tflite::gpu::DataType;
using ::tflite::gpu::HWC;
using ::tflite::gpu::BHWC;
using ::tflite::gpu::OperationType;
using ::tflite::gpu::TensorRef;
@ -163,6 +164,42 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMaximumWithConstantHWCTensor {
OperationType op_type = OperationType::MAXIMUM;
const BHWC shape(1, 2, 1, 2);
tflite::gpu::ElementwiseAttributes attr;
tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
hwc_tensor.shape = HWC(2, 1, 2);
hwc_tensor.data = {0.5f, 2.0f, 0.7f, 4.7f};
attr.param = hwc_tensor;
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/attr},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
XCTAssertTrue(model.PopulateTensor(0, {1.0f, -6.2f, -2.0f, 3.0f}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({1.0f, 2.0f, 0.7f, 4.7f}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMaximumWithConstantHWCTensorBroadcastChannels {
OperationType op_type = OperationType::MAXIMUM;
const BHWC shape(1, 2, 1, 2);
tflite::gpu::ElementwiseAttributes attr;
tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
hwc_tensor.shape = HWC(2, 1, 1);
hwc_tensor.data = {0.5f, 2.0f};
attr.param = hwc_tensor;
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/attr},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
XCTAssertTrue(model.PopulateTensor(0, {1.0f, -6.2f, -2.0f, 3.0f}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({1.0f, 0.5f, 2.0f, 3.0f}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMinimum {
OperationType op_type = OperationType::MINIMUM;
const BHWC shape(1, 2, 2, 1);

View File

@ -1,83 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h"
#include <cstdint>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/substitute.h"
#include "absl/types/variant.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
namespace gpu {
namespace metal {
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
ValueId output_id,
const MultiplyAttributes& attr,
const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = true;
auto multiplier = absl::get_if<float>(&attr.param);
auto mul_buffer =
absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
const bool scalar = multiplier != nullptr;
const std::string param_desc =
scalar ? "float multiplier" : "device FLT4* const mul_buf";
std::string code =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, ";
code += param_desc + ") {\n";
if (scalar) {
code += "return value * multiplier;\n";
} else {
code += "return value * mul_buf[gid.z];\n";
}
code += "}\n";
desc->shader_source = code;
desc->input_buffers = {{input_id}};
desc->output_buffer = {output_id};
if (scalar) {
std::vector<uint8_t> multiplier_bits =
GetByteBuffer(std::vector<float>{*multiplier});
desc->uniform_buffers = {
{"constant float&",
[multiplier_bits](const std::map<ValueId, BHWC>& buffers) {
return multiplier_bits;
}},
};
} else {
desc->immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(mul_buffer->data, options.storage_precision)},
};
}
return {desc};
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -1,37 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
namespace gpu {
namespace metal {
// Multiply operation, supports scalar and vector broadcast.
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
ValueId output_id,
const MultiplyAttributes& attr,
const RuntimeOptions& options);
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_

View File

@ -1,98 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h"
#import <XCTest/XCTest.h>
#include <string>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
using ::tflite::gpu::DataType;
using ::tflite::gpu::BHWC;
using ::tflite::gpu::Linear;
using ::tflite::gpu::MultiplyAttributes;
using ::tflite::gpu::OperationType;
using ::tflite::gpu::Tensor;
using ::tflite::gpu::TensorRef;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
@interface MulTest : XCTestCase
@end
@implementation MulTest
- (void)setUp {
[super setUp];
}
- (void)testMulScalar {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
MultiplyAttributes attr;
attr.param = 2;
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({2, 4, 6, 8}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMulLinear {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 2);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 2, 2);
MultiplyAttributes attr;
Tensor<Linear, DataType::FLOAT32> tensor;
tensor.shape.v = 2;
tensor.id = 1;
tensor.data = {2, 3};
attr.param = std::move(tensor);
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({2, 6, 6, 12}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
@end