Added SplitV support to model_builder.

SplitV implementation for Metal and OpenCL.

PiperOrigin-RevId: 357872655
Change-Id: Ia06c5290891d37e7a455bee031081398f2e068aa
This commit is contained in:
Raman Sarokin 2021-02-16 21:34:24 -08:00 committed by TensorFlower Gardener
parent f6957507f3
commit 08359250e2
10 changed files with 285 additions and 0 deletions

View File

@ -1819,6 +1819,40 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser {
}
};
class SplitVOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
const TfLiteSplitVParams* split_params;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
if (split_params->num_splits == 1) {
return absl::InvalidArgumentError(
"SplitV with num_splits = 1 is a no-op.");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
const TfLiteTensor* input = reader->GetInputTensor(0);
const TfLiteTensor* axis_tensor = reader->GetInputTensor(2);
SplitAttributes attr;
RETURN_IF_ERROR(
ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SPLIT);
node->operation.attributes = attr;
RETURN_IF_ERROR(reader->AddInput(node, 0));
for (int i = 0; i < tflite_node->outputs->size; ++i) {
RETURN_IF_ERROR(reader->AddOutput(node, i));
}
return absl::OkStatus();
}
};
class StridedSliceOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
@ -2382,6 +2416,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
return std::make_unique<SoftmaxOperationParser>();
case kTfLiteBuiltinSpaceToDepth:
return std::make_unique<SpaceToDepthOperationParser>();
case kTfLiteBuiltinSplitV:
return std::make_unique<SplitVOperationParser>();
case kTfLiteBuiltinSqrt:
return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
case kTfLiteBuiltinSquare:

View File

