Added support of HWC constant tensor in ElementwiseWithOneInputAndConstantArguent.
Remove Mul (replaced by elementwise). Changed Add (some cases handled in elementwise). PiperOrigin-RevId: 315000693 Change-Id: Icd981170d11f17df418361a554dc08822f3ee273
This commit is contained in:
parent
ffc8ddf6bf
commit
9c9e961174
@ -35,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/prelu.h"
|
||||
@ -55,22 +54,6 @@ namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
|
||||
bool IsWidthBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w &&
|
||||
inputs[1]->tensor.shape.w == 1;
|
||||
}
|
||||
bool IsHeightBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h &&
|
||||
inputs[1]->tensor.shape.h == 1;
|
||||
}
|
||||
bool IsChannelsBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
|
||||
inputs[1]->tensor.shape.c == 1;
|
||||
}
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> SelectDepthWiseConv(
|
||||
int id, ValueId input_id, ValueId output_id,
|
||||
const DepthwiseConvolution2DAttributes& attr,
|
||||
@ -205,26 +188,22 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
auto op_type = OperationTypeFromString(node->operation.type);
|
||||
switch (op_type) {
|
||||
case OperationType::ADD: {
|
||||
const auto srcs = graph.FindInputs(node_id);
|
||||
ElementwiseBroadcastSettings broadcast;
|
||||
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
|
||||
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
|
||||
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
|
||||
if (broadcast.width || broadcast.height || broadcast.channels) {
|
||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type,
|
||||
broadcast);
|
||||
} else {
|
||||
const AddAttributes& attr =
|
||||
absl::any_cast<AddAttributes>(node->operation.attributes);
|
||||
const auto* hwc_tensor =
|
||||
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
|
||||
&attr.param);
|
||||
if (hwc_tensor) {
|
||||
if (inputs.size() == 1) {
|
||||
if (node->operation.attributes.has_value()) {
|
||||
auto attr = absl::any_cast<AddAttributes>(node->operation.attributes);
|
||||
*tasks = ElementwiseWithOneInputAndConstantArguent(
|
||||
node_id, inputs[0], outputs[0], options, op_type, attr.param);
|
||||
} else {
|
||||
return absl::UnimplementedError(
|
||||
"Unsupported op: " + node->operation.type +
|
||||
", no support of HWC constant tensor");
|
||||
"Missing attributes for single input op: " +
|
||||
node->operation.type);
|
||||
}
|
||||
*tasks = Add(node_id, inputs, outputs[0], attr, options);
|
||||
} else if (inputs.size() == 2) {
|
||||
const auto srcs = graph.FindInputs(node_id);
|
||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
|
||||
srcs[1]->tensor.shape, op_type);
|
||||
} else { // more than 2 inputs
|
||||
*tasks = Add(node_id, inputs, outputs[0], options);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -309,31 +288,21 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
absl::any_cast<MeanAttributes>(node->operation.attributes));
|
||||
break;
|
||||
case OperationType::MUL:
|
||||
if (node->operation.attributes.has_value()) {
|
||||
const MultiplyAttributes& attr =
|
||||
absl::any_cast<MultiplyAttributes>(node->operation.attributes);
|
||||
const auto* hwc_tensor =
|
||||
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
|
||||
&attr.param);
|
||||
if (hwc_tensor) {
|
||||
return absl::UnimplementedError(
|
||||
"Unsupported op: " + node->operation.type +
|
||||
", no support of HWC constant tensor");
|
||||
}
|
||||
*tasks = Multiply(node_id, inputs[0], outputs[0], attr, options);
|
||||
} else {
|
||||
if (inputs.size() == 2) {
|
||||
const auto srcs = graph.FindInputs(node_id);
|
||||
ElementwiseBroadcastSettings broadcast;
|
||||
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
|
||||
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
|
||||
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
|
||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
|
||||
op_type, broadcast);
|
||||
if (inputs.size() == 1) {
|
||||
if (node->operation.attributes.has_value()) {
|
||||
auto attr =
|
||||
absl::any_cast<MultiplyAttributes>(node->operation.attributes);
|
||||
*tasks = ElementwiseWithOneInputAndConstantArguent(
|
||||
node_id, inputs[0], outputs[0], options, op_type, attr.param);
|
||||
} else {
|
||||
return absl::UnimplementedError(
|
||||
"No support of multiply with more than 2 inputs");
|
||||
"Missing attributes for single input op: " +
|
||||
node->operation.type);
|
||||
}
|
||||
} else if (inputs.size() == 2) {
|
||||
const auto srcs = graph.FindInputs(node_id);
|
||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
|
||||
srcs[1]->tensor.shape, op_type);
|
||||
}
|
||||
break;
|
||||
case OperationType::PAD: {
|
||||
@ -413,27 +382,21 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
case OperationType::POW:
|
||||
case OperationType::SQUARED_DIFF:
|
||||
case OperationType::SUB: {
|
||||
const ElementwiseAttributes* attr =
|
||||
absl::any_cast<ElementwiseAttributes>(&node->operation.attributes);
|
||||
if (attr) {
|
||||
const auto* hwc_tensor =
|
||||
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
|
||||
&attr->param);
|
||||
if (hwc_tensor) {
|
||||
if (inputs.size() == 1) {
|
||||
if (node->operation.attributes.has_value()) {
|
||||
auto attr =
|
||||
absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
|
||||
*tasks = ElementwiseWithOneInputAndConstantArguent(
|
||||
node_id, inputs[0], outputs[0], options, op_type, attr.param);
|
||||
} else {
|
||||
return absl::UnimplementedError(
|
||||
"Unsupported op: " + node->operation.type +
|
||||
", no support of HWC constant tensor");
|
||||
"Missing attributes for single input op: " +
|
||||
node->operation.type);
|
||||
}
|
||||
*tasks = ElementwiseWithOneInputAndConstantArguent(
|
||||
node_id, inputs[0], outputs[0], options, op_type, *attr);
|
||||
} else {
|
||||
} else if (inputs.size() == 2) {
|
||||
const auto srcs = graph.FindInputs(node_id);
|
||||
ElementwiseBroadcastSettings broadcast;
|
||||
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
|
||||
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
|
||||
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
|
||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type,
|
||||
broadcast);
|
||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
|
||||
srcs[1]->tensor.shape, op_type);
|
||||
}
|
||||
} break;
|
||||
case OperationType::BATCH_NORMALIZATION:
|
||||
|
@ -27,7 +27,6 @@ cc_library(
|
||||
":fully_connected",
|
||||
":max_unpooling",
|
||||
":mean",
|
||||
":mul",
|
||||
":padding",
|
||||
":pooling",
|
||||
":prelu",
|
||||
@ -228,6 +227,7 @@ cc_library(
|
||||
srcs = ["elementwise.cc"],
|
||||
hdrs = ["elementwise.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:convert",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
@ -381,46 +381,6 @@ ios_unit_test(
|
||||
deps = [":mean_test_lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mul",
|
||||
srcs = ["mul.cc"],
|
||||
hdrs = ["mul.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
||||
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "mul_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["mul_test.mm"],
|
||||
sdk_frameworks = ["XCTest"],
|
||||
deps = [
|
||||
":mul",
|
||||
":test_util",
|
||||
],
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "mul_test",
|
||||
testonly = 1,
|
||||
minimum_os_version = "9.0",
|
||||
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_android",
|
||||
],
|
||||
deps = [":mul_test_lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "padding",
|
||||
srcs = ["padding.cc"],
|
||||
@ -953,7 +913,6 @@ objc_library(
|
||||
"elementwise_test.mm",
|
||||
"fully_connected_test.mm",
|
||||
"max_unpooling_test.mm",
|
||||
"mul_test.mm",
|
||||
"padding_test.mm",
|
||||
"pooling_test.mm",
|
||||
"prelu_test.mm",
|
||||
|
@ -52,39 +52,9 @@ std::string GetAddTableCodeFused(int src_count) {
|
||||
std::vector<ComputeTaskDescriptorPtr> Add(int id,
|
||||
const std::vector<ValueId> input_ids,
|
||||
ValueId output_id,
|
||||
const AddAttributes& attr,
|
||||
const RuntimeOptions& options) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
|
||||
// Add scalar
|
||||
const float* add_value = absl::get_if<float>(&attr.param);
|
||||
if (add_value) {
|
||||
desc->is_linkable = true;
|
||||
desc->shader_source =
|
||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {
|
||||
return value + )" +
|
||||
std::to_string(*add_value) + ";}";
|
||||
desc->input_buffers = {{input_ids[0]}};
|
||||
desc->output_buffer = {output_id};
|
||||
return {desc};
|
||||
}
|
||||
// Add vector
|
||||
auto broadcast = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
||||
if (broadcast) {
|
||||
desc->is_linkable = true;
|
||||
desc->shader_source =
|
||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
|
||||
device FLT4* const broadcast) { return value + broadcast[gid.z]; })";
|
||||
desc->input_buffers = {{input_ids[0]}};
|
||||
desc->output_buffer = {output_id};
|
||||
desc->immutable_buffers = {
|
||||
{"device FLT4* const",
|
||||
GetByteBufferConverted(broadcast->data, options.storage_precision)},
|
||||
};
|
||||
return {desc};
|
||||
}
|
||||
|
||||
desc->is_linkable = true;
|
||||
desc->is_associative_op = true;
|
||||
desc->shader_source = GetAddTableCodeFused(input_ids.size() - 1);
|
||||
|
@ -27,11 +27,9 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
// Add with broadcast.
|
||||
std::vector<ComputeTaskDescriptorPtr> Add(int id,
|
||||
const std::vector<ValueId> input_ids,
|
||||
ValueId output_id,
|
||||
const AddAttributes& attr,
|
||||
const RuntimeOptions& options);
|
||||
|
||||
} // namespace metal
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/convert.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
@ -77,20 +78,20 @@ std::string TwoInputFunctor(OperationType op_type, const std::string& value0,
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
|
||||
int id, std::vector<ValueId> input_ids, ValueId output_id,
|
||||
OperationType op_type, const ElementwiseBroadcastSettings& settings) {
|
||||
const BHWC& second_shape, OperationType op_type) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
desc->is_linkable = true;
|
||||
const std::string x_coord = settings.width ? "0" : "int(gid.x)";
|
||||
const std::string y_coord = settings.height ? "0" : "int(gid.y)";
|
||||
const std::string s_coord = settings.channels ? "0" : "int(gid.z)";
|
||||
const std::string x_coord = second_shape.w == 1 ? "0" : "int(gid.x)";
|
||||
const std::string y_coord = second_shape.h == 1 ? "0" : "int(gid.y)";
|
||||
const std::string s_coord = second_shape.c == 1 ? "0" : "int(gid.z)";
|
||||
std::string code =
|
||||
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* "
|
||||
"const second_tensor, int2 second_size) {\n";
|
||||
code += " int second_index = (" + s_coord + " * second_size.y + " + y_coord +
|
||||
") * second_size.x + " + x_coord + ";\n";
|
||||
code += " FLT4 src_1 = second_tensor[second_index];\n";
|
||||
if (settings.channels) {
|
||||
if (second_shape.c == 1) {
|
||||
code += " src_1.y = src_1.x;\n";
|
||||
code += " src_1.z = src_1.x;\n";
|
||||
code += " src_1.w = src_1.x;\n";
|
||||
@ -138,13 +139,13 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
||||
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
|
||||
OperationType op_type, const ElementwiseAttributes& attr) {
|
||||
OperationType op_type, const TensorOrScalar& attr) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
desc->is_linkable = true;
|
||||
auto scalar = absl::get_if<float>(&attr.param);
|
||||
auto linear_buf =
|
||||
absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
||||
auto scalar = absl::get_if<float>(&attr);
|
||||
auto linear_buf = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr);
|
||||
auto hwc_buf = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr);
|
||||
std::string param_desc;
|
||||
if (scalar) {
|
||||
param_desc += ", float scalar_val";
|
||||
@ -152,6 +153,9 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
||||
if (linear_buf) {
|
||||
param_desc += ", device FLT4* const linear_buf";
|
||||
}
|
||||
if (hwc_buf) {
|
||||
param_desc += ", device FLT4* const hwc_buf, int2 hwc_size";
|
||||
}
|
||||
desc->shader_source =
|
||||
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid" + param_desc +
|
||||
") {\n";
|
||||
@ -159,6 +163,18 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
||||
desc->shader_source += " FLT4 second_arg = FLT4(scalar_val);\n";
|
||||
} else if (linear_buf) {
|
||||
desc->shader_source += " FLT4 second_arg = linear_buf[gid.z];\n";
|
||||
} else if (hwc_buf) {
|
||||
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "int(gid.x)";
|
||||
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "int(gid.y)";
|
||||
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "int(gid.z)";
|
||||
std::string index = "(" + s_coord + " * hwc_size.y + " + y_coord +
|
||||
") * hwc_size.x + " + x_coord;
|
||||
desc->shader_source += " FLT4 second_arg = hwc_buf[" + index + "];\n";
|
||||
if (hwc_buf->shape.c == 1) {
|
||||
desc->shader_source += " second_arg.y = second_arg.x;\n";
|
||||
desc->shader_source += " second_arg.z = second_arg.x;\n";
|
||||
desc->shader_source += " second_arg.w = second_arg.x;\n";
|
||||
}
|
||||
}
|
||||
desc->shader_source +=
|
||||
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
|
||||
@ -180,6 +196,20 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
||||
{"device FLT4* const",
|
||||
GetByteBufferConverted(linear_buf->data, options.storage_precision)},
|
||||
};
|
||||
} else if (hwc_buf) {
|
||||
std::vector<uint8_t> size_bits =
|
||||
GetByteBuffer(std::vector<int>{hwc_buf->shape.w, hwc_buf->shape.h});
|
||||
desc->uniform_buffers = {
|
||||
{"constant int2&",
|
||||
[size_bits](const std::map<ValueId, BHWC>& buffers) {
|
||||
return size_bits;
|
||||
}},
|
||||
};
|
||||
desc->immutable_buffers = {
|
||||
{"device FLT4* const",
|
||||
GetByteBufferConverted(ConvertToPHWC4(*hwc_buf),
|
||||
options.storage_precision)},
|
||||
};
|
||||
}
|
||||
return {desc};
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
|
||||
@ -25,25 +26,19 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
struct ElementwiseBroadcastSettings {
|
||||
bool width = false;
|
||||
bool height = false;
|
||||
bool channels = false;
|
||||
};
|
||||
|
||||
// Two inputs are two runtime tensors
|
||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
|
||||
int id, std::vector<ValueId> input_ids, ValueId output_id,
|
||||
OperationType op_type, const ElementwiseBroadcastSettings& settings);
|
||||
|
||||
// One input is one runtime tensor
|
||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
|
||||
int id, ValueId input_id, ValueId output_id, OperationType op_type);
|
||||
|
||||
// Two inputs are two runtime tensors
|
||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
|
||||
int id, std::vector<ValueId> input_ids, ValueId output_id,
|
||||
const BHWC& second_shape, OperationType op_type);
|
||||
|
||||
// First input is one runtime tensor and second input is constant argument
|
||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
|
||||
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
|
||||
OperationType op_type, const ElementwiseAttributes& attr);
|
||||
OperationType op_type, const TensorOrScalar& attr);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
using ::tflite::gpu::DataType;
|
||||
using ::tflite::gpu::HWC;
|
||||
using ::tflite::gpu::BHWC;
|
||||
using ::tflite::gpu::OperationType;
|
||||
using ::tflite::gpu::TensorRef;
|
||||
@ -163,6 +164,42 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMaximumWithConstantHWCTensor {
|
||||
OperationType op_type = OperationType::MAXIMUM;
|
||||
const BHWC shape(1, 2, 1, 2);
|
||||
tflite::gpu::ElementwiseAttributes attr;
|
||||
tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
|
||||
hwc_tensor.shape = HWC(2, 1, 2);
|
||||
hwc_tensor.data = {0.5f, 2.0f, 0.7f, 4.7f};
|
||||
attr.param = hwc_tensor;
|
||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/attr},
|
||||
/*inputs=*/{GetTensorRef(0, shape)},
|
||||
/*outputs=*/{GetTensorRef(1, shape)});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1.0f, -6.2f, -2.0f, 3.0f}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({1.0f, 2.0f, 0.7f, 4.7f}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMaximumWithConstantHWCTensorBroadcastChannels {
|
||||
OperationType op_type = OperationType::MAXIMUM;
|
||||
const BHWC shape(1, 2, 1, 2);
|
||||
tflite::gpu::ElementwiseAttributes attr;
|
||||
tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
|
||||
hwc_tensor.shape = HWC(2, 1, 1);
|
||||
hwc_tensor.data = {0.5f, 2.0f};
|
||||
attr.param = hwc_tensor;
|
||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/attr},
|
||||
/*inputs=*/{GetTensorRef(0, shape)},
|
||||
/*outputs=*/{GetTensorRef(1, shape)});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1.0f, -6.2f, -2.0f, 3.0f}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({1.0f, 0.5f, 2.0f, 3.0f}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMinimum {
|
||||
OperationType op_type = OperationType::MINIMUM;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
|
@ -1,83 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const MultiplyAttributes& attr,
|
||||
const RuntimeOptions& options) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
desc->is_linkable = true;
|
||||
auto multiplier = absl::get_if<float>(&attr.param);
|
||||
auto mul_buffer =
|
||||
absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
||||
const bool scalar = multiplier != nullptr;
|
||||
const std::string param_desc =
|
||||
scalar ? "float multiplier" : "device FLT4* const mul_buf";
|
||||
std::string code =
|
||||
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, ";
|
||||
code += param_desc + ") {\n";
|
||||
if (scalar) {
|
||||
code += "return value * multiplier;\n";
|
||||
} else {
|
||||
code += "return value * mul_buf[gid.z];\n";
|
||||
}
|
||||
code += "}\n";
|
||||
desc->shader_source = code;
|
||||
desc->input_buffers = {{input_id}};
|
||||
desc->output_buffer = {output_id};
|
||||
if (scalar) {
|
||||
std::vector<uint8_t> multiplier_bits =
|
||||
GetByteBuffer(std::vector<float>{*multiplier});
|
||||
desc->uniform_buffers = {
|
||||
{"constant float&",
|
||||
[multiplier_bits](const std::map<ValueId, BHWC>& buffers) {
|
||||
return multiplier_bits;
|
||||
}},
|
||||
};
|
||||
} else {
|
||||
desc->immutable_buffers = {
|
||||
{"device FLT4* const",
|
||||
GetByteBufferConverted(mul_buffer->data, options.storage_precision)},
|
||||
};
|
||||
}
|
||||
return {desc};
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -1,37 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
// Multiply operation, supports scalar and vector broadcast.
|
||||
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const MultiplyAttributes& attr,
|
||||
const RuntimeOptions& options);
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_
|
@ -1,98 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h"
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
using ::tflite::gpu::DataType;
|
||||
using ::tflite::gpu::BHWC;
|
||||
using ::tflite::gpu::Linear;
|
||||
using ::tflite::gpu::MultiplyAttributes;
|
||||
using ::tflite::gpu::OperationType;
|
||||
using ::tflite::gpu::Tensor;
|
||||
using ::tflite::gpu::TensorRef;
|
||||
using ::tflite::gpu::metal::CompareVectors;
|
||||
using ::tflite::gpu::metal::SingleOpModel;
|
||||
|
||||
@interface MulTest : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation MulTest
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
}
|
||||
|
||||
- (void)testMulScalar {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
MultiplyAttributes attr;
|
||||
attr.param = 2;
|
||||
|
||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({2, 4, 6, 8}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMulLinear {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
MultiplyAttributes attr;
|
||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||
tensor.shape.v = 2;
|
||||
tensor.id = 1;
|
||||
tensor.data = {2, 3};
|
||||
attr.param = std::move(tensor);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({2, 6, 6, 12}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
@end
|
Loading…
Reference in New Issue
Block a user