iOS Metal delegate: add broadcast operation is supported now.
PiperOrigin-RevId: 272298006
This commit is contained in:
parent
264073b659
commit
cb98da75f2
tensorflow/lite/delegates/gpu/metal/kernels
@ -86,6 +86,23 @@ std::vector<ComputeTaskDescriptorPtr> Add(int id,
|
||||
desc->output_buffer = {output_id};
|
||||
return {desc};
|
||||
}
|
||||
// Add vector
|
||||
auto broadcast = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
||||
if (broadcast) {
|
||||
desc->is_linkable = true;
|
||||
desc->shader_source =
|
||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
|
||||
device FLT4* const broadcast) { return value + broadcast[gid.z]; })";
|
||||
desc->input_buffers = {{input_ids[0]}};
|
||||
desc->output_buffer = {output_id};
|
||||
auto values = options.storage_precision == RuntimeOptions::Precision::FP32
|
||||
? VectorToUint8Vector(broadcast->data)
|
||||
: VectorFloatToHalf(broadcast->data);
|
||||
desc->immutable_buffers = {
|
||||
{"device FLT4* const", values},
|
||||
};
|
||||
return {desc};
|
||||
}
|
||||
|
||||
desc->is_linkable = false;
|
||||
desc->shader_source = GetAddTableCode(input_ids.size());
|
||||
|
@ -31,10 +31,12 @@ limitations under the License.
|
||||
using ::tflite::gpu::AddAttributes;
|
||||
using ::tflite::gpu::BHWC;
|
||||
using ::tflite::gpu::DataType;
|
||||
using ::tflite::gpu::Linear;
|
||||
using ::tflite::gpu::OperationType;
|
||||
using ::tflite::gpu::Tensor;
|
||||
using ::tflite::gpu::TensorRef;
|
||||
using ::tflite::gpu::metal::CompareVectors;
|
||||
using ::tflite::gpu::metal::SingleOpModel;
|
||||
using ::tflite::gpu::TensorRef;
|
||||
using ::tflite::gpu::OperationType;
|
||||
|
||||
@interface AddTest : XCTestCase
|
||||
@end
|
||||
@ -88,4 +90,32 @@ using ::tflite::gpu::OperationType;
|
||||
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
|
||||
}
|
||||
|
||||
- (void)testInputTensorWithConstandBroadcast {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 2);
|
||||
|
||||
AddAttributes attr;
|
||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||
tensor.shape.v = 2;
|
||||
tensor.id = 1;
|
||||
tensor.data.push_back(10.0);
|
||||
tensor.data.push_back(20.0);
|
||||
attr.param = std::move(tensor);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 2, 2, 2);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
|
||||
status =
|
||||
CompareVectors({11.0, 22.0, 13.0, 24.0, 15.0, 26.0, 17.0, 28.0}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
|
||||
}
|
||||
|
||||
@end
|
||||
|
Loading…
Reference in New Issue
Block a user