TFLComputeTask converted to c++ style ComputeTask.
PiperOrigin-RevId: 346908996 Change-Id: I8df1ee7b777a0f8b40580922b2fe91cb0d2983cd
This commit is contained in:
parent
579ce3a2e1
commit
a2d2b5c3d5
@ -27,34 +27,77 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/precision.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/common.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
|
||||
|
||||
@interface TFLComputeTask : NSObject
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
/// Returns empty string or error if shader can't be compiled.
|
||||
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
|
||||
taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc
|
||||
precision:(tflite::gpu::CalculationsPrecision)precision;
|
||||
class ComputeTask {
|
||||
public:
|
||||
ComputeTask() = default;
|
||||
|
||||
/// Updates parameters for inputs/outputs/intermediate tensors
|
||||
- (absl::Status)updateParamsWithDevice:(id<MTLDevice>)device
|
||||
tensorShapes:(const std::map<tflite::gpu::ValueId, tflite::gpu::BHWC>&)
|
||||
tensorShapes;
|
||||
// Move only
|
||||
ComputeTask(ComputeTask&& args) = default;
|
||||
ComputeTask& operator=(ComputeTask&& args) = default;
|
||||
ComputeTask(const ComputeTask&) = delete;
|
||||
ComputeTask& operator=(const ComputeTask&) = delete;
|
||||
|
||||
- (bool)hasInOutIds:(const std::set<::tflite::gpu::ValueId>&)ids;
|
||||
/// Returns empty string or error if shader can't be compiled.
|
||||
absl::Status CompileWithDevice(id<MTLDevice> device,
|
||||
const NodeDescriptor& desc,
|
||||
CalculationsPrecision precision);
|
||||
|
||||
- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder;
|
||||
/// Updates parameters for inputs/outputs/intermediate tensors
|
||||
absl::Status UpdateParamsWithDevice(
|
||||
id<MTLDevice> device, const std::map<ValueId, BHWC>& tensor_shapes);
|
||||
|
||||
- (std::vector<tflite::gpu::ValueId>)getOutputIds;
|
||||
- (std::vector<tflite::gpu::ValueId>)getInputIds;
|
||||
bool HasInOutIds(const std::set<ValueId>& ids) const;
|
||||
|
||||
- (void)setSrcTensor:(const tflite::gpu::metal::MetalSpatialTensor&)tensor withIndex:(int)index;
|
||||
void EncodeWithEncoder(id<MTLComputeCommandEncoder> encoder);
|
||||
|
||||
- (void)setDstTensor:(const tflite::gpu::metal::MetalSpatialTensor&)tensor withIndex:(int)index;
|
||||
std::vector<ValueId> GetOutputIds() const;
|
||||
std::vector<ValueId> GetInputIds() const;
|
||||
|
||||
- (void)setDescription:(const std::string&)description;
|
||||
void SetSrcTensor(const MetalSpatialTensor& tensor, int index);
|
||||
|
||||
@end
|
||||
void SetDstTensor(const MetalSpatialTensor& tensor, int index);
|
||||
|
||||
void SetDescription(const std::string& description);
|
||||
|
||||
private:
|
||||
struct InputBuffer {
|
||||
ValueId uid;
|
||||
id<MTLBuffer> metal_handle;
|
||||
};
|
||||
|
||||
struct OutputBuffer {
|
||||
ValueId uid;
|
||||
id<MTLBuffer> metal_handle;
|
||||
};
|
||||
|
||||
struct UniformBuffer {
|
||||
std::vector<uint8_t> data;
|
||||
UniformsFunction data_function;
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> program_;
|
||||
std::vector<InputBuffer> input_buffers_;
|
||||
std::vector<OutputBuffer> output_buffers_;
|
||||
std::vector<id<MTLBuffer>> immutable_buffers_;
|
||||
std::vector<UniformBuffer> uniform_buffers_;
|
||||
uint3 groups_size_;
|
||||
uint3 groups_count_;
|
||||
DispatchParamsFunction resize_function_;
|
||||
std::string description_;
|
||||
MetalArguments metal_args_;
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_
|
||||
|
@ -27,55 +27,16 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/common.h"
|
||||
|
||||
using ::tflite::gpu::AlignByN;
|
||||
using ::tflite::gpu::BHWC;
|
||||
using ::tflite::gpu::HalfBits;
|
||||
using ::tflite::gpu::metal::ComputeTaskDescriptorPtr;
|
||||
using ::tflite::gpu::metal::CreateComputeProgram;
|
||||
using ::tflite::gpu::metal::DispatchParamsFunction;
|
||||
using ::tflite::gpu::CalculationsPrecision;
|
||||
using ::tflite::gpu::metal::UniformsFunction;
|
||||
using ::tflite::gpu::uint3;
|
||||
using ::tflite::gpu::ValueId;
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
namespace {
|
||||
|
||||
struct InputBuffer {
|
||||
ValueId uid;
|
||||
id<MTLBuffer> metalHandle;
|
||||
};
|
||||
|
||||
struct OutputBuffer {
|
||||
ValueId uid;
|
||||
id<MTLBuffer> metalHandle;
|
||||
};
|
||||
|
||||
struct UniformBuffer {
|
||||
std::vector<uint8_t> data;
|
||||
UniformsFunction dataFunction;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
@implementation TFLComputeTask {
|
||||
id<MTLComputePipelineState> _program;
|
||||
std::vector<InputBuffer> _inputBuffers;
|
||||
std::vector<OutputBuffer> _outputBuffers;
|
||||
std::vector<id<MTLBuffer>> _immutableBuffers;
|
||||
std::vector<UniformBuffer> _uniformBuffers;
|
||||
uint3 _groupsSize;
|
||||
uint3 _groupsCount;
|
||||
DispatchParamsFunction _resizeFunction;
|
||||
std::string _description;
|
||||
tflite::gpu::metal::MetalArguments _metal_args;
|
||||
}
|
||||
|
||||
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
|
||||
taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc
|
||||
precision:(CalculationsPrecision)precision; {
|
||||
absl::Status ComputeTask::CompileWithDevice(id<MTLDevice> device,
|
||||
const NodeDescriptor& desc,
|
||||
CalculationsPrecision precision) {
|
||||
size_t offset = desc.task->src_tensors_names.size() + desc.task->uniform_buffers.size()
|
||||
+ desc.task->immutable_buffers.size() + 1;
|
||||
RETURN_IF_ERROR(_metal_args.Init(device, offset, &desc.task->args, &desc.task->shader_source));
|
||||
RETURN_IF_ERROR(metal_args_.Init(device, offset, &desc.task->args, &desc.task->shader_source));
|
||||
NSString* barrier;
|
||||
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
|
||||
if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
|
||||
@ -129,12 +90,12 @@ struct UniformBuffer {
|
||||
return absl::InternalError("Unknown shader compilation error");
|
||||
}
|
||||
for (auto& id : desc.src_tensors_ids) {
|
||||
_inputBuffers.emplace_back(InputBuffer{id, nil});
|
||||
input_buffers_.emplace_back(InputBuffer{id, nil});
|
||||
}
|
||||
for (auto& uniform : desc.task->uniform_buffers) {
|
||||
_uniformBuffers.emplace_back(UniformBuffer{{}, uniform.data_function});
|
||||
uniform_buffers_.emplace_back(UniformBuffer{{}, uniform.data_function});
|
||||
}
|
||||
_outputBuffers.emplace_back(OutputBuffer{desc.dst_tensors_ids[0], nil});
|
||||
output_buffers_.emplace_back(OutputBuffer{desc.dst_tensors_ids[0], nil});
|
||||
const bool f32_storage = precision == CalculationsPrecision::F32;
|
||||
for (auto& immutable : desc.task->immutable_buffers) {
|
||||
int padding = 4 * (f32_storage ? sizeof(float) : sizeof(HalfBits));
|
||||
@ -143,61 +104,60 @@ struct UniformBuffer {
|
||||
id<MTLBuffer> metalBuffer = [device newBufferWithBytes:immutable.data.data()
|
||||
length:immutable.data.size()
|
||||
options:MTLResourceStorageModeShared];
|
||||
_immutableBuffers.emplace_back(metalBuffer);
|
||||
immutable_buffers_.emplace_back(metalBuffer);
|
||||
}
|
||||
_resizeFunction = desc.task->resize_function;
|
||||
_program = program;
|
||||
resize_function_ = desc.task->resize_function;
|
||||
program_ = program;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
- (absl::Status)
|
||||
updateParamsWithDevice:(id<MTLDevice>)device
|
||||
tensorShapes:(const std::map<tflite::gpu::ValueId, tflite::gpu::BHWC>&)tensorShapes {
|
||||
absl::Status ComputeTask::UpdateParamsWithDevice(
|
||||
id<MTLDevice> device, const std::map<ValueId, BHWC>& tensor_shapes) {
|
||||
std::vector<BHWC> src_shapes;
|
||||
std::vector<BHWC> dst_shapes;
|
||||
for (const auto& in_buf : _inputBuffers) {
|
||||
auto it = tensorShapes.find(in_buf.uid);
|
||||
if (it == tensorShapes.end()) {
|
||||
for (const auto& in_buf : input_buffers_) {
|
||||
auto it = tensor_shapes.find(in_buf.uid);
|
||||
if (it == tensor_shapes.end()) {
|
||||
return absl::InvalidArgumentError("Missing tensor shape");
|
||||
}
|
||||
src_shapes.push_back(it->second);
|
||||
}
|
||||
for (const auto& out_buf : _outputBuffers) {
|
||||
auto it = tensorShapes.find(out_buf.uid);
|
||||
if (it == tensorShapes.end()) {
|
||||
for (const auto& out_buf : output_buffers_) {
|
||||
auto it = tensor_shapes.find(out_buf.uid);
|
||||
if (it == tensor_shapes.end()) {
|
||||
return absl::InvalidArgumentError("Missing tensor shape");
|
||||
}
|
||||
dst_shapes.push_back(it->second);
|
||||
}
|
||||
for (auto& uniform : _uniformBuffers) {
|
||||
uniform.data = uniform.dataFunction(src_shapes, dst_shapes);
|
||||
for (auto& uniform : uniform_buffers_) {
|
||||
uniform.data = uniform.data_function(src_shapes, dst_shapes);
|
||||
}
|
||||
|
||||
// Dispatch parameters re-calculation
|
||||
auto workGroups = _resizeFunction(src_shapes, dst_shapes);
|
||||
_groupsSize = workGroups.first;
|
||||
auto workGroups = resize_function_(src_shapes, dst_shapes);
|
||||
groups_size_ = workGroups.first;
|
||||
MTLSize threadsPerGroup = [device maxThreadsPerThreadgroup];
|
||||
if (_groupsSize.x > threadsPerGroup.width || _groupsSize.y > threadsPerGroup.height ||
|
||||
_groupsSize.z > threadsPerGroup.depth) {
|
||||
if (groups_size_.x > threadsPerGroup.width || groups_size_.y > threadsPerGroup.height ||
|
||||
groups_size_.z > threadsPerGroup.depth) {
|
||||
std::string error("Threads per working group: ");
|
||||
error += std::to_string(_groupsSize.x) + ", " + std::to_string(_groupsSize.y) + ", " +
|
||||
std::to_string(_groupsSize.z);
|
||||
error += std::to_string(groups_size_.x) + ", " + std::to_string(groups_size_.y) + ", " +
|
||||
std::to_string(groups_size_.z);
|
||||
error += "is larger than the MTLDevice can support: ";
|
||||
error += std::to_string(threadsPerGroup.width) + ", " + std::to_string(threadsPerGroup.height) +
|
||||
", " + std::to_string(threadsPerGroup.depth);
|
||||
return absl::InvalidArgumentError(error);
|
||||
}
|
||||
_groupsCount = workGroups.second;
|
||||
groups_count_ = workGroups.second;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
- (bool)hasInOutIds:(const std::set<::tflite::gpu::ValueId>&)ids {
|
||||
for (auto& buffer : _inputBuffers) {
|
||||
bool ComputeTask::HasInOutIds(const std::set<ValueId>& ids) const {
|
||||
for (auto& buffer : input_buffers_) {
|
||||
if (ids.count(buffer.uid)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (auto& buffer : _outputBuffers) {
|
||||
for (auto& buffer : output_buffers_) {
|
||||
if (ids.count(buffer.uid)) {
|
||||
return true;
|
||||
}
|
||||
@ -205,66 +165,66 @@ struct UniformBuffer {
|
||||
return false;
|
||||
}
|
||||
|
||||
- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder {
|
||||
void ComputeTask::EncodeWithEncoder(id<MTLComputeCommandEncoder> encoder) {
|
||||
// The dispatch call is intended to be skipped.
|
||||
if (_groupsCount.x * _groupsCount.y * _groupsCount.z == 0) {
|
||||
if (groups_count_.x * groups_count_.y * groups_count_.z == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:_program];
|
||||
[encoder setComputePipelineState:program_];
|
||||
|
||||
int bindIndex = 0;
|
||||
for (const auto& buffer : _outputBuffers) {
|
||||
[encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
|
||||
for (const auto& buffer : output_buffers_) {
|
||||
[encoder setBuffer:buffer.metal_handle offset:0 atIndex:bindIndex];
|
||||
bindIndex++;
|
||||
}
|
||||
for (const auto& buffer : _inputBuffers) {
|
||||
[encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
|
||||
for (const auto& buffer : input_buffers_) {
|
||||
[encoder setBuffer:buffer.metal_handle offset:0 atIndex:bindIndex];
|
||||
bindIndex++;
|
||||
}
|
||||
for (auto& immutable : _immutableBuffers) {
|
||||
for (auto& immutable : immutable_buffers_) {
|
||||
[encoder setBuffer:immutable offset:0 atIndex:bindIndex];
|
||||
bindIndex++;
|
||||
}
|
||||
for (auto& uniform : _uniformBuffers) {
|
||||
for (auto& uniform : uniform_buffers_) {
|
||||
[encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex];
|
||||
bindIndex++;
|
||||
}
|
||||
_metal_args.Encode(encoder, bindIndex);
|
||||
metal_args_.Encode(encoder, bindIndex);
|
||||
|
||||
MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z);
|
||||
MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z);
|
||||
MTLSize groupsCount = MTLSizeMake(groups_count_.x, groups_count_.y, groups_count_.z);
|
||||
MTLSize groupsSize = MTLSizeMake(groups_size_.x, groups_size_.y, groups_size_.z);
|
||||
[encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
|
||||
}
|
||||
|
||||
- (std::vector<tflite::gpu::ValueId>)getOutputIds {
|
||||
std::vector<ValueId> ComputeTask::GetOutputIds() const {
|
||||
std::vector<tflite::gpu::ValueId> result;
|
||||
for (auto& buffer : _outputBuffers) {
|
||||
for (auto& buffer : output_buffers_) {
|
||||
result.push_back(buffer.uid);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
- (std::vector<tflite::gpu::ValueId>)getInputIds {
|
||||
std::vector<ValueId> ComputeTask::GetInputIds() const {
|
||||
std::vector<tflite::gpu::ValueId> result;
|
||||
for (auto& buffer : _inputBuffers) {
|
||||
for (auto& buffer : input_buffers_) {
|
||||
result.push_back(buffer.uid);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
- (void)setSrcTensor:(const tflite::gpu::metal::MetalSpatialTensor&)tensor
|
||||
withIndex:(int)index; {
|
||||
_inputBuffers[index].metalHandle = tensor.GetBufferHandle();
|
||||
void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
|
||||
input_buffers_[index].metal_handle = tensor.GetBufferHandle();
|
||||
}
|
||||
|
||||
- (void)setDstTensor:(const tflite::gpu::metal::MetalSpatialTensor&)tensor
|
||||
withIndex:(int)index; {
|
||||
_outputBuffers[index].metalHandle = tensor.GetBufferHandle();
|
||||
void ComputeTask::SetDstTensor(const MetalSpatialTensor& tensor, int index) {
|
||||
output_buffers_[index].metal_handle = tensor.GetBufferHandle();
|
||||
}
|
||||
|
||||
- (void)setDescription:(const std::string&)description {
|
||||
_description = description;
|
||||
void ComputeTask::SetDescription(const std::string& description) {
|
||||
description_ = description;
|
||||
}
|
||||
|
||||
@end
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -37,6 +37,7 @@ using ::tflite::gpu::DataType;
|
||||
using ::tflite::gpu::HalfBits;
|
||||
using ::tflite::gpu::int2;
|
||||
using ::tflite::gpu::MemoryStrategy;
|
||||
using ::tflite::gpu::metal::ComputeTask;
|
||||
using ::tflite::gpu::metal::ComputeTaskDescriptorPtr;
|
||||
using ::tflite::gpu::metal::MetalSpatialTensor;
|
||||
using ::tflite::gpu::TensorDescriptor;
|
||||
@ -60,7 +61,7 @@ void AddUsage(ValueId id, int task_index,
|
||||
} // namespace
|
||||
|
||||
@implementation TFLInferenceContext {
|
||||
std::vector<TFLComputeTask*> _computeTasks;
|
||||
std::vector<ComputeTask> _computeTasks;
|
||||
// contains indexes of _computeTasks
|
||||
std::vector<int> _taskIdsWithPreallocatedTensors;
|
||||
std::vector<ValueId> _inputIds;
|
||||
@ -85,17 +86,15 @@ void AddUsage(ValueId id, int task_index,
|
||||
_precision = precision;
|
||||
// Metal resources are created here.
|
||||
for (const auto& node : compiledModel.nodes) {
|
||||
TFLComputeTask* task = [[TFLComputeTask alloc] init];
|
||||
RETURN_IF_ERROR([task compileWithDevice:device
|
||||
taskDescriptor:node
|
||||
precision:_precision]);
|
||||
[task setDescription:node.description];
|
||||
_computeTasks.emplace_back(task);
|
||||
ComputeTask task;
|
||||
RETURN_IF_ERROR(task.CompileWithDevice(device, node, _precision));
|
||||
task.SetDescription(node.description);
|
||||
_computeTasks.emplace_back(std::move(task));
|
||||
}
|
||||
_tensorShapes = compiledModel.tensor_shapes;
|
||||
for (auto& task : _computeTasks) {
|
||||
// The same device must be used here as well as on shader compilation stage.
|
||||
RETURN_IF_ERROR([task updateParamsWithDevice:device tensorShapes:_tensorShapes]);
|
||||
RETURN_IF_ERROR(task.UpdateParamsWithDevice(device, _tensorShapes));
|
||||
}
|
||||
RETURN_IF_ERROR([self allocateTensors:device]);
|
||||
return absl::OkStatus();
|
||||
@ -111,7 +110,7 @@ void AddUsage(ValueId id, int task_index,
|
||||
}
|
||||
for (int i = 0; i < _computeTasks.size(); ++i) {
|
||||
auto& task = _computeTasks[i];
|
||||
if ([task hasInOutIds:preallocatedIds]) {
|
||||
if (task.HasInOutIds(preallocatedIds)) {
|
||||
_taskIdsWithPreallocatedTensors.push_back(i);
|
||||
}
|
||||
}
|
||||
@ -144,15 +143,15 @@ void AddUsage(ValueId id, int task_index,
|
||||
|
||||
- (void)bindTensorsToOperations {
|
||||
for (auto& task : _computeTasks) {
|
||||
const auto& src_ids = [task getInputIds];
|
||||
const auto& src_ids = task.GetInputIds();
|
||||
for (int i = 0; i < src_ids.size(); ++i) {
|
||||
MetalSpatialTensor* tensor = [self getTensor:src_ids[i]];
|
||||
[task setSrcTensor:*tensor withIndex:i];
|
||||
task.SetSrcTensor(*tensor, i);
|
||||
}
|
||||
const auto& dst_ids = [task getOutputIds];
|
||||
const auto& dst_ids = task.GetOutputIds();
|
||||
for (int i = 0; i < dst_ids.size(); ++i) {
|
||||
MetalSpatialTensor* tensor = [self getTensor:dst_ids[i]];
|
||||
[task setDstTensor:*tensor withIndex:i];
|
||||
task.SetDstTensor(*tensor, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -164,12 +163,12 @@ void AddUsage(ValueId id, int task_index,
|
||||
}
|
||||
}
|
||||
for (int op_index = 0; op_index < _computeTasks.size(); ++op_index) {
|
||||
for (auto& tensor_id : [_computeTasks[op_index] getInputIds]) {
|
||||
for (auto& tensor_id : _computeTasks[op_index].GetInputIds()) {
|
||||
if (_preallocatedTensors.find(tensor_id) == _preallocatedTensors.end()) {
|
||||
AddUsage(tensor_id, op_index, usages);
|
||||
}
|
||||
}
|
||||
for (auto& tensor_id : [_computeTasks[op_index] getOutputIds]) {
|
||||
for (auto& tensor_id : _computeTasks[op_index].GetOutputIds()) {
|
||||
if (_preallocatedTensors.find(tensor_id) == _preallocatedTensors.end()) {
|
||||
AddUsage(tensor_id, op_index, usages);
|
||||
}
|
||||
@ -239,8 +238,8 @@ void AddUsage(ValueId id, int task_index,
|
||||
descriptor.data_type = f32_storage ? DataType::FLOAT32 : DataType::FLOAT16;
|
||||
descriptor.layout = tflite::gpu::Layout::HWC;
|
||||
for (auto& task : _computeTasks) {
|
||||
const std::vector<ValueId> input_ids = [task getInputIds];
|
||||
const std::vector<ValueId> output_ids = [task getOutputIds];
|
||||
const std::vector<ValueId> input_ids = task.GetInputIds();
|
||||
const std::vector<ValueId> output_ids = task.GetOutputIds();
|
||||
std::vector<ValueId> all_ids = input_ids;
|
||||
all_ids.insert(all_ids.end(), output_ids.begin(), output_ids.end());
|
||||
for (auto& tensor_id : all_ids) {
|
||||
@ -263,7 +262,7 @@ void AddUsage(ValueId id, int task_index,
|
||||
[self updatePreallocatedTensors:inputOutputBuffers];
|
||||
for (int i = 0; i < _computeTasks.size(); ++i) {
|
||||
auto& task = _computeTasks[i];
|
||||
[task encodeWithEncoder:commandEncoder];
|
||||
task.EncodeWithEncoder(commandEncoder);
|
||||
}
|
||||
}
|
||||
|
||||
@ -274,7 +273,7 @@ void AddUsage(ValueId id, int task_index,
|
||||
for (int i = 0; i < _computeTasks.size(); ++i) {
|
||||
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
||||
auto& task = _computeTasks[i];
|
||||
[task encodeWithEncoder:encoder];
|
||||
task.EncodeWithEncoder(encoder);
|
||||
[encoder endEncoding];
|
||||
}
|
||||
}
|
||||
@ -288,7 +287,7 @@ void AddUsage(ValueId id, int task_index,
|
||||
for (int i = 0; i < _computeTasks.size(); ++i) {
|
||||
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
||||
auto& task = _computeTasks[i];
|
||||
[task encodeWithEncoder:encoder];
|
||||
task.EncodeWithEncoder(encoder);
|
||||
[encoder endEncoding];
|
||||
if (i % flushPeriod == (flushPeriod - 1)) {
|
||||
[commandBuffer commit];
|
||||
@ -304,18 +303,18 @@ void AddUsage(ValueId id, int task_index,
|
||||
}
|
||||
for (auto& task_index : _taskIdsWithPreallocatedTensors) {
|
||||
auto& task = _computeTasks[task_index];
|
||||
const auto& src_ids = [task getInputIds];
|
||||
const auto& src_ids = task.GetInputIds();
|
||||
for (int i = 0; i < src_ids.size(); ++i) {
|
||||
const auto& it = _preallocatedTensors.find(src_ids[i]);
|
||||
if (it != _preallocatedTensors.end()) {
|
||||
[task setSrcTensor:it->second withIndex:i];
|
||||
task.SetSrcTensor(it->second, i);
|
||||
}
|
||||
}
|
||||
const auto& dst_ids = [task getOutputIds];
|
||||
const auto& dst_ids = task.GetOutputIds();
|
||||
for (int i = 0; i < dst_ids.size(); ++i) {
|
||||
const auto& it = _preallocatedTensors.find(dst_ids[i]);
|
||||
if (it != _preallocatedTensors.end()) {
|
||||
[task setDstTensor:it->second withIndex:i];
|
||||
task.SetDstTensor(it->second, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user