TFLite GPU: Implement SPACE_TO_DEPTH.

PiperOrigin-RevId: 296321368
Change-Id: I3b5844fde83ef48002a4c326eeb745587068c208
This commit is contained in:
Juhyun Lee 2020-02-20 16:14:26 -08:00 committed by TensorFlower Gardener
parent 0685f70521
commit 0213d7a4d6
19 changed files with 1101 additions and 52 deletions

View File

@ -1193,6 +1193,39 @@ cc_test(
],
)
cc_library(
name = "space_to_depth",
srcs = ["space_to_depth.cc"],
hdrs = ["space_to_depth.h"],
deps = [
":gpu_operation",
":util",
":work_group_picking",
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
],
)
cc_test(
name = "space_to_depth_test",
srcs = ["space_to_depth_test.cc"],
linkstatic = True,
tags = tf_gpu_tests_tags() + [
"linux",
"local",
],
deps = [
":cl_test",
":space_to_depth",
"//tensorflow/lite/delegates/gpu/cl:tensor",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "strided_slice",
srcs = ["strided_slice.cc"],

View File

@ -0,0 +1,141 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h"
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
namespace tflite {
namespace gpu {
namespace cl {
namespace {
std::string GetSpaceToDepthCode(
const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor(
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
op_def.dst_tensors[0]);
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size,\n";
c += " int4 dst_size,\n";
c += " int src_channels,\n";
c += " int block_size) {\n";
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
c += " FLT tmp[4];\n";
c += " tmp[0] = (FLT)(0.0f);\n";
c += " tmp[1] = (FLT)(0.0f);\n";
c += " tmp[2] = (FLT)(0.0f);\n";
c += " tmp[3] = (FLT)(0.0f);\n";
c += " for (int i = 0; i < 4; ++i) {\n";
c += " int dst_c = 4 * Z + i;\n";
c += " int block_id = dst_c / src_channels;\n";
c += " int src_x = X * block_size + block_id % block_size;\n";
c += " int src_y = Y * block_size + block_id / block_size;\n";
c += " int src_c = dst_c % src_channels;\n";
c += " int src_z = src_c / 4;\n";
c += " FLT4 t = " + src_tensor.ReadWHS("src_x", "src_y", "src_z") + ";\n";
c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
c += " tmp[i] = t_ar[src_c % 4];\n";
c += " }\n";
c += " FLT4 result = (FLT4)(tmp[0], tmp[1], tmp[2], tmp[3]);\n";
const LinkingContext context = {
.var_name = "result",
.x_coord = "X",
.y_coord = "Y",
.s_coord = "Z",
};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHS("result", "X", "Y", "Z");
c += "}\n";
return c;
}
} // namespace
SpaceToDepth::SpaceToDepth(SpaceToDepth&& operation)
: GPUOperation(std::move(operation)),
attr_(operation.attr_),
kernel_(std::move(operation.kernel_)),
work_group_size_(operation.work_group_size_) {}
SpaceToDepth& SpaceToDepth::operator=(SpaceToDepth&& operation) {
if (this != &operation) {
attr_ = operation.attr_;
kernel_ = std::move(operation.kernel_);
std::swap(work_group_size_, operation.work_group_size_);
GPUOperation::operator=(std::move(operation));
}
return *this;
}
Status SpaceToDepth::Compile(const CreationContext& creation_context) {
const auto code = GetSpaceToDepthCode(definition_, linked_operations_);
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
Status SpaceToDepth::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels()));
return kernel_.SetBytesAuto(attr_.block_size);
}
int3 SpaceToDepth::GetGridSize() const {
const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
const int grid_y = dst_[0]->Height();
const int grid_z = dst_[0]->Slices();
return int3(grid_x, grid_y, grid_z);
}
Status SpaceToDepth::Tune(const TuningParameters& params) {
RETURN_IF_ERROR(BindArguments());
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
}
Status SpaceToDepth::AddToQueue(CLCommandQueue* queue) {
RETURN_IF_ERROR(BindArguments());
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
}
SpaceToDepth CreateSpaceToDepth(const OperationDef& op_def,
const SpaceToDepthAttributes& attr) {
return SpaceToDepth(op_def, attr);
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPACE_TO_DEPTH_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPACE_TO_DEPTH_H_
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
namespace tflite {
namespace gpu {
namespace cl {
class SpaceToDepth : public GPUOperation {
public:
SpaceToDepth(const OperationDef& op_def, const SpaceToDepthAttributes& attr)
: GPUOperation(op_def), attr_(attr), work_group_size_(8, 4, 1) {}
Status AddToQueue(CLCommandQueue* queue) override;
Status Tune(const TuningParameters& params) override;
Status Compile(const CreationContext& creation_context) override;
SpaceToDepth(SpaceToDepth&& operation);
SpaceToDepth& operator=(SpaceToDepth&& operation);
SpaceToDepth(const SpaceToDepth&) = delete;
SpaceToDepth& operator=(const SpaceToDepth&) = delete;
private:
Status BindArguments();
int3 GetGridSize() const;
SpaceToDepthAttributes attr_;
CLKernel kernel_;
int3 work_group_size_;
};
SpaceToDepth CreateSpaceToDepth(const OperationDef& op_def,
const SpaceToDepthAttributes& attr);
} // namespace cl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPACE_TO_DEPTH_H_

View File

@ -0,0 +1,144 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
using ::testing::FloatNear;
using ::testing::Pointwise;
namespace tflite {
namespace gpu {
namespace cl {
namespace {
/*
// A known Qualcomm Adreno bug makes the 1 channel test fail on old devices.
TEST_F(OpenCLOperationTest, SpaceToDepthTensorShape1x2x2x1BlockSize2) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 1);
src_tensor.data = {half(1.0f), half(2.0f), half(3.0f), half(4.0f)};
const SpaceToDepthAttributes attr = {.block_size = 2};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
SpaceToDepth operation = CreateSpaceToDepth(op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 1, 1, 4), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(1e-6),
{half(1.0f), half(2.0f), half(3.0f), half(4.0f)}));
}
}
}
*/
TEST_F(OpenCLOperationTest, SpaceToDepthTensorShape1x2x2x2BlockSize2) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {half(1.4f), half(2.3f), half(3.2f), half(4.1f),
half(5.4f), half(6.3f), half(7.2f), half(8.1f)};
const SpaceToDepthAttributes attr = {.block_size = 2};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
SpaceToDepth operation = CreateSpaceToDepth(op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 1, 1, 8), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(1e-6),
{half(1.4f), half(2.3f), half(3.2f), half(4.1f),
half(5.4f), half(6.3f), half(7.2f), half(8.1f)}));
}
}
}
TEST_F(OpenCLOperationTest, SpaceToDepthTensorShape1x2x2x3BlockSize2) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 3);
src_tensor.data = {half(1.0f), half(2.0f), half(3.0f), half(4.0f),
half(5.0f), half(6.0f), half(7.0f), half(8.0f),
half(9.0f), half(10.0f), half(11.0f), half(12.0f)};
const SpaceToDepthAttributes attr = {.block_size = 2};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
SpaceToDepth operation = CreateSpaceToDepth(op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 1, 1, 12), &dst_tensor));
EXPECT_THAT(
dst_tensor.data,
Pointwise(FloatNear(1e-6), {half(1.0f), half(2.0f), half(3.0f), //
half(4.0f), half(5.0f), half(6.0f), //
half(7.0f), half(8.0f), half(9.0f), //
half(10.0f), half(11.0f), half(12.0f)}));
}
}
}
TEST_F(OpenCLOperationTest, SpaceToDepthTensorShape1x4x4x1BlockSize2) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 4, 4, 1);
src_tensor.data = {half(1.0f), half(2.0f), half(5.0f), half(6.0f),
half(3.0f), half(4.0f), half(7.0f), half(8.0f),
half(9.0f), half(10.0f), half(13.0f), half(14.0f),
half(11.0f), half(12.0f), half(15.0f), half(16.0f)};
const SpaceToDepthAttributes attr = {.block_size = 2};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
SpaceToDepth operation = CreateSpaceToDepth(op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 2, 2, 4), &dst_tensor));
EXPECT_THAT(
dst_tensor.data,
Pointwise(FloatNear(1e-6),
{half(1.0f), half(2.0f), half(3.0f), half(4.0f), //
half(5.0f), half(6.0f), half(7.0f), half(8.0f), //
half(9.0f), half(10.0f), half(11.0f), half(12.0f), //
half(13.0f), half(14.0f), half(15.0f), half(16.0f)}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -196,6 +196,10 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
SelectReshape(src_channels, attr.new_shape.c, op_def, gpu_op);
return OkStatus();
}
case OperationType::RESIZE: {
auto attr = absl::any_cast<Resize2DAttributes>(node.operation.attributes);
return SelectResize(attr, op_def, gpu_op);
}
case OperationType::SLICE: {
auto attr = absl::any_cast<SliceAttributes>(node.operation.attributes);
SelectStridedSlice(attr, op_def, gpu_op);
@ -205,16 +209,18 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
SelectSoftmax(inputs[0]->tensor.shape, op_def, gpu_op);
return OkStatus();
}
case OperationType::SPACE_TO_DEPTH: {
auto attr =
absl::any_cast<SpaceToDepthAttributes>(node.operation.attributes);
SelectSpaceToDepth(attr, op_def, gpu_op);
return OkStatus();
}
case OperationType::TRANSPOSE: {
auto attr =
absl::any_cast<TransposeAttributes>(node.operation.attributes);
SelectTranspose(attr, op_def, gpu_op);
return OkStatus();
}
case OperationType::RESIZE: {
auto attr = absl::any_cast<Resize2DAttributes>(node.operation.attributes);
return SelectResize(attr, op_def, gpu_op);
}
case OperationType::ABS:
case OperationType::COS:
case OperationType::HARD_SWISH:

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/resize.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@ -125,6 +126,13 @@ void SelectReshape(int src_channels, int dst_channels,
}
}
void SelectSpaceToDepth(const SpaceToDepthAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
SpaceToDepth operation = CreateSpaceToDepth(op_def, attr);
*ptr = absl::make_unique<SpaceToDepth>(std::move(operation));
}
void SelectPadding(const PadAttributes& attr, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
Padding operation = CreatePadding(op_def, attr);

View File

@ -82,6 +82,10 @@ Status SelectBroadcastAdd(const AddAttributes& attr,
void SelectSoftmax(const BHWC& shape, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
void SelectSpaceToDepth(const SpaceToDepthAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
void SelectTranspose(const TransposeAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);

View File

@ -1872,50 +1872,6 @@ class Resize2DOperationParser : public TFLiteOperationParser {
SamplingType sampling_type_ = SamplingType::UNKNOWN;
};
class SoftmaxOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
TfLiteSoftmaxParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->beta != 1) {
// TODO(eignasheva): figure out, what's wrong with softmax.
return UnimplementedError("Softmax.beta != 1 is not supported.");
}
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SOFTMAX);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options =
reinterpret_cast<const TfLiteSoftmaxParams*>(tflite_node->builtin_data);
if (!tf_options) {
return InternalError("Missing tflite params");
}
if (tf_options->beta != 1) {
// there is multiply by scalar operation fused in softmax. Make a layer
// out of it before softmax.
return UnimplementedError("Softmax.beta != 1 is not supported.");
// auto mul_node = reader->NewPassthroughNode(node);
// mul_node->operation.type = ToString(OperationType::MUL);
}
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS; // always by channels
node->operation.attributes = attr;
return OkStatus();
}
};
class SliceOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
@ -1995,6 +1951,86 @@ class SliceOperationParser : public TFLiteOperationParser {
}
};
class SoftmaxOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
TfLiteSoftmaxParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->beta != 1) {
// TODO(eignasheva): figure out, what's wrong with softmax.
return UnimplementedError("Softmax.beta != 1 is not supported.");
}
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SOFTMAX);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options =
reinterpret_cast<const TfLiteSoftmaxParams*>(tflite_node->builtin_data);
if (!tf_options) {
return InternalError("Missing tflite params");
}
if (tf_options->beta != 1) {
// there is multiply by scalar operation fused in softmax. Make a layer
// out of it before softmax.
return UnimplementedError("Softmax.beta != 1 is not supported.");
// auto mul_node = reader->NewPassthroughNode(node);
// mul_node->operation.type = ToString(OperationType::MUL);
}
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS; // always by channels
node->operation.attributes = attr;
return OkStatus();
}
};
class SpaceToDepthOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
// TODO(impjdi): Dims check.
TfLiteSpaceToDepthParams* s2d_params = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params));
if (s2d_params->block_size == 1) {
return InvalidArgumentError("SPACE_TO_DEPTH block_size = 1 is a no-op.");
}
if (s2d_params->block_size < 1) {
return InvalidArgumentError("SPACE_TO_DEPTH block_size must be > 1.");
}
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options = reinterpret_cast<const TfLiteSpaceToDepthParams*>(
tflite_node->builtin_data);
SpaceToDepthAttributes attr;
attr.block_size = tf_options->block_size;
node->operation.attributes = attr;
return OkStatus();
}
};
class StridedSliceOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
@ -2651,12 +2687,12 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
OperationType::RSQRT);
case kTfLiteBuiltinSin:
return absl::make_unique<ElementwiseOperationParser>(OperationType::SIN);
case kTfLiteBuiltinSoftmax:
return absl::make_unique<SoftmaxOperationParser>();
case kTfLiteBuiltinSlice:
return absl::make_unique<SliceOperationParser>();
case kTfLiteBuiltinStridedSlice:
return absl::make_unique<StridedSliceOperationParser>();
case kTfLiteBuiltinSoftmax:
return absl::make_unique<SoftmaxOperationParser>();
case kTfLiteBuiltinSpaceToDepth:
return absl::make_unique<SpaceToDepthOperationParser>();
case kTfLiteBuiltinSqrt:
return absl::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
case kTfLiteBuiltinSquare:
@ -2665,6 +2701,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
case kTfLiteBuiltinSquaredDifference:
return absl::make_unique<ElementwiseOperationParser>(
OperationType::SQUARED_DIFF);
case kTfLiteBuiltinStridedSlice:
return absl::make_unique<StridedSliceOperationParser>();
case kTfLiteBuiltinSub:
return absl::make_unique<ElementwiseOperationParser>(OperationType::SUB);
case kTfLiteBuiltinTanh:

