Reorganized structure of Elementwise operations.

Support of linking for two input elementwise.
Added broadcast parameters.
Removed ApplyMask(mul + broadcast).

PiperOrigin-RevId: 302545362
Change-Id: Icb9cb94aaad448a205dc7160f4b44820081d69ca
This commit is contained in:
Raman Sarokin 2020-03-23 16:35:47 -07:00 committed by TensorFlower Gardener
parent 1bf7b2f0cf
commit ec5c9bebf8
10 changed files with 240 additions and 279 deletions

View File

@ -51,6 +51,25 @@ namespace tflite {
namespace gpu {
namespace metal {
namespace {
bool IsWidthBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w &&
inputs[1]->tensor.shape.w == 1;
}
bool IsHeightBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h &&
inputs[1]->tensor.shape.h == 1;
}
bool IsChannelsBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
inputs[1]->tensor.shape.c == 1;
}
std::vector<ComputeTaskDescriptorPtr> SelectDepthWiseConv(
int id, ValueId input_id, ValueId output_id,
const DepthwiseConvolution2DAttributes& attr,
@ -134,11 +153,22 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
int node_id = static_cast<int>(node->id);
auto op_type = OperationTypeFromString(node->operation.type);
switch (op_type) {
case OperationType::ADD:
*tasks = Add(node_id, inputs, outputs[0],
absl::any_cast<AddAttributes>(node->operation.attributes),
options);
case OperationType::ADD: {
const auto srcs = graph.FindInputs(node_id);
ElementwiseBroadcastSettings broadcast;
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
if (broadcast.width || broadcast.height || broadcast.channels) {
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type,
broadcast);
} else {
*tasks = Add(node_id, inputs, outputs[0],
absl::any_cast<AddAttributes>(node->operation.attributes),
options);
}
break;
}
case OperationType::CONCAT: {
std::vector<BHWC> input_shapes;
for (auto& input : graph.FindInputs(node->id)) {
@ -194,7 +224,18 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
absl::any_cast<MultiplyAttributes>(node->operation.attributes),
options);
} else {
*tasks = ApplyMask(node_id, inputs[0], inputs[1], outputs[0], options);
if (inputs.size() == 2) {
const auto srcs = graph.FindInputs(node_id);
ElementwiseBroadcastSettings broadcast;
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
op_type, broadcast);
} else {
return absl::UnimplementedError(
"No support of multiply with more than 2 inputs");
}
}
break;
case OperationType::PAD: {
@ -269,8 +310,18 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
case OperationType::SUB: {
const ElementwiseAttributes* attr =
absl::any_cast<ElementwiseAttributes>(&node->operation.attributes);
*tasks =
ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type, attr);
if (attr) {
*tasks = ElementwiseWithOneInputAndConstantArguent(
node_id, inputs[0], outputs[0], options, op_type, *attr);
} else {
const auto srcs = graph.FindInputs(node_id);
ElementwiseBroadcastSettings broadcast;
broadcast.width = IsWidthBroadcastedForSecondInput(srcs);
broadcast.height = IsHeightBroadcastedForSecondInput(srcs);
broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs);
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type,
broadcast);
}
} break;
case OperationType::BATCH_NORMALIZATION:
case OperationType::BATCH_TO_SPACE:

View File

