Added SplitV support to model_builder.
SplitV implementation for Metal and OpenCL. PiperOrigin-RevId: 357872655 Change-Id: Ia06c5290891d37e7a455bee031081398f2e068aa
This commit is contained in:
parent
f6957507f3
commit
08359250e2
@ -1819,6 +1819,40 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser {
|
||||
}
|
||||
};
|
||||
|
||||
class SplitVOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
const TfLiteSplitVParams* split_params;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
|
||||
if (split_params->num_splits == 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"SplitV with num_splits = 1 is a no-op.");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader) final {
|
||||
const TfLiteTensor* input = reader->GetInputTensor(0);
|
||||
const TfLiteTensor* axis_tensor = reader->GetInputTensor(2);
|
||||
SplitAttributes attr;
|
||||
RETURN_IF_ERROR(
|
||||
ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
|
||||
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::SPLIT);
|
||||
node->operation.attributes = attr;
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
for (int i = 0; i < tflite_node->outputs->size; ++i) {
|
||||
RETURN_IF_ERROR(reader->AddOutput(node, i));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
@ -2382,6 +2416,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
return std::make_unique<SoftmaxOperationParser>();
|
||||
case kTfLiteBuiltinSpaceToDepth:
|
||||
return std::make_unique<SpaceToDepthOperationParser>();
|
||||
case kTfLiteBuiltinSplitV:
|
||||
return std::make_unique<SplitVOperationParser>();
|
||||
case kTfLiteBuiltinSqrt:
|
||||
return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
|
||||
case kTfLiteBuiltinSquare:
|
||||
|
@ -176,6 +176,8 @@ std::string ToString(enum OperationType op) {
|
||||
return "space_to_batch";
|
||||
case OperationType::SPACE_TO_DEPTH:
|
||||
return "space_to_depth";
|
||||
case OperationType::SPLIT:
|
||||
return "split";
|
||||
case OperationType::SQRT:
|
||||
return "sqrt";
|
||||
case OperationType::SQUARE:
|
||||
@ -246,6 +248,7 @@ OperationType OperationTypeFromString(const std::string& name) {
|
||||
{"slice", OperationType::SLICE},
|
||||
{"softmax", OperationType::SOFTMAX},
|
||||
{"space_to_depth", OperationType::SPACE_TO_DEPTH},
|
||||
{"split", OperationType::SPLIT},
|
||||
{"sqrt", OperationType::SQRT},
|
||||
{"square", OperationType::SQUARE},
|
||||
{"squared_diff", OperationType::SQUARED_DIFF},
|
||||
|
@ -85,6 +85,7 @@ enum class OperationType {
|
||||
SOFTMAX,
|
||||
SPACE_TO_BATCH,
|
||||
SPACE_TO_DEPTH,
|
||||
SPLIT,
|
||||
SQRT,
|
||||
SQUARE,
|
||||
SQUARED_DIFF,
|
||||
@ -547,6 +548,11 @@ struct SpaceToDepthAttributes {
|
||||
int block_size;
|
||||
};
|
||||
|
||||
struct SplitAttributes {
|
||||
// Defines axis by which to split.
|
||||
Axis axis = Axis::UNKNOWN;
|
||||
};
|
||||
|
||||
// These help perform a combination of Quantize & Dequantize to adjust float
|
||||
// values like quantized inference would.
|
||||
struct QuantizeAndDequantizeAttributes {
|
||||
|
@ -136,6 +136,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:softmax",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:softmax1x1",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:space_to_depth",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:split",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:strided_slice",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:transpose",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:winograd",
|
||||
|
@ -491,6 +491,11 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
|
||||
SelectSpaceToDepth(attr, op_def, gpu_op);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::SPLIT: {
|
||||
auto attr = absl::any_cast<SplitAttributes>(node.operation.attributes);
|
||||
RETURN_IF_ERROR(SelectSplit(attr, op_def, gpu_op));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::TRANSPOSE: {
|
||||
auto attr =
|
||||
absl::any_cast<TransposeAttributes>(node.operation.attributes);
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/softmax.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/space_to_depth.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/strided_slice.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/transpose.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/winograd.h"
|
||||
@ -134,6 +135,17 @@ void SelectSpaceToDepth(const SpaceToDepthAttributes& attr,
|
||||
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
|
||||
}
|
||||
|
||||
absl::Status SelectSplit(const SplitAttributes& attr,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
if (attr.axis != Axis::CHANNELS) {
|
||||
return absl::UnimplementedError("No split for this axis.");
|
||||
}
|
||||
Split operation = CreateSplit(op_def, attr);
|
||||
*ptr = absl::make_unique<Split>(std::move(operation));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void SelectPadding(const PadAttributes& attr, const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
GPUOperation operation = CreatePadding(op_def, attr);
|
||||
|
@ -82,6 +82,10 @@ void SelectSpaceToDepth(const SpaceToDepthAttributes& attr,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
|
||||
absl::Status SelectSplit(const SplitAttributes& attr,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
|
||||
void SelectTranspose(const TransposeAttributes& attr,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
|
@ -861,6 +861,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "split",
|
||||
srcs = ["split.cc"],
|
||||
hdrs = ["split.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "strided_slice",
|
||||
srcs = ["strided_slice.cc"],
|
||||
|
156
tensorflow/lite/delegates/gpu/common/tasks/split.cc
Normal file
156
tensorflow/lite/delegates/gpu/common/tasks/split.cc
Normal file
@ -0,0 +1,156 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
Split::Split(const OperationDef& definition, const SplitAttributes& attr)
|
||||
: GPUOperation(definition), attr_(attr) {
|
||||
work_group_size_ = int3(8, 4, 1);
|
||||
code_ = attr.axis == Axis::CHANNELS ? GetSplitChannelsCode() : GetSplitCode();
|
||||
}
|
||||
|
||||
std::string Split::GetSplitCode() {
|
||||
AddSrcTensor("src_tensor", definition_.src_tensors[0]);
|
||||
for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
|
||||
AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
|
||||
}
|
||||
const std::string task_width =
|
||||
attr_.axis == Axis::WIDTH ? "1" : "args.src_tensor.Width()";
|
||||
const std::string task_height =
|
||||
attr_.axis == Axis::HEIGHT ? "1" : "args.src_tensor.Height()";
|
||||
const std::string task_depth =
|
||||
attr_.axis == Axis::DEPTH ? "1" : "args.src_tensor.Depth()";
|
||||
const std::string task_batch =
|
||||
attr_.axis == Axis::BATCH ? "1" : "args.src_tensor.Batch()";
|
||||
const std::string task_slices =
|
||||
attr_.axis == Axis::CHANNELS ? "1" : "args.src_tensor.Slices()";
|
||||
|
||||
std::string c;
|
||||
c += "MAIN_FUNCTION($0) {\n";
|
||||
c += " int task_width = "
|
||||
";\n";
|
||||
if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " int linear_id = GLOBAL_ID_0;\n";
|
||||
c += " int X = linear_id / " + task_batch + ";\n";
|
||||
c += " int B = linear_id % " + task_batch + ";\n";
|
||||
} else {
|
||||
c += " int X = GLOBAL_ID_0;\n";
|
||||
}
|
||||
if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
c += " int linear_id = GLOBAL_ID_1;\n";
|
||||
c += " int Y = linear_id % " + task_height + ";\n";
|
||||
c += " int B = linear_id / " + task_height + ";\n";
|
||||
} else {
|
||||
c += " int Y = GLOBAL_ID_1;\n";
|
||||
}
|
||||
c += " int S = GLOBAL_ID_2;\n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"S >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
c += " int src_counter = 0;\n";
|
||||
for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
|
||||
const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
|
||||
c += " for (int i = 0; i < " + dst_name +
|
||||
".Slices(); ++i, src_counter++) {\n";
|
||||
c += " FLT4 result = args.src_tensor.Read(s_x, s_y, src_counter);\n";
|
||||
c += " " + dst_name + ".Write(result, X, Y, i);\n";
|
||||
c += " }\n";
|
||||
}
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
|
||||
std::string Split::GetSplitChannelsCode() {
|
||||
AddSrcTensor("src_tensor", definition_.src_tensors[0]);
|
||||
for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
|
||||
AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
|
||||
}
|
||||
|
||||
const std::string batch_coord =
|
||||
definition_.src_tensors[0].HasAxis(Axis::BATCH) ? ", B" : "";
|
||||
std::string coords = "X, Y";
|
||||
std::string c;
|
||||
c += "MAIN_FUNCTION($0) {\n";
|
||||
if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " int linear_id = GLOBAL_ID_0;\n";
|
||||
c += " int X = linear_id / args.src_tensor.Batch();\n";
|
||||
c += " int B = linear_id % args.src_tensor.Batch();\n";
|
||||
c += " if (X >= args.src_tensor.Width()) return;\n";
|
||||
} else {
|
||||
c += " int X = GLOBAL_ID_0;\n";
|
||||
c += " if (X >= args.src_tensor.Width()) return;\n";
|
||||
}
|
||||
if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
c += " int linear_id = GLOBAL_ID_1;\n";
|
||||
c += " int Y = linear_id % args.src_tensor.Height();\n";
|
||||
c += " int Z = linear_id / args.src_tensor.Height();\n";
|
||||
c += " if (Z >= args.src_tensor.Depth()) return;\n";
|
||||
coords += ", Z";
|
||||
} else {
|
||||
c += " int Y = GLOBAL_ID_1;\n";
|
||||
c += " if (Y >= args.src_tensor.Height()) return;\n";
|
||||
}
|
||||
c += " int src_channel = 0;\n";
|
||||
const std::string postfixes[] = {"x", "y", "z", "w"};
|
||||
for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
|
||||
const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
|
||||
c += " for (int i = 0; i < " + dst_name + ".Slices(); ++i) {\n";
|
||||
c += " FLT4 result = INIT_FLT4(0.0f);\n";
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
c += " if (i * 4 + " + std::to_string(j) + " < " + dst_name +
|
||||
".Channels()) {\n";
|
||||
c += " int src_slice = src_channel >> 2;\n";
|
||||
c += " int src_sub_ch = src_channel & 3;\n";
|
||||
c += " FLT4 t = args.src_tensor.Read(" + coords + ", src_slice" +
|
||||
batch_coord + ");\n";
|
||||
c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
|
||||
c += " result." + postfixes[j] + " = t_ar[src_sub_ch];\n";
|
||||
c += " src_channel++;\n";
|
||||
c += " }\n";
|
||||
}
|
||||
c += " " + dst_name + ".Write(result, " + coords + ", i" + batch_coord +
|
||||
");\n";
|
||||
c += " }\n";
|
||||
}
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
|
||||
int3 Split::GetGridSize() const {
|
||||
const int width = attr_.axis == Axis::WIDTH ? 1 : src_[0]->Width();
|
||||
const int height = attr_.axis == Axis::HEIGHT ? 1 : src_[0]->Height();
|
||||
const int depth = attr_.axis == Axis::DEPTH ? 1 : src_[0]->Depth();
|
||||
const int batch = attr_.axis == Axis::BATCH ? 1 : src_[0]->Batch();
|
||||
const int slices = attr_.axis == Axis::CHANNELS ? 1 : src_[0]->Slices();
|
||||
const int grid_x = width * batch;
|
||||
const int grid_y = height * depth;
|
||||
const int grid_z = slices;
|
||||
return int3(grid_x, grid_y, grid_z);
|
||||
}
|
||||
|
||||
Split CreateSplit(const OperationDef& definition, const SplitAttributes& attr) {
|
||||
return Split(definition, attr);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
49
tensorflow/lite/delegates/gpu/common/tasks/split.h
Normal file
49
tensorflow/lite/delegates/gpu/common/tasks/split.h
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_STRIDED_SPLIT_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_STRIDED_SPLIT_H_
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
class Split : public GPUOperation {
|
||||
public:
|
||||
Split(const OperationDef& definition, const SplitAttributes& attr);
|
||||
int3 GetGridSize() const override;
|
||||
|
||||
// Move only
|
||||
Split(Split&& operation) = default;
|
||||
Split& operator=(Split&& operation) = default;
|
||||
Split(const Split&) = delete;
|
||||
Split& operator=(const Split&) = delete;
|
||||
|
||||
private:
|
||||
std::string GetSplitCode();
|
||||
std::string GetSplitChannelsCode();
|
||||
|
||||
SplitAttributes attr_;
|
||||
};
|
||||
|
||||
Split CreateSplit(const OperationDef& definition, const SplitAttributes& attr);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_STRIDED_SPLIT_H_
|
Loading…
x
Reference in New Issue
Block a user