View File

@ -134,6 +134,8 @@ std::string ToString(enum OperationType op) {
return "softmax";
case OperationType::SPACE_TO_BATCH:
return "space_to_batch";
case OperationType::SPACE_TO_DEPTH:
return "space_to_depth";
case OperationType::SQRT:
return "sqrt";
case OperationType::SQUARE:
@ -186,6 +188,7 @@ OperationType OperationTypeFromString(const std::string& name) {
{"sin", OperationType::SIN},
{"slice", OperationType::SLICE},
{"softmax", OperationType::SOFTMAX},
{"space_to_depth", OperationType::SPACE_TO_DEPTH},
{"sqrt", OperationType::SQRT},
{"square", OperationType::SQUARE},
{"squared_diff", OperationType::SQUARED_DIFF},

View File

@ -65,6 +65,7 @@ enum class OperationType {
SLICE,
SOFTMAX,
SPACE_TO_BATCH,
SPACE_TO_DEPTH,
SQRT,
SQUARE,
SQUARED_DIFF,
@ -472,6 +473,10 @@ struct TransposeAttributes {
// the given input.
BHWC CalculateOutputShape(const BHWC& input, const TransposeAttributes& attr);
struct SpaceToDepthAttributes {
int block_size;
};
} // namespace gpu
} // namespace tflite

View File

@ -569,6 +569,35 @@ cc_test(
],
)
cc_library(
name = "space_to_depth",
srcs = ["space_to_depth.cc"],
hdrs = ["space_to_depth.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/gl:node_shader",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:any",
],
)
cc_test(
name = "space_to_depth_test",
srcs = ["space_to_depth_test.cc"],
tags = tf_gpu_tests_tags() + [
"notap",
"tflite_not_portable_ios",
],
deps = [
":space_to_depth",
":test_util",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "test_util",
testonly = 1,
@ -676,6 +705,7 @@ TFLITE_GPU_BINARY_RELEASE_OPERATORS = [
"resize",
"slice",
"softmax",
"space_to_depth",
"transpose_conv",
]

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.h"
#include <string>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/types/any.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
namespace tflite {
namespace gpu {
namespace gl {
namespace {
class SpaceToDepth : public NodeShader {
public:
Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final {
const auto attr =
absl::any_cast<SpaceToDepthAttributes>(ctx.node->operation.attributes);
const auto& input_data_0 = ctx.graph->FindInputs(ctx.node->id)[0]->tensor;
std::string code = R"(
for (int i = 0; i < 4; ++i) {
int dst_c = 4 * gid.z + i;
int block_id = dst_c / $input_data_0_c$;
int src_x = gid.x * $block_size$ + block_id % $block_size$;
int src_y = gid.y * $block_size$ + block_id / $block_size$;
int src_c = dst_c % $input_data_0_c$;
value_0[i] = $input_data_0[src_x, src_y, src_c / 4]$[src_c % 4];
}
)";
*generated_code = {
/*parameters=*/{
{"block_size", attr.block_size},
{"input_data_0_c", input_data_0.shape.c},
},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/std::move(code),
/*input=*/IOStructure::ONLY_DEFINITIONS,
/*output=*/IOStructure::AUTO,
};
return OkStatus();
}
};
} // namespace
std::unique_ptr<NodeShader> NewSpaceToDepthNodeShader() {
return absl::make_unique<SpaceToDepth>();
}
} // namespace gl
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_SPACE_TO_DEPTH_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_SPACE_TO_DEPTH_H_
#include <memory>
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
namespace tflite {
namespace gpu {
namespace gl {
std::unique_ptr<NodeShader> NewSpaceToDepthNodeShader();
} // namespace gl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_SPACE_TO_DEPTH_H_

View File

@ -0,0 +1,104 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.h"
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/test_util.h"
using ::testing::FloatNear;
using ::testing::Pointwise;
namespace tflite {
namespace gpu {
namespace gl {
namespace {
TEST(SpaceToDepthTest, TensorShape1x2x2x1BlockSize2) {
const TensorRef<BHWC> input = {
.type = DataType::FLOAT32, .shape = BHWC(1, 2, 2, 1), .ref = 0};
const TensorRef<BHWC> output = {
.type = DataType::FLOAT32, .shape = BHWC(1, 1, 1, 4), .ref = 1};
const SpaceToDepthAttributes attr = {.block_size = 2};
SingleOpModel model({ToString(OperationType::SPACE_TO_DEPTH), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {1.0f, 2.0f, 3.0f, 4.0f}));
ASSERT_OK(model.Invoke(*NewSpaceToDepthNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0f, 2.0f, 3.0f, 4.0f}));
}
TEST(SpaceToDepthTest, TensorShape1x2x2x2BlockSize2) {
const TensorRef<BHWC> input = {
.type = DataType::FLOAT32, .shape = BHWC(1, 2, 2, 2), .ref = 0};
const TensorRef<BHWC> output = {
.type = DataType::FLOAT32, .shape = BHWC(1, 1, 1, 8), .ref = 1};
const SpaceToDepthAttributes attr = {.block_size = 2};
SingleOpModel model({ToString(OperationType::SPACE_TO_DEPTH), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(
0, {1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f}));
ASSERT_OK(model.Invoke(*NewSpaceToDepthNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6),
{1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f}));
}
TEST(SpaceToDepthTest, TensorShape1x2x2x3BlockSize2) {
const TensorRef<BHWC> input = {
.type = DataType::FLOAT32, .shape = BHWC(1, 2, 2, 3), .ref = 0};
const TensorRef<BHWC> output = {
.type = DataType::FLOAT32, .shape = BHWC(1, 1, 1, 12), .ref = 1};
const SpaceToDepthAttributes attr = {.block_size = 2};
SingleOpModel model({ToString(OperationType::SPACE_TO_DEPTH), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {1.0f, 2.0f, 3.0f, //
4.0f, 5.0f, 6.0f, //
7.0f, 8.0f, 9.0f, //
10.0f, 11.0f, 12.0f}));
ASSERT_OK(model.Invoke(*NewSpaceToDepthNodeShader()));
EXPECT_THAT(
model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, //
7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}));
}
TEST(SpaceToDepthTest, TensorShape1x4x4x1BlockSize2) {
const TensorRef<BHWC> input = {
.type = DataType::FLOAT32, .shape = BHWC(1, 4, 4, 1), .ref = 0};
const TensorRef<BHWC> output = {
.type = DataType::FLOAT32, .shape = BHWC(1, 2, 2, 4), .ref = 1};
const SpaceToDepthAttributes attr = {.block_size = 2};
SingleOpModel model({ToString(OperationType::SPACE_TO_DEPTH), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 5.0, 6.0, //
3.0, 4.0, 7.0, 8.0, //
9.0, 10.0, 13.0, 14.0, //
11.0, 12.0, 15.0, 16.0}));
ASSERT_OK(model.Invoke(*NewSpaceToDepthNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 2.0, 3.0, 4.0, //
5.0, 6.0, 7.0, 8.0, //
9.0, 10.0, 11.0, 12.0, //
13.0, 14.0, 15.0, 16.0}));
}
} // namespace
} // namespace gl
} // namespace gpu
} // namespace tflite

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/kernels/resize.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/slice.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
@ -137,6 +138,12 @@ std::vector<ComputeTaskDescriptorPtr> SelectSoftmax(const GraphFloat32& graph,
}
}
std::vector<ComputeTaskDescriptorPtr> SelectSpaceToDepth(
const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id,
const SpaceToDepthAttributes& attr) {
return SpaceToDepth(id, input_id, output_id, attr);
}
Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
const std::vector<ValueId>& inputs,
const std::vector<ValueId>& outputs,
@ -254,6 +261,11 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
*tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]);
break;
}
case OperationType::SPACE_TO_DEPTH:
*tasks = SelectSpaceToDepth(
graph, node_id, inputs[0], outputs[0],
absl::any_cast<SpaceToDepthAttributes>(node->operation.attributes));
break;
case OperationType::ABS:
case OperationType::COS:
case OperationType::HARD_SWISH:

