TFLComputeTask converted to c++ style ComputeTask.

PiperOrigin-RevId: 346908996
Change-Id: I8df1ee7b777a0f8b40580922b2fe91cb0d2983cd
This commit is contained in:
Raman Sarokin 2020-12-10 18:30:58 -08:00 committed by TensorFlower Gardener
parent 579ce3a2e1
commit a2d2b5c3d5
3 changed files with 141 additions and 139 deletions

View File

@ -27,34 +27,77 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/precision.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/metal/common.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
@interface TFLComputeTask : NSObject
namespace tflite {
namespace gpu {
namespace metal {
/// Returns empty string or error if shader can't be compiled.
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc
precision:(tflite::gpu::CalculationsPrecision)precision;
class ComputeTask {
public:
ComputeTask() = default;
/// Updates parameters for inputs/outputs/intermediate tensors
- (absl::Status)updateParamsWithDevice:(id<MTLDevice>)device
tensorShapes:(const std::map<tflite::gpu::ValueId, tflite::gpu::BHWC>&)
tensorShapes;
// Move only
ComputeTask(ComputeTask&& args) = default;
ComputeTask& operator=(ComputeTask&& args) = default;
ComputeTask(const ComputeTask&) = delete;
ComputeTask& operator=(const ComputeTask&) = delete;
- (bool)hasInOutIds:(const std::set<::tflite::gpu::ValueId>&)ids;
/// Returns empty string or error if shader can't be compiled.
absl::Status CompileWithDevice(id<MTLDevice> device,
const NodeDescriptor& desc,
CalculationsPrecision precision);
- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder;
/// Updates parameters for inputs/outputs/intermediate tensors
absl::Status UpdateParamsWithDevice(
id<MTLDevice> device, const std::map<ValueId, BHWC>& tensor_shapes);
- (std::vector<tflite::gpu::ValueId>)getOutputIds;
- (std::vector<tflite::gpu::ValueId>)getInputIds;
bool HasInOutIds(const std::set<ValueId>& ids) const;
- (void)setSrcTensor:(const tflite::gpu::metal::MetalSpatialTensor&)tensor withIndex:(int)index;
void EncodeWithEncoder(id<MTLComputeCommandEncoder> encoder);
- (void)setDstTensor:(const tflite::gpu::metal::MetalSpatialTensor&)tensor withIndex:(int)index;
std::vector<ValueId> GetOutputIds() const;
std::vector<ValueId> GetInputIds() const;
- (void)setDescription:(const std::string&)description;
void SetSrcTensor(const MetalSpatialTensor& tensor, int index);
@end
void SetDstTensor(const MetalSpatialTensor& tensor, int index);
void SetDescription(const std::string& description);
private:
struct InputBuffer {
ValueId uid;
id<MTLBuffer> metal_handle;
};
struct OutputBuffer {
ValueId uid;
id<MTLBuffer> metal_handle;
};
struct UniformBuffer {
std::vector<uint8_t> data;
UniformsFunction data_function;
};
id<MTLComputePipelineState> program_;
std::vector<InputBuffer> input_buffers_;
std::vector<OutputBuffer> output_buffers_;
std::vector<id<MTLBuffer>> immutable_buffers_;
std::vector<UniformBuffer> uniform_buffers_;
uint3 groups_size_;
uint3 groups_count_;
DispatchParamsFunction resize_function_;
std::string description_;
MetalArguments metal_args_;
};
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_

View File

@ -27,55 +27,16 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/common.h"
using ::tflite::gpu::AlignByN;
using ::tflite::gpu::BHWC;
using ::tflite::gpu::HalfBits;
using ::tflite::gpu::metal::ComputeTaskDescriptorPtr;
using ::tflite::gpu::metal::CreateComputeProgram;
using ::tflite::gpu::metal::DispatchParamsFunction;
using ::tflite::gpu::CalculationsPrecision;
using ::tflite::gpu::metal::UniformsFunction;
using ::tflite::gpu::uint3;
using ::tflite::gpu::ValueId;
namespace tflite {
namespace gpu {
namespace metal {
namespace {
struct InputBuffer {
ValueId uid;
id<MTLBuffer> metalHandle;
};
struct OutputBuffer {
ValueId uid;
id<MTLBuffer> metalHandle;
};
struct UniformBuffer {
std::vector<uint8_t> data;
UniformsFunction dataFunction;
};
} // namespace
@implementation TFLComputeTask {
id<MTLComputePipelineState> _program;
std::vector<InputBuffer> _inputBuffers;
std::vector<OutputBuffer> _outputBuffers;
std::vector<id<MTLBuffer>> _immutableBuffers;
std::vector<UniformBuffer> _uniformBuffers;
uint3 _groupsSize;
uint3 _groupsCount;
DispatchParamsFunction _resizeFunction;
std::string _description;
tflite::gpu::metal::MetalArguments _metal_args;
}
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc
precision:(CalculationsPrecision)precision; {
absl::Status ComputeTask::CompileWithDevice(id<MTLDevice> device,
const NodeDescriptor& desc,
CalculationsPrecision precision) {
size_t offset = desc.task->src_tensors_names.size() + desc.task->uniform_buffers.size()
+ desc.task->immutable_buffers.size() + 1;
RETURN_IF_ERROR(_metal_args.Init(device, offset, &desc.task->args, &desc.task->shader_source));
RETURN_IF_ERROR(metal_args_.Init(device, offset, &desc.task->args, &desc.task->shader_source));
NSString* barrier;
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
@ -129,12 +90,12 @@ struct UniformBuffer {
return absl::InternalError("Unknown shader compilation error");
}
for (auto& id : desc.src_tensors_ids) {
_inputBuffers.emplace_back(InputBuffer{id, nil});
input_buffers_.emplace_back(InputBuffer{id, nil});
}
for (auto& uniform : desc.task->uniform_buffers) {
_uniformBuffers.emplace_back(UniformBuffer{{}, uniform.data_function});
uniform_buffers_.emplace_back(UniformBuffer{{}, uniform.data_function});
}
_outputBuffers.emplace_back(OutputBuffer{desc.dst_tensors_ids[0], nil});
output_buffers_.emplace_back(OutputBuffer{desc.dst_tensors_ids[0], nil});
const bool f32_storage = precision == CalculationsPrecision::F32;
for (auto& immutable : desc.task->immutable_buffers) {
int padding = 4 * (f32_storage ? sizeof(float) : sizeof(HalfBits));
@ -143,61 +104,60 @@ struct UniformBuffer {
id<MTLBuffer> metalBuffer = [device newBufferWithBytes:immutable.data.data()
length:immutable.data.size()
options:MTLResourceStorageModeShared];
_immutableBuffers.emplace_back(metalBuffer);
immutable_buffers_.emplace_back(metalBuffer);
}
_resizeFunction = desc.task->resize_function;
_program = program;
resize_function_ = desc.task->resize_function;
program_ = program;
return absl::OkStatus();
}
- (absl::Status)
updateParamsWithDevice:(id<MTLDevice>)device
tensorShapes:(const std::map<tflite::gpu::ValueId, tflite::gpu::BHWC>&)tensorShapes {
absl::Status ComputeTask::UpdateParamsWithDevice(
id<MTLDevice> device, const std::map<ValueId, BHWC>& tensor_shapes) {
std::vector<BHWC> src_shapes;
std::vector<BHWC> dst_shapes;
for (const auto& in_buf : _inputBuffers) {
auto it = tensorShapes.find(in_buf.uid);
if (it == tensorShapes.end()) {
for (const auto& in_buf : input_buffers_) {
auto it = tensor_shapes.find(in_buf.uid);
if (it == tensor_shapes.end()) {
return absl::InvalidArgumentError("Missing tensor shape");
}
src_shapes.push_back(it->second);
}
for (const auto& out_buf : _outputBuffers) {
auto it = tensorShapes.find(out_buf.uid);
if (it == tensorShapes.end()) {
for (const auto& out_buf : output_buffers_) {
auto it = tensor_shapes.find(out_buf.uid);
if (it == tensor_shapes.end()) {
return absl::InvalidArgumentError("Missing tensor shape");
}
dst_shapes.push_back(it->second);
}
for (auto& uniform : _uniformBuffers) {
uniform.data = uniform.dataFunction(src_shapes, dst_shapes);
for (auto& uniform : uniform_buffers_) {
uniform.data = uniform.data_function(src_shapes, dst_shapes);
}
// Dispatch parameters re-calculation
auto workGroups = _resizeFunction(src_shapes, dst_shapes);
_groupsSize = workGroups.first;
auto workGroups = resize_function_(src_shapes, dst_shapes);
groups_size_ = workGroups.first;
MTLSize threadsPerGroup = [device maxThreadsPerThreadgroup];
if (_groupsSize.x > threadsPerGroup.width || _groupsSize.y > threadsPerGroup.height ||
_groupsSize.z > threadsPerGroup.depth) {
if (groups_size_.x > threadsPerGroup.width || groups_size_.y > threadsPerGroup.height ||
groups_size_.z > threadsPerGroup.depth) {
std::string error("Threads per working group: ");
error += std::to_string(_groupsSize.x) + ", " + std::to_string(_groupsSize.y) + ", " +
std::to_string(_groupsSize.z);
error += std::to_string(groups_size_.x) + ", " + std::to_string(groups_size_.y) + ", " +
std::to_string(groups_size_.z);
error += "is larger than the MTLDevice can support: ";
error += std::to_string(threadsPerGroup.width) + ", " + std::to_string(threadsPerGroup.height) +
", " + std::to_string(threadsPerGroup.depth);
return absl::InvalidArgumentError(error);
}
_groupsCount = workGroups.second;
groups_count_ = workGroups.second;
return absl::OkStatus();
}
- (bool)hasInOutIds:(const std::set<::tflite::gpu::ValueId>&)ids {
for (auto& buffer : _inputBuffers) {
bool ComputeTask::HasInOutIds(const std::set<ValueId>& ids) const {
for (auto& buffer : input_buffers_) {
if (ids.count(buffer.uid)) {
return true;
}
}
for (auto& buffer : _outputBuffers) {
for (auto& buffer : output_buffers_) {
if (ids.count(buffer.uid)) {
return true;
}
@ -205,66 +165,66 @@ struct UniformBuffer {
return false;
}
- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder {
void ComputeTask::EncodeWithEncoder(id<MTLComputeCommandEncoder> encoder) {
// The dispatch call is intended to be skipped.
if (_groupsCount.x * _groupsCount.y * _groupsCount.z == 0) {
if (groups_count_.x * groups_count_.y * groups_count_.z == 0) {
return;
}
[encoder setComputePipelineState:_program];
[encoder setComputePipelineState:program_];
int bindIndex = 0;
for (const auto& buffer : _outputBuffers) {
[encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
for (const auto& buffer : output_buffers_) {
[encoder setBuffer:buffer.metal_handle offset:0 atIndex:bindIndex];
bindIndex++;
}
for (const auto& buffer : _inputBuffers) {
[encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
for (const auto& buffer : input_buffers_) {
[encoder setBuffer:buffer.metal_handle offset:0 atIndex:bindIndex];
bindIndex++;
}
for (auto& immutable : _immutableBuffers) {
for (auto& immutable : immutable_buffers_) {
[encoder setBuffer:immutable offset:0 atIndex:bindIndex];
bindIndex++;
}
for (auto& uniform : _uniformBuffers) {
for (auto& uniform : uniform_buffers_) {
[encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex];
bindIndex++;
}
_metal_args.Encode(encoder, bindIndex);
metal_args_.Encode(encoder, bindIndex);
MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z);
MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z);
MTLSize groupsCount = MTLSizeMake(groups_count_.x, groups_count_.y, groups_count_.z);
MTLSize groupsSize = MTLSizeMake(groups_size_.x, groups_size_.y, groups_size_.z);
[encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
}
- (std::vector<tflite::gpu::ValueId>)getOutputIds {
std::vector<ValueId> ComputeTask::GetOutputIds() const {
std::vector<tflite::gpu::ValueId> result;
for (auto& buffer : _outputBuffers) {
for (auto& buffer : output_buffers_) {
result.push_back(buffer.uid);
}
return result;
}
- (std::vector<tflite::gpu::ValueId>)getInputIds {
std::vector<ValueId> ComputeTask::GetInputIds() const {
std::vector<tflite::gpu::ValueId> result;
for (auto& buffer : _inputBuffers) {
for (auto& buffer : input_buffers_) {
result.push_back(buffer.uid);
}
return result;
}
- (void)setSrcTensor:(const tflite::gpu::metal::MetalSpatialTensor&)tensor
withIndex:(int)index; {
_inputBuffers[index].metalHandle = tensor.GetBufferHandle();
void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
input_buffers_[index].metal_handle = tensor.GetBufferHandle();
}
- (void)setDstTensor:(const tflite::gpu::metal::MetalSpatialTensor&)tensor
withIndex:(int)index; {
_outputBuffers[index].metalHandle = tensor.GetBufferHandle();
void ComputeTask::SetDstTensor(const MetalSpatialTensor& tensor, int index) {
output_buffers_[index].metal_handle = tensor.GetBufferHandle();
}
- (void)setDescription:(const std::string&)description {
_description = description;
void ComputeTask::SetDescription(const std::string& description) {
description_ = description;
}
@end
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -37,6 +37,7 @@ using ::tflite::gpu::DataType;
using ::tflite::gpu::HalfBits;
using ::tflite::gpu::int2;
using ::tflite::gpu::MemoryStrategy;
using ::tflite::gpu::metal::ComputeTask;
using ::tflite::gpu::metal::ComputeTaskDescriptorPtr;
using ::tflite::gpu::metal::MetalSpatialTensor;
using ::tflite::gpu::TensorDescriptor;
@ -60,7 +61,7 @@ void AddUsage(ValueId id, int task_index,
} // namespace
@implementation TFLInferenceContext {
std::vector<TFLComputeTask*> _computeTasks;
std::vector<ComputeTask> _computeTasks;
// contains indexes of _computeTasks
std::vector<int> _taskIdsWithPreallocatedTensors;
std::vector<ValueId> _inputIds;
@ -85,17 +86,15 @@ void AddUsage(ValueId id, int task_index,
_precision = precision;
// Metal resources are created here.
for (const auto& node : compiledModel.nodes) {
TFLComputeTask* task = [[TFLComputeTask alloc] init];
RETURN_IF_ERROR([task compileWithDevice:device
taskDescriptor:node
precision:_precision]);
[task setDescription:node.description];
_computeTasks.emplace_back(task);
ComputeTask task;
RETURN_IF_ERROR(task.CompileWithDevice(device, node, _precision));
task.SetDescription(node.description);
_computeTasks.emplace_back(std::move(task));
}
_tensorShapes = compiledModel.tensor_shapes;
for (auto& task : _computeTasks) {
// The same device must be used here as well as on shader compilation stage.
RETURN_IF_ERROR([task updateParamsWithDevice:device tensorShapes:_tensorShapes]);
RETURN_IF_ERROR(task.UpdateParamsWithDevice(device, _tensorShapes));
}
RETURN_IF_ERROR([self allocateTensors:device]);
return absl::OkStatus();
@ -111,7 +110,7 @@ void AddUsage(ValueId id, int task_index,
}
for (int i = 0; i < _computeTasks.size(); ++i) {
auto& task = _computeTasks[i];
if ([task hasInOutIds:preallocatedIds]) {
if (task.HasInOutIds(preallocatedIds)) {
_taskIdsWithPreallocatedTensors.push_back(i);
}
}
@ -144,15 +143,15 @@ void AddUsage(ValueId id, int task_index,
- (void)bindTensorsToOperations {
for (auto& task : _computeTasks) {
const auto& src_ids = [task getInputIds];
const auto& src_ids = task.GetInputIds();
for (int i = 0; i < src_ids.size(); ++i) {
MetalSpatialTensor* tensor = [self getTensor:src_ids[i]];
[task setSrcTensor:*tensor withIndex:i];
task.SetSrcTensor(*tensor, i);
}
const auto& dst_ids = [task getOutputIds];
const auto& dst_ids = task.GetOutputIds();
for (int i = 0; i < dst_ids.size(); ++i) {
MetalSpatialTensor* tensor = [self getTensor:dst_ids[i]];
[task setDstTensor:*tensor withIndex:i];
task.SetDstTensor(*tensor, i);
}
}
}
@ -164,12 +163,12 @@ void AddUsage(ValueId id, int task_index,
}
}
for (int op_index = 0; op_index < _computeTasks.size(); ++op_index) {
for (auto& tensor_id : [_computeTasks[op_index] getInputIds]) {
for (auto& tensor_id : _computeTasks[op_index].GetInputIds()) {
if (_preallocatedTensors.find(tensor_id) == _preallocatedTensors.end()) {
AddUsage(tensor_id, op_index, usages);
}
}
for (auto& tensor_id : [_computeTasks[op_index] getOutputIds]) {
for (auto& tensor_id : _computeTasks[op_index].GetOutputIds()) {
if (_preallocatedTensors.find(tensor_id) == _preallocatedTensors.end()) {
AddUsage(tensor_id, op_index, usages);
}
@ -239,8 +238,8 @@ void AddUsage(ValueId id, int task_index,
descriptor.data_type = f32_storage ? DataType::FLOAT32 : DataType::FLOAT16;
descriptor.layout = tflite::gpu::Layout::HWC;
for (auto& task : _computeTasks) {
const std::vector<ValueId> input_ids = [task getInputIds];
const std::vector<ValueId> output_ids = [task getOutputIds];
const std::vector<ValueId> input_ids = task.GetInputIds();
const std::vector<ValueId> output_ids = task.GetOutputIds();
std::vector<ValueId> all_ids = input_ids;
all_ids.insert(all_ids.end(), output_ids.begin(), output_ids.end());
for (auto& tensor_id : all_ids) {
@ -263,7 +262,7 @@ void AddUsage(ValueId id, int task_index,
[self updatePreallocatedTensors:inputOutputBuffers];
for (int i = 0; i < _computeTasks.size(); ++i) {
auto& task = _computeTasks[i];
[task encodeWithEncoder:commandEncoder];
task.EncodeWithEncoder(commandEncoder);
}
}
@ -274,7 +273,7 @@ void AddUsage(ValueId id, int task_index,
for (int i = 0; i < _computeTasks.size(); ++i) {
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
auto& task = _computeTasks[i];
[task encodeWithEncoder:encoder];
task.EncodeWithEncoder(encoder);
[encoder endEncoding];
}
}
@ -288,7 +287,7 @@ void AddUsage(ValueId id, int task_index,
for (int i = 0; i < _computeTasks.size(); ++i) {
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
auto& task = _computeTasks[i];
[task encodeWithEncoder:encoder];
task.EncodeWithEncoder(encoder);
[encoder endEncoding];
if (i % flushPeriod == (flushPeriod - 1)) {
[commandBuffer commit];
@ -304,18 +303,18 @@ void AddUsage(ValueId id, int task_index,
}
for (auto& task_index : _taskIdsWithPreallocatedTensors) {
auto& task = _computeTasks[task_index];
const auto& src_ids = [task getInputIds];
const auto& src_ids = task.GetInputIds();
for (int i = 0; i < src_ids.size(); ++i) {
const auto& it = _preallocatedTensors.find(src_ids[i]);
if (it != _preallocatedTensors.end()) {
[task setSrcTensor:it->second withIndex:i];
task.SetSrcTensor(it->second, i);
}
}
const auto& dst_ids = [task getOutputIds];
const auto& dst_ids = task.GetOutputIds();
for (int i = 0; i < dst_ids.size(); ++i) {
const auto& it = _preallocatedTensors.find(dst_ids[i]);
if (it != _preallocatedTensors.end()) {
[task setDstTensor:it->second withIndex:i];
task.SetDstTensor(it->second, i);
}
}
}