Add Maximum & Minimum op support for Metal
PiperOrigin-RevId: 296149175 Change-Id: I3d26f756cb8f5fe0d94fac3f8515da8b2124dcc4
This commit is contained in:
parent
76562fef92
commit
6343b77f13
@ -271,9 +271,12 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
case OperationType::MINIMUM:
|
||||
case OperationType::POW:
|
||||
case OperationType::SQUARED_DIFF:
|
||||
case OperationType::SUB:
|
||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type);
|
||||
break;
|
||||
case OperationType::SUB: {
|
||||
const ElementwiseAttributes* attr =
|
||||
absl::any_cast<ElementwiseAttributes>(&node->operation.attributes);
|
||||
*tasks =
|
||||
ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type, attr);
|
||||
} break;
|
||||
case OperationType::BATCH_NORMALIZATION:
|
||||
case OperationType::BATCH_TO_SPACE:
|
||||
case OperationType::CONST:
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
@ -29,7 +30,8 @@ namespace metal {
|
||||
namespace {
|
||||
|
||||
std::string GetElementwiseWithTwoInputsCode(int src_count,
|
||||
OperationType op_type) {
|
||||
OperationType op_type,
|
||||
const float* scalar) {
|
||||
std::string code = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
@ -49,33 +51,38 @@ std::string GetElementwiseWithTwoInputsCode(int src_count,
|
||||
|
||||
int linear_index = (int(gid.z) * params.src_size.y + int(gid.y)) *
|
||||
params.src_size.x + int(gid.x);
|
||||
)";
|
||||
FLT4 src_0 = src_buffer0[linear_index];
|
||||
)";
|
||||
|
||||
if (scalar == nullptr) {
|
||||
code += " FLT4 src_1 = src_buffer1[linear_index];";
|
||||
} else {
|
||||
code +=
|
||||
absl::StrCat(" FLT4 src_1 = FLT4(", std::to_string(*scalar), ");");
|
||||
}
|
||||
switch (op_type) {
|
||||
case OperationType::DIV: {
|
||||
code +=
|
||||
" FLT4 value = src_buffer0[linear_index] / "
|
||||
"src_buffer1[linear_index];";
|
||||
code += " FLT4 value = src_0 / src_1;";
|
||||
break;
|
||||
}
|
||||
case OperationType::MAXIMUM: {
|
||||
code += " FLT4 value = max(src_0, src_1);";
|
||||
break;
|
||||
}
|
||||
case OperationType::MINIMUM: {
|
||||
code += " FLT4 value = min(src_0, src_1);";
|
||||
break;
|
||||
}
|
||||
case OperationType::POW: {
|
||||
code +=
|
||||
" FLT4 value = pow(src_buffer0[linear_index], "
|
||||
"src_buffer1[linear_index]);";
|
||||
code += " FLT4 value = pow(src_0, src_1);";
|
||||
break;
|
||||
}
|
||||
case OperationType::SQUARED_DIFF: {
|
||||
code += R"(
|
||||
FLT4 src_0 = src_buffer0[linear_index];
|
||||
FLT4 src_1 = src_buffer1[linear_index];
|
||||
FLT4 value = (src_0 - src_1) * (src_0 - src_1);
|
||||
)";
|
||||
code += " FLT4 value = (src_0 - src_1) * (src_0 - src_1);";
|
||||
break;
|
||||
}
|
||||
case OperationType::SUB: {
|
||||
code +=
|
||||
" FLT4 value = src_buffer0[linear_index] - "
|
||||
"src_buffer1[linear_index];";
|
||||
code += " FLT4 value = src_0 - src_1;";
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
@ -92,12 +99,16 @@ std::string GetElementwiseWithTwoInputsCode(int src_count,
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
|
||||
int id, std::vector<ValueId> input_ids, ValueId output_id,
|
||||
OperationType op_type) {
|
||||
OperationType op_type, const ElementwiseAttributes* attr) {
|
||||
const float* scalar = nullptr;
|
||||
if (attr) {
|
||||
scalar = absl::get_if<float>(&attr->param);
|
||||
}
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
desc->is_linkable = false;
|
||||
desc->shader_source =
|
||||
GetElementwiseWithTwoInputsCode(input_ids.size(), op_type);
|
||||
GetElementwiseWithTwoInputsCode(input_ids.size(), op_type, scalar);
|
||||
|
||||
for (int i = 0; i < input_ids.size(); ++i) {
|
||||
const std::string buffer_name =
|
||||
|
@ -27,7 +27,7 @@ namespace metal {
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
|
||||
int id, std::vector<ValueId> input_ids, ValueId output_id,
|
||||
OperationType op_type);
|
||||
OperationType op_type, const ElementwiseAttributes* attr);
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
|
||||
int id, ValueId input_id, ValueId output_id, OperationType op_type);
|
||||
|
@ -118,6 +118,64 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
- (void)testMaximum {
|
||||
OperationType op_type = OperationType::MAXIMUM;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
|
||||
/*outputs=*/{GetTensorRef(2, shape)});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
|
||||
XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, -2.0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
status = CompareVectors({1.0, 2.0, 3.0, -2.0}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
- (void)testMaximumWithScalar {
|
||||
OperationType op_type = OperationType::MAXIMUM;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
tflite::gpu::ElementwiseAttributes attr;
|
||||
attr.param = -1.0f;
|
||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/attr},
|
||||
/*inputs=*/{GetTensorRef(0, shape)},
|
||||
/*outputs=*/{GetTensorRef(1, shape)});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
status = CompareVectors({0.0, -1.0, 2.0, -1.0}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
- (void)testMinimum {
|
||||
OperationType op_type = OperationType::MINIMUM;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
|
||||
/*outputs=*/{GetTensorRef(2, shape)});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
|
||||
XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, -2.0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
status = CompareVectors({0.0, -6.2, 2.0, -3.0}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
- (void)testMinimumWithScalar {
|
||||
OperationType op_type = OperationType::MINIMUM;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
tflite::gpu::ElementwiseAttributes attr;
|
||||
attr.param = -1.0f;
|
||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/attr},
|
||||
/*inputs=*/{GetTensorRef(0, shape)},
|
||||
/*outputs=*/{GetTensorRef(1, shape)});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
status = CompareVectors({-1.0, -6.2, -1.0, -3.0}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
- (void)testPow {
|
||||
OperationType op_type = OperationType::POW;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
|
Loading…
x
Reference in New Issue
Block a user