Removed tensor ids from task descriptors.

Moved to node descriptor.

PiperOrigin-RevId: 345309110
Change-Id: I47978117abcae5aa2e5f91d7f0cc11c4562e051e
This commit is contained in:
Raman Sarokin 2020-12-02 14:07:37 -08:00 committed by TensorFlower Gardener
parent f5d892c802
commit f0d15b56a0
34 changed files with 385 additions and 426 deletions

View File

@ -54,51 +54,52 @@ namespace gpu {
namespace metal {
namespace {
std::vector<ComputeTaskDescriptorPtr> SelectDepthWiseConv(
ComputeTaskDescriptorPtr SelectDepthWiseConv(
ValueId input_id, ValueId output_id,
const DepthwiseConvolution2DAttributes& attr,
const metal::RuntimeOptions& options) {
if (CheckDepthWiseConv3x3Stride1x1Support(attr)) {
auto gpu_op = DepthWiseConv3x3Stride1x1(input_id, output_id, attr, options);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else if (CheckDepthWiseConv3x3Stride2Support(attr)) {
auto gpu_op = DepthWiseConv3x3Stride2(input_id, output_id, attr, options);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
auto gpu_op = DepthWiseConvolution(input_id, output_id, attr, options);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
}
std::vector<ComputeTaskDescriptorPtr> SelectConvolutionTransposed(
ComputeTaskDescriptorPtr SelectConvolutionTransposed(
ValueId input_id, ValueId output_id,
const ConvolutionTransposedAttributes& attr, const GpuInfo& gpu_info,
const metal::RuntimeOptions& options) {
if (CheckConvolutionTransposed4x4Support(attr)) {
auto gpu_op =
ConvolutionTransposed4x4(input_id, output_id, attr, gpu_info, options);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
auto gpu_op =
ConvolutionTransposed(input_id, output_id, attr, gpu_info, options);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
}
std::vector<ComputeTaskDescriptorPtr> SelectQuantizeAndDequantize(
ComputeTaskDescriptorPtr SelectQuantizeAndDequantize(
ValueId input_id, ValueId output_id,
const QuantizeAndDequantizeAttributes& attr) {
auto gpu_op = QuantizeAndDequantize(input_id, output_id, attr);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
std::vector<ComputeTaskDescriptorPtr> SelectPReLU(
const BHWC& src_shape, ValueId input_id, ValueId output_id,
const PReLUAttributes& attr, const metal::RuntimeOptions& options) {
ComputeTaskDescriptorPtr SelectPReLU(const BHWC& src_shape, ValueId input_id,
ValueId output_id,
const PReLUAttributes& attr,
const metal::RuntimeOptions& options) {
auto alpha = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
if (alpha) {
auto gpu_op = PReLU(input_id, output_id, attr, options);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
auto alpha3d = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
if (!alpha3d) {
@ -109,61 +110,60 @@ std::vector<ComputeTaskDescriptorPtr> SelectPReLU(
return {};
}
auto gpu_op = PReLUFull(input_id, output_id, attr, options);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
std::vector<ComputeTaskDescriptorPtr> SelectReshape(
const BHWC& src_shape, ValueId input_id, ValueId output_id,
const ReshapeAttributes& attr) {
ComputeTaskDescriptorPtr SelectReshape(const BHWC& src_shape, ValueId input_id,
ValueId output_id,
const ReshapeAttributes& attr) {
if (src_shape.c % 4 == 0 && attr.new_shape.c % 4 == 0) {
auto gpu_op = Reshapex4(input_id, output_id, attr);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
auto gpu_op = Reshape(input_id, output_id, attr);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
}
std::vector<ComputeTaskDescriptorPtr> SelectSoftmax(const BHWC& src_shape,
ValueId input_id,
ValueId output_id,
const GpuInfo& gpu_info) {
ComputeTaskDescriptorPtr SelectSoftmax(const BHWC& src_shape, ValueId input_id,
ValueId output_id,
const GpuInfo& gpu_info) {
if (src_shape.w == 1 && src_shape.h == 1) {
auto gpu_op = Softmax1x1(input_id, output_id, gpu_info, src_shape.c);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
auto gpu_op = Softmax(input_id, output_id, src_shape.c);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
}
std::vector<ComputeTaskDescriptorPtr> SelectSpaceToDepth(
ComputeTaskDescriptorPtr SelectSpaceToDepth(
ValueId input_id, ValueId output_id, const SpaceToDepthAttributes& attr) {
auto gpu_op = SpaceToDepth(input_id, output_id, attr);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
std::vector<ComputeTaskDescriptorPtr> SelectWinograd4x4To36(
ComputeTaskDescriptorPtr SelectWinograd4x4To36(
ValueId input_id, ValueId output_id, const Winograd4x4To36Attributes& attr,
const GpuInfo& gpu_info, const metal::RuntimeOptions& options) {
if (gpu_info.IsApple()) {
auto gpu_op = Winograd4x4To36(input_id, output_id, attr);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
auto gpu_op = Winograd4x4To36TileX6(input_id, output_id, attr, options);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
}
std::vector<ComputeTaskDescriptorPtr> SelectWinograd36To4x4(
ComputeTaskDescriptorPtr SelectWinograd36To4x4(
ValueId input_id, ValueId output_id, const Winograd36To4x4Attributes& attr,
const GpuInfo& gpu_info, const metal::RuntimeOptions& options) {
if (gpu_info.IsApple()) {
auto gpu_op = Winograd36To4x4(input_id, output_id, options, attr);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
auto gpu_op = Winograd36To4x4Tile4x1(input_id, output_id, options, attr);
return {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
return std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
}
@ -190,15 +190,20 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
const std::vector<ValueId>& outputs,
const GpuInfo& gpu_info,
const RuntimeOptions& options,
int* last_node_id, int* last_value_id,
int* last_value_id,
std::map<ValueId, BHWC>* tensor_shapes,
std::vector<ComputeTaskDescriptorPtr>* tasks) {
std::vector<NodeDescriptor>* nodes) {
if (!IsBatchMatchesForAllValues(graph)) {
return absl::InvalidArgumentError(
"Only identical batch dimension is supported");
}
int node_id = static_cast<int>(node->id);
auto op_type = OperationTypeFromString(node->operation.type);
nodes->push_back({});
auto& node_desc = nodes->back();
node_desc.description = node->operation.type + "_" + std::to_string(node->id);
node_desc.src_tensors_ids = inputs;
node_desc.dst_tensors_ids = outputs;
switch (op_type) {
case OperationType::ADD: {
if (inputs.size() == 1) {
@ -207,7 +212,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
auto gpu_op = ElementwiseWithOneInputAndConstantArguent(
inputs[0], outputs[0], options, op_type, attr.param);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
return absl::UnimplementedError(
"Missing attributes for single input op: " +
@ -217,10 +223,12 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
const auto srcs = graph.FindInputs(node_id);
auto gpu_op = ElementwiseWithTwoInputs(inputs, outputs[0],
srcs[1]->tensor.shape, op_type);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else { // more than 2 inputs
auto gpu_op = Add(inputs, outputs[0], options);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
break;
}
@ -233,7 +241,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
Concat(inputs, outputs[0],
absl::any_cast<ConcatAttributes>(node->operation.attributes),
input_shapes);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
break;
}
case OperationType::CONVOLUTION_2D: {
@ -253,31 +262,39 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
Winograd4x4To36Attributes wino_up_attr;
wino_up_attr.padding = attr.padding;
(*last_node_id) += 1;
int value_id = *last_value_id + 1;
(*tensor_shapes)[value_id] = shape_0;
(*tensor_shapes)[value_id + 1] = shape_1;
*tasks = SelectWinograd4x4To36(inputs[0], value_id, wino_up_attr,
gpu_info, options);
nodes->resize(3);
(*nodes)[0].description = "winograd_up_" + std::to_string(node->id);
(*nodes)[1].description =
node->operation.type + std::to_string(node->id);
(*nodes)[2].description = "winograd_down_" + std::to_string(node->id);
(*nodes)[0].task = SelectWinograd4x4To36(
inputs[0], value_id, wino_up_attr, gpu_info, options);
(*nodes)[0].src_tensors_ids = {inputs[0]};
(*nodes)[0].dst_tensors_ids = {static_cast<unsigned int>(value_id)};
(*last_node_id) += 1;
auto gpu_op = ConvolutionWino4x4To6x6(value_id, value_id + 1, shape_1,
attr, gpu_info, options);
tasks->push_back(
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op)));
(*nodes)[1].task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
(*nodes)[1].src_tensors_ids = {static_cast<unsigned int>(value_id)};
(*nodes)[1].dst_tensors_ids = {static_cast<unsigned int>(value_id + 1)};
Winograd36To4x4Attributes wino_down_attr;
wino_down_attr.output_shape = dst_shape;
wino_down_attr.biases = attr.bias;
(*last_node_id) += 1;
auto t2 = SelectWinograd36To4x4(value_id + 1, outputs[0],
wino_down_attr, gpu_info, options);
tasks->insert(tasks->end(), t2.begin(), t2.end());
(*nodes)[2].task = SelectWinograd36To4x4(
value_id + 1, outputs[0], wino_down_attr, gpu_info, options);
(*nodes)[2].src_tensors_ids = {static_cast<unsigned int>(value_id + 1)};
(*nodes)[2].dst_tensors_ids = {outputs[0]};
(*last_value_id) += 2;
} else {
auto gpu_op = ConvolutionGeneric(inputs[0], outputs[0], dst_shape, attr,
gpu_info, options);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
break;
}
@ -287,7 +304,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
"Convolution Transposed does not support more than 1 runtime "
"tensor");
}
*tasks = SelectConvolutionTransposed(
node_desc.task = SelectConvolutionTransposed(
inputs[0], outputs[0],
absl::any_cast<ConvolutionTransposedAttributes>(
node->operation.attributes),
@ -299,7 +316,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
"DepthWise Convolution does not support more than 1 runtime "
"tensor");
}
*tasks =
node_desc.task =
SelectDepthWiseConv(inputs[0], outputs[0],
absl::any_cast<DepthwiseConvolution2DAttributes>(
node->operation.attributes),
@ -310,14 +327,16 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
inputs[0], outputs[0],
absl::any_cast<FullyConnectedAttributes>(node->operation.attributes),
gpu_info, options);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
break;
}
case OperationType::MAX_UNPOOLING_2D: {
auto gpu_op = MaxUnpooling(
inputs[0], inputs[1], outputs[0],
absl::any_cast<MaxUnpooling2DAttributes>(node->operation.attributes));
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
break;
}
case OperationType::MEAN: {
@ -326,7 +345,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
return absl::UnimplementedError("Mean supports HW axis only in Metal");
}
auto gpu_op = Mean(inputs[0], outputs[0], attr);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
break;
}
case OperationType::MUL:
@ -336,7 +356,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
auto gpu_op = ElementwiseWithOneInputAndConstantArguent(
inputs[0], outputs[0], options, op_type, attr.param);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
return absl::UnimplementedError(
"Missing attributes for single input op: " +
@ -346,7 +367,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
const auto srcs = graph.FindInputs(node_id);
auto gpu_op = ElementwiseWithTwoInputs(inputs, outputs[0],
srcs[1]->tensor.shape, op_type);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
break;
case OperationType::PAD: {
@ -355,24 +377,32 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
return absl::UnimplementedError("Padding for BATCH is not supported.");
}
auto gpu_op = Padding(inputs[0], outputs[0], attr);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
break;
}
case OperationType::POOLING_2D: {
auto attr =
absl::any_cast<Pooling2DAttributes>(node->operation.attributes);
auto gpu_op = Pooling(inputs[0], outputs[0], attr, false);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
node_desc.dst_tensors_ids = {outputs[0]};
if (attr.type == PoolingType::MAX && attr.output_indices) {
auto gpu_ind_op = Pooling(inputs[0], outputs[1], attr, true);
tasks->push_back(
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_ind_op)));
nodes->push_back({});
nodes->back().description =
node->operation.type + "_indices_" + std::to_string(node->id);
nodes->back().task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_ind_op));
nodes->back().src_tensors_ids = {inputs[0]};
nodes->back().dst_tensors_ids = {outputs[1]};
}
break;
}
case OperationType::PRELU: {
const auto src_shape = graph.FindInputs(node_id)[0]->tensor.shape;
*tasks = SelectPReLU(
node_desc.task = SelectPReLU(
src_shape, inputs[0], outputs[0],
absl::any_cast<PReLUAttributes>(node->operation.attributes), options);
break;
@ -381,18 +411,19 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
auto gpu_op =
ReLU(inputs[0], outputs[0],
absl::any_cast<ReLUAttributes>(node->operation.attributes));
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
break;
}
case OperationType::QUANTIZE_AND_DEQUANTIZE:
*tasks = SelectQuantizeAndDequantize(
node_desc.task = SelectQuantizeAndDequantize(
inputs[0], outputs[0],
absl::any_cast<QuantizeAndDequantizeAttributes>(
node->operation.attributes));
break;
case OperationType::RESHAPE: {
const auto src_shape = graph.FindInputs(node_id)[0]->tensor.shape;
*tasks = SelectReshape(
node_desc.task = SelectReshape(
src_shape, inputs[0], outputs[0],
absl::any_cast<ReshapeAttributes>(node->operation.attributes));
break;
@ -401,14 +432,16 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
auto gpu_op = Resize(
inputs[0], outputs[0],
absl::any_cast<Resize2DAttributes>(node->operation.attributes));
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
break;
}
case OperationType::SLICE: {
auto gpu_op =
Slice(inputs[0], outputs[0],
absl::any_cast<SliceAttributes>(node->operation.attributes));
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
break;
}
case OperationType::SOFTMAX: {
@ -418,11 +451,12 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
"Softmax supports only CHANNELS dimension");
}
const auto src_shape = graph.FindInputs(node_id)[0]->tensor.shape;
*tasks = SelectSoftmax(src_shape, inputs[0], outputs[0], gpu_info);
node_desc.task =
SelectSoftmax(src_shape, inputs[0], outputs[0], gpu_info);
break;
}
case OperationType::SPACE_TO_DEPTH:
*tasks = SelectSpaceToDepth(
node_desc.task = SelectSpaceToDepth(
inputs[0], outputs[0],
absl::any_cast<SpaceToDepthAttributes>(node->operation.attributes));
break;
@ -441,7 +475,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
case OperationType::SQUARE:
case OperationType::TANH: {
auto gpu_op = ElementwiseWithOneInput(inputs[0], outputs[0], op_type);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
break;
}
case OperationType::DIV:
@ -456,7 +491,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
auto gpu_op = ElementwiseWithOneInputAndConstantArguent(
inputs[0], outputs[0], options, op_type, attr.param);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
return absl::UnimplementedError(
"Missing attributes for single input op: " +
@ -466,7 +502,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
const auto srcs = graph.FindInputs(node_id);
auto gpu_op = ElementwiseWithTwoInputs(inputs, outputs[0],
srcs[1]->tensor.shape, op_type);
*tasks = {std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op))};
node_desc.task =
std::make_shared<ComputeTaskDescriptor>(std::move(gpu_op));
}
} break;
case OperationType::BATCH_NORMALIZATION:
@ -501,10 +538,6 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info,
const RuntimeOptions& options,
CompiledModel* compiled_model) {
int last_node_id = 0;
for (const auto& node : graph.nodes()) {
last_node_id = std::max(last_node_id, static_cast<int>(node->id));
}
int last_value_id = 0;
for (const auto& value : graph.values()) {
compiled_model->tensor_shapes[value->id] = value->tensor.shape;
@ -520,13 +553,14 @@ absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info,
for (auto& output : graph.FindOutputs(node->id)) {
outputs.push_back(static_cast<ValueId>(output->id));
}
std::vector<ComputeTaskDescriptorPtr> tasks;
std::vector<NodeDescriptor> node_descs;
std::vector<ComputeTaskDescriptorPtr> custom_tasks;
auto custom_status =
RegisterCustomOps(graph, node, inputs, outputs, options, &tasks);
RegisterCustomOps(graph, node, inputs, outputs, options, &custom_tasks);
if (!custom_status.ok()) {
auto primary_status = RegisterPrimaryOps(
graph, node, inputs, outputs, gpu_info, options, &last_node_id,
&last_value_id, &compiled_model->tensor_shapes, &tasks);
graph, node, inputs, outputs, gpu_info, options, &last_value_id,
&compiled_model->tensor_shapes, &node_descs);
if (!primary_status.ok()) {
return absl::UnimplementedError(
absl::Substitute("Unsupported op type: $0; custom registry error: "
@ -534,12 +568,18 @@ absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info,
node->operation.type, custom_status.message(),
primary_status.message()));
}
} else {
for (auto& custom_task : custom_tasks) {
NodeDescriptor node_desc;
node_desc.task = custom_task;
node_desc.description =
node->operation.type + "_" + std::to_string(node->id);
node_desc.src_tensors_ids = inputs;
node_desc.dst_tensors_ids = outputs;
node_descs.push_back(node_desc);
}
}
for (auto& task : tasks) {
NodeDescriptor node_desc;
node_desc.task = task;
node_desc.description =
node->operation.type + "_" + std::to_string(node->id);
for (auto& node_desc : node_descs) {
node_desc.id = node_linear_id++;
compiled_model->nodes.push_back(node_desc);
}

View File

@ -89,38 +89,20 @@ bool Contains(const std::vector<ValueId>& wide,
return true;
}
// Checks if all elements of the narrow vector exist in the wide vector. Vectors
// are expected to be unsorted.
bool Contains(
const std::vector<ValueId>& wide,
const std::vector<ComputeTaskDescriptor::InputBufferDescriptor>& buffers) {
if (buffers.empty() || buffers.size() > wide.size()) {
return false;
}
std::set<ValueId> wide_sorted(wide.begin(), wide.end());
for (const auto& buffer : buffers) {
if (!std::binary_search(wide_sorted.begin(), wide_sorted.end(),
buffer.id)) {
return false;
}
}
return true;
}
uint32_t BufferUseCount(ValueId id,
const std::list<NodeDescriptor>& descriptors,
std::list<FusionSequence>* chains) {
uint32_t use_count = 0;
// Buffer may be read by both processed and not processed operations.
for (auto& desc : descriptors) {
if (Contains(desc.task->input_buffers, id)) {
if (Contains(desc.src_tensors_ids, id)) {
use_count++;
}
}
for (auto& chain : *chains) {
for (auto& desc : chain) {
if (Contains(desc.task->input_buffers, id)) {
if (Contains(desc.src_tensors_ids, id)) {
use_count++;
}
}
@ -137,8 +119,8 @@ bool CanFuseOperations(const NodeDescriptor& first,
const std::list<NodeDescriptor>& descriptors,
std::list<FusionSequence>* chains) {
return second.task->is_linkable &&
!Contains(output_ids, first.task->output_buffer.id) &&
BufferUseCount(first.task->output_buffer.id, descriptors, chains) == 1;
!Contains(output_ids, first.dst_tensors_ids[0]) &&
BufferUseCount(first.dst_tensors_ids[0], descriptors, chains) == 1;
}
// Takes an unsorted list of task descriptors, builds a list of chains. Each
@ -153,9 +135,10 @@ void BuildFusableChains(const std::vector<ValueId>& input_ids,
for (auto input_id : input_ids) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->is_linkable = true;
desc->output_buffer = {input_id};
desc->AddDstTensor("");
NodeDescriptor node;
node.task = desc;
node.dst_tensors_ids = {input_id};
chains->push_back({node});
}
@ -172,22 +155,22 @@ void BuildFusableChains(const std::vector<ValueId>& input_ids,
std::vector<ValueId> ready_buffer_ids;
ready_buffer_ids.reserve(chains->size());
for (const auto& chain : *chains) {
ready_buffer_ids.push_back(chain.back().task->output_buffer.id);
ready_buffer_ids.push_back(chain.back().dst_tensors_ids[0]);
}
// Check if all inputs of this operation are ready.
if (Contains(ready_buffer_ids, task_descriptor.task->input_buffers)) {
if (Contains(ready_buffer_ids, task_descriptor.src_tensors_ids)) {
// Now find a chain to fuse with.
bool fused = false;
for (auto& chain : *chains) {
// We can fuse only single output for now.
bool can_link = false;
if (task_descriptor.task->is_associative_op) {
can_link = Contains(task_descriptor.task->input_buffers,
chain.back().task->output_buffer.id);
can_link = Contains(task_descriptor.src_tensors_ids,
chain.back().dst_tensors_ids[0]);
} else {
can_link = task_descriptor.task->input_buffers[0].id ==
chain.back().task->output_buffer.id;
can_link = task_descriptor.src_tensors_ids[0] ==
chain.back().dst_tensors_ids[0];
}
if (can_link && CanFuseOperations(chain.back(), task_descriptor,
output_ids, *descriptors, chains)) {
@ -229,7 +212,7 @@ std::list<FusionSequence> SortChains(
ready_buffer_ids.insert(ready_buffer_ids.begin(), graph_input_ids.begin(),
graph_input_ids.end());
for (auto& chain : sorted_chains) {
ready_buffer_ids.push_back(chain.back().task->output_buffer.id);
ready_buffer_ids.push_back(chain.back().dst_tensors_ids[0]);
}
for (auto it = chains->begin(); it != chains->end();) {
@ -240,15 +223,15 @@ std::list<FusionSequence> SortChains(
std::vector<ValueId> elements_output_buffer_ids;
elements_output_buffer_ids.reserve(chain.size());
for (const auto& element : chain) {
elements_output_buffer_ids.push_back(element.task->output_buffer.id);
elements_output_buffer_ids.push_back(element.dst_tensors_ids[0]);
}
// Collect all inputs also for linked operations.
std::vector<ValueId> elements_input_buffer_ids;
for (const auto& element : chain) {
for (const auto& buffer : element.task->input_buffers) {
if (!Contains(elements_output_buffer_ids, buffer.id)) {
elements_input_buffer_ids.push_back(buffer.id);
for (const auto& id : element.src_tensors_ids) {
if (!Contains(elements_output_buffer_ids, id)) {
elements_input_buffer_ids.push_back(id);
}
}
}
@ -275,9 +258,9 @@ std::vector<ValueId> GetUsedInputBufferIds(
output_and_intermediate_ids.reserve(sorted_chains.size());
std::set<ValueId> input_and_intermediate_ids;
for (auto it = sorted_chains.begin(); it != sorted_chains.end(); ++it) {
output_and_intermediate_ids.push_back(it->back().task->output_buffer.id);
for (const auto& buffer : it->front().task->input_buffers) {
input_and_intermediate_ids.insert(buffer.id);
output_and_intermediate_ids.push_back(it->back().dst_tensors_ids[0]);
for (const auto& id : it->front().src_tensors_ids) {
input_and_intermediate_ids.insert(id);
}
}
std::vector<ValueId> input_ids;
@ -299,7 +282,7 @@ std::vector<ValueId> GetMissingOutputBufferIds(
std::vector<ValueId> output_and_intermediate_ids;
output_and_intermediate_ids.reserve(sorted_chains.size());
for (auto it = sorted_chains.begin(); it != sorted_chains.end(); ++it) {
output_and_intermediate_ids.push_back(it->back().task->output_buffer.id);
output_and_intermediate_ids.push_back(it->back().dst_tensors_ids[0]);
}
std::vector<ValueId> missing_output_ids;
for (ValueId id : output_ids) {
@ -323,19 +306,19 @@ std::vector<ValueId> DeductOutputBufferIds(
if (it1 != it2) {
std::vector<ValueId> input_ids;
for (const auto& element : *it2) {
for (const auto& buffer : element.task->input_buffers) {
input_ids.push_back(buffer.id);
for (const auto& id : element.src_tensors_ids) {
input_ids.push_back(id);
}
}
if (Contains(input_ids, it1->back().task->output_buffer.id)) {
if (Contains(input_ids, it1->back().dst_tensors_ids[0])) {
found_as_input = true;
break;
}
}
}
if (!found_as_input) {
if (!Contains(output_ids, it1->back().task->output_buffer.id)) {
extra_output_ids.push_back(it1->back().task->output_buffer.id);
if (!Contains(output_ids, it1->back().dst_tensors_ids[0])) {
extra_output_ids.push_back(it1->back().dst_tensors_ids[0]);
}
}
}
@ -350,7 +333,7 @@ std::vector<int> DeleteUnusedTasks(const std::vector<ValueId>& output_ids,
std::vector<int> unused_operations;
for (auto it1 = chains->rbegin(); it1 != chains->rend();) {
// Don't delete if output is requested.
if (Contains(output_ids, it1->back().task->output_buffer.id)) {
if (Contains(output_ids, it1->back().dst_tensors_ids[0])) {
++it1;
continue;
}
@ -360,11 +343,11 @@ std::vector<int> DeleteUnusedTasks(const std::vector<ValueId>& output_ids,
for (auto it2 = chains->rbegin(); it2 != chains->rend(); ++it2) {
std::vector<ValueId> input_ids;
for (const auto& element : *it2) {
for (const auto& buffer : element.task->input_buffers) {
input_ids.push_back(buffer.id);
for (const auto& id : element.src_tensors_ids) {
input_ids.push_back(id);
}
}
if (Contains(input_ids, it1->back().task->output_buffer.id)) {
if (Contains(input_ids, it1->back().dst_tensors_ids[0])) {
output_used = true;
break;
}
@ -386,7 +369,7 @@ void RemoveInputProxies(std::list<FusionSequence>* chains) {
for (auto it = chains->begin(); it != chains->end();) {
auto& chain = *it;
// Remove input proxy-operations.
if (chain.front().task->input_buffers.empty()) {
if (chain.front().src_tensors_ids.empty()) {
chain.erase(chain.begin());
}
if (chain.empty()) {
@ -420,11 +403,8 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
}
)";
desc->input_buffers = {
{input_id, "device FLT4* const input_buffer"},
};
desc->output_buffer = {output_id, "device FLT4* output_buffer"};
desc->AddSrcTensor("input_buffer");
desc->AddDstTensor("output_buffer");
desc->uniform_buffers = {
{"constant int2& size",
@ -447,18 +427,20 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
NodeDescriptor node_desc;
node_desc.task = desc;
node_desc.id = operation_id;
node_desc.src_tensors_ids = {input_id};
node_desc.dst_tensors_ids = {output_id};
return node_desc;
}
NodeDescriptor FuseChain(const FusionSequence& chain) {
NodeDescriptor node_desc;
auto fused_descriptor = std::make_shared<ComputeTaskDescriptor>();
FusionSequence sequence;
if (chain.front().task->is_linkable) {
// The first task is linkable so it contains only linkable code. Insert
// unlinkable meta-task with remaining shader code.
sequence.push_back(
NonLinkableStub(-1, chain.front().task->input_buffers[0].id,
chain.front().task->input_buffers[0].id));
sequence.push_back(NonLinkableStub(-1, chain.front().src_tensors_ids[0],
chain.front().src_tensors_ids[0]));
}
sequence.insert(sequence.end(), chain.begin(), chain.end());
@ -469,12 +451,12 @@ NodeDescriptor FuseChain(const FusionSequence& chain) {
bool invalid_id = true;
ValueId fused_id;
for (const auto& desc : sequence) {
for (const auto& buffer : desc.task->input_buffers) {
if (invalid_id || buffer.id != fused_id) {
for (const auto& id : desc.src_tensors_ids) {
if (invalid_id || id != fused_id) {
num_inputs++;
}
}
fused_id = desc.task->output_buffer.id;
fused_id = desc.dst_tensors_ids[0];
invalid_id = false;
num_immutables += desc.task->immutable_buffers.size();
}
@ -496,24 +478,25 @@ NodeDescriptor FuseChain(const FusionSequence& chain) {
} else {
// Declare output buffer only for the first unlinkable task.
buffer_declarations +=
desc.task->output_buffer.declaration + "[[buffer(0)]],\n";
desc.task->dst_tensors_names[0] + "[[buffer(0)]],\n";
output_index++;
}
std::string call_arguments;
for (const auto& buffer : desc.task->input_buffers) {
if (invalid_id || buffer.id != fused_id) {
for (int i = 0; i < desc.task->src_tensors_names.size(); ++i) {
if (invalid_id || desc.src_tensors_ids[i] != fused_id) {
std::string index = std::to_string(input_index);
std::string name = (desc.task->is_linkable ? (" buffer" + index) : "");
buffer_declarations +=
buffer.declaration + name + "[[buffer(" + index + ")]],\n";
buffer_declarations += desc.task->src_tensors_names[i] + name +
"[[buffer(" + index + ")]],\n";
call_arguments += ", buffer" + index;
input_index++;
fused_descriptor->input_buffers.push_back({buffer.id, ""});
fused_descriptor->AddSrcTensor("");
node_desc.src_tensors_ids.push_back(desc.src_tensors_ids[i]);
}
}
// We have an output id that is the input for the next task.
fused_id = desc.task->output_buffer.id;
fused_id = desc.dst_tensors_ids[0];
invalid_id = false;
for (const auto& buffer : desc.task->immutable_buffers) {
@ -549,9 +532,9 @@ NodeDescriptor FuseChain(const FusionSequence& chain) {
fused_descriptor->shader_source =
absl::Substitute(non_linkable.task->shader_source, function_code + "$0",
buffer_declarations + "$1", call_code);
fused_descriptor->output_buffer = {fused_id, ""};
fused_descriptor->AddDstTensor("");
fused_descriptor->resize_function = non_linkable.task->resize_function;
NodeDescriptor node_desc;
node_desc.dst_tensors_ids = {fused_id};
node_desc.task = fused_descriptor;
// The id of fused descriptor is the id of the first descriptor in the list.
node_desc.id = chain.front().id;

View File

@ -27,13 +27,6 @@ namespace tflite {
namespace gpu {
namespace metal {
struct NodeDescriptor {
ComputeTaskDescriptorPtr task;
// Unique ID to match the graph compilation errors.
int id;
std::string description;
};
struct CompiledModel {
std::vector<NodeDescriptor> nodes;
std::map<ValueId, BHWC> tensor_shapes;

View File

@ -37,10 +37,12 @@ static std::vector<tflite::gpu::metal::NodeDescriptor> MulLinkable(ValueId input
desc->shader_source = R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {
return value * 1.1f;
})";
desc->input_buffers = {{input_id}};
desc->output_buffer = {output_id};
desc->AddSrcTensor("");
desc->AddDstTensor("");
tflite::gpu::metal::NodeDescriptor node_desc;
node_desc.task = desc;
node_desc.src_tensors_ids = {input_id};
node_desc.dst_tensors_ids = {output_id};
return {node_desc};
}
@ -66,11 +68,8 @@ static std::vector<tflite::gpu::metal::NodeDescriptor> Add(
}
)";
desc->input_buffers = {
{input_id, "device FLT4* const input_buffer"},
};
desc->output_buffer = {output_id, "device FLT4* output_buffer"};
desc->AddSrcTensor("input_buffer");
desc->AddDstTensor("output_buffer");
desc->uniform_buffers = {
{"constant int2& size",
@ -95,6 +94,8 @@ static std::vector<tflite::gpu::metal::NodeDescriptor> Add(
tflite::gpu::metal::NodeDescriptor node_desc;
node_desc.task = desc;
node_desc.src_tensors_ids = {input_id};
node_desc.dst_tensors_ids = {output_id};
return {node_desc};
}
@ -120,12 +121,9 @@ static std::vector<tflite::gpu::metal::NodeDescriptor> Add2(
}
)";
desc->input_buffers = {
{input_id1, "device FLT4* const input_buffer1"},
{input_id2, "device FLT4* const input_buffer2"},
};
desc->output_buffer = {output_id, "device FLT4* output_buffer"};
desc->AddSrcTensor("input_buffer1");
desc->AddSrcTensor("input_buffer2");
desc->AddDstTensor("output_buffer");
desc->uniform_buffers = {
{"constant int2& size",
@ -150,6 +148,8 @@ static std::vector<tflite::gpu::metal::NodeDescriptor> Add2(
tflite::gpu::metal::NodeDescriptor node_desc;
node_desc.task = desc;
node_desc.src_tensors_ids = {input_id1, input_id2};
node_desc.dst_tensors_ids = {output_id};
return {node_desc};
}
@ -163,13 +163,14 @@ static std::vector<tflite::gpu::metal::NodeDescriptor> Add2Linkable(ValueId inpu
FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* const buffer2) {
return value + buffer2[linear_index];
})";
desc->input_buffers = {
{input_id1, "device FLT4* const"},
{input_id2, "device FLT4* const"},
};
desc->output_buffer = {output_id};
desc->AddSrcTensor("");
desc->AddSrcTensor("");
desc->AddDstTensor("");
tflite::gpu::metal::NodeDescriptor node_desc;
node_desc.task = desc;
node_desc.src_tensors_ids = {input_id1, input_id2};
node_desc.dst_tensors_ids = {output_id};
return {node_desc};
}

View File

@ -33,7 +33,7 @@ limitations under the License.
/// Returns empty string or error if shader can't be compiled.
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
taskDescriptor:(::tflite::gpu::metal::ComputeTaskDescriptorPtr)desc
taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc
runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options;
/// Updates parameters for inputs/outputs/intermediate tensors

View File

@ -72,11 +72,11 @@ struct UniformBuffer {
}
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
taskDescriptor:(ComputeTaskDescriptorPtr)desc
taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc
runtimeOptions:(const RuntimeOptions&)options {
size_t offset = desc->input_buffers.size() + desc->uniform_buffers.size()
+ desc->immutable_buffers.size() + 1;
RETURN_IF_ERROR(_metal_args.Init(device, offset, &desc->args, &desc->shader_source));
size_t offset = desc.task->src_tensors_names.size() + desc.task->uniform_buffers.size()
+ desc.task->immutable_buffers.size() + 1;
RETURN_IF_ERROR(_metal_args.Init(device, offset, &desc.task->args, &desc.task->shader_source));
NSString* barrier;
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
@ -122,21 +122,21 @@ struct UniformBuffer {
@"SIMDGROUP_BARRIER" : barrier,
};
NSString* code = [NSString stringWithCString:desc->shader_source.c_str()
NSString* code = [NSString stringWithCString:desc.task->shader_source.c_str()
encoding:[NSString defaultCStringEncoding]];
id<MTLComputePipelineState> program;
RETURN_IF_ERROR(CreateComputeProgram(device, code, @"ComputeFunction", macros, &program));
if (!program) {
return absl::InternalError("Unknown shader compilation error");
}
for (auto& buffer : desc->input_buffers) {
_inputBuffers.emplace_back(InputBuffer{buffer.id, nil});
for (auto& id : desc.src_tensors_ids) {
_inputBuffers.emplace_back(InputBuffer{id, nil});
}
for (auto& uniform : desc->uniform_buffers) {
for (auto& uniform : desc.task->uniform_buffers) {
_uniformBuffers.emplace_back(UniformBuffer{{}, uniform.data_function});
}
_outputBuffers.emplace_back(OutputBuffer{desc->output_buffer.id, nil});
for (auto& immutable : desc->immutable_buffers) {
_outputBuffers.emplace_back(OutputBuffer{desc.dst_tensors_ids[0], nil});
for (auto& immutable : desc.task->immutable_buffers) {
int padding =
4 * (options.storage_precision == RuntimeOptions::Precision::FP32 ? sizeof(float)
: sizeof(HalfBits));
@ -147,7 +147,7 @@ struct UniformBuffer {
options:MTLResourceStorageModeShared];
_immutableBuffers.emplace_back(metalBuffer);
}
_resizeFunction = desc->resize_function;
_resizeFunction = desc.task->resize_function;
_program = program;
return absl::OkStatus();
}

View File

@ -61,6 +61,14 @@ std::vector<uint8_t> GetByteBufferConvertedResized(
return result;
}
void ComputeTaskDescriptor::AddSrcTensor(const std::string& tensor_name) {
src_tensors_names.push_back("device FLT4* " + tensor_name);
}
void ComputeTaskDescriptor::AddDstTensor(const std::string& tensor_name) {
dst_tensors_names.push_back("device FLT4* " + tensor_name);
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -44,20 +44,6 @@ using DispatchParamsFunction = std::function<std::pair<uint3, uint3>(
// building blocks. All required data like immutable operation parameters
// (weights etc.) is attached to the descriptor.
struct ComputeTaskDescriptor {
struct InputBufferDescriptor {
ValueId id;
// The declaration is inserted into the compute function arguments list.
// Example for non-linkable task: "device FLT4* const input_buffer"
// Example for linkable: "device FLT4* const"
std::string declaration;
};
struct OutputBufferDescriptor {
ValueId id;
// The declaration is inserted into the compute function arguments list.
// Example for non-linkable task: "device FLT4* output_buffer"
// Example for linkable: "device FLT4*"
std::string declaration;
};
struct ImmutableBufferDescriptor {
std::string declaration;
std::vector<uint8_t> data;
@ -102,19 +88,30 @@ struct ComputeTaskDescriptor {
// for example add is associative
bool is_associative_op = false;
std::string shader_source;
std::vector<InputBufferDescriptor> input_buffers;
// A single per-operation output is supported now.
OutputBufferDescriptor output_buffer;
std::vector<std::string> src_tensors_names;
std::vector<std::string> dst_tensors_names;
std::vector<ImmutableBufferDescriptor> immutable_buffers;
std::vector<UniformBufferDescriptor> uniform_buffers;
// Dynamic resizing of input tensor is supported. User-defined functions to
// calculate new parameters for GPU compute task dispatching. A leading
// unlinkable task must provide this.
DispatchParamsFunction resize_function;
void AddSrcTensor(const std::string& tensor_name);
void AddDstTensor(const std::string& tensor_name);
};
using ComputeTaskDescriptorPtr = std::shared_ptr<ComputeTaskDescriptor>;
struct NodeDescriptor {
ComputeTaskDescriptorPtr task;
// Unique ID to match the graph compilation errors.
int id;
std::string description;
std::vector<ValueId> src_tensors_ids;
std::vector<ValueId> dst_tensors_ids;
};
/// Helper function to convert buffer's content into stream of bytes
template <typename T>
std::vector<uint8_t> GetByteBuffer(const std::vector<T>& input_vector) {

View File

@ -62,7 +62,7 @@ using ::tflite::gpu::TensorUsageRecord;
for (const auto& node : compiledModel.nodes) {
TFLComputeTask* task = [[TFLComputeTask alloc] init];
RETURN_IF_ERROR([task compileWithDevice:_device
taskDescriptor:node.task
taskDescriptor:node
runtimeOptions:_options]);
[task setDescription:node.description];
_computeTasks.emplace_back(task);

View File

@ -39,20 +39,19 @@ using ::tflite::gpu::uint3;
using ::tflite::gpu::ValueId;
// This is an example of simple linkable operation performing multiplication by a constant.
static std::vector<ComputeTaskDescriptorPtr> MulLinkable(ValueId input_id,
ValueId output_id) {
static ComputeTaskDescriptorPtr MulLinkable() {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->is_linkable = true;
desc->shader_source = R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {
return value * 1.1f;
})";
desc->input_buffers = {{input_id}};
desc->output_buffer = {output_id};
return {desc};
desc->AddSrcTensor("");
desc->AddDstTensor("");
return desc;
}
// This is an example of simple non-linkable operation performing add with a constant.
static std::vector<ComputeTaskDescriptorPtr> Add(ValueId input_id, ValueId output_id) {
static ComputeTaskDescriptorPtr Add() {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->is_linkable = false;
desc->shader_source = R"(
@ -72,11 +71,8 @@ static std::vector<ComputeTaskDescriptorPtr> Add(ValueId input_id, ValueId outpu
}
)";
desc->input_buffers = {
{input_id, "device FLT4* const input_buffer"},
};
desc->output_buffer = {output_id, "device FLT4* output_buffer"};
desc->AddSrcTensor("input_buffer");
desc->AddDstTensor("output_buffer");
desc->uniform_buffers = {
{"constant int2& size",
@ -99,20 +95,20 @@ static std::vector<ComputeTaskDescriptorPtr> Add(ValueId input_id, ValueId outpu
return std::make_pair(groups_size, groups_count);
};
return {desc};
return desc;
}
// This is an example of simple linkable operation performing multiplication by a uniform
static std::vector<ComputeTaskDescriptorPtr> AddUniformLinkable(
ValueId input_id, ValueId output_id, const std::vector<float>& channel_multipliers) {
static ComputeTaskDescriptorPtr AddUniformLinkable(
const std::vector<float>& channel_multipliers) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->is_linkable = true;
desc->shader_source = R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, FLT4 multiplier)
{
return value + multiplier;
})";
desc->input_buffers = {{input_id}};
desc->output_buffer = {output_id};
desc->AddSrcTensor("");
desc->AddDstTensor("");
desc->uniform_buffers = {
{"constant FLT4&",
[channel_multipliers](const std::vector<BHWC>& src_shapes,
@ -120,20 +116,20 @@ static std::vector<ComputeTaskDescriptorPtr> AddUniformLinkable(
return GetByteBuffer(channel_multipliers);
}},
};
return {desc};
return desc;
}
// This is an example of simple linkable operation performing multiplication by a constant.
static std::vector<ComputeTaskDescriptorPtr> MulArrayLinkable(
ValueId input_id, ValueId output_id, const std::vector<float>& channel_multipliers) {
static ComputeTaskDescriptorPtr MulArrayLinkable(
const std::vector<float>& channel_multipliers) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->is_linkable = true;
desc->shader_source = R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
device FLT4* const multiplier) {
return value * multiplier[gid.z];
})";
desc->input_buffers = {{input_id}};
desc->output_buffer = {output_id};
desc->AddSrcTensor("");
desc->AddDstTensor("");
desc->immutable_buffers = {
{"device FLT4* const", GetByteBuffer(channel_multipliers)},
};
@ -157,9 +153,14 @@ static std::vector<ComputeTaskDescriptorPtr> MulArrayLinkable(
- (void)testTwoInputsShaderOutput {
ValueId inputBufferID = 1;
ValueId outputBufferID = 3;
auto graph = Add(inputBufferID, 2);
auto graph2 = MulLinkable(2, outputBufferID);
graph.insert(graph.end(), graph2.begin(), graph2.end());
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(2);
nodes[0].task = Add();
nodes[0].src_tensors_ids = {inputBufferID};
nodes[0].dst_tensors_ids = {2};
tflite::gpu::metal::NodeDescriptor node1;
nodes[1].task = MulLinkable();
nodes[1].src_tensors_ids = {2};
nodes[1].dst_tensors_ids = {outputBufferID};
TensorFloat32 input;
input.shape = BHWC(1, 1, 1, 3);
input.id = inputBufferID;
@ -168,7 +169,7 @@ static std::vector<ComputeTaskDescriptorPtr> MulArrayLinkable(
output.shape = BHWC(1, 1, 1, 3);
std::map<ValueId, TensorFloat32> inputs{{inputBufferID, input}};
std::map<ValueId, TensorFloat32> outputs{{outputBufferID, output}};
auto status = RunGraph(graph, _device, inputs, &outputs);
auto status = RunGraph(nodes, _device, inputs, &outputs);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({2.2f, 3.3f, 4.4f}, outputs[outputBufferID].data, 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
@ -177,8 +178,10 @@ static std::vector<ComputeTaskDescriptorPtr> MulArrayLinkable(
- (void)testImmutableShaderOutput {
ValueId inputBufferID = 1;
ValueId outputBufferID = 2;
auto graph = MulArrayLinkable(inputBufferID, outputBufferID,
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
nodes[0].task = MulArrayLinkable({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
nodes[0].src_tensors_ids = {inputBufferID};
nodes[0].dst_tensors_ids = {outputBufferID};
TensorFloat32 input;
input.shape = BHWC(1, 1, 1, 7);
input.id = inputBufferID;
@ -187,7 +190,7 @@ static std::vector<ComputeTaskDescriptorPtr> MulArrayLinkable(
output.shape = BHWC(1, 1, 1, 7);
std::map<ValueId, TensorFloat32> inputs{{inputBufferID, input}};
std::map<ValueId, TensorFloat32> outputs{{outputBufferID, output}};
auto status = RunGraph(graph, _device, inputs, &outputs);
auto status = RunGraph(nodes, _device, inputs, &outputs);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({1, 4, 9, 16, 25, 36, 49}, outputs[outputBufferID].data, 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
@ -196,7 +199,10 @@ static std::vector<ComputeTaskDescriptorPtr> MulArrayLinkable(
- (void)testUniformShaderOutput {
ValueId inputBufferID = 1;
ValueId outputBufferID = 2;
auto graph = AddUniformLinkable(inputBufferID, outputBufferID, {1.0f, 2.0f, 3.0f, 4.0f});
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
nodes[0].task = AddUniformLinkable({1.0f, 2.0f, 3.0f, 4.0f});
nodes[0].src_tensors_ids = {inputBufferID};
nodes[0].dst_tensors_ids = {outputBufferID};
TensorFloat32 input;
input.shape = BHWC(1, 1, 1, 3);
input.id = inputBufferID;
@ -205,7 +211,7 @@ static std::vector<ComputeTaskDescriptorPtr> MulArrayLinkable(
output.shape = BHWC(1, 1, 1, 3);
std::map<ValueId, TensorFloat32> inputs{{inputBufferID, input}};
std::map<ValueId, TensorFloat32> outputs{{outputBufferID, output}};
auto status = RunGraph(graph, _device, inputs, &outputs);
auto status = RunGraph(nodes, _device, inputs, &outputs);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({2, 4, 6}, outputs[outputBufferID].data, 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
@ -214,10 +220,14 @@ static std::vector<ComputeTaskDescriptorPtr> MulArrayLinkable(
- (void)testUniformAndImmutableShaderOutput {
ValueId inputBufferID = 1;
ValueId outputBufferID = 3;
auto graph =
MulArrayLinkable(inputBufferID, 2, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
auto graph2 = AddUniformLinkable(2, outputBufferID, {1.0f, 2.0f, 3.0f, 4.0f});
graph.insert(graph.end(), graph2.begin(), graph2.end());
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(2);
nodes[0].task = MulArrayLinkable({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
nodes[0].src_tensors_ids = {inputBufferID};
nodes[0].dst_tensors_ids = {2};
tflite::gpu::metal::NodeDescriptor node1;
nodes[1].task = AddUniformLinkable({1.0f, 2.0f, 3.0f, 4.0f});
nodes[1].src_tensors_ids = {2};
nodes[1].dst_tensors_ids = {outputBufferID};
TensorFloat32 input;
input.shape = BHWC(1, 1, 1, 7);
input.id = inputBufferID;
@ -226,7 +236,7 @@ static std::vector<ComputeTaskDescriptorPtr> MulArrayLinkable(
output.shape = BHWC(1, 1, 1, 7);
std::map<ValueId, TensorFloat32> inputs{{inputBufferID, input}};
std::map<ValueId, TensorFloat32> outputs{{outputBufferID, output}};
auto status = RunGraph(graph, _device, inputs, &outputs);
auto status = RunGraph(nodes, _device, inputs, &outputs);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({2, 6, 12, 20, 26, 38, 52}, outputs[outputBufferID].data, 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());

View File

@ -57,9 +57,9 @@ ComputeTaskDescriptor Add(const std::vector<ValueId> input_ids,
desc.shader_source = GetAddTableCodeFused(input_ids.size() - 1);
for (int i = 0; i < input_ids.size(); ++i) {
desc.input_buffers.push_back({input_ids[i], "device FLT4* const"});
desc.AddSrcTensor("");
}
desc.output_buffer = {output_id};
desc.AddDstTensor("");
return desc;
}

View File

@ -141,12 +141,9 @@ ComputeTaskDescriptor ConcatZ(std::vector<ValueId> input_ids, ValueId output_id,
desc.shader_source = GetConcatZCode(channels);
for (int i = 0; i < input_ids.size(); ++i) {
const std::string buffer_name =
"device FLT4* const src_buffer" + std::to_string(i);
desc.input_buffers.push_back({input_ids[i], buffer_name});
desc.AddSrcTensor("src_buffer" + std::to_string(i));
}
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddDstTensor("dst_buffer");
desc.uniform_buffers = {
{"constant uniforms& U",
@ -227,12 +224,9 @@ ComputeTaskDescriptor ConcatX(std::vector<ValueId> input_ids, ValueId output_id,
desc.shader_source = code;
for (int i = 0; i < input_ids.size(); ++i) {
const std::string buffer_name =
"device FLT4* const src_buffer" + std::to_string(i);
desc.input_buffers.push_back({input_ids[i], buffer_name});
desc.AddSrcTensor("src_buffer" + std::to_string(i));
}
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddDstTensor("dst_buffer");
desc.uniform_buffers = {
{"constant int3& size",
@ -305,12 +299,9 @@ ComputeTaskDescriptor ConcatY(std::vector<ValueId> input_ids, ValueId output_id,
desc.shader_source = code;
for (int i = 0; i < input_ids.size(); ++i) {
const std::string buffer_name =
"device FLT4* const src_buffer" + std::to_string(i);
desc.input_buffers.push_back({input_ids[i], buffer_name});
desc.AddSrcTensor("src_buffer" + std::to_string(i));
}
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddDstTensor("dst_buffer");
desc.uniform_buffers = {
{"constant int3& size",

View File

@ -1054,12 +1054,8 @@ ComputeTaskDescriptor ConvolutionGeneric(ValueId input_id, ValueId output_id,
ComputeTaskDescriptor desc;
desc.shader_source = GenerateConvolution(params);
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
auto weights_reordered = ReorderWeightsForConv(attr.weights, params);
std::string addr_space =
@ -1136,12 +1132,8 @@ ComputeTaskDescriptor ConvolutionWino4x4To6x6(
ComputeTaskDescriptor desc;
desc.shader_source = GenerateConvolution(params);
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
::tflite::gpu::Tensor<OHWI, DataType::FLOAT32> wino_weights;
RearrangeWeightsToWinograd4x4To6x6Weights(attr.weights, &wino_weights);

View File

@ -300,28 +300,23 @@ using ::tflite::gpu::metal::SingleOpModel;
tflite::gpu::GpuInfo gpu_info;
tflite::gpu::GetGpuInfoFromDeviceDescription(device_name, tflite::gpu::GpuApi::kMetal, &gpu_info);
auto gpu_op0 = ConvolutionGeneric(0, 1, dst_shape, attr, gpu_info, options);
std::vector<tflite::gpu::metal::ComputeTaskDescriptorPtr> tasks_v0 =
{std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op0))};
auto status = RunGraph(tasks_v0, device, inputs_v0, &outputs_v0);
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op0));
nodes[0].src_tensors_ids = {0};
nodes[0].dst_tensors_ids = {1};
auto status = RunGraph(nodes, device, inputs_v0, &outputs_v0);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
tflite::gpu::metal::Winograd4x4To36Attributes wino_up_attr;
wino_up_attr.padding = attr.padding;
auto gpu_op1 = tflite::gpu::metal::Winograd4x4To36(0, 2, wino_up_attr);
std::vector<tflite::gpu::metal::ComputeTaskDescriptorPtr> tasks_v1 =
{std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op1))};
auto gpu_op2 = ConvolutionWino4x4To6x6(2, 3, conv_shape, attr, gpu_info, options);
std::vector<tflite::gpu::metal::ComputeTaskDescriptorPtr> tasks_v2 =
{std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op2))};
tflite::gpu::metal::Winograd36To4x4Attributes wino_down_attr;
wino_down_attr.output_shape = dst_shape;
wino_down_attr.biases = attr.bias;
auto gpu_op3 = tflite::gpu::metal::Winograd36To4x4(3, 1, options, wino_down_attr);
std::vector<tflite::gpu::metal::ComputeTaskDescriptorPtr> tasks_v3 =
{std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op3))};
std::map<ValueId, TensorFloat32> inputs_v1;
inputs_v1[0] = src_tensor;
@ -329,21 +324,30 @@ using ::tflite::gpu::metal::SingleOpModel;
outputs_v1[2].shape = conv_shape;
outputs_v1[2].shape.c = src_shape.c;
outputs_v1[2].data.resize(outputs_v1[2].shape.DimensionsProduct());
status = RunGraph(tasks_v1, device, inputs_v1, &outputs_v1);
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op1));
nodes[0].src_tensors_ids = {0};
nodes[0].dst_tensors_ids = {2};
status = RunGraph(nodes, device, inputs_v1, &outputs_v1);
std::map<ValueId, TensorFloat32> inputs_v2;
inputs_v2[2] = outputs_v1[2];
std::map<ValueId, TensorFloat32> outputs_v2;
outputs_v2[3].shape = conv_shape;
outputs_v2[3].data.resize(outputs_v2[3].shape.DimensionsProduct());
status = RunGraph(tasks_v2, device, inputs_v2, &outputs_v2);
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op2));
nodes[0].src_tensors_ids = {2};
nodes[0].dst_tensors_ids = {3};
status = RunGraph(nodes, device, inputs_v2, &outputs_v2);
std::map<ValueId, TensorFloat32> inputs_v3;
inputs_v3[3] = outputs_v2[3];
std::map<ValueId, TensorFloat32> outputs_v3;
outputs_v3[1].shape = dst_shape;
outputs_v3[1].data.resize(outputs_v3[1].shape.DimensionsProduct());
status = RunGraph(tasks_v3, device, inputs_v3, &outputs_v3);
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op3));
nodes[0].src_tensors_ids = {3};
nodes[0].dst_tensors_ids = {1};
status = RunGraph(nodes, device, inputs_v3, &outputs_v3);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
status = CompareVectors(outputs_v0[1].data, outputs_v3[1].data, 1e-4f);

View File

@ -553,12 +553,8 @@ ComputeTaskDescriptor DepthWiseConvolution(
}
)";
desc.shader_source = shader_source;
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
const int output_channels_count = attr.weights.shape.i * attr.weights.shape.o;
desc.immutable_buffers = {
@ -618,12 +614,8 @@ ComputeTaskDescriptor DepthWiseConv3x3Stride1x1(
const RuntimeOptions& options) {
ComputeTaskDescriptor desc;
desc.shader_source = GetKernelDepthWiseConv3x3Stride1x1();
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
// For this operation we keep weights and biases in one buffer
auto weights_reordered = ReorderWeightsDepthWiseConv3x3Stride1x1(attr);
@ -674,11 +666,8 @@ ComputeTaskDescriptor DepthWiseConv3x3Stride2(
ComputeTaskDescriptor desc;
desc.shader_source = GetKernelDepthWiseConv3x3Stride2();
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
// For this operation we keep weights and biases in one buffer
auto weights_reordered = ReorderWeightsDepthWiseConv3x3Stride2(attr);

View File

@ -108,11 +108,9 @@ ComputeTaskDescriptor ElementwiseWithTwoInputs(std::vector<ValueId> input_ids,
desc.shader_source = code;
desc.input_buffers = {
{input_ids[0], "device FLT4* const"},
{input_ids[1], "device FLT4* const"},
};
desc.output_buffer = {output_id};
desc.AddSrcTensor("");
desc.AddSrcTensor("");
desc.AddDstTensor("");
desc.uniform_buffers = {
{"constant int2&",
@ -139,8 +137,8 @@ ComputeTaskDescriptor ElementwiseWithOneInput(ValueId input_id,
" return " + OneInputFunctor(op_type, "value") + ";\n";
desc.shader_source += " }";
desc.input_buffers = {{input_id}};
desc.output_buffer = {output_id};
desc.AddSrcTensor("");
desc.AddDstTensor("");
return desc;
}
@ -186,8 +184,8 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
desc.shader_source += " }";
desc.input_buffers = {{input_id}};
desc.output_buffer = {output_id};
desc.AddSrcTensor("");
desc.AddDstTensor("");
if (scalar) {
std::vector<uint8_t> scalar_bits =
GetByteBuffer(std::vector<float>{*scalar});

View File

@ -127,11 +127,8 @@ ComputeTaskDescriptor FullyConnected(ValueId input_id, ValueId output_id,
desc.args.AddInt("src_slices", DivideRoundUp(attr.weights.shape.i, 4));
desc.args.AddInt("dst_channels_alignedx8", AlignByN(attr.weights.shape.o, 8));
desc.input_buffers = {
{input_id, "device FLT4* const vector"},
};
desc.output_buffer = {output_id, "device FLT4* result"};
desc.AddSrcTensor("vector");
desc.AddDstTensor("result");
bool shared_memory = gpu_info.IsApple() &&
gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal();

View File

@ -92,12 +92,9 @@ ComputeTaskDescriptor MaxUnpooling(ValueId input_id, ValueId input_indices_id,
ComputeTaskDescriptor desc;
desc.shader_source = GetMaxUnpoolingCode(params.kernel);
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
{input_indices_id, "device FLT4* const src_indices_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* output_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddSrcTensor("src_indices_buffer");
desc.AddDstTensor("output_buffer");
desc.uniform_buffers = {
{"constant uniforms& params",

View File

@ -115,11 +115,9 @@ ComputeTaskDescriptor Mean(ValueId input_id, ValueId output_id,
std::string code = GetMeanCode(work_group_size);
desc.shader_source = code;
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.uniform_buffers = {
{"constant uniforms& params",
[work_group_size](const std::vector<BHWC>& src_shapes,

View File

@ -154,11 +154,8 @@ ComputeTaskDescriptor Padding(ValueId input_id, ValueId output_id,
ComputeTaskDescriptor desc;
desc.shader_source = GetPaddingCode(attr);
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
desc.uniform_buffers = {
{"constant uniforms& params",

View File

@ -205,11 +205,8 @@ ComputeTaskDescriptor Pooling(ValueId input_id, ValueId output_id,
desc.shader_source = GetAveragePoolingCode(params.kernel);
}
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* output_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("output_buffer");
desc.uniform_buffers = {
{"constant uniforms& params",

View File

@ -57,8 +57,8 @@ ComputeTaskDescriptor PReLU(ValueId input_id, ValueId output_id,
return FLT4(max(FLT4(0.0f), value) + alphas[gid.z] * min(FLT4(0.0f), value));
})";
}
desc.input_buffers = {{input_id}};
desc.output_buffer = {output_id};
desc.AddSrcTensor("");
desc.AddDstTensor("");
desc.immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(alpha_buffer->data, options.storage_precision)},
@ -99,8 +99,8 @@ ComputeTaskDescriptor PReLUFull(ValueId input_id, ValueId output_id,
return FLT4(max(FLT4(0.0f), value) + alphas[linear_index] * min(FLT4(0.0f), value));
})";
}
desc.input_buffers = {{input_id}};
desc.output_buffer = {output_id};
desc.AddSrcTensor("");
desc.AddDstTensor("");
desc.immutable_buffers = {
{"device FLT4* const", GetByteBufferConverted(ConvertToPHWC4(*alpha),
options.storage_precision)},

View File

@ -36,8 +36,8 @@ ComputeTaskDescriptor QuantizeAndDequantize(
}
)";
desc.input_buffers = {{input_id}};
desc.output_buffer = {output_id};
desc.AddSrcTensor("");
desc.AddDstTensor("");
desc.uniform_buffers = {
{"constant float3&",
[attr](const std::vector<BHWC>& src_shapes,

View File

@ -46,8 +46,8 @@ ComputeTaskDescriptor ReLU(ValueId input_id, ValueId output_id,
desc.shader_source =
parameters + " return FLT4(max(value, " + min_func + "));\n}";
}
desc.input_buffers = {{input_id}};
desc.output_buffer = {output_id};
desc.AddSrcTensor("");
desc.AddDstTensor("");
desc.uniform_buffers = {
{"constant float2&",
[attr](const std::vector<BHWC>& src_shapes,

View File

@ -121,11 +121,8 @@ ComputeTaskDescriptor Reshape(ValueId input_id, ValueId output_id,
ComputeTaskDescriptor desc;
desc.shader_source = GetReshapeCode();
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
desc.uniform_buffers = {
{"constant uniforms& params",
@ -166,11 +163,8 @@ ComputeTaskDescriptor Reshapex4(ValueId input_id, ValueId output_id,
ComputeTaskDescriptor desc;
desc.shader_source = GetReshapex4Code();
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
desc.uniform_buffers = {
{"constant uniforms& params",

View File

@ -132,11 +132,8 @@ ComputeTaskDescriptor Resize(ValueId input_id, ValueId output_id,
return {};
}
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* output_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("output_buffer");
desc.uniform_buffers = {
{"constant int4& size",

View File

@ -134,11 +134,8 @@ ComputeTaskDescriptor Slice(ValueId input_id, ValueId output_id,
ComputeTaskDescriptor desc;
desc.shader_source = GetSliceCode(attr);
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
desc.uniform_buffers = {
{"constant uniforms& params",

View File

@ -145,11 +145,8 @@ ComputeTaskDescriptor Softmax(ValueId input_id, ValueId output_id,
}
)";
desc.input_buffers = {
{input_id, "device FLT4* const input_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* output_buffer"};
desc.AddSrcTensor("input_buffer");
desc.AddDstTensor("output_buffer");
desc.uniform_buffers = {
{"constant int2& size",
@ -176,11 +173,8 @@ ComputeTaskDescriptor Softmax1x1(ValueId input_id, ValueId output_id,
ComputeTaskDescriptor desc;
desc.shader_source = GetSoftmax1x1Code(gpu_info);
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
desc.uniform_buffers = {
{"constant uniforms& params",

View File

@ -64,9 +64,8 @@ kernel void ComputeFunction($1 uint3 gid[[thread_position_in_grid]]) {
dst_buffer[gid.x + dst_size.x * (gid.y + dst_size.y * gid.z)] = value;
})";
desc.input_buffers = {{input_id, "device FLT4* const src_buffer"}};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
desc.uniform_buffers = {
{"constant uniforms& params",

View File

@ -63,7 +63,7 @@ absl::Status CompareVectors(const std::vector<float>& reference,
/// Helper function that compiles previously configured graph (with added
/// tasks), initializes graph with specified inputs, invokes and fills specified
/// outputs
absl::Status RunGraph(const std::vector<ComputeTaskDescriptorPtr>& nodes,
absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes,
id<MTLDevice> device,
const std::map<ValueId, TensorFloat32>& inputs,
std::map<ValueId, TensorFloat32>* outputs);

View File

@ -166,7 +166,7 @@ absl::Status CompareVectors(const std::vector<float>& reference, const std::vect
return absl::OkStatus();
}
absl::Status RunGraph(const std::vector<ComputeTaskDescriptorPtr>& nodes, id<MTLDevice> device,
absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes, id<MTLDevice> device,
const std::map<ValueId, TensorFloat32>& inputs,
std::map<ValueId, TensorFloat32>* outputs) {
std::vector<ValueId> inputBufferIDs;
@ -181,11 +181,7 @@ absl::Status RunGraph(const std::vector<ComputeTaskDescriptorPtr>& nodes, id<MTL
}
std::map<ValueId, BHWC> outputDimensions;
CompiledModel raw_model;
for (auto& node : nodes) {
NodeDescriptor node_desc;
node_desc.task = node;
raw_model.nodes.push_back(node_desc);
}
raw_model.nodes = nodes;
for(const auto& input : inputs) {
raw_model.tensor_shapes[input.first] = input.second.shape;
}

View File

@ -473,11 +473,8 @@ ComputeTaskDescriptor ConvolutionTransposed(
desc.shader_source = GetDeconvolution(params);
}
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
const int src_ch_aligned = AlignByN(params.weights.shape.i, 4);
const int dst_ch_aligned = AlignByN(params.weights.shape.o, 4);
@ -612,11 +609,8 @@ ComputeTaskDescriptor ConvolutionTransposed4x4(
const int2 block_size(recommended_2x ? 2 : 1, 1);
desc.shader_source = GetDeconvolution4x4(block_size, gpu_info);
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
desc.immutable_buffers = {
{"device FLT4* const filters", filters},

View File

@ -468,11 +468,8 @@ ComputeTaskDescriptor Winograd4x4To36(ValueId input_id, ValueId output_id,
ComputeTaskDescriptor desc;
desc.shader_source = GetKernelWinograd4x4To36();
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
desc.uniform_buffers = {
{"constant uniforms& U",
@ -526,11 +523,8 @@ ComputeTaskDescriptor Winograd4x4To36TileX6(
ComputeTaskDescriptor desc;
desc.shader_source = GetKernelWinograd4x4To36TileX6();
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
std::vector<float> bt_aligned(6 * 8);
auto bt_mat = BtMatrixForWinograd4x4To6x6();
@ -595,11 +589,8 @@ ComputeTaskDescriptor Winograd36To4x4(ValueId input_id, ValueId output_id,
ComputeTaskDescriptor desc;
desc.shader_source = GetKernelWinograd36To4x4();
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
desc.immutable_buffers = {
{"device FLT4* const biases",
@ -646,11 +637,8 @@ ComputeTaskDescriptor Winograd36To4x4Tile4x1(
ComputeTaskDescriptor desc;
desc.shader_source = GetKernelWinograd36To4x4Tile4x1();
desc.input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc.output_buffer = {output_id, "device FLT4* dst_buffer"};
desc.AddSrcTensor("src_buffer");
desc.AddDstTensor("dst_buffer");
std::vector<float> at_aligned(4 * 8);
auto at_mat = AtMatrixForWinograd4x4To6x6();

View File

@ -85,8 +85,10 @@ using ::tflite::gpu::metal::CompareVectors;
attr.padding.prepended = tflite::gpu::HW(1, 1);
attr.padding.appended = tflite::gpu::HW(1, 1);
auto gpu_op = tflite::gpu::metal::Winograd4x4To36(0, 1, attr);
std::vector<tflite::gpu::metal::ComputeTaskDescriptorPtr> tasks =
{std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op))};
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
nodes[0].src_tensors_ids = {0};
nodes[0].dst_tensors_ids = {1};
std::map<ValueId, TensorFloat32> inputs;
inputs[0] = src_tensor;
@ -95,7 +97,7 @@ using ::tflite::gpu::metal::CompareVectors;
outputs[1].data.resize(36, 0.0f);
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
auto status = RunGraph(tasks, device, inputs, &outputs);
auto status = RunGraph(nodes, device, inputs, &outputs);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f);
@ -149,8 +151,10 @@ using ::tflite::gpu::metal::CompareVectors;
attr.padding.prepended = tflite::gpu::HW(1, 1);
attr.padding.appended = tflite::gpu::HW(1, 1);
auto gpu_op = tflite::gpu::metal::Winograd4x4To36TileX6(0, 1, attr, options);
std::vector<tflite::gpu::metal::ComputeTaskDescriptorPtr> tasks =
{std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op))};
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
nodes[0].src_tensors_ids = {0};
nodes[0].dst_tensors_ids = {1};
std::map<ValueId, TensorFloat32> inputs;
inputs[0] = src_tensor;
@ -159,7 +163,7 @@ using ::tflite::gpu::metal::CompareVectors;
outputs[1].data.resize(36, 0.0f);
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
auto status = RunGraph(tasks, device, inputs, &outputs);
auto status = RunGraph(nodes, device, inputs, &outputs);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f);
@ -214,8 +218,10 @@ using ::tflite::gpu::metal::CompareVectors;
options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32;
auto gpu_op = tflite::gpu::metal::Winograd36To4x4(0, 1, options, attr);
std::vector<tflite::gpu::metal::ComputeTaskDescriptorPtr> tasks =
{std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op))};
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
nodes[0].src_tensors_ids = {0};
nodes[0].dst_tensors_ids = {1};
std::map<ValueId, TensorFloat32> inputs;
inputs[0] = src_tensor;
@ -224,7 +230,7 @@ using ::tflite::gpu::metal::CompareVectors;
outputs[1].data.resize(16, 0.0f);
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
auto status = RunGraph(tasks, device, inputs, &outputs);
auto status = RunGraph(nodes, device, inputs, &outputs);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-5f);
@ -279,8 +285,10 @@ using ::tflite::gpu::metal::CompareVectors;
options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32;
auto gpu_op = tflite::gpu::metal::Winograd36To4x4Tile4x1(0, 1, options, attr);
std::vector<tflite::gpu::metal::ComputeTaskDescriptorPtr> tasks =
{std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op))};
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
nodes[0].src_tensors_ids = {0};
nodes[0].dst_tensors_ids = {1};
std::map<ValueId, TensorFloat32> inputs;
inputs[0] = src_tensor;
@ -289,7 +297,7 @@ using ::tflite::gpu::metal::CompareVectors;
outputs[1].data.resize(16, 0.0f);
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
auto status = RunGraph(tasks, device, inputs, &outputs);
auto status = RunGraph(nodes, device, inputs, &outputs);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f);