Support Resize operation with two NearestNeighbor and Bilinear params.

PiperOrigin-RevId: 291261285
Change-Id: Ia4e20f6766182af9e451a75be23d612d2807a8ab
This commit is contained in:
A. Unique TensorFlower 2020-01-23 15:44:46 -08:00 committed by TensorFlower Gardener
parent 306dee4096
commit 4b17c10739
24 changed files with 403 additions and 271 deletions

View File

@ -1222,9 +1222,9 @@ cc_library(
)
cc_library(
name = "upsample",
srcs = ["upsample.cc"],
hdrs = ["upsample.h"],
name = "resize",
srcs = ["resize.cc"],
hdrs = ["resize.h"],
deps = [
":gpu_operation",
":util",
@ -1237,8 +1237,8 @@ cc_library(
)
cc_test(
name = "upsample_test",
srcs = ["upsample_test.cc"],
name = "resize_test",
srcs = ["resize_test.cc"],
linkstatic = True,
tags = tf_gpu_tests_tags() + [
"linux",
@ -1246,7 +1246,7 @@ cc_test(
],
deps = [
":cl_test",
":upsample",
":resize",
"//tensorflow/lite/delegates/gpu/cl:tensor",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
@ -1317,10 +1317,10 @@ test_suite(
"relu_test",
"reshape_test",
"reshapex4_test",
"resize_test",
"softmax1x1_test",
"softmax_test",
"strided_slice_test",
"transpose_test",
"upsample_test",
],
)

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,18 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/upsample.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/resize.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
namespace tflite {
namespace gpu {
namespace cl {
namespace {
std::string GetUpsampleCode(
std::string GetResizeCode(
const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor(
@ -84,7 +85,7 @@ std::string GetUpsampleCode(
return c;
}
std::string GetUpsample3DCode(
std::string GetResize3DCode(
const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor(
@ -161,13 +162,13 @@ std::string GetUpsample3DCode(
} // namespace
Upsample::Upsample(Upsample&& operation)
Resize::Resize(Resize&& operation)
: GPUOperation(std::move(operation)),
attr_(operation.attr_),
kernel_(std::move(operation.kernel_)),
work_group_size_(operation.work_group_size_) {}
Upsample& Upsample::operator=(Upsample&& operation) {
Resize& Resize::operator=(Resize&& operation) {
if (this != &operation) {
attr_ = operation.attr_;
kernel_ = std::move(operation.kernel_);
@ -177,14 +178,17 @@ Upsample& Upsample::operator=(Upsample&& operation) {
return *this;
}
Status Upsample::Compile(const CreationContext& creation_context) {
const auto code = GetUpsampleCode(definition_, linked_operations_);
Status Resize::Compile(const CreationContext& creation_context) {
if (attr_.type != SamplingType::BILINEAR) {
return InternalError("Only bilinear sampling is currently supported");
}
const auto code = GetResizeCode(definition_, linked_operations_);
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
Status Upsample::BindArguments() {
Status Resize::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
@ -200,35 +204,35 @@ Status Upsample::BindArguments() {
return OkStatus();
}
int3 Upsample::GetGridSize() const {
int3 Resize::GetGridSize() const {
const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
const int grid_y = dst_[0]->Height();
const int grid_z = dst_[0]->Slices();
return int3(grid_x, grid_y, grid_z);
}
Status Upsample::AddToQueue(CLCommandQueue* queue) {
Status Resize::AddToQueue(CLCommandQueue* queue) {
RETURN_IF_ERROR(BindArguments());
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
}
Status Upsample::Tune(const TuningParameters& params) {
Status Resize::Tune(const TuningParameters& params) {
RETURN_IF_ERROR(BindArguments());
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
}
Upsample CreateUpsample(const OperationDef& definition,
const Upsample2DAttributes& attr) {
return Upsample(definition, attr);
Resize CreateResize(const OperationDef& definition,
const Resize2DAttributes& attr) {
return Resize(definition, attr);
}
Upsample3D::Upsample3D(Upsample3D&& operation)
Resize3D::Resize3D(Resize3D&& operation)
: GPUOperation(std::move(operation)),
attr_(operation.attr_),
kernel_(std::move(operation.kernel_)),
work_group_size_(operation.work_group_size_) {}
Upsample3D& Upsample3D::operator=(Upsample3D&& operation) {
Resize3D& Resize3D::operator=(Resize3D&& operation) {
if (this != &operation) {
attr_ = operation.attr_;
kernel_ = std::move(operation.kernel_);
@ -238,14 +242,14 @@ Upsample3D& Upsample3D::operator=(Upsample3D&& operation) {
return *this;
}
Status Upsample3D::Compile(const CreationContext& creation_context) {
const auto code = GetUpsample3DCode(definition_, linked_operations_);
Status Resize3D::Compile(const CreationContext& creation_context) {
const auto code = GetResize3DCode(definition_, linked_operations_);
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
Status Upsample3D::BindArguments() {
Status Resize3D::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
@ -265,26 +269,26 @@ Status Upsample3D::BindArguments() {
return OkStatus();
}
int3 Upsample3D::GetGridSize() const {
int3 Resize3D::GetGridSize() const {
const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
const int grid_y = dst_[0]->Height();
const int grid_z = dst_[0]->Slices() * dst_[0]->Depth();
return int3(grid_x, grid_y, grid_z);
}
Status Upsample3D::AddToQueue(CLCommandQueue* queue) {
Status Resize3D::AddToQueue(CLCommandQueue* queue) {
RETURN_IF_ERROR(BindArguments());
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
}
Status Upsample3D::Tune(const TuningParameters& params) {
Status Resize3D::Tune(const TuningParameters& params) {
RETURN_IF_ERROR(BindArguments());
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
}
Upsample3D CreateUpsample3D(const OperationDef& definition,
const Upsample3DAttributes& attr) {
return Upsample3D(definition, attr);
Resize3D CreateResize3D(const OperationDef& definition,
const Resize3DAttributes& attr) {
return Resize3D(definition, attr);
}
} // namespace cl

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_UPSAMPLE_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_UPSAMPLE_H_
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_RESIZE_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_RESIZE_H_
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
@ -25,7 +25,7 @@ namespace tflite {
namespace gpu {
namespace cl {
class Upsample : public GPUOperation {
class Resize : public GPUOperation {
public:
Status AddToQueue(CLCommandQueue* queue) override;
Status Tune(const TuningParameters& params) override;
@ -33,30 +33,30 @@ class Upsample : public GPUOperation {
Status Compile(const CreationContext& creation_context) override;
// Move only
Upsample(Upsample&& operation);
Upsample& operator=(Upsample&& operation);
Upsample(const Upsample&) = delete;
Upsample& operator=(const Upsample&) = delete;
Resize(Resize&& operation);
Resize& operator=(Resize&& operation);
Resize(const Resize&) = delete;
Resize& operator=(const Resize&) = delete;
friend Upsample CreateUpsample(const OperationDef& definition,
const Upsample2DAttributes& attr);
friend Resize CreateResize(const OperationDef& definition,
const Resize2DAttributes& attr);
private:
Upsample(const OperationDef& definition, const Upsample2DAttributes& attr)
Resize(const OperationDef& definition, const Resize2DAttributes& attr)
: GPUOperation(definition), attr_(attr) {}
Status BindArguments();
int3 GetGridSize() const;
Upsample2DAttributes attr_;
Resize2DAttributes attr_;
CLKernel kernel_;
int3 work_group_size_ = int3(8, 4, 1);
};
Upsample CreateUpsample(const OperationDef& definition,
const Upsample2DAttributes& attr);
Resize CreateResize(const OperationDef& definition,
const Resize2DAttributes& attr);
class Upsample3D : public GPUOperation {
class Resize3D : public GPUOperation {
public:
Status AddToQueue(CLCommandQueue* queue) override;
Status Tune(const TuningParameters& params) override;
@ -64,31 +64,31 @@ class Upsample3D : public GPUOperation {
Status Compile(const CreationContext& creation_context) override;
// Move only
Upsample3D(Upsample3D&& operation);
Upsample3D& operator=(Upsample3D&& operation);
Upsample3D(const Upsample3D&) = delete;
Upsample3D& operator=(const Upsample3D&) = delete;
Resize3D(Resize3D&& operation);
Resize3D& operator=(Resize3D&& operation);
Resize3D(const Resize3D&) = delete;
Resize3D& operator=(const Resize3D&) = delete;
friend Upsample3D CreateUpsample3D(const OperationDef& definition,
const Upsample3DAttributes& attr);
friend Resize3D CreateResize3D(const OperationDef& definition,
const Resize3DAttributes& attr);
private:
Upsample3D(const OperationDef& definition, const Upsample3DAttributes& attr)
Resize3D(const OperationDef& definition, const Resize3DAttributes& attr)
: GPUOperation(definition), attr_(attr) {}
Status BindArguments();
int3 GetGridSize() const;
Upsample3DAttributes attr_;
Resize3DAttributes attr_;
CLKernel kernel_;
int3 work_group_size_ = int3(8, 4, 1);
};
Upsample3D CreateUpsample3D(const OperationDef& definition,
const Upsample3DAttributes& attr);
Resize3D CreateResize3D(const OperationDef& definition,
const Resize3DAttributes& attr);
} // namespace cl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_UPSAMPLE_H_
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_RESIZE_H_

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/upsample.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/resize.h"
#include <vector>
@ -31,13 +31,13 @@ namespace gpu {
namespace cl {
namespace {
TEST_F(OpenCLOperationTest, UpsampleBilinearAligned) {
TEST_F(OpenCLOperationTest, ResizeBilinearAligned) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 3, 1);
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Upsample2DAttributes attr;
attr.type = UpsamplingType::BILINEAR;
Resize2DAttributes attr;
attr.type = SamplingType::BILINEAR;
attr.new_shape = HW(4, 4);
attr.align_corners = true;
@ -50,7 +50,7 @@ TEST_F(OpenCLOperationTest, UpsampleBilinearAligned) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
Upsample operation = CreateUpsample(op_def, attr);
Resize operation = CreateResize(op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 4, 4, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
@ -62,13 +62,13 @@ TEST_F(OpenCLOperationTest, UpsampleBilinearAligned) {
}
}
TEST_F(OpenCLOperationTest, UpsampleBilinearNonAligned) {
TEST_F(OpenCLOperationTest, ResizeBilinearNonAligned) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 3, 1);
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Upsample2DAttributes attr;
attr.type = UpsamplingType::BILINEAR;
Resize2DAttributes attr;
attr.type = SamplingType::BILINEAR;
attr.new_shape = HW(4, 4);
attr.align_corners = false;
@ -81,7 +81,7 @@ TEST_F(OpenCLOperationTest, UpsampleBilinearNonAligned) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
Upsample operation = CreateUpsample(op_def, attr);
Resize operation = CreateResize(op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 4, 4, 1), &dst_tensor));
EXPECT_THAT(

View File

@ -116,11 +116,11 @@ cc_library(
"//tensorflow/lite/delegates/gpu/cl/kernels:relu",
"//tensorflow/lite/delegates/gpu/cl/kernels:reshape",
"//tensorflow/lite/delegates/gpu/cl/kernels:reshapex4",
"//tensorflow/lite/delegates/gpu/cl/kernels:resize",
"//tensorflow/lite/delegates/gpu/cl/kernels:softmax",
"//tensorflow/lite/delegates/gpu/cl/kernels:softmax1x1",
"//tensorflow/lite/delegates/gpu/cl/kernels:strided_slice",
"//tensorflow/lite/delegates/gpu/cl/kernels:transpose",
"//tensorflow/lite/delegates/gpu/cl/kernels:upsample",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:status",

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -160,10 +160,9 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
SelectTranspose(attr, op_def, gpu_op);
return OkStatus();
}
case OperationType::UPSAMPLE_2D: {
auto attr =
absl::any_cast<Upsample2DAttributes>(node.operation.attributes);
return SelectUpsampling(attr, op_def, gpu_op);
case OperationType::RESIZE: {
auto attr = absl::any_cast<Resize2DAttributes>(node.operation.attributes);
return SelectResize(attr, op_def, gpu_op);
}
case OperationType::ABS:
case OperationType::COS:

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -31,11 +31,11 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/relu.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/reshape.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/resize.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/upsample.h"
namespace tflite {
namespace gpu {
@ -90,14 +90,13 @@ void SelectAdd(const OperationDef& op_def, const std::vector<int>& channels,
*ptr = absl::make_unique<Add>(std::move(operation));
}
Status SelectUpsampling(const Upsample2DAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
if (attr.type != UpsamplingType::BILINEAR) {
return UnimplementedError("Upsample2D supports only bilinear type.");
Status SelectResize(const Resize2DAttributes& attr, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
if (attr.type != SamplingType::BILINEAR) {
return UnimplementedError("Resize2D supports only bilinear sampling.");
}
Upsample operation = CreateUpsample(op_def, attr);
*ptr = absl::make_unique<Upsample>(std::move(operation));
Resize operation = CreateResize(op_def, attr);
*ptr = absl::make_unique<Resize>(std::move(operation));
return OkStatus();
}

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -52,9 +52,8 @@ void SelectMaxUnpooling(const MaxUnpooling2DAttributes& attr,
void SelectAdd(const OperationDef& op_def, const std::vector<int>& channels,
int dst_channels, std::unique_ptr<GPUOperation>* ptr);
Status SelectUpsampling(const Upsample2DAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
Status SelectResize(const Resize2DAttributes& attr, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
Status SelectConcat(const ConcatAttributes& attr,
const std::vector<int>& channels,

View File

@ -1693,8 +1693,11 @@ class ReshapeOperationParser : public TFLiteOperationParser {
}
};
class ResizeBilinearOperationParser : public TFLiteOperationParser {
class Resize2DOperationParser : public TFLiteOperationParser {
public:
explicit Resize2DOperationParser(SamplingType sampling_type)
: sampling_type_(sampling_type) {}
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
@ -1702,9 +1705,9 @@ class ResizeBilinearOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
// TODO(eignasheva): check shapes.
TfLiteResizeBilinearParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node));
bool align_corners;
RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners));
return OkStatus();
}
@ -1712,26 +1715,71 @@ class ResizeBilinearOperationParser : public TFLiteOperationParser {
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::UPSAMPLE_2D);
node->operation.type = ToString(OperationType::RESIZE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
// Here we may have extra inputs. Other tensors were supposed to
// define new shape, but in TFLite these are ignored.
const auto* tf_options =
reinterpret_cast<const TfLiteResizeBilinearParams*>(
tflite_node->builtin_data);
if (!tf_options) {
return InternalError("Missing tflite params");
}
Upsample2DAttributes attr;
attr.align_corners = tf_options->align_corners;
attr.type = UpsamplingType::BILINEAR;
Resize2DAttributes attr;
RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
attr.type = sampling_type_;
attr.new_shape.CopyAllDefinedAxis(
graph->FindOutputs(node->id)[0]->tensor.shape);
node->operation.attributes = attr;
return OkStatus();
}
private:
Status GetAlignCornersValue(const TfLiteNode* tflite_node,
bool* align_corners) {
switch (sampling_type_) {
case SamplingType::BILINEAR:
return GetAlignCornersValueForType<TfLiteResizeBilinearParams>(
tflite_node, align_corners);
case SamplingType::NEAREST:
return GetAlignCornersValueForType<TfLiteResizeNearestNeighborParams>(
tflite_node, align_corners);
case SamplingType::UNKNOWN:
return InternalError("Sampling type is not specified");
}
return OkStatus();
}
template <class T>
Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
bool* align_corners) {
const auto* tf_options =
reinterpret_cast<const T*>(tflite_node->builtin_data);
if (!tf_options) {
return InternalError("Missing tflite params");
}
*align_corners = tf_options->align_corners;
return OkStatus();
}
Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node) {
const auto* input = context->tensors + tflite_node->inputs->data[0];
const auto* output = context->tensors + tflite_node->outputs->data[0];
if (!input->dims || input->dims->size != 4) {
return InvalidArgumentError("input.dims.size != 4");
}
if (!output->dims || output->dims->size != 4) {
return InvalidArgumentError("output.dims.size != 4");
}
if (output->dims->data[1] < input->dims->data[1] ||
output->dims->data[2] < input->dims->data[2]) {
return InvalidArgumentError(absl::StrCat(
"Only upsampling is supported, received output h,w = ",
output->dims->data[1], ",", output->dims->data[2],
" input h,w = ", input->dims->data[1], ",", input->dims->data[2]));
}
return OkStatus();
}
SamplingType sampling_type_ = SamplingType::UNKNOWN;
};
class SoftmaxOperationParser : public TFLiteOperationParser {
@ -2499,7 +2547,9 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
case kTfLiteBuiltinReshape:
return absl::make_unique<ReshapeOperationParser>();
case kTfLiteBuiltinResizeBilinear:
return absl::make_unique<ResizeBilinearOperationParser>();
return absl::make_unique<Resize2DOperationParser>(SamplingType::BILINEAR);
case kTfLiteBuiltinResizeNearestNeighbor:
return absl::make_unique<Resize2DOperationParser>(SamplingType::NEAREST);
case kTfLiteBuiltinRsqrt:
return absl::make_unique<ElementwiseOperationParser>(
OperationType::RSQRT);

View File

@ -146,8 +146,6 @@ std::string ToString(enum OperationType op) {
return "tanh";
case OperationType::TRANSPOSE:
return "transpose";
case OperationType::UPSAMPLE_2D:
return "upsample_2d";
default:
break;
}
@ -194,7 +192,6 @@ OperationType OperationTypeFromString(const std::string& name) {
{"subtract", OperationType::SUB},
{"tanh", OperationType::TANH},
{"transpose", OperationType::TRANSPOSE},
{"upsample_2d", OperationType::UPSAMPLE_2D},
});
auto op = operations->find(name);
return op == operations->end() ? OperationType::UNKNOWN : op->second;
@ -598,25 +595,24 @@ Padding3D CalculateSamePadding(const BHWDC& input,
}
float CalculateResizeScale(int32_t input_size, int32_t output_size,
const Upsample2DAttributes& attr) {
const Resize2DAttributes& attr) {
return attr.align_corners && input_size > 1 && output_size > 1
? static_cast<float>(input_size - 1) / (output_size - 1)
: static_cast<float>(input_size) / output_size;
}
float CalculateResizeScale(int32_t input_size, int32_t output_size,
const Upsample3DAttributes& attr) {
const Resize3DAttributes& attr) {
return attr.align_corners && input_size > 1 && output_size > 1
? static_cast<float>(input_size - 1) / (output_size - 1)
: static_cast<float>(input_size) / output_size;
}
BHWC CalculateOutputShape(const BHWC& input, const Upsample2DAttributes& attr) {
BHWC CalculateOutputShape(const BHWC& input, const Resize2DAttributes& attr) {
return BHWC(input.b, attr.new_shape.h, attr.new_shape.w, input.c);
}
BHWDC CalculateOutputShape(const BHWDC& input,
const Upsample3DAttributes& attr) {
BHWDC CalculateOutputShape(const BHWDC& input, const Resize3DAttributes& attr) {
return BHWDC(input.b, attr.new_shape.h, attr.new_shape.w, attr.new_shape.d,
input.c);
}

View File

@ -72,7 +72,6 @@ enum class OperationType {
SUB,
TANH,
TRANSPOSE,
UPSAMPLE_2D,
};
std::string ToString(enum OperationType op);
@ -360,25 +359,27 @@ struct MultiplyScalarAttributes {
param;
};
enum class UpsamplingType {
NEAREST = 0,
BILINEAR = 1,
enum class SamplingType {
UNKNOWN = 0,
NEAREST = 1,
BILINEAR = 2,
};
struct Upsample2DAttributes {
struct Resize2DAttributes {
HW new_shape;
UpsamplingType type = UpsamplingType::NEAREST;
SamplingType type = SamplingType::UNKNOWN;
// If true, the centers of the 4 corner pixels of the input and output tensors
// are aligned, preserving the values at the corner pixels. Defaults to false.
bool align_corners = false;
};
struct Upsample3DAttributes {
// TODO(b/147771327): rename to Resize3D
struct Resize3DAttributes {
HWD new_shape;
UpsamplingType type = UpsamplingType::NEAREST;
SamplingType type = SamplingType::NEAREST;
// If true, the centers of the 8 corner pixels of the input and output tensors
// are aligned, preserving the values at the corner pixels. Defaults to false.
@ -386,19 +387,18 @@ struct Upsample3DAttributes {
};
float CalculateResizeScale(int32_t input_size, int32_t output_size,
const Upsample2DAttributes& attr);
const Resize2DAttributes& attr);
float CalculateResizeScale(int32_t input_size, int32_t output_size,
const Upsample3DAttributes& attr);
const Resize3DAttributes& attr);
// @return shape of a tensor after upscale operation is applied to the given
// @return shape of a tensor after scale operation is applied to the given
// input.
BHWC CalculateOutputShape(const BHWC& input, const Upsample2DAttributes& attr);
BHWC CalculateOutputShape(const BHWC& input, const Resize2DAttributes& attr);
// @return shape of a tensor after upscale operation is applied to the given
// @return shape of a tensor after scale operation is applied to the given
// input.
BHWDC CalculateOutputShape(const BHWDC& input,
const Upsample3DAttributes& attr);
BHWDC CalculateOutputShape(const BHWDC& input, const Resize3DAttributes& attr);
enum class PaddingContentType {
ZEROS = 0,

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -84,7 +84,7 @@ std::unique_ptr<SequenceTransformation> NewRemoveSingleInputAdd() {
}
std::unique_ptr<SequenceTransformation> NewRemoveDegenerateUpsampling() {
auto type = ToString(OperationType::UPSAMPLE_2D);
auto type = ToString(OperationType::RESIZE);
return absl::make_unique<RemoveOperation>(
[type](GraphFloat32* graph, Node* node) {
if (node->operation.type != type) {

View File

@ -147,10 +147,10 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok());
output->tensor.shape = BHWC(1, 5, 5, 1);
node_to_remove->operation.type = ToString(OperationType::UPSAMPLE_2D);
Upsample2DAttributes attr;
node_to_remove->operation.type = ToString(OperationType::RESIZE);
Resize2DAttributes attr;
attr.new_shape = HW(5, 5);
attr.type = UpsamplingType::BILINEAR;
attr.type = SamplingType::BILINEAR;
node_to_remove->operation.attributes = attr;
Value<TensorRef<BHWC>>* link;

View File

@ -630,9 +630,9 @@ cc_test(
)
cc_library(
name = "upsampling_bilinear",
srcs = ["upsampling_bilinear.cc"],
hdrs = ["upsampling_bilinear.h"],
name = "resize",
srcs = ["resize.cc"],
hdrs = ["resize.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
@ -644,15 +644,15 @@ cc_library(
)
cc_test(
name = "upsampling_bilinear_test",
srcs = ["upsampling_bilinear_test.cc"],
name = "resize_test",
srcs = ["resize_test.cc"],
tags = tf_gpu_tests_tags() + [
"notap",
"tflite_not_portable_ios",
],
deps = [
":resize",
":test_util",
":upsampling_bilinear",
"//tensorflow/lite/delegates/gpu/common:operations",
"@com_google_googletest//:gtest",
],
@ -673,10 +673,10 @@ TFLITE_GPU_BINARY_RELEASE_OPERATORS = [
"relu",
"mean",
"reshape",
"resize",
"slice",
"softmax",
"transpose_conv",
"upsampling_bilinear",
]
NON_TFLITE_GPU_BINARY_RELEASE_OPERATORS = [

View File

@ -42,10 +42,10 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/gl/kernels/prelu.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/relu.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/reshape.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/slice.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.h"
#ifndef TFLITE_GPU_BINARY_RELEASE
#include "tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.h"
@ -87,10 +87,10 @@ class Registry : public NodeShader {
insert_op(Type::POOLING_2D, NewPoolingNodeShader);
insert_op(Type::PRELU, NewPReLUNodeShader);
insert_op(Type::RELU, NewReLUNodeShader);
insert_op(Type::RESIZE, NewResizeNodeShader);
insert_op(Type::RESHAPE, NewReshapeNodeShader);
insert_op(Type::SLICE, NewSliceNodeShader);
insert_op(Type::SOFTMAX, NewSoftmaxNodeShader);
insert_op(Type::UPSAMPLE_2D, NewUpsamplingNodeShader);
insert_elementwise_op(Type::ABS);
insert_elementwise_op(Type::COS);

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
#include <algorithm>
#include <cstdint>
@ -22,25 +22,25 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/gl/variable.h"
namespace tflite {
namespace gpu {
namespace gl {
namespace {
class UpsamplingBilinear : public NodeShader {
class Resize : public NodeShader {
public:
UpsamplingBilinear() {}
Resize() {}
Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final {
auto input = ctx.graph->FindInputs(ctx.node->id)[0];
auto output = ctx.graph->FindOutputs(ctx.node->id)[0];
auto attr =
absl::any_cast<Upsample2DAttributes>(ctx.node->operation.attributes);
absl::any_cast<Resize2DAttributes>(ctx.node->operation.attributes);
if (input->tensor.shape.w > output->tensor.shape.w ||
input->tensor.shape.h > output->tensor.shape.h) {
@ -54,9 +54,6 @@ class UpsamplingBilinear : public NodeShader {
if (input->tensor.shape.c != output->tensor.shape.c) {
return InvalidArgumentError("Input/output channels mismatch.");
}
if (attr.type != UpsamplingType::BILINEAR) {
return UnimplementedError("Upsample2D supports only bilinear type.");
}
if (input->tensor.shape.h == 1 && input->tensor.shape.w == 1) {
// Copy a single element from input.
*generated_code = {
@ -81,23 +78,31 @@ class UpsamplingBilinear : public NodeShader {
output->tensor.shape.h, attr))},
};
std::string source = R"(
vec2 coord = vec2(gid.xy) * $scale_factor$;
std::string source;
if (attr.type == SamplingType::BILINEAR) {
source = R"(
vec2 coord = vec2(gid.xy) * $scale_factor$;
ivec2 borders = ivec2($input_data_0_w$, $input_data_0_h$) - ivec2(1, 1);
ivec4 st;
st.xy = ivec2(coord);
st.zw = min(st.xy + ivec2(1, 1), borders);
ivec2 borders = ivec2($input_data_0_w$, $input_data_0_h$) - ivec2(1, 1);
ivec4 st;
st.xy = ivec2(coord);
st.zw = min(st.xy + ivec2(1, 1), borders);
vec2 t = coord - vec2(st.xy); //interpolating factors
vec2 t = coord - vec2(st.xy); //interpolating factors
vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$;
vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$;
vec4 tex12 = $input_data_0[st.x, st.w, gid.z]$;
vec4 tex22 = $input_data_0[st.z, st.w, gid.z]$;
vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$;
vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$;
vec4 tex12 = $input_data_0[st.x, st.w, gid.z]$;
vec4 tex22 = $input_data_0[st.z, st.w, gid.z]$;
value_0 = mix(mix(tex11, tex21, t.x), mix(tex12, tex22, t.x), t.y);
)";
value_0 = mix(mix(tex11, tex21, t.x), mix(tex12, tex22, t.x), t.y);)";
} else if (attr.type == SamplingType::NEAREST) {
source = R"(
ivec2 coord = ivec2(vec2(gid.xy) * $scale_factor$);
value_0 = $input_data_0[coord.x, coord.y, gid.z]$;
)";
} else {
return InvalidArgumentError("Unknown sampling type");
}
*generated_code = {
/*parameters=*/std::move(parameters),
/*objects=*/{},
@ -114,8 +119,8 @@ class UpsamplingBilinear : public NodeShader {
} // namespace
std::unique_ptr<NodeShader> NewUpsamplingNodeShader() {
return absl::make_unique<UpsamplingBilinear>();
std::unique_ptr<NodeShader> NewResizeNodeShader() {
return absl::make_unique<Resize>();
}
} // namespace gl

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_UPSAMPLING_BILINEAR_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_UPSAMPLING_BILINEAR_H_
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_RESIZE_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_RESIZE_H_
#include <memory>
@ -25,10 +25,10 @@ namespace tflite {
namespace gpu {
namespace gl {
std::unique_ptr<NodeShader> NewUpsamplingNodeShader();
std::unique_ptr<NodeShader> NewResizeNodeShader();
} // namespace gl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_UPSAMPLING_BILINEAR_H_
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_RESIZE_H_

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,9 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.h"
#include <vector>
#include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@ -30,7 +28,7 @@ namespace gpu {
namespace gl {
namespace {
TEST(UpsamplingBilinearTest, 1x1x2To2x2x2) {
TEST(ResizeTest, Bilinear1x1x2To2x2x2) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
@ -41,21 +39,21 @@ TEST(UpsamplingBilinearTest, 1x1x2To2x2x2) {
output.ref = 1;
output.shape = BHWC(1, 2, 2, 2);
Upsample2DAttributes attr;
Resize2DAttributes attr;
attr.align_corners = true;
attr.new_shape = HW(2, 2);
attr.type = UpsamplingType::BILINEAR;
attr.type = SamplingType::BILINEAR;
SingleOpModel model({ToString(OperationType::UPSAMPLE_2D), attr}, {input},
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0}));
ASSERT_OK(model.Invoke(*NewUpsamplingNodeShader()));
ASSERT_OK(model.Invoke(*NewResizeNodeShader()));
EXPECT_THAT(
model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0}));
}
TEST(UpsamplingBilinearTest, 1x2x1To1x4x1) {
TEST(ResizeTest, Bilinear1x2x1To1x4x1) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
@ -66,20 +64,20 @@ TEST(UpsamplingBilinearTest, 1x2x1To1x4x1) {
output.ref = 1;
output.shape = BHWC(1, 1, 4, 1);
Upsample2DAttributes attr;
Resize2DAttributes attr;
attr.align_corners = false;
attr.new_shape = HW(1, 4);
attr.type = UpsamplingType::BILINEAR;
attr.type = SamplingType::BILINEAR;
SingleOpModel model({ToString(OperationType::UPSAMPLE_2D), attr}, {input},
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 4.0}));
ASSERT_OK(model.Invoke(*NewUpsamplingNodeShader()));
ASSERT_OK(model.Invoke(*NewResizeNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 2.5, 4.0, 4.0}));
}
TEST(UpsamplingBilinearTest, 2x2x1To4x4x1) {
TEST(ResizeTest, Bilinear2x2x1To4x4x1) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
@ -90,21 +88,46 @@ TEST(UpsamplingBilinearTest, 2x2x1To4x4x1) {
output.ref = 1;
output.shape = BHWC(1, 4, 4, 1);
Upsample2DAttributes attr;
Resize2DAttributes attr;
attr.align_corners = false;
attr.new_shape = HW(4, 4);
attr.type = UpsamplingType::BILINEAR;
attr.type = SamplingType::BILINEAR;
SingleOpModel model({ToString(OperationType::UPSAMPLE_2D), attr}, {input},
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 4.0, 6.0, 8.0}));
ASSERT_OK(model.Invoke(*NewUpsamplingNodeShader()));
ASSERT_OK(model.Invoke(*NewResizeNodeShader()));
EXPECT_THAT(
model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 2.5, 4.0, 4.0, 3.5, 4.75, 6.0, 6.0, 6.0,
7.0, 8.0, 8.0, 6.0, 7.0, 8.0, 8.0}));
}
TEST(ResizeTest, Nearest1x2x1To2x4x1) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 2, 4, 1);
Resize2DAttributes attr;
attr.align_corners = false;
attr.new_shape = HW(2, 4);
attr.type = SamplingType::NEAREST;
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0}));
ASSERT_OK(model.Invoke(*NewResizeNodeShader()));
EXPECT_THAT(
model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0}));
}
} // namespace
} // namespace gl
} // namespace gpu

View File

@ -273,10 +273,10 @@ objc_library(
"//tensorflow/lite/delegates/gpu/metal/kernels:prelu_test.mm",
"//tensorflow/lite/delegates/gpu/metal/kernels:relu_test.mm",
"//tensorflow/lite/delegates/gpu/metal/kernels:reshape_test.mm",
"//tensorflow/lite/delegates/gpu/metal/kernels:resize_test.mm",
"//tensorflow/lite/delegates/gpu/metal/kernels:slice_test.mm",
"//tensorflow/lite/delegates/gpu/metal/kernels:softmax_test.mm",
"//tensorflow/lite/delegates/gpu/metal/kernels:transpose_conv_test.mm",
"//tensorflow/lite/delegates/gpu/metal/kernels:upsample_test.mm",
],
hdrs = [
],

View File

@ -40,10 +40,10 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/kernels/prelu.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/relu.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/reshape.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/resize.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/slice.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/upsample.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
@ -232,6 +232,11 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
graph, node_id, inputs[0], outputs[0],
absl::any_cast<ReshapeAttributes>(node->operation.attributes));
break;
case OperationType::RESIZE:
*tasks = Resize(
node_id, inputs[0], outputs[0],
absl::any_cast<Resize2DAttributes>(node->operation.attributes));
break;
case OperationType::SLICE:
*tasks =
Slice(node_id, inputs[0], outputs[0],
@ -245,11 +250,6 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
*tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]);
break;
}
case OperationType::UPSAMPLE_2D:
*tasks = Upsample(
node_id, inputs[0], outputs[0],
absl::any_cast<Upsample2DAttributes>(node->operation.attributes));
break;
case OperationType::ABS:
case OperationType::COS:
case OperationType::HARD_SWISH:
@ -274,7 +274,6 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
case OperationType::CONST:
case OperationType::LSTM:
case OperationType::MUL:
case OperationType::RESIZE:
case OperationType::SPACE_TO_BATCH:
case OperationType::TRANSPOSE:
case OperationType::UNKNOWN:

View File

@ -27,10 +27,10 @@ cc_library(
":prelu",
":relu",
":reshape",
":resize",
":slice",
":softmax",
":transpose_conv",
":upsample",
],
)
@ -545,6 +545,44 @@ ios_unit_test(
deps = [":relu_test_lib"],
)
cc_library(
name = "resize",
srcs = ["resize.cc"],
hdrs = ["resize.h"],
deps = [
":util",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:tensor",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
"@com_google_absl//absl/types:variant",
],
)
objc_library(
name = "resize_test_lib",
testonly = 1,
srcs = ["resize_test.mm"],
sdk_frameworks = ["XCTest"],
deps = [
":resize",
":test_util",
],
)
ios_unit_test(
name = "resize_test",
testonly = 1,
minimum_os_version = "10.0",
tags = tf_gpu_tests_tags() + [
"notap",
"tflite_not_portable_android",
],
deps = [":resize_test_lib"],
)
cc_library(
name = "reshape",
srcs = ["reshape.cc"],
@ -699,44 +737,6 @@ ios_unit_test(
deps = [":transpose_conv_test_lib"],
)
cc_library(
name = "upsample",
srcs = ["upsample.cc"],
hdrs = ["upsample.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
"@com_google_absl//absl/strings",
],
)
objc_library(
name = "upsample_test_lib",
testonly = 1,
srcs = ["upsample_test.mm"],
sdk_frameworks = ["XCTest"],
deps = [
":test_util",
":upsample",
],
)
ios_unit_test(
name = "upsample_test",
testonly = 1,
minimum_os_version = "10.0",
tags = tf_gpu_tests_tags() + [
"notap",
"tflite_not_portable_android",
],
deps = [":upsample_test_lib"],
)
cc_library(
name = "util",
srcs = ["util.cc"],

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/upsample.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/resize.h"
#include <map>
#include <memory>
@ -31,14 +31,8 @@ namespace tflite {
namespace gpu {
namespace metal {
std::vector<ComputeTaskDescriptorPtr> Upsample(
int id, ValueId input_id, ValueId output_id,
const Upsample2DAttributes& attr) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = R"(
std::string GetResizeBilinearCode() {
return R"(
#include <metal_stdlib>
using namespace metal;
$0
@ -70,6 +64,46 @@ std::vector<ComputeTaskDescriptorPtr> Upsample(
output_buffer[linear_index] = value;
}
)";
}
std::string GetResizeNearestCode() {
return R"(
#include <metal_stdlib>
using namespace metal;
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
return;
}
const int2 coord = int2(float2(gid.xy) * scale);
const int src_index = (gid.z * size.y + coord.y) * size.x + coord.x;
FLT4 value = src_buffer[src_index];
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
$2
output_buffer[linear_index] = value;
}
)";
}
std::vector<ComputeTaskDescriptorPtr> Resize(int id, ValueId input_id,
ValueId output_id,
const Resize2DAttributes& attr) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
switch (attr.type) {
case SamplingType::BILINEAR:
desc->shader_source = GetResizeBilinearCode();
break;
case SamplingType::NEAREST:
desc->shader_source = GetResizeNearestCode();
break;
default:
// Unknown sampling type
return {};
}
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UPSAMPLE_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UPSAMPLE_H_
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_RESIZE_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_RESIZE_H_
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
@ -24,12 +24,12 @@ namespace tflite {
namespace gpu {
namespace metal {
std::vector<ComputeTaskDescriptorPtr> Upsample(
int id, ValueId input_id, ValueId output_id,
const Upsample2DAttributes& attr);
std::vector<ComputeTaskDescriptorPtr> Resize(int id, ValueId input_id,
ValueId output_id,
const Resize2DAttributes& attr);
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UPSAMPLE_H_
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_RESIZE_H_

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/upsample.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/resize.h"
#import <XCTest/XCTest.h>
@ -32,21 +32,21 @@ using ::tflite::gpu::BHWC;
using ::tflite::gpu::DataType;
using ::tflite::gpu::HW;
using ::tflite::gpu::OperationType;
using ::tflite::gpu::Resize2DAttributes;
using ::tflite::gpu::SamplingType;
using ::tflite::gpu::TensorRef;
using ::tflite::gpu::Upsample2DAttributes;
using ::tflite::gpu::UpsamplingType;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
@interface UpsampleTest : XCTestCase
@interface ResizeTest : XCTestCase
@end
@implementation UpsampleTest
@implementation ResizeTest
- (void)setUp {
[super setUp];
}
- (void)testUpsamplingBilinear1x1x2To2x2x2 {
- (void)testResizeBilinear1x1x2To2x2x2 {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
@ -57,12 +57,12 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 2);
Upsample2DAttributes attr;
Resize2DAttributes attr;
attr.align_corners = true;
attr.new_shape = HW(2, 2);
attr.type = UpsamplingType::BILINEAR;
attr.type = SamplingType::BILINEAR;
SingleOpModel model({ToString(OperationType::UPSAMPLE_2D), attr}, {input}, {output});
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
@ -70,7 +70,7 @@ using ::tflite::gpu::metal::SingleOpModel;
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
}
- (void)testUpsamplingBilinear1x2x1To1x4x1 {
- (void)testResizeBilinear1x2x1To1x4x1 {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
@ -81,12 +81,12 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 1;
output.shape = BHWC(1, 1, 4, 1);
Upsample2DAttributes attr;
Resize2DAttributes attr;
attr.align_corners = false;
attr.new_shape = HW(1, 4);
attr.type = UpsamplingType::BILINEAR;
attr.type = SamplingType::BILINEAR;
SingleOpModel model({ToString(OperationType::UPSAMPLE_2D), attr}, {input}, {output});
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1.0, 4.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
@ -94,7 +94,7 @@ using ::tflite::gpu::metal::SingleOpModel;
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
}
- (void)testUpsamplingBilinear2x2x1To4x4x1 {
- (void)testResizeBilinear2x2x1To4x4x1 {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
@ -105,12 +105,12 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 1;
output.shape = BHWC(1, 4, 4, 1);
Upsample2DAttributes attr;
Resize2DAttributes attr;
attr.align_corners = false;
attr.new_shape = HW(4, 4);
attr.type = UpsamplingType::BILINEAR;
attr.type = SamplingType::BILINEAR;
SingleOpModel model({ToString(OperationType::UPSAMPLE_2D), attr}, {input}, {output});
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1.0, 4.0, 6.0, 8.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
@ -120,4 +120,28 @@ using ::tflite::gpu::metal::SingleOpModel;
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
}
- (void)testResizeNearest1x2x1To2x4x1 {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 2, 4, 1);
Resize2DAttributes attr;
attr.align_corners = false;
attr.new_shape = HW(2, 4);
attr.type = SamplingType::NEAREST;
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
status = CompareVectors({1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
}
@end