Metal: 'Add' operation supports scalar.

PiperOrigin-RevId: 258254040
This commit is contained in:
A. Unique TensorFlower 2019-07-15 15:52:32 -07:00 committed by TensorFlower Gardener
parent 7c33772967
commit b4f842384a
4 changed files with 40 additions and 34 deletions

View File

@ -134,7 +134,9 @@ Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
auto op_type = OperationTypeFromString(node->operation.type);
switch (op_type) {
case OperationType::ADD:
tasks = AddTable(node_id, inputs, outputs[0]);
tasks = Add(node_id, inputs, outputs[0],
absl::any_cast<AddAttributes>(node->operation.attributes),
options);
break;
case OperationType::CONCAT: {
std::vector<BHWC> input_shapes;

View File

@ -66,39 +66,27 @@ std::string GetAddTableCode(int src_count) {
}
} // namespace
std::vector<ComputeTaskDescriptorPtr> Add(int id, ValueId input_id,
std::vector<ComputeTaskDescriptorPtr> Add(int id,
const std::vector<ValueId> input_ids,
ValueId output_id,
const AddAttributes& attr,
const RuntimeOptions& options) {
auto add_buffer =
absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
if (!add_buffer) {
return {};
}
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = true;
desc->shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
device FLT4* const add_buf) {
return value + add_buf[gid.z];
})";
desc->input_buffers = {{input_id}};
desc->output_buffer = {output_id};
auto coeffs = options.storage_precision == RuntimeOptions::Precision::FP32
? VectorToUint8Vector(add_buffer->data)
: VectorFloatToHalf(add_buffer->data);
desc->immutable_buffers = {
{"device FLT4* const", coeffs},
};
return {desc};
}
std::vector<ComputeTaskDescriptorPtr> AddTable(int id,
std::vector<ValueId> input_ids,
ValueId output_id) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
// Add scalar
const float* add_value = absl::get_if<float>(&attr.param);
if (add_value) {
desc->is_linkable = true;
desc->shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {
return value + )" +
std::to_string(*add_value) + ";}";
desc->input_buffers = {{input_ids[0]}};
desc->output_buffer = {output_id};
return {desc};
}
desc->is_linkable = false;
desc->shader_source = GetAddTableCode(input_ids.size());

View File

@ -28,16 +28,12 @@ namespace gpu {
namespace metal {
// Add with broadcast.
std::vector<ComputeTaskDescriptorPtr> Add(int id, ValueId input_id,
std::vector<ComputeTaskDescriptorPtr> Add(int id,
const std::vector<ValueId> input_ids,
ValueId output_id,
const AddAttributes& attr,
const RuntimeOptions& options);
// Add tensors.
std::vector<ComputeTaskDescriptorPtr> AddTable(int id,
std::vector<ValueId> input_ids,
ValueId output_id);
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -68,4 +68,24 @@ using ::tflite::gpu::OperationType;
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
}
- (void)testInputTensorAndScalar {
AddAttributes attr;
attr.param = 0.1f;
TensorRef<BHWC> input, output;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 1, 2);
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 3, 1, 2);
SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
status = CompareVectors({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
}
@end