Added support of HWC constant tensor in ElementwiseWithOneInputAndConstantArguent.
Remove Mul (replaced by elementwise). Changed Add (some cases handled in elementwise). PiperOrigin-RevId: 315000693 Change-Id: Icd981170d11f17df418361a554dc08822f3ee273
This commit is contained in:
parent
ffc8ddf6bf
commit
9c9e961174
@ -35,7 +35,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/prelu.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/prelu.h"
|
||||||
@ -55,22 +54,6 @@ namespace gpu {
|
|||||||
namespace metal {
|
namespace metal {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool IsWidthBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
|
||||||
return inputs.size() == 2 &&
|
|
||||||
inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w &&
|
|
||||||
inputs[1]->tensor.shape.w == 1;
|
|
||||||
}
|
|
||||||
bool IsHeightBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
|
||||||
return inputs.size() == 2 &&
|
|
||||||
inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h &&
|
|
||||||
inputs[1]->tensor.shape.h == 1;
|
|
||||||
}
|
|
||||||
bool IsChannelsBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
|
||||||
return inputs.size() == 2 &&
|
|
||||||
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
|
|
||||||
inputs[1]->tensor.shape.c == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<ComputeTaskDescriptorPtr> SelectDepthWiseConv(
|
std::vector<ComputeTaskDescriptorPtr> SelectDepthWiseConv(
|
||||||
int id, ValueId input_id, ValueId output_id,
|
int id, ValueId input_id, ValueId output_id,
|
||||||
const DepthwiseConvolution2DAttributes& attr,
|
const DepthwiseConvolution2DAttributes& attr,
|
||||||
@ -205,26 +188,22 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
|||||||
auto op_type = OperationTypeFromString(node->operation.type);
|
auto op_type = OperationTypeFromString(node->operation.type);
|
||||||
switch (op_type) {
|
switch (op_type) {
|
||||||
case OperationType::ADD: {
|
case OperationType::ADD: {
|
||||||
const auto srcs = graph.FindInputs(node_id);
|
if (inputs.size() == 1) {
|
||||||
ElementwiseBroadcastSettings broadcast;
|
if (node->operation.attributes.has_value()) {
|
||||||
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
|
auto attr = absl::any_cast<AddAttributes>(node->operation.attributes);
|
||||||
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
|
*tasks = ElementwiseWithOneInputAndConstantArguent(
|
||||||
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
|
node_id, inputs[0], outputs[0], options, op_type, attr.param);
|
||||||
if (broadcast.width || broadcast.height || broadcast.channels) {
|
} else {
|
||||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type,
|
|
||||||
broadcast);
|
|
||||||
} else {
|
|
||||||
const AddAttributes& attr =
|
|
||||||
absl::any_cast<AddAttributes>(node->operation.attributes);
|
|
||||||
const auto* hwc_tensor =
|
|
||||||
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
|
|
||||||
&attr.param);
|
|
||||||
if (hwc_tensor) {
|
|
||||||
return absl::UnimplementedError(
|
return absl::UnimplementedError(
|
||||||
"Unsupported op: " + node->operation.type +
|
"Missing attributes for single input op: " +
|
||||||
", no support of HWC constant tensor");
|
node->operation.type);
|
||||||
}
|
}
|
||||||
*tasks = Add(node_id, inputs, outputs[0], attr, options);
|
} else if (inputs.size() == 2) {
|
||||||
|
const auto srcs = graph.FindInputs(node_id);
|
||||||
|
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
|
||||||
|
srcs[1]->tensor.shape, op_type);
|
||||||
|
} else { // more than 2 inputs
|
||||||
|
*tasks = Add(node_id, inputs, outputs[0], options);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -309,31 +288,21 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
|||||||
absl::any_cast<MeanAttributes>(node->operation.attributes));
|
absl::any_cast<MeanAttributes>(node->operation.attributes));
|
||||||
break;
|
break;
|
||||||
case OperationType::MUL:
|
case OperationType::MUL:
|
||||||
if (node->operation.attributes.has_value()) {
|
if (inputs.size() == 1) {
|
||||||
const MultiplyAttributes& attr =
|
if (node->operation.attributes.has_value()) {
|
||||||
absl::any_cast<MultiplyAttributes>(node->operation.attributes);
|
auto attr =
|
||||||
const auto* hwc_tensor =
|
absl::any_cast<MultiplyAttributes>(node->operation.attributes);
|
||||||
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
|
*tasks = ElementwiseWithOneInputAndConstantArguent(
|
||||||
&attr.param);
|
node_id, inputs[0], outputs[0], options, op_type, attr.param);
|
||||||
if (hwc_tensor) {
|
|
||||||
return absl::UnimplementedError(
|
|
||||||
"Unsupported op: " + node->operation.type +
|
|
||||||
", no support of HWC constant tensor");
|
|
||||||
}
|
|
||||||
*tasks = Multiply(node_id, inputs[0], outputs[0], attr, options);
|
|
||||||
} else {
|
|
||||||
if (inputs.size() == 2) {
|
|
||||||
const auto srcs = graph.FindInputs(node_id);
|
|
||||||
ElementwiseBroadcastSettings broadcast;
|
|
||||||
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
|
|
||||||
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
|
|
||||||
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
|
|
||||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
|
|
||||||
op_type, broadcast);
|
|
||||||
} else {
|
} else {
|
||||||
return absl::UnimplementedError(
|
return absl::UnimplementedError(
|
||||||
"No support of multiply with more than 2 inputs");
|
"Missing attributes for single input op: " +
|
||||||
|
node->operation.type);
|
||||||
}
|
}
|
||||||
|
} else if (inputs.size() == 2) {
|
||||||
|
const auto srcs = graph.FindInputs(node_id);
|
||||||
|
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
|
||||||
|
srcs[1]->tensor.shape, op_type);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case OperationType::PAD: {
|
case OperationType::PAD: {
|
||||||
@ -413,27 +382,21 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
|||||||
case OperationType::POW:
|
case OperationType::POW:
|
||||||
case OperationType::SQUARED_DIFF:
|
case OperationType::SQUARED_DIFF:
|
||||||
case OperationType::SUB: {
|
case OperationType::SUB: {
|
||||||
const ElementwiseAttributes* attr =
|
if (inputs.size() == 1) {
|
||||||
absl::any_cast<ElementwiseAttributes>(&node->operation.attributes);
|
if (node->operation.attributes.has_value()) {
|
||||||
if (attr) {
|
auto attr =
|
||||||
const auto* hwc_tensor =
|
absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
|
||||||
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
|
*tasks = ElementwiseWithOneInputAndConstantArguent(
|
||||||
&attr->param);
|
node_id, inputs[0], outputs[0], options, op_type, attr.param);
|
||||||
if (hwc_tensor) {
|
} else {
|
||||||
return absl::UnimplementedError(
|
return absl::UnimplementedError(
|
||||||
"Unsupported op: " + node->operation.type +
|
"Missing attributes for single input op: " +
|
||||||
", no support of HWC constant tensor");
|
node->operation.type);
|
||||||
}
|
}
|
||||||
*tasks = ElementwiseWithOneInputAndConstantArguent(
|
} else if (inputs.size() == 2) {
|
||||||
node_id, inputs[0], outputs[0], options, op_type, *attr);
|
|
||||||
} else {
|
|
||||||
const auto srcs = graph.FindInputs(node_id);
|
const auto srcs = graph.FindInputs(node_id);
|
||||||
ElementwiseBroadcastSettings broadcast;
|
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
|
||||||
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
|
srcs[1]->tensor.shape, op_type);
|
||||||
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
|
|
||||||
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
|
|
||||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type,
|
|
||||||
broadcast);
|
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case OperationType::BATCH_NORMALIZATION:
|
case OperationType::BATCH_NORMALIZATION:
|
||||||
|
@ -27,7 +27,6 @@ cc_library(
|
|||||||
":fully_connected",
|
":fully_connected",
|
||||||
":max_unpooling",
|
":max_unpooling",
|
||||||
":mean",
|
":mean",
|
||||||
":mul",
|
|
||||||
":padding",
|
":padding",
|
||||||
":pooling",
|
":pooling",
|
||||||
":prelu",
|
":prelu",
|
||||||
@ -228,6 +227,7 @@ cc_library(
|
|||||||
srcs = ["elementwise.cc"],
|
srcs = ["elementwise.cc"],
|
||||||
hdrs = ["elementwise.h"],
|
hdrs = ["elementwise.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:convert",
|
||||||
"//tensorflow/lite/delegates/gpu/common:model",
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
@ -381,46 +381,6 @@ ios_unit_test(
|
|||||||
deps = [":mean_test_lib"],
|
deps = [":mean_test_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "mul",
|
|
||||||
srcs = ["mul.cc"],
|
|
||||||
hdrs = ["mul.h"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:model",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:types",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:util",
|
|
||||||
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
|
||||||
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/types:variant",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
objc_library(
|
|
||||||
name = "mul_test_lib",
|
|
||||||
testonly = 1,
|
|
||||||
srcs = ["mul_test.mm"],
|
|
||||||
sdk_frameworks = ["XCTest"],
|
|
||||||
deps = [
|
|
||||||
":mul",
|
|
||||||
":test_util",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
ios_unit_test(
|
|
||||||
name = "mul_test",
|
|
||||||
testonly = 1,
|
|
||||||
minimum_os_version = "9.0",
|
|
||||||
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
|
||||||
tags = tf_gpu_tests_tags() + [
|
|
||||||
"notap",
|
|
||||||
"tflite_not_portable_android",
|
|
||||||
],
|
|
||||||
deps = [":mul_test_lib"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "padding",
|
name = "padding",
|
||||||
srcs = ["padding.cc"],
|
srcs = ["padding.cc"],
|
||||||
@ -953,7 +913,6 @@ objc_library(
|
|||||||
"elementwise_test.mm",
|
"elementwise_test.mm",
|
||||||
"fully_connected_test.mm",
|
"fully_connected_test.mm",
|
||||||
"max_unpooling_test.mm",
|
"max_unpooling_test.mm",
|
||||||
"mul_test.mm",
|
|
||||||
"padding_test.mm",
|
"padding_test.mm",
|
||||||
"pooling_test.mm",
|
"pooling_test.mm",
|
||||||
"prelu_test.mm",
|
"prelu_test.mm",
|
||||||
|
@ -52,39 +52,9 @@ std::string GetAddTableCodeFused(int src_count) {
|
|||||||
std::vector<ComputeTaskDescriptorPtr> Add(int id,
|
std::vector<ComputeTaskDescriptorPtr> Add(int id,
|
||||||
const std::vector<ValueId> input_ids,
|
const std::vector<ValueId> input_ids,
|
||||||
ValueId output_id,
|
ValueId output_id,
|
||||||
const AddAttributes& attr,
|
|
||||||
const RuntimeOptions& options) {
|
const RuntimeOptions& options) {
|
||||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||||
desc->id = id;
|
desc->id = id;
|
||||||
|
|
||||||
// Add scalar
|
|
||||||
const float* add_value = absl::get_if<float>(&attr.param);
|
|
||||||
if (add_value) {
|
|
||||||
desc->is_linkable = true;
|
|
||||||
desc->shader_source =
|
|
||||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {
|
|
||||||
return value + )" +
|
|
||||||
std::to_string(*add_value) + ";}";
|
|
||||||
desc->input_buffers = {{input_ids[0]}};
|
|
||||||
desc->output_buffer = {output_id};
|
|
||||||
return {desc};
|
|
||||||
}
|
|
||||||
// Add vector
|
|
||||||
auto broadcast = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
|
||||||
if (broadcast) {
|
|
||||||
desc->is_linkable = true;
|
|
||||||
desc->shader_source =
|
|
||||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
|
|
||||||
device FLT4* const broadcast) { return value + broadcast[gid.z]; })";
|
|
||||||
desc->input_buffers = {{input_ids[0]}};
|
|
||||||
desc->output_buffer = {output_id};
|
|
||||||
desc->immutable_buffers = {
|
|
||||||
{"device FLT4* const",
|
|
||||||
GetByteBufferConverted(broadcast->data, options.storage_precision)},
|
|
||||||
};
|
|
||||||
return {desc};
|
|
||||||
}
|
|
||||||
|
|
||||||
desc->is_linkable = true;
|
desc->is_linkable = true;
|
||||||
desc->is_associative_op = true;
|
desc->is_associative_op = true;
|
||||||
desc->shader_source = GetAddTableCodeFused(input_ids.size() - 1);
|
desc->shader_source = GetAddTableCodeFused(input_ids.size() - 1);
|
||||||
|
@ -27,11 +27,9 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace metal {
|
namespace metal {
|
||||||
|
|
||||||
// Add with broadcast.
|
|
||||||
std::vector<ComputeTaskDescriptorPtr> Add(int id,
|
std::vector<ComputeTaskDescriptorPtr> Add(int id,
|
||||||
const std::vector<ValueId> input_ids,
|
const std::vector<ValueId> input_ids,
|
||||||
ValueId output_id,
|
ValueId output_id,
|
||||||
const AddAttributes& attr,
|
|
||||||
const RuntimeOptions& options);
|
const RuntimeOptions& options);
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/convert.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||||
@ -77,20 +78,20 @@ std::string TwoInputFunctor(OperationType op_type, const std::string& value0,
|
|||||||
|
|
||||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
|
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
|
||||||
int id, std::vector<ValueId> input_ids, ValueId output_id,
|
int id, std::vector<ValueId> input_ids, ValueId output_id,
|
||||||
OperationType op_type, const ElementwiseBroadcastSettings& settings) {
|
const BHWC& second_shape, OperationType op_type) {
|
||||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||||
desc->id = id;
|
desc->id = id;
|
||||||
desc->is_linkable = true;
|
desc->is_linkable = true;
|
||||||
const std::string x_coord = settings.width ? "0" : "int(gid.x)";
|
const std::string x_coord = second_shape.w == 1 ? "0" : "int(gid.x)";
|
||||||
const std::string y_coord = settings.height ? "0" : "int(gid.y)";
|
const std::string y_coord = second_shape.h == 1 ? "0" : "int(gid.y)";
|
||||||
const std::string s_coord = settings.channels ? "0" : "int(gid.z)";
|
const std::string s_coord = second_shape.c == 1 ? "0" : "int(gid.z)";
|
||||||
std::string code =
|
std::string code =
|
||||||
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* "
|
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* "
|
||||||
"const second_tensor, int2 second_size) {\n";
|
"const second_tensor, int2 second_size) {\n";
|
||||||
code += " int second_index = (" + s_coord + " * second_size.y + " + y_coord +
|
code += " int second_index = (" + s_coord + " * second_size.y + " + y_coord +
|
||||||
") * second_size.x + " + x_coord + ";\n";
|
") * second_size.x + " + x_coord + ";\n";
|
||||||
code += " FLT4 src_1 = second_tensor[second_index];\n";
|
code += " FLT4 src_1 = second_tensor[second_index];\n";
|
||||||
if (settings.channels) {
|
if (second_shape.c == 1) {
|
||||||
code += " src_1.y = src_1.x;\n";
|
code += " src_1.y = src_1.x;\n";
|
||||||
code += " src_1.z = src_1.x;\n";
|
code += " src_1.z = src_1.x;\n";
|
||||||
code += " src_1.w = src_1.x;\n";
|
code += " src_1.w = src_1.x;\n";
|
||||||
@ -138,13 +139,13 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
|
|||||||
|
|
||||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
||||||
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
|
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
|
||||||
OperationType op_type, const ElementwiseAttributes& attr) {
|
OperationType op_type, const TensorOrScalar& attr) {
|
||||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||||
desc->id = id;
|
desc->id = id;
|
||||||
desc->is_linkable = true;
|
desc->is_linkable = true;
|
||||||
auto scalar = absl::get_if<float>(&attr.param);
|
auto scalar = absl::get_if<float>(&attr);
|
||||||
auto linear_buf =
|
auto linear_buf = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr);
|
||||||
absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
auto hwc_buf = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr);
|
||||||
std::string param_desc;
|
std::string param_desc;
|
||||||
if (scalar) {
|
if (scalar) {
|
||||||
param_desc += ", float scalar_val";
|
param_desc += ", float scalar_val";
|
||||||
@ -152,6 +153,9 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
|||||||
if (linear_buf) {
|
if (linear_buf) {
|
||||||
param_desc += ", device FLT4* const linear_buf";
|
param_desc += ", device FLT4* const linear_buf";
|
||||||
}
|
}
|
||||||
|
if (hwc_buf) {
|
||||||
|
param_desc += ", device FLT4* const hwc_buf, int2 hwc_size";
|
||||||
|
}
|
||||||
desc->shader_source =
|
desc->shader_source =
|
||||||
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid" + param_desc +
|
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid" + param_desc +
|
||||||
") {\n";
|
") {\n";
|
||||||
@ -159,6 +163,18 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
|||||||
desc->shader_source += " FLT4 second_arg = FLT4(scalar_val);\n";
|
desc->shader_source += " FLT4 second_arg = FLT4(scalar_val);\n";
|
||||||
} else if (linear_buf) {
|
} else if (linear_buf) {
|
||||||
desc->shader_source += " FLT4 second_arg = linear_buf[gid.z];\n";
|
desc->shader_source += " FLT4 second_arg = linear_buf[gid.z];\n";
|
||||||
|
} else if (hwc_buf) {
|
||||||
|
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "int(gid.x)";
|
||||||
|
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "int(gid.y)";
|
||||||
|
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "int(gid.z)";
|
||||||
|
std::string index = "(" + s_coord + " * hwc_size.y + " + y_coord +
|
||||||
|
") * hwc_size.x + " + x_coord;
|
||||||
|
desc->shader_source += " FLT4 second_arg = hwc_buf[" + index + "];\n";
|
||||||
|
if (hwc_buf->shape.c == 1) {
|
||||||
|
desc->shader_source += " second_arg.y = second_arg.x;\n";
|
||||||
|
desc->shader_source += " second_arg.z = second_arg.x;\n";
|
||||||
|
desc->shader_source += " second_arg.w = second_arg.x;\n";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
desc->shader_source +=
|
desc->shader_source +=
|
||||||
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
|
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
|
||||||
@ -180,6 +196,20 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
|||||||
{"device FLT4* const",
|
{"device FLT4* const",
|
||||||
GetByteBufferConverted(linear_buf->data, options.storage_precision)},
|
GetByteBufferConverted(linear_buf->data, options.storage_precision)},
|
||||||
};
|
};
|
||||||
|
} else if (hwc_buf) {
|
||||||
|
std::vector<uint8_t> size_bits =
|
||||||
|
GetByteBuffer(std::vector<int>{hwc_buf->shape.w, hwc_buf->shape.h});
|
||||||
|
desc->uniform_buffers = {
|
||||||
|
{"constant int2&",
|
||||||
|
[size_bits](const std::map<ValueId, BHWC>& buffers) {
|
||||||
|
return size_bits;
|
||||||
|
}},
|
||||||
|
};
|
||||||
|
desc->immutable_buffers = {
|
||||||
|
{"device FLT4* const",
|
||||||
|
GetByteBufferConverted(ConvertToPHWC4(*hwc_buf),
|
||||||
|
options.storage_precision)},
|
||||||
|
};
|
||||||
}
|
}
|
||||||
return {desc};
|
return {desc};
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||||
|
|
||||||
@ -25,25 +26,19 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace metal {
|
namespace metal {
|
||||||
|
|
||||||
struct ElementwiseBroadcastSettings {
|
|
||||||
bool width = false;
|
|
||||||
bool height = false;
|
|
||||||
bool channels = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Two inputs are two runtime tensors
|
|
||||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
|
|
||||||
int id, std::vector<ValueId> input_ids, ValueId output_id,
|
|
||||||
OperationType op_type, const ElementwiseBroadcastSettings& settings);
|
|
||||||
|
|
||||||
// One input is one runtime tensor
|
// One input is one runtime tensor
|
||||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
|
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
|
||||||
int id, ValueId input_id, ValueId output_id, OperationType op_type);
|
int id, ValueId input_id, ValueId output_id, OperationType op_type);
|
||||||
|
|
||||||
|
// Two inputs are two runtime tensors
|
||||||
|
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
|
||||||
|
int id, std::vector<ValueId> input_ids, ValueId output_id,
|
||||||
|
const BHWC& second_shape, OperationType op_type);
|
||||||
|
|
||||||
// First input is one runtime tensor and second input is constant argument
|
// First input is one runtime tensor and second input is constant argument
|
||||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
||||||
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
|
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
|
||||||
OperationType op_type, const ElementwiseAttributes& attr);
|
OperationType op_type, const TensorOrScalar& attr);
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||||
|
|
||||||
using ::tflite::gpu::DataType;
|
using ::tflite::gpu::DataType;
|
||||||
|
using ::tflite::gpu::HWC;
|
||||||
using ::tflite::gpu::BHWC;
|
using ::tflite::gpu::BHWC;
|
||||||
using ::tflite::gpu::OperationType;
|
using ::tflite::gpu::OperationType;
|
||||||
using ::tflite::gpu::TensorRef;
|
using ::tflite::gpu::TensorRef;
|
||||||
@ -163,6 +164,42 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
|
|||||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
- (void)testMaximumWithConstantHWCTensor {
|
||||||
|
OperationType op_type = OperationType::MAXIMUM;
|
||||||
|
const BHWC shape(1, 2, 1, 2);
|
||||||
|
tflite::gpu::ElementwiseAttributes attr;
|
||||||
|
tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
|
||||||
|
hwc_tensor.shape = HWC(2, 1, 2);
|
||||||
|
hwc_tensor.data = {0.5f, 2.0f, 0.7f, 4.7f};
|
||||||
|
attr.param = hwc_tensor;
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
XCTAssertTrue(model.PopulateTensor(0, {1.0f, -6.2f, -2.0f, 3.0f}));
|
||||||
|
auto status = model.Invoke();
|
||||||
|
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||||
|
status = CompareVectors({1.0f, 2.0f, 0.7f, 4.7f}, model.GetOutput(0), 1e-6f);
|
||||||
|
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testMaximumWithConstantHWCTensorBroadcastChannels {
|
||||||
|
OperationType op_type = OperationType::MAXIMUM;
|
||||||
|
const BHWC shape(1, 2, 1, 2);
|
||||||
|
tflite::gpu::ElementwiseAttributes attr;
|
||||||
|
tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
|
||||||
|
hwc_tensor.shape = HWC(2, 1, 1);
|
||||||
|
hwc_tensor.data = {0.5f, 2.0f};
|
||||||
|
attr.param = hwc_tensor;
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
XCTAssertTrue(model.PopulateTensor(0, {1.0f, -6.2f, -2.0f, 3.0f}));
|
||||||
|
auto status = model.Invoke();
|
||||||
|
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||||
|
status = CompareVectors({1.0f, 0.5f, 2.0f, 3.0f}, model.GetOutput(0), 1e-6f);
|
||||||
|
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
- (void)testMinimum {
|
- (void)testMinimum {
|
||||||
OperationType op_type = OperationType::MINIMUM;
|
OperationType op_type = OperationType::MINIMUM;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
@ -1,83 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h"
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <map>
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/strings/substitute.h"
|
|
||||||
#include "absl/types/variant.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace gpu {
|
|
||||||
namespace metal {
|
|
||||||
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
|
|
||||||
ValueId output_id,
|
|
||||||
const MultiplyAttributes& attr,
|
|
||||||
const RuntimeOptions& options) {
|
|
||||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
|
||||||
desc->id = id;
|
|
||||||
desc->is_linkable = true;
|
|
||||||
auto multiplier = absl::get_if<float>(&attr.param);
|
|
||||||
auto mul_buffer =
|
|
||||||
absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
|
||||||
const bool scalar = multiplier != nullptr;
|
|
||||||
const std::string param_desc =
|
|
||||||
scalar ? "float multiplier" : "device FLT4* const mul_buf";
|
|
||||||
std::string code =
|
|
||||||
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, ";
|
|
||||||
code += param_desc + ") {\n";
|
|
||||||
if (scalar) {
|
|
||||||
code += "return value * multiplier;\n";
|
|
||||||
} else {
|
|
||||||
code += "return value * mul_buf[gid.z];\n";
|
|
||||||
}
|
|
||||||
code += "}\n";
|
|
||||||
desc->shader_source = code;
|
|
||||||
desc->input_buffers = {{input_id}};
|
|
||||||
desc->output_buffer = {output_id};
|
|
||||||
if (scalar) {
|
|
||||||
std::vector<uint8_t> multiplier_bits =
|
|
||||||
GetByteBuffer(std::vector<float>{*multiplier});
|
|
||||||
desc->uniform_buffers = {
|
|
||||||
{"constant float&",
|
|
||||||
[multiplier_bits](const std::map<ValueId, BHWC>& buffers) {
|
|
||||||
return multiplier_bits;
|
|
||||||
}},
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
desc->immutable_buffers = {
|
|
||||||
{"device FLT4* const",
|
|
||||||
GetByteBufferConverted(mul_buffer->data, options.storage_precision)},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return {desc};
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace metal
|
|
||||||
} // namespace gpu
|
|
||||||
} // namespace tflite
|
|
@ -1,37 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_
|
|
||||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_
|
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace gpu {
|
|
||||||
namespace metal {
|
|
||||||
|
|
||||||
// Multiply operation, supports scalar and vector broadcast.
|
|
||||||
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
|
|
||||||
ValueId output_id,
|
|
||||||
const MultiplyAttributes& attr,
|
|
||||||
const RuntimeOptions& options);
|
|
||||||
} // namespace metal
|
|
||||||
} // namespace gpu
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_
|
|
@ -1,98 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h"
|
|
||||||
|
|
||||||
#import <XCTest/XCTest.h>
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
|
||||||
|
|
||||||
using ::tflite::gpu::DataType;
|
|
||||||
using ::tflite::gpu::BHWC;
|
|
||||||
using ::tflite::gpu::Linear;
|
|
||||||
using ::tflite::gpu::MultiplyAttributes;
|
|
||||||
using ::tflite::gpu::OperationType;
|
|
||||||
using ::tflite::gpu::Tensor;
|
|
||||||
using ::tflite::gpu::TensorRef;
|
|
||||||
using ::tflite::gpu::metal::CompareVectors;
|
|
||||||
using ::tflite::gpu::metal::SingleOpModel;
|
|
||||||
|
|
||||||
@interface MulTest : XCTestCase
|
|
||||||
@end
|
|
||||||
|
|
||||||
@implementation MulTest
|
|
||||||
- (void)setUp {
|
|
||||||
[super setUp];
|
|
||||||
}
|
|
||||||
|
|
||||||
- (void)testMulScalar {
|
|
||||||
TensorRef<BHWC> input;
|
|
||||||
input.type = DataType::FLOAT32;
|
|
||||||
input.ref = 0;
|
|
||||||
input.shape = BHWC(1, 2, 2, 1);
|
|
||||||
|
|
||||||
TensorRef<BHWC> output;
|
|
||||||
output.type = DataType::FLOAT32;
|
|
||||||
output.ref = 1;
|
|
||||||
output.shape = BHWC(1, 2, 2, 1);
|
|
||||||
|
|
||||||
MultiplyAttributes attr;
|
|
||||||
attr.param = 2;
|
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
|
||||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
|
||||||
auto status = model.Invoke();
|
|
||||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
|
||||||
status = CompareVectors({2, 4, 6, 8}, model.GetOutput(0), 1e-6f);
|
|
||||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
- (void)testMulLinear {
|
|
||||||
TensorRef<BHWC> input;
|
|
||||||
input.type = DataType::FLOAT32;
|
|
||||||
input.ref = 0;
|
|
||||||
input.shape = BHWC(1, 1, 2, 2);
|
|
||||||
|
|
||||||
TensorRef<BHWC> output;
|
|
||||||
output.type = DataType::FLOAT32;
|
|
||||||
output.ref = 1;
|
|
||||||
output.shape = BHWC(1, 1, 2, 2);
|
|
||||||
|
|
||||||
MultiplyAttributes attr;
|
|
||||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
|
||||||
tensor.shape.v = 2;
|
|
||||||
tensor.id = 1;
|
|
||||||
tensor.data = {2, 3};
|
|
||||||
attr.param = std::move(tensor);
|
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
|
||||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
|
||||||
auto status = model.Invoke();
|
|
||||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
|
||||||
status = CompareVectors({2, 6, 6, 12}, model.GetOutput(0), 1e-6f);
|
|
||||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
@end
|
|
Loading…
Reference in New Issue
Block a user