Metal: 'Add' operation supports scalar.
PiperOrigin-RevId: 258254040
This commit is contained in:
parent
7c33772967
commit
b4f842384a
@ -134,7 +134,9 @@ Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
|
||||
auto op_type = OperationTypeFromString(node->operation.type);
|
||||
switch (op_type) {
|
||||
case OperationType::ADD:
|
||||
tasks = AddTable(node_id, inputs, outputs[0]);
|
||||
tasks = Add(node_id, inputs, outputs[0],
|
||||
absl::any_cast<AddAttributes>(node->operation.attributes),
|
||||
options);
|
||||
break;
|
||||
case OperationType::CONCAT: {
|
||||
std::vector<BHWC> input_shapes;
|
||||
|
@ -66,39 +66,27 @@ std::string GetAddTableCode(int src_count) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> Add(int id, ValueId input_id,
|
||||
std::vector<ComputeTaskDescriptorPtr> Add(int id,
|
||||
const std::vector<ValueId> input_ids,
|
||||
ValueId output_id,
|
||||
const AddAttributes& attr,
|
||||
const RuntimeOptions& options) {
|
||||
auto add_buffer =
|
||||
absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
||||
if (!add_buffer) {
|
||||
return {};
|
||||
}
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
desc->is_linkable = true;
|
||||
desc->shader_source =
|
||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
|
||||
device FLT4* const add_buf) {
|
||||
return value + add_buf[gid.z];
|
||||
})";
|
||||
desc->input_buffers = {{input_id}};
|
||||
desc->output_buffer = {output_id};
|
||||
auto coeffs = options.storage_precision == RuntimeOptions::Precision::FP32
|
||||
? VectorToUint8Vector(add_buffer->data)
|
||||
: VectorFloatToHalf(add_buffer->data);
|
||||
desc->immutable_buffers = {
|
||||
{"device FLT4* const", coeffs},
|
||||
};
|
||||
return {desc};
|
||||
}
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> AddTable(int id,
|
||||
std::vector<ValueId> input_ids,
|
||||
ValueId output_id) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
// Add scalar
|
||||
const float* add_value = absl::get_if<float>(&attr.param);
|
||||
if (add_value) {
|
||||
desc->is_linkable = true;
|
||||
desc->shader_source =
|
||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {
|
||||
return value + )" +
|
||||
std::to_string(*add_value) + ";}";
|
||||
desc->input_buffers = {{input_ids[0]}};
|
||||
desc->output_buffer = {output_id};
|
||||
return {desc};
|
||||
}
|
||||
|
||||
desc->is_linkable = false;
|
||||
desc->shader_source = GetAddTableCode(input_ids.size());
|
||||
|
||||
|
@ -28,16 +28,12 @@ namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
// Add with broadcast.
|
||||
std::vector<ComputeTaskDescriptorPtr> Add(int id, ValueId input_id,
|
||||
std::vector<ComputeTaskDescriptorPtr> Add(int id,
|
||||
const std::vector<ValueId> input_ids,
|
||||
ValueId output_id,
|
||||
const AddAttributes& attr,
|
||||
const RuntimeOptions& options);
|
||||
|
||||
// Add tensors.
|
||||
std::vector<ComputeTaskDescriptorPtr> AddTable(int id,
|
||||
std::vector<ValueId> input_ids,
|
||||
ValueId output_id);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -68,4 +68,24 @@ using ::tflite::gpu::OperationType;
|
||||
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
|
||||
}
|
||||
|
||||
- (void)testInputTensorAndScalar {
|
||||
AddAttributes attr;
|
||||
attr.param = 0.1f;
|
||||
TensorRef<BHWC> input, output;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 3, 1, 2);
|
||||
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 3, 1, 2);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
|
||||
status = CompareVectors({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
|
||||
}
|
||||
|
||||
@end
|
||||
|
Loading…
x
Reference in New Issue
Block a user