Raman Sarokin 9c9e961174 Added support of HWC constant tensor in ElementwiseWithOneInputAndConstantArguent.
Remove Mul (replaced by elementwise).
Changed Add (some cases handled in elementwise).

PiperOrigin-RevId: 315000693
Change-Id: Icd981170d11f17df418361a554dc08822f3ee273
2020-06-05 14:56:38 -07:00

463 lines
19 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/api.h"
#include <vector>
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/concat.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/mean.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/prelu.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/relu.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/reshape.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/resize.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/slice.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::vector<ComputeTaskDescriptorPtr> SelectDepthWiseConv(
int id, ValueId input_id, ValueId output_id,
const DepthwiseConvolution2DAttributes& attr,
const metal::RuntimeOptions& options) {
if (CheckDepthWiseConv3x3Stride1x1Support(attr)) {
return DepthWiseConv3x3Stride1x1(id, input_id, output_id, attr, options);
} else if (CheckDepthWiseConv3x3Stride2Support(attr)) {
return DepthWiseConv3x3Stride2(id, input_id, output_id, attr, options);
} else {
return DepthWiseConvolution(id, input_id, output_id, attr, options);
}
}
std::vector<ComputeTaskDescriptorPtr> SelectConvolutionTransposed(
int id, ValueId input_id, ValueId output_id,
const ConvolutionTransposedAttributes& attr, const DeviceInfo& device_info,
const metal::RuntimeOptions& options) {
if (CheckConvolutionTransposed4x4Support(attr)) {
return ConvolutionTransposed4x4(id, input_id, output_id, attr, device_info,
options);
} else {
return ConvolutionTransposed(id, input_id, output_id, attr, device_info,
options);
}
}
std::vector<ComputeTaskDescriptorPtr> SelectQuantizeAndDequantize(
int id, ValueId input_id, ValueId output_id,
const QuantizeAndDequantizeAttributes& attr) {
return QuantizeAndDequantize(id, input_id, output_id, attr);
}
std::vector<ComputeTaskDescriptorPtr> SelectPReLU(
const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id,
const PReLUAttributes& attr, const metal::RuntimeOptions& options) {
auto alpha = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
if (alpha) {
return PReLU(id, input_id, output_id, attr, options);
}
auto alpha3d = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
if (!alpha3d) {
return {};
}
const auto shape = graph.FindInputs(id)[0]->tensor.shape;
if (alpha3d->shape.h != shape.h || alpha3d->shape.w != shape.w ||
alpha3d->shape.c != shape.c) {
return {};
}
return PReLUFull(id, input_id, output_id, attr, options);
}
std::vector<ComputeTaskDescriptorPtr> SelectReshape(
const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id,
const ReshapeAttributes& attr) {
const auto src_shape = graph.FindInputs(id)[0]->tensor.shape;
if (src_shape.c % 4 == 0 && attr.new_shape.c % 4 == 0) {
return Reshapex4(id, input_id, output_id, attr);
} else {
return Reshape(id, input_id, output_id, attr);
}
}
std::vector<ComputeTaskDescriptorPtr> SelectSoftmax(
const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id,
const DeviceInfo& device_info) {
const auto src_shape = graph.FindInputs(id)[0]->tensor.shape;
if (src_shape.w == 1 && src_shape.h == 1) {
return Softmax1x1(id, input_id, output_id, device_info, src_shape.c);
} else {
return Softmax(id, input_id, output_id, src_shape.c);
}
}
std::vector<ComputeTaskDescriptorPtr> SelectSpaceToDepth(
const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id,
const SpaceToDepthAttributes& attr) {
return SpaceToDepth(id, input_id, output_id, attr);
}
std::vector<ComputeTaskDescriptorPtr> SelectWinograd4x4To36(
int id, ValueId input_id, ValueId output_id,
const Winograd4x4To36Attributes& attr, const DeviceInfo& device_info,
const metal::RuntimeOptions& options) {
if (device_info.IsAppleGPU()) {
return Winograd4x4To36(id, input_id, output_id, attr);
} else {
return Winograd4x4To36TileX6(id, input_id, output_id, attr, options);
}
}
std::vector<ComputeTaskDescriptorPtr> SelectWinograd36To4x4(
int id, ValueId input_id, ValueId output_id,
const Winograd36To4x4Attributes& attr, const DeviceInfo& device_info,
const metal::RuntimeOptions& options) {
if (device_info.IsAppleGPU()) {
return Winograd36To4x4(id, input_id, output_id, options, attr);
} else {
return Winograd36To4x4Tile4x1(id, input_id, output_id, options, attr);
}
}
bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
const BHWC& dst_shape) {
const int tiles_x = DivideRoundUp(dst_shape.w, 4);
const int tiles_y = DivideRoundUp(dst_shape.h, 4);
const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
const bool suitable_attributes =
attr.weights.shape.w == 3 && attr.weights.shape.h == 3 &&
attr.dilations == HW(1, 1) && attr.strides == HW(1, 1);
const int min_depth = 16;
const int min_hw = 32;
const bool recommended_channels =
src_depth >= min_depth && dst_depth >= min_depth;
const bool recommended_hw = tiles_x * tiles_y >= min_hw;
return suitable_attributes && recommended_channels && recommended_hw;
}
absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
const std::vector<ValueId>& inputs,
const std::vector<ValueId>& outputs,
const DeviceInfo& device_info,
const RuntimeOptions& options,
int* last_node_id, int* last_value_id,
std::vector<ComputeTaskDescriptorPtr>* tasks) {
if (!IsBatchMatchesForAllValues(graph)) {
return absl::InvalidArgumentError(
"Only identical batch dimension is supported");
}
int node_id = static_cast<int>(node->id);
auto op_type = OperationTypeFromString(node->operation.type);
switch (op_type) {
case OperationType::ADD: {
if (inputs.size() == 1) {
if (node->operation.attributes.has_value()) {
auto attr = absl::any_cast<AddAttributes>(node->operation.attributes);
*tasks = ElementwiseWithOneInputAndConstantArguent(
node_id, inputs[0], outputs[0], options, op_type, attr.param);
} else {
return absl::UnimplementedError(
"Missing attributes for single input op: " +
node->operation.type);
}
} else if (inputs.size() == 2) {
const auto srcs = graph.FindInputs(node_id);
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
srcs[1]->tensor.shape, op_type);
} else { // more than 2 inputs
*tasks = Add(node_id, inputs, outputs[0], options);
}
break;
}
case OperationType::CONCAT: {
std::vector<BHWC> input_shapes;
for (auto& input : graph.FindInputs(node->id)) {
input_shapes.push_back(input->tensor.shape);
}
*tasks =
Concat(node_id, inputs, outputs[0],
absl::any_cast<ConcatAttributes>(node->operation.attributes),
input_shapes);
break;
}
case OperationType::CONVOLUTION_2D: {
if (graph.FindInputs(node->id).size() != 1) {
return absl::UnimplementedError(
"Convolution does not support more than 1 runtime tensor");
}
const auto dst_shape = graph.FindOutputs(node_id)[0]->tensor.shape;
auto attr =
absl::any_cast<Convolution2DAttributes>(node->operation.attributes);
if (IsSuitableForWinograd4x4To6x6(attr, dst_shape)) {
int tiles_x = DivideRoundUp(dst_shape.w, 4);
int tiles_y = DivideRoundUp(dst_shape.h, 4);
Winograd4x4To36Attributes wino_up_attr;
wino_up_attr.padding = attr.padding;
(*last_node_id) += 1;
int value_id = *last_value_id + 1;
*tasks = SelectWinograd4x4To36(*last_node_id, inputs[0], value_id,
wino_up_attr, device_info, options);
BHWC conv_shape{dst_shape.b, 36, tiles_x * tiles_y, dst_shape.c};
(*last_node_id) += 1;
auto t1 =
ConvolutionWino4x4To6x6(*last_node_id, value_id, value_id + 1,
conv_shape, attr, device_info, options);
tasks->insert(tasks->end(), t1.begin(), t1.end());
Winograd36To4x4Attributes wino_down_attr;
wino_down_attr.output_shape = dst_shape;
wino_down_attr.biases = attr.bias;
(*last_node_id) += 1;
auto t2 = SelectWinograd36To4x4(*last_node_id, value_id + 1, outputs[0],
wino_down_attr, device_info, options);
tasks->insert(tasks->end(), t2.begin(), t2.end());
(*last_value_id) += 2;
} else {
*tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape,
attr, device_info, options);
}
break;
}
case OperationType::CONVOLUTION_TRANSPOSED:
*tasks = SelectConvolutionTransposed(
node_id, inputs[0], outputs[0],
absl::any_cast<ConvolutionTransposedAttributes>(
node->operation.attributes),
device_info, options);
break;
case OperationType::DEPTHWISE_CONVOLUTION:
*tasks =
SelectDepthWiseConv(node_id, inputs[0], outputs[0],
absl::any_cast<DepthwiseConvolution2DAttributes>(
node->operation.attributes),
options);
break;
case OperationType::FULLY_CONNECTED:
*tasks = FullyConnected(
node_id, inputs[0], outputs[0],
absl::any_cast<FullyConnectedAttributes>(node->operation.attributes),
device_info, options);
break;
case OperationType::MAX_UNPOOLING_2D:
*tasks = MaxUnpooling(
node_id, inputs[0], inputs[1], outputs[0],
absl::any_cast<MaxUnpooling2DAttributes>(node->operation.attributes));
break;
case OperationType::MEAN:
*tasks = Mean(node_id, inputs[0], outputs[0],
absl::any_cast<MeanAttributes>(node->operation.attributes));
break;
case OperationType::MUL:
if (inputs.size() == 1) {
if (node->operation.attributes.has_value()) {
auto attr =
absl::any_cast<MultiplyAttributes>(node->operation.attributes);
*tasks = ElementwiseWithOneInputAndConstantArguent(
node_id, inputs[0], outputs[0], options, op_type, attr.param);
} else {
return absl::UnimplementedError(
"Missing attributes for single input op: " +
node->operation.type);
}
} else if (inputs.size() == 2) {
const auto srcs = graph.FindInputs(node_id);
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
srcs[1]->tensor.shape, op_type);
}
break;
case OperationType::PAD: {
auto attr = absl::any_cast<PadAttributes>(node->operation.attributes);
if (attr.appended.b != 0 || attr.prepended.b != 0) {
return absl::UnimplementedError("Padding for BATCH is not supported.");
}
*tasks = Padding(node_id, inputs[0], outputs[0], attr);
break;
}
case OperationType::POOLING_2D:
*tasks = Pooling(
node_id, inputs[0], outputs,
absl::any_cast<Pooling2DAttributes>(node->operation.attributes));
break;
case OperationType::PRELU:
*tasks = SelectPReLU(
graph, node_id, inputs[0], outputs[0],
absl::any_cast<PReLUAttributes>(node->operation.attributes), options);
break;
case OperationType::RELU:
*tasks = ReLU(node_id, inputs[0], outputs[0],
absl::any_cast<ReLUAttributes>(node->operation.attributes));
break;
case OperationType::QUANTIZE_AND_DEQUANTIZE:
*tasks = SelectQuantizeAndDequantize(
node_id, inputs[0], outputs[0],
absl::any_cast<QuantizeAndDequantizeAttributes>(
node->operation.attributes));
break;
case OperationType::RESHAPE:
*tasks = SelectReshape(
graph, node_id, inputs[0], outputs[0],
absl::any_cast<ReshapeAttributes>(node->operation.attributes));
break;
case OperationType::RESIZE:
*tasks = Resize(
node_id, inputs[0], outputs[0],
absl::any_cast<Resize2DAttributes>(node->operation.attributes));
break;
case OperationType::SLICE:
*tasks =
Slice(node_id, inputs[0], outputs[0],
absl::any_cast<SliceAttributes>(node->operation.attributes));
break;
case OperationType::SOFTMAX: {
auto attr = absl::any_cast<SoftmaxAttributes>(node->operation.attributes);
if (attr.axis != Axis::CHANNELS) {
return absl::UnimplementedError(
"Softmax supports only CHANNELS dimension");
}
*tasks =
SelectSoftmax(graph, node_id, inputs[0], outputs[0], device_info);
break;
}
case OperationType::SPACE_TO_DEPTH:
*tasks = SelectSpaceToDepth(
graph, node_id, inputs[0], outputs[0],
absl::any_cast<SpaceToDepthAttributes>(node->operation.attributes));
break;
case OperationType::ABS:
case OperationType::COS:
case OperationType::EXP:
case OperationType::HARD_SWISH:
case OperationType::LOG:
case OperationType::RSQRT:
case OperationType::SIGMOID:
case OperationType::SIN:
case OperationType::SQRT:
case OperationType::SQUARE:
case OperationType::TANH:
*tasks = ElementwiseWithOneInput(node_id, inputs[0], outputs[0], op_type);
break;
case OperationType::DIV:
case OperationType::MAXIMUM:
case OperationType::MINIMUM:
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB: {
if (inputs.size() == 1) {
if (node->operation.attributes.has_value()) {
auto attr =
absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
*tasks = ElementwiseWithOneInputAndConstantArguent(
node_id, inputs[0], outputs[0], options, op_type, attr.param);
} else {
return absl::UnimplementedError(
"Missing attributes for single input op: " +
node->operation.type);
}
} else if (inputs.size() == 2) {
const auto srcs = graph.FindInputs(node_id);
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0],
srcs[1]->tensor.shape, op_type);
}
} break;
case OperationType::BATCH_NORMALIZATION:
case OperationType::BATCH_TO_SPACE:
case OperationType::CONST:
case OperationType::LSTM:
case OperationType::SPACE_TO_BATCH:
case OperationType::TRANSPOSE:
case OperationType::UNKNOWN:
return absl::UnimplementedError("Unsupported op: " +
node->operation.type);
}
return absl::OkStatus();
}
} // namespace
absl::Status Compile(const GraphFloat32& graph, const DeviceInfo& device_info,
const RuntimeOptions& options,
CompiledModel* compiled_model) {
int last_node_id = 0;
for (const auto& node : graph.nodes()) {
last_node_id = std::max(last_node_id, static_cast<int>(node->id));
}
int last_value_id = 0;
for (const auto& value : graph.values()) {
last_value_id = std::max(last_value_id, static_cast<int>(value->id));
}
for (const auto& node : graph.nodes()) {
std::vector<ValueId> inputs;
for (auto& input : graph.FindInputs(node->id)) {
inputs.push_back(static_cast<ValueId>(input->id));
}
std::vector<ValueId> outputs;
for (auto& output : graph.FindOutputs(node->id)) {
outputs.push_back(static_cast<ValueId>(output->id));
}
std::vector<ComputeTaskDescriptorPtr> tasks;
auto custom_status =
RegisterCustomOps(graph, node, inputs, outputs, options, &tasks);
if (!custom_status.ok()) {
auto primary_status =
RegisterPrimaryOps(graph, node, inputs, outputs, device_info, options,
&last_node_id, &last_value_id, &tasks);
if (!primary_status.ok()) {
return absl::UnimplementedError(
absl::Substitute("Unsupported op type: $0; custom registry error: "
"$1; primary registry error: $2;",
node->operation.type, custom_status.message(),
primary_status.message()));
}
}
for (const auto& task : tasks) {
task->description = node->operation.type + "_" + std::to_string(node->id);
}
compiled_model->insert(compiled_model->end(), tasks.begin(), tasks.end());
}
return absl::OkStatus();
}
} // namespace metal
} // namespace gpu
} // namespace tflite