@ -180,10 +180,16 @@ void BuildFusableChains(const std::vector<ValueId>& input_ids,
bool fused = false;
for (auto& chain : *chains) {
// We can fuse only single output for now.
if (Contains(task_descriptor->input_buffers,
chain.back()->output_buffer.id) &&
CanFuseOperations(chain.back(), task_descriptor, output_ids,
*descriptors, chains)) {
bool can_link = false;
if (task_descriptor->is_associative_op) {
can_link = Contains(task_descriptor->input_buffers,
chain.back()->output_buffer.id);
} else {
can_link = task_descriptor->input_buffers[0].id ==
chain.back()->output_buffer.id;
}
if (can_link && CanFuseOperations(chain.back(), task_descriptor,
output_ids, *descriptors, chains)) {
chain.push_back(task_descriptor);
fused = true;
break;

View File

@ -99,6 +99,10 @@ struct ComputeTaskDescriptor {
// $2
// output_buffer[linear_index] = value;
// }
// when operation associative, we can rearrange input tensors
// for example add is associative
bool is_associative_op = false;
std::string shader_source;
std::vector<InputBufferDescriptor> input_buffers;
// A single per-operation output is supported now.

View File

@ -86,6 +86,7 @@ std::vector<ComputeTaskDescriptorPtr> Add(int id,
}
desc->is_linkable = true;
desc->is_associative_op = true;
desc->shader_source = GetAddTableCodeFused(input_ids.size() - 1);
for (int i = 0; i < input_ids.size(); ++i) {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@ -29,115 +30,93 @@ namespace metal {
namespace {
std::string GetElementwiseWithTwoInputsCode(int src_count,
OperationType op_type,
const float* scalar) {
std::string code = R"(
#include <metal_stdlib>
using namespace metal;
std::string OneInputFunctor(OperationType op_type, const std::string& value) {
const std::unordered_map<OperationType, std::string> functors{
{OperationType::ABS, "abs($0)"},
{OperationType::SIN, "sin($0)"},
{OperationType::HARD_SWISH,
"$0 * clamp($0 / 6.0f + FLT4(0.5f), FLT4(0.0f), FLT4(1.0f))"},
{OperationType::COS, "cos($0)"},
{OperationType::EXP, "exp($0)"},
{OperationType::LOG, "log($0)"},
{OperationType::SQRT, "sqrt($0)"},
{OperationType::RSQRT, "1.0 / sqrt($0)"},
{OperationType::SQUARE, "$0 * $0"},
{OperationType::SIGMOID, "1.0 / (1.0 + exp(-1.0 * $0))"},
{OperationType::TANH, "tanh($0)"},
};
struct uniforms {
int4 src_size;
};
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (static_cast<int>(gid.x) >= params.src_size.x ||
static_cast<int>(gid.y) >= params.src_size.y) {
return;
}
int linear_index = (int(gid.z) * params.src_size.y + int(gid.y)) *
params.src_size.x + int(gid.x);
FLT4 src_0 = src_buffer0[linear_index];
)";
if (scalar == nullptr) {
code += " FLT4 src_1 = src_buffer1[linear_index];";
} else {
code += " FLT4 src_1 = FLT4(" + std::to_string(*scalar) + ");";
if (functors.find(op_type) == functors.end()) {
return "Error, unknown op";
}
switch (op_type) {
case OperationType::DIV: {
code += " FLT4 value = src_0 / src_1;";
break;
}
case OperationType::MAXIMUM: {
code += " FLT4 value = max(src_0, src_1);";
break;
}
case OperationType::MINIMUM: {
code += " FLT4 value = min(src_0, src_1);";
break;
}
case OperationType::POW: {
code += " FLT4 value = pow(src_0, src_1);";
break;
}
case OperationType::SQUARED_DIFF: {
code += " FLT4 value = (src_0 - src_1) * (src_0 - src_1);";
break;
}
case OperationType::SUB: {
code += " FLT4 value = src_0 - src_1;";
break;
}
default: {
return "";
}
}
code += R"(
$2
dst_buffer[linear_index] = value;
})";
return code;
return absl::Substitute(functors.at(op_type), value);
}
std::string TwoInputFunctor(OperationType op_type, const std::string& value0,
const std::string& value1) {
const std::unordered_map<OperationType, std::string> functors{
{OperationType::ADD, "$0 + $1"},
{OperationType::DIV, "$0 / $1"},
{OperationType::MAXIMUM, "max($0, $1)"},
{OperationType::MINIMUM, "min($0, $1)"},
{OperationType::MUL, "$0 * $1"},
{OperationType::POW, "pow($0, $1)"},
{OperationType::SQUARED_DIFF, "($0 - $1) * ($0 - $1)"},
{OperationType::SUB, "$0 - $1"},
};
if (functors.find(op_type) == functors.end()) {
return "Error, unknown op";
}
return absl::Substitute(functors.at(op_type), value0, value1);
}
} // namespace
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
int id, std::vector<ValueId> input_ids, ValueId output_id,
OperationType op_type, const ElementwiseAttributes* attr) {
const float* scalar = nullptr;
if (attr) {
scalar = absl::get_if<float>(&attr->param);
}
OperationType op_type, const ElementwiseBroadcastSettings& settings) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source =
GetElementwiseWithTwoInputsCode(input_ids.size(), op_type, scalar);
for (int i = 0; i < input_ids.size(); ++i) {
const std::string buffer_name =
"device FLT4* const src_buffer" + std::to_string(i);
desc->input_buffers.push_back({input_ids[i], buffer_name});
desc->is_linkable = true;
const std::string x_coord = settings.width ? "0" : "int(gid.x)";
const std::string y_coord = settings.height ? "0" : "int(gid.y)";
const std::string s_coord = settings.channels ? "0" : "int(gid.z)";
std::string code =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* "
"const second_tensor, int2 second_size) {\n";
code += " int second_index = (" + s_coord + " * second_size.y + " + y_coord +
") * second_size.x + " + x_coord + ";\n";
code += " FLT4 src_1 = second_tensor[second_index];\n";
if (settings.channels) {
code += " src_1.y = src_1.x;\n";
code += " src_1.z = src_1.x;\n";
code += " src_1.w = src_1.x;\n";
}
code += " return " + TwoInputFunctor(op_type, "value", "src_1") + ";\n";
code += "}\n";
desc->output_buffer = {output_id, "device FLT4* dst_buffer",
[input_ids](const std::map<ValueId, BHWC>& buffers) {
return buffers.find(input_ids[0])->second;
}};
desc->shader_source = code;
desc->input_buffers = {
{input_ids[0], "device FLT4* const"},
{input_ids[1], "device FLT4* const"},
};
desc->output_buffer = {output_id};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_ids](const std::map<ValueId, BHWC>& buffers) {
const auto& dimension = buffers.find(input_ids[0])->second;
std::vector<int> uniform_params = {dimension.w, dimension.h, 0, 0};
{"constant int2&",
[input_ids, output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& input_dim_1 = buffers.find(input_ids[1])->second;
std::vector<int> uniform_params{
input_dim_1.w,
input_dim_1.h,
};
return GetByteBuffer(uniform_params);
}},
};
desc->resize_function = [input_ids](const std::map<ValueId, BHWC>& buffers) {
const auto& src_dim = buffers.find(input_ids[0])->second;
const uint3 groups_size{16, 16, 1};
int groups_x = IntegralDivideRoundUp(src_dim.w, groups_size.x);
int groups_y = IntegralDivideRoundUp(src_dim.h, groups_size.y);
const int dst_layers = IntegralDivideRoundUp(src_dim.c, 4);
int groups_z = IntegralDivideRoundUp(dst_layers, groups_size.z);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
};
return {desc};
}
@ -146,29 +125,10 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = true;
const std::unordered_map<OperationType, std::string> functors{
{OperationType::ABS, "abs(value)"},
{OperationType::SIN, "sin(value)"},
{OperationType::HARD_SWISH,
"value * clamp(value / 6.0f + FLT4(0.5f), FLT4(0.0f), FLT4(1.0f))"},
{OperationType::COS, "cos(value)"},
{OperationType::EXP, "exp(value)"},
{OperationType::LOG, "log(value)"},
{OperationType::SQRT, "sqrt(value)"},
{OperationType::RSQRT, "1.0 / sqrt(value)"},
{OperationType::SQUARE, "value * value"},
{OperationType::SIGMOID, "1.0 / (1.0 + exp(-1.0 * value))"},
{OperationType::TANH, "tanh(value)"},
};
if (functors.count(op_type) == 0) {
return {};
}
desc->shader_source =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {\n";
desc->shader_source += " return " + functors.at(op_type) + ";\n";
desc->shader_source +=
" return " + OneInputFunctor(op_type, "value") + ";\n";
desc->shader_source += " }";
desc->input_buffers = {{input_id}};
@ -176,6 +136,54 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
return {desc};
}
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
OperationType op_type, const ElementwiseAttributes& attr) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = true;
auto scalar = absl::get_if<float>(&attr.param);
auto linear_buf =
absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
std::string param_desc;
if (scalar) {
param_desc += ", float scalar_val";
}
if (linear_buf) {
param_desc += ", device FLT4* const linear_buf";
}
desc->shader_source =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid" + param_desc +
") {\n";
if (scalar) {
desc->shader_source += " FLT4 second_arg = FLT4(scalar_val);\n";
} else if (linear_buf) {
desc->shader_source += " FLT4 second_arg = linear_buf[gid.z];\n";
}
desc->shader_source +=
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
desc->shader_source += " }";
desc->input_buffers = {{input_id}};
desc->output_buffer = {output_id};
if (scalar) {
std::vector<uint8_t> scalar_bits =
GetByteBuffer(std::vector<float>{*scalar});
desc->uniform_buffers = {
{"constant float&",
[scalar_bits](const std::map<ValueId, BHWC>& buffers) {
return scalar_bits;
}},
};
} else if (linear_buf) {
desc->immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(linear_buf->data, options.storage_precision)},
};
}
return {desc};
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -25,13 +25,26 @@ namespace tflite {
namespace gpu {
namespace metal {
struct ElementwiseBroadcastSettings {
bool width = false;
bool height = false;
bool channels = false;
};
// Two inputs are two runtime tensors
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
int id, std::vector<ValueId> input_ids, ValueId output_id,
OperationType op_type, const ElementwiseAttributes* attr);
OperationType op_type, const ElementwiseBroadcastSettings& settings);
// One input is one runtime tensor
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInput(
int id, ValueId input_id, ValueId output_id, OperationType op_type);
// First input is one runtime tensor and second input is constant argument
std::vector<ComputeTaskDescriptorPtr> ElementwiseWithOneInputAndConstantArguent(
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
OperationType op_type, const ElementwiseAttributes& attr);
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -94,7 +94,7 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
- (void)testExp {
OperationType op_type = OperationType::EXP;
const BHWC shape(1, 1, 1, 5);
const BHWC shape(1, 1, 1, 7);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
@ -312,4 +312,34 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMulBroadcastChannels {
OperationType op_type = OperationType::MUL;
const BHWC shape(1, 1, 2, 2);
const BHWC shape_2(1, 1, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape_2)},
/*outputs=*/{GetTensorRef(2, shape)});
XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0}));
XCTAssertTrue(model.PopulateTensor(1, {2.0, 3.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({2.0, 4.0, 9.0, 12.0}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMulBroadcastWidthAndHeight {
OperationType op_type = OperationType::MUL;
const BHWC shape(1, 1, 2, 2);
const BHWC shape_2(1, 1, 1, 2);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape_2)},
/*outputs=*/{GetTensorRef(2, shape)});
XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0}));
XCTAssertTrue(model.PopulateTensor(1, {2.0, 3.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({2.0, 6.0, 6.0, 12.0}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
@end

View File

@ -35,102 +35,6 @@ limitations under the License.
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::string GetApplyMaskCode() {
std::string shader_source = R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_0_size;
int4 src_1_size;
int4 dst_size;
};
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
int X = static_cast<int>(gid.x);
int Y = static_cast<int>(gid.y);
if (X >= params.dst_size.x || Y >= params.dst_size.y) {
return;
}
int src_0_index = (gid.z * params.src_0_size.y + static_cast<int>(gid.y)) *
params.src_0_size.x + static_cast<int>(gid.x);
int src_1_index = 0;
if (params.dst_size.z == 1) {
// [H, W, C] x [H, W, 0][0]
src_1_index = static_cast<int>(gid.y) * params.src_1_size.x +
static_cast<int>(gid.x);
} else if (params.src_0_size.y == params.src_1_size.y &&
params.src_0_size.x == params.src_1_size.x) {
// [H, W, C] x [H, W, C]
src_1_index = src_0_index;
} else {
// [H, W, C] x [0, 0, C]
src_1_index = gid.z * params.src_1_size.y * params.src_1_size.x ;
}
FLT4 value = src_buffer_0[src_0_index] * src_buffer_1[src_1_index];
int linear_index = (gid.z * params.dst_size.y + static_cast<int>(gid.y)) *
params.dst_size.x + static_cast<int>(gid.x);
$2
dst_buffer[linear_index] = value;
}
)";
return shader_source;
}
} // namespace
std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
ValueId input_id_1,
ValueId output_id,
const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = GetApplyMaskCode();
desc->input_buffers = {
{input_id_0, "device FLT4* const src_buffer_0"}, // data
{input_id_1, "device FLT4* const src_buffer_1"}, // mask
};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id_0, input_id_1](const std::map<ValueId, BHWC>& buffers) {
return buffers.find(input_id_0)->second;
}};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_id_0, input_id_1,
output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& input_dim_0 = buffers.find(input_id_0)->second;
const auto& input_dim_1 = buffers.find(input_id_1)->second;
const auto& output_dim = buffers.find(output_id)->second;
std::vector<int> uniform_params{
input_dim_0.w, input_dim_0.h, input_dim_0.c, 0,
input_dim_1.w, input_dim_1.h, input_dim_1.c, 0,
output_dim.w, output_dim.h, output_dim.c, 0,
};
return GetByteBuffer(uniform_params);
}},
};
desc->resize_function = [input_id_0,
input_id_1](const std::map<ValueId, BHWC>& buffers) {
const auto& src_shape = buffers.find(input_id_0)->second;
const uint3 groups_size{16, 16, 1};
int groups_x = IntegralDivideRoundUp(src_shape.w, groups_size.x);
int groups_y = IntegralDivideRoundUp(src_shape.h, groups_size.y);
int groups_z = IntegralDivideRoundUp(src_shape.c, 4);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
};
return {desc};
}
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
ValueId output_id,
const MultiplyAttributes& attr,

View File

@ -30,11 +30,6 @@ std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
ValueId output_id,
const MultiplyAttributes& attr,
const RuntimeOptions& options);
std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
ValueId input_id_1,
ValueId output_id,
const RuntimeOptions& options);
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -95,55 +95,4 @@ using ::tflite::gpu::metal::SingleOpModel;
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testApplyMaskChannel1 {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 2);
TensorRef<BHWC> mask;
mask.type = DataType::FLOAT32;
mask.ref = 1;
mask.shape = BHWC(1, 1, 2, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 1, 2, 2);
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
XCTAssertTrue(model.PopulateTensor(1, {2, 3}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({2, 4, 9, 12}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testApplyMaskEqualsToInputChannel {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 2);
TensorRef<BHWC> mask;
mask.type = DataType::FLOAT32;
mask.ref = 1;
mask.shape = BHWC(1, 1, 2, 2);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 1, 2, 2);
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
XCTAssertTrue(model.PopulateTensor(1, {1, 2, 3, 4}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
@end