@ -176,6 +176,8 @@ std::string ToString(enum OperationType op) {
return "space_to_batch";
case OperationType::SPACE_TO_DEPTH:
return "space_to_depth";
case OperationType::SPLIT:
return "split";
case OperationType::SQRT:
return "sqrt";
case OperationType::SQUARE:
@ -246,6 +248,7 @@ OperationType OperationTypeFromString(const std::string& name) {
{"slice", OperationType::SLICE},
{"softmax", OperationType::SOFTMAX},
{"space_to_depth", OperationType::SPACE_TO_DEPTH},
{"split", OperationType::SPLIT},
{"sqrt", OperationType::SQRT},
{"square", OperationType::SQUARE},
{"squared_diff", OperationType::SQUARED_DIFF},

View File

@ -85,6 +85,7 @@ enum class OperationType {
SOFTMAX,
SPACE_TO_BATCH,
SPACE_TO_DEPTH,
SPLIT,
SQRT,
SQUARE,
SQUARED_DIFF,
@ -547,6 +548,11 @@ struct SpaceToDepthAttributes {
int block_size;
};
struct SplitAttributes {
// Defines axis by which to split.
Axis axis = Axis::UNKNOWN;
};
// These help perform a combination of Quantize & Dequantize to adjust float
// values like quantized inference would.
struct QuantizeAndDequantizeAttributes {

View File

@ -136,6 +136,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common/tasks:softmax",
"//tensorflow/lite/delegates/gpu/common/tasks:softmax1x1",
"//tensorflow/lite/delegates/gpu/common/tasks:space_to_depth",
"//tensorflow/lite/delegates/gpu/common/tasks:split",
"//tensorflow/lite/delegates/gpu/common/tasks:strided_slice",
"//tensorflow/lite/delegates/gpu/common/tasks:transpose",
"//tensorflow/lite/delegates/gpu/common/tasks:winograd",

View File

@ -491,6 +491,11 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
SelectSpaceToDepth(attr, op_def, gpu_op);
return absl::OkStatus();
}
case OperationType::SPLIT: {
auto attr = absl::any_cast<SplitAttributes>(node.operation.attributes);
RETURN_IF_ERROR(SelectSplit(attr, op_def, gpu_op));
return absl::OkStatus();
}
case OperationType::TRANSPOSE: {
auto attr =
absl::any_cast<TransposeAttributes>(node.operation.attributes);

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/tasks/softmax.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/space_to_depth.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/strided_slice.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/transpose.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/winograd.h"
@ -134,6 +135,17 @@ void SelectSpaceToDepth(const SpaceToDepthAttributes& attr,
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
}
absl::Status SelectSplit(const SplitAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
if (attr.axis != Axis::CHANNELS) {
return absl::UnimplementedError("No split for this axis.");
}
Split operation = CreateSplit(op_def, attr);
*ptr = absl::make_unique<Split>(std::move(operation));
return absl::OkStatus();
}
void SelectPadding(const PadAttributes& attr, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
GPUOperation operation = CreatePadding(op_def, attr);

View File

@ -82,6 +82,10 @@ void SelectSpaceToDepth(const SpaceToDepthAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
absl::Status SelectSplit(const SplitAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
void SelectTranspose(const TransposeAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);

View File

@ -861,6 +861,19 @@ cc_library(
],
)
cc_library(
name = "split",
srcs = ["split.cc"],
hdrs = ["split.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
"//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
],
)
cc_library(
name = "strided_slice",
srcs = ["strided_slice.cc"],

View File

@ -0,0 +1,156 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
#include <string>
#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
namespace tflite {
namespace gpu {
Split::Split(const OperationDef& definition, const SplitAttributes& attr)
: GPUOperation(definition), attr_(attr) {
work_group_size_ = int3(8, 4, 1);
code_ = attr.axis == Axis::CHANNELS ? GetSplitChannelsCode() : GetSplitCode();
}
std::string Split::GetSplitCode() {
AddSrcTensor("src_tensor", definition_.src_tensors[0]);
for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
}
const std::string task_width =
attr_.axis == Axis::WIDTH ? "1" : "args.src_tensor.Width()";
const std::string task_height =
attr_.axis == Axis::HEIGHT ? "1" : "args.src_tensor.Height()";
const std::string task_depth =
attr_.axis == Axis::DEPTH ? "1" : "args.src_tensor.Depth()";
const std::string task_batch =
attr_.axis == Axis::BATCH ? "1" : "args.src_tensor.Batch()";
const std::string task_slices =
attr_.axis == Axis::CHANNELS ? "1" : "args.src_tensor.Slices()";
std::string c;
c += "MAIN_FUNCTION($0) {\n";
c += " int task_width = "
";\n";
if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
c += " int linear_id = GLOBAL_ID_0;\n";
c += " int X = linear_id / " + task_batch + ";\n";
c += " int B = linear_id % " + task_batch + ";\n";
} else {
c += " int X = GLOBAL_ID_0;\n";
}
if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
c += " int linear_id = GLOBAL_ID_1;\n";
c += " int Y = linear_id % " + task_height + ";\n";
c += " int B = linear_id / " + task_height + ";\n";
} else {
c += " int Y = GLOBAL_ID_1;\n";
}
c += " int S = GLOBAL_ID_2;\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"S >= args.dst_tensor.Slices()) { \n";
c += " return; \n";
c += " } \n";
c += " int src_counter = 0;\n";
for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
c += " for (int i = 0; i < " + dst_name +
".Slices(); ++i, src_counter++) {\n";
c += " FLT4 result = args.src_tensor.Read(s_x, s_y, src_counter);\n";
c += " " + dst_name + ".Write(result, X, Y, i);\n";
c += " }\n";
}
c += "}\n";
return c;
}
std::string Split::GetSplitChannelsCode() {
AddSrcTensor("src_tensor", definition_.src_tensors[0]);
for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
}
const std::string batch_coord =
definition_.src_tensors[0].HasAxis(Axis::BATCH) ? ", B" : "";
std::string coords = "X, Y";
std::string c;
c += "MAIN_FUNCTION($0) {\n";
if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
c += " int linear_id = GLOBAL_ID_0;\n";
c += " int X = linear_id / args.src_tensor.Batch();\n";
c += " int B = linear_id % args.src_tensor.Batch();\n";
c += " if (X >= args.src_tensor.Width()) return;\n";
} else {
c += " int X = GLOBAL_ID_0;\n";
c += " if (X >= args.src_tensor.Width()) return;\n";
}
if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
c += " int linear_id = GLOBAL_ID_1;\n";
c += " int Y = linear_id % args.src_tensor.Height();\n";
c += " int Z = linear_id / args.src_tensor.Height();\n";
c += " if (Z >= args.src_tensor.Depth()) return;\n";
coords += ", Z";
} else {
c += " int Y = GLOBAL_ID_1;\n";
c += " if (Y >= args.src_tensor.Height()) return;\n";
}
c += " int src_channel = 0;\n";
const std::string postfixes[] = {"x", "y", "z", "w"};
for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
c += " for (int i = 0; i < " + dst_name + ".Slices(); ++i) {\n";
c += " FLT4 result = INIT_FLT4(0.0f);\n";
for (int j = 0; j < 4; ++j) {
c += " if (i * 4 + " + std::to_string(j) + " < " + dst_name +
".Channels()) {\n";
c += " int src_slice = src_channel >> 2;\n";
c += " int src_sub_ch = src_channel & 3;\n";
c += " FLT4 t = args.src_tensor.Read(" + coords + ", src_slice" +
batch_coord + ");\n";
c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
c += " result." + postfixes[j] + " = t_ar[src_sub_ch];\n";
c += " src_channel++;\n";
c += " }\n";
}
c += " " + dst_name + ".Write(result, " + coords + ", i" + batch_coord +
");\n";
c += " }\n";
}
c += "}\n";
return c;
}
int3 Split::GetGridSize() const {
const int width = attr_.axis == Axis::WIDTH ? 1 : src_[0]->Width();
const int height = attr_.axis == Axis::HEIGHT ? 1 : src_[0]->Height();
const int depth = attr_.axis == Axis::DEPTH ? 1 : src_[0]->Depth();
const int batch = attr_.axis == Axis::BATCH ? 1 : src_[0]->Batch();
const int slices = attr_.axis == Axis::CHANNELS ? 1 : src_[0]->Slices();
const int grid_x = width * batch;
const int grid_y = height * depth;
const int grid_z = slices;
return int3(grid_x, grid_y, grid_z);
}
Split CreateSplit(const OperationDef& definition, const SplitAttributes& attr) {
return Split(definition, attr);
}
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,49 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_STRIDED_SPLIT_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_STRIDED_SPLIT_H_
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
namespace tflite {
namespace gpu {
class Split : public GPUOperation {
public:
Split(const OperationDef& definition, const SplitAttributes& attr);
int3 GetGridSize() const override;
// Move only
Split(Split&& operation) = default;
Split& operator=(Split&& operation) = default;
Split(const Split&) = delete;
Split& operator=(const Split&) = delete;
private:
std::string GetSplitCode();
std::string GetSplitChannelsCode();
SplitAttributes attr_;
};
Split CreateSplit(const OperationDef& definition, const SplitAttributes& attr);
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_STRIDED_SPLIT_H_