iOS Metal delegate: add broadcast operation is supported now.

PiperOrigin-RevId: 272298006
This commit is contained in:
A. Unique TensorFlower 2019-10-01 14:42:27 -07:00 committed by TensorFlower Gardener
parent 264073b659
commit cb98da75f2
2 changed files with 49 additions and 2 deletions
tensorflow/lite/delegates/gpu/metal/kernels

View File

@ -86,6 +86,23 @@ std::vector<ComputeTaskDescriptorPtr> Add(int id,
desc->output_buffer = {output_id};
return {desc};
}
// Add vector
auto broadcast = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
if (broadcast) {
desc->is_linkable = true;
desc->shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
device FLT4* const broadcast) { return value + broadcast[gid.z]; })";
desc->input_buffers = {{input_ids[0]}};
desc->output_buffer = {output_id};
auto values = options.storage_precision == RuntimeOptions::Precision::FP32
? VectorToUint8Vector(broadcast->data)
: VectorFloatToHalf(broadcast->data);
desc->immutable_buffers = {
{"device FLT4* const", values},
};
return {desc};
}
desc->is_linkable = false;
desc->shader_source = GetAddTableCode(input_ids.size());

View File

@ -31,10 +31,12 @@ limitations under the License.
using ::tflite::gpu::AddAttributes;
using ::tflite::gpu::BHWC;
using ::tflite::gpu::DataType;
using ::tflite::gpu::Linear;
using ::tflite::gpu::OperationType;
using ::tflite::gpu::Tensor;
using ::tflite::gpu::TensorRef;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
using ::tflite::gpu::TensorRef;
using ::tflite::gpu::OperationType;
@interface AddTest : XCTestCase
@end
@ -88,4 +90,32 @@ using ::tflite::gpu::OperationType;
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
}
- (void)testInputTensorWithConstandBroadcast {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 2);
AddAttributes attr;
Tensor<Linear, DataType::FLOAT32> tensor;
tensor.shape.v = 2;
tensor.id = 1;
tensor.data.push_back(10.0);
tensor.data.push_back(20.0);
attr.param = std::move(tensor);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 2, 2, 2);
SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
status =
CompareVectors({11.0, 22.0, 13.0, 24.0, 15.0, 26.0, 17.0, 28.0}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
}
@end