Add Maximum & Minimum op support for Metal

PiperOrigin-RevId: 296149175
Change-Id: I3d26f756cb8f5fe0d94fac3f8515da8b2124dcc4
This commit is contained in:
Terry Heo 2020-02-20 00:47:06 -08:00 committed by TensorFlower Gardener
parent 76562fef92
commit 6343b77f13
4 changed files with 94 additions and 22 deletions

View File

@ -271,9 +271,12 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
case OperationType::MINIMUM:
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB:
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type);
break;
case OperationType::SUB: {
const ElementwiseAttributes* attr =
absl::any_cast<ElementwiseAttributes>(&node->operation.attributes);
*tasks =
ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type, attr);
} break;
case OperationType::BATCH_NORMALIZATION:
case OperationType::BATCH_TO_SPACE:
case OperationType::CONST:

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
#include <cstddef>
#include <unordered_map>
#include <vector>
@ -29,7 +30,8 @@ namespace metal {
namespace {
std::string GetElementwiseWithTwoInputsCode(int src_count,
OperationType op_type) {
OperationType op_type,
const float* scalar) {
std::string code = R"(
#include <metal_stdlib>
using namespace metal;
@ -49,33 +51,38 @@ std::string GetElementwiseWithTwoInputsCode(int src_count,
int linear_index = (int(gid.z) * params.src_size.y + int(gid.y)) *
params.src_size.x + int(gid.x);
)";
FLT4 src_0 = src_buffer0[linear_index];
)";
if (scalar == nullptr) {
code += " FLT4 src_1 = src_buffer1[linear_index];";
} else {
code +=
absl::StrCat(" FLT4 src_1 = FLT4(", std::to_string(*scalar), ");");
}
switch (op_type) {
case OperationType::DIV: {
code +=
" FLT4 value = src_buffer0[linear_index] / "
"src_buffer1[linear_index];";
code += " FLT4 value = src_0 / src_1;";
break;
}
case OperationType::MAXIMUM: {
code += " FLT4 value = max(src_0, src_1);";
break;
}
case OperationType::MINIMUM: {
code += " FLT4 value = min(src_0, src_1);";
break;
}
case OperationType::POW: {
code +=
" FLT4 value = pow(src_buffer0[linear_index], "
"src_buffer1[linear_index]);";
code += " FLT4 value = pow(src_0, src_1);";
break;
}
case OperationType::SQUARED_DIFF: {
code += R"(
FLT4 src_0 = src_buffer0[linear_index];
FLT4 src_1 = src_buffer1[linear_index];
FLT4 value = (src_0 - src_1) * (src_0 - src_1);
)";
code += " FLT4 value = (src_0 - src_1) * (src_0 - src_1);";
break;
}
case OperationType::SUB: {
code +=
" FLT4 value = src_buffer0[linear_index] - "
"src_buffer1[linear_index];";
code += " FLT4 value = src_0 - src_1;";
break;
}
default: {
@ -92,12 +99,16 @@ std::string GetElementwiseWithTwoInputsCode(int src_count,
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
int id, std::vector<ValueId> input_ids, ValueId output_id,
OperationType op_type) {
OperationType op_type, const ElementwiseAttributes* attr) {
const float* scalar = nullptr;
if (attr) {
scalar = absl::get_if<float>(&attr->param);
}
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source =
GetElementwiseWithTwoInputsCode(input_ids.size(), op_type);
GetElementwiseWithTwoInputsCode(input_ids.size(), op_type, scalar);
for (int i = 0; i < input_ids.size(); ++i) {
const std::string buffer_name =

View File

@ -27,7 +27,7 @@ namespace metal {
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
int id, std::vector<ValueId> input_ids, ValueId output_id,
OperationType op_type);
OperationType op_type, const ElementwiseAttributes* attr);
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
int id, ValueId input_id, ValueId output_id, OperationType op_type);

View File

@ -118,6 +118,64 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
}
- (void)testMaximum {
OperationType op_type = OperationType::MAXIMUM;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, -2.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
status = CompareVectors({1.0, 2.0, 3.0, -2.0}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
}
- (void)testMaximumWithScalar {
OperationType op_type = OperationType::MAXIMUM;
const BHWC shape(1, 2, 2, 1);
tflite::gpu::ElementwiseAttributes attr;
attr.param = -1.0f;
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/attr},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
status = CompareVectors({0.0, -1.0, 2.0, -1.0}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
}
- (void)testMinimum {
OperationType op_type = OperationType::MINIMUM;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
XCTAssertTrue(model.PopulateTensor(1, {1.0, 2.0, 3.0, -2.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
status = CompareVectors({0.0, -6.2, 2.0, -3.0}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
}
- (void)testMinimumWithScalar {
OperationType op_type = OperationType::MINIMUM;
const BHWC shape(1, 2, 2, 1);
tflite::gpu::ElementwiseAttributes attr;
attr.param = -1.0f;
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/attr},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
XCTAssertTrue(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
status = CompareVectors({-1.0, -6.2, -1.0, -3.0}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
}
- (void)testPow {
OperationType op_type = OperationType::POW;
const BHWC shape(1, 2, 2, 1);