View File

@ -30,6 +30,7 @@ cc_library(
":resize",
":slice",
":softmax",
":space_to_depth",
":transpose_conv",
],
)
@ -698,6 +699,42 @@ ios_unit_test(
deps = [":softmax_test_lib"],
)
cc_library(
name = "space_to_depth",
srcs = ["space_to_depth.cc"],
hdrs = ["space_to_depth.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
"//tensorflow/lite/delegates/gpu/metal/kernels:util",
],
)
objc_library(
name = "space_to_depth_test_lib",
testonly = 1,
srcs = ["space_to_depth_test.mm"],
sdk_frameworks = ["XCTest"],
deps = [
":space_to_depth",
":test_util",
],
)
ios_unit_test(
name = "space_to_depth_test",
testonly = 1,
minimum_os_version = "10.0",
tags = tf_gpu_tests_tags() + [
"notap",
"tflite_not_portable_android",
],
deps = [":space_to_depth_test_lib"],
)
cc_library(
name = "transpose_conv",
srcs = ["transpose_conv.cc"],

View File

@ -0,0 +1,129 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth.h"
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h"
namespace tflite {
namespace gpu {
namespace metal {
std::vector<ComputeTaskDescriptorPtr> SpaceToDepth(
int id, ValueId input_id, ValueId output_id,
const SpaceToDepthAttributes& attr) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
uint4 src_size;
uint4 dst_size;
uint4 block_size;
};
$0
kernel void ComputeFunction($1 uint3 gid[[thread_position_in_grid]]) {
uint3 src_size = (uint3)(params.src_size.xyz);
uint3 dst_size = (uint3)(params.dst_size.xyz);
uint block_size = (uint)(params.block_size.x);
if (gid.x >= dst_size.x || gid.y >= dst_size.y || gid.z * 4 >= dst_size.z) {
return;
}
FLT4 value;
for (uint i = 0; i < 4; ++i) {
uint dst_c = 4 * gid.z + i;
uint block_id = dst_c / src_size.z;
uint src_x = gid.x * block_size + block_id % block_size;
uint src_y = gid.y * block_size + block_id / block_size;
uint src_c = dst_c % src_size.z;
value[i] =
src_buffer[src_x + src_size.x * (src_y + src_size.y * (src_c / 4))]
[src_c % 4];
}
$2
dst_buffer[gid.x + dst_size.x * (gid.y + dst_size.y * gid.z)] = value;
})";
desc->input_buffers = {{input_id, "device FLT4* const src_buffer"}};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id, attr](const std::map<ValueId, BHWC>& buffers) -> BHWC {
const BHWC& input_shape = buffers.find(input_id)->second;
return BHWC(input_shape.b, //
input_shape.h / attr.block_size,
input_shape.w / attr.block_size,
input_shape.c * attr.block_size * attr.block_size);
}};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_id, output_id, attr](const std::map<ValueId, BHWC>& buffers) {
const BHWC& input_shape = buffers.find(input_id)->second;
const BHWC& output_shape = buffers.find(output_id)->second;
const std::vector<int> uniform_params = {
// src_size
input_shape.w,
input_shape.h,
input_shape.c,
0,
// dst_size
output_shape.w,
output_shape.h,
output_shape.c,
0,
// block_size
attr.block_size,
0,
0,
0,
};
return GetByteBuffer(uniform_params);
}},
};
desc->resize_function =
[input_id, attr](
const std::map<ValueId, BHWC>& buffers) -> std::pair<uint3, uint3> {
const BHWC& input_shape = buffers.find(input_id)->second;
const BHWC output_shape(input_shape.b, //
input_shape.h / attr.block_size,
input_shape.w / attr.block_size,
input_shape.c * attr.block_size * attr.block_size);
const uint3 grid = uint3(output_shape.w, output_shape.h,
IntegralDivideRoundUp(output_shape.c, 4));
const uint3 groups_size = GetWorkGroupSizeForGrid(grid);
const int groups_x = IntegralDivideRoundUp(grid.x, groups_size.x);
const int groups_y = IntegralDivideRoundUp(grid.y, groups_size.y);
const int groups_z = IntegralDivideRoundUp(grid.z, groups_size.z);
return std::make_pair(groups_size, uint3(groups_x, groups_y, groups_z));
};
return {desc};
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_SPACE_TO_DEPTH_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_SPACE_TO_DEPTH_H_
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
namespace tflite {
namespace gpu {
namespace metal {
std::vector<ComputeTaskDescriptorPtr> SpaceToDepth(
int id, ValueId input_id, ValueId output_id,
const SpaceToDepthAttributes& attr);
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_SPACE_TO_DEPTH_H_

View File

@ -0,0 +1,153 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth.h"
#import <XCTest/XCTest.h>
#include <cmath>
#include <string>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
using ::tflite::gpu::BHWC;
using ::tflite::gpu::DataType;
using ::tflite::gpu::OperationType;
using ::tflite::gpu::SpaceToDepthAttributes;
using ::tflite::gpu::TensorRef;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
@interface SpaceToDepthTest : XCTestCase
@end
@implementation SpaceToDepthTest
- (void)testTensorShape1x2x2x1BlockSize2 {
const TensorRef<BHWC> input = {.type = DataType::FLOAT32, .shape = BHWC(1, 2, 2, 1), .ref = 0};
const TensorRef<BHWC> output = {.type = DataType::FLOAT32, .shape = BHWC(1, 1, 1, 4), .ref = 1};
const SpaceToDepthAttributes attr = {.block_size = 2};
SingleOpModel model({ToString(OperationType::SPACE_TO_DEPTH), attr}, {input}, {output});
if (!model.PopulateTensor(0, {1.0f, 2.0f, 3.0f, 4.0f})) {
XCTFail(@"PopulateTensor()");
}
const auto status = model.Invoke();
if (!status.ok()) XCTFail(@"%s", status.error_message().c_str());
const std::vector<float>& actual = model.GetOutput(0);
const std::vector<float> expected = {1.0f, 2.0f, 3.0f, 4.0f};
XCTAssertEqual(actual[0], expected[0]);
XCTAssertEqual(actual[1], expected[1]);
XCTAssertEqual(actual[2], expected[2]);
XCTAssertEqual(actual[3], expected[3]);
}
- (void)testTensorShape1x2x2x2BlockSize2 {
const TensorRef<BHWC> input = {.type = DataType::FLOAT32, .shape = BHWC(1, 2, 2, 2), .ref = 0};
const TensorRef<BHWC> output = {.type = DataType::FLOAT32, .shape = BHWC(1, 1, 1, 8), .ref = 1};
const SpaceToDepthAttributes attr = {.block_size = 2};
SingleOpModel model({ToString(OperationType::SPACE_TO_DEPTH), attr}, {input}, {output});
if (!model.PopulateTensor(0, {1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f})) {
XCTFail(@"PopulateTensor()");
}
const auto status = model.Invoke();
if (!status.ok()) XCTFail(@"%s", status.error_message().c_str());
const std::vector<float>& actual = model.GetOutput(0);
const std::vector<float> expected = {1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f};
XCTAssertEqual(actual[0], expected[0]);
XCTAssertEqual(actual[1], expected[1]);
XCTAssertEqual(actual[2], expected[2]);
XCTAssertEqual(actual[3], expected[3]);
XCTAssertEqual(actual[4], expected[4]);
XCTAssertEqual(actual[5], expected[5]);
XCTAssertEqual(actual[6], expected[6]);
XCTAssertEqual(actual[7], expected[7]);
}
- (void)testTensorShape1x2x2x3BlockSize2 {
const TensorRef<BHWC> input = {.type = DataType::FLOAT32, .shape = BHWC(1, 2, 2, 3), .ref = 0};
const TensorRef<BHWC> output = {.type = DataType::FLOAT32, .shape = BHWC(1, 1, 1, 12), .ref = 1};
const SpaceToDepthAttributes attr = {.block_size = 2};
SingleOpModel model({ToString(OperationType::SPACE_TO_DEPTH), attr}, {input}, {output});
if (!model.PopulateTensor(0, {1.0f, 2.0f, 3.0f, //
4.0f, 5.0f, 6.0f, //
7.0f, 8.0f, 9.0f, //
10.0f, 11.0f, 12.0f})) {
XCTFail(@"PopulateTensor()");
}
const auto status = model.Invoke();
if (!status.ok()) XCTFail(@"%s", status.error_message().c_str());
const std::vector<float>& actual = model.GetOutput(0);
const std::vector<float> expected = {1.0f, 2.0f, 3.0f, //
4.0f, 5.0f, 6.0f, //
7.0f, 8.0f, 9.0f, //
10.0f, 11.0f, 12.0f};
XCTAssertEqual(actual[0], expected[0]);
XCTAssertEqual(actual[1], expected[1]);
XCTAssertEqual(actual[2], expected[2]);
XCTAssertEqual(actual[3], expected[3]);
XCTAssertEqual(actual[4], expected[4]);
XCTAssertEqual(actual[5], expected[5]);
XCTAssertEqual(actual[6], expected[6]);
XCTAssertEqual(actual[7], expected[7]);
XCTAssertEqual(actual[8], expected[8]);
XCTAssertEqual(actual[9], expected[9]);
XCTAssertEqual(actual[10], expected[10]);
XCTAssertEqual(actual[11], expected[11]);
}
- (void)testTensorShape1x4x4x1BlockSize2 {
const TensorRef<BHWC> input = {.type = DataType::FLOAT32, .shape = BHWC(1, 4, 4, 1), .ref = 0};
const TensorRef<BHWC> output = {.type = DataType::FLOAT32, .shape = BHWC(1, 2, 2, 4), .ref = 1};
const SpaceToDepthAttributes attr = {.block_size = 2};
SingleOpModel model({ToString(OperationType::SPACE_TO_DEPTH), attr}, {input}, {output});
if (!model.PopulateTensor(0, {1.0f, 2.0f, 5.0f, 6.0f, //
3.0f, 4.0f, 7.0f, 8.0f, //
9.0f, 10.0f, 13.0f, 14.0f, //
11.0f, 12.0f, 15.0f, 16.0f})) {
XCTFail(@"PopulateTensor()");
}
const auto status = model.Invoke();
if (!status.ok()) XCTFail(@"%s", status.error_message().c_str());
const std::vector<float>& actual = model.GetOutput(0);
const std::vector<float> expected = {1.0f, 2.0f, 3.0f, 4.0f, //
5.0f, 6.0f, 7.0f, 8.0f, //
9.0f, 10.0f, 11.0f, 12.0f, //
13.0f, 14.0f, 15.0f, 16.0f};
XCTAssertEqual(actual[0], expected[0]);
XCTAssertEqual(actual[1], expected[1]);
XCTAssertEqual(actual[2], expected[2]);
XCTAssertEqual(actual[3], expected[3]);
XCTAssertEqual(actual[4], expected[4]);
XCTAssertEqual(actual[5], expected[5]);
XCTAssertEqual(actual[6], expected[6]);
XCTAssertEqual(actual[7], expected[7]);
XCTAssertEqual(actual[8], expected[8]);
XCTAssertEqual(actual[9], expected[9]);
XCTAssertEqual(actual[10], expected[10]);
XCTAssertEqual(actual[11], expected[11]);
XCTAssertEqual(actual[12], expected[12]);
XCTAssertEqual(actual[13], expected[13]);
XCTAssertEqual(actual[14], expected[14]);
XCTAssertEqual(actual[15], expected[15]);
}
@end