Reshape and ReshapeX4 tasks modified to be Metal compatible.

Added reshape_test_util with unified tests.
Added Metal reshape unit tests.

PiperOrigin-RevId: 352853802
Change-Id: I995f72b20e9fc4feb91313bc04475ea323c93e56
This commit is contained in:
Raman Sarokin 2021-01-20 12:45:46 -08:00 committed by TensorFlower Gardener
parent 053d50118d
commit 9e5c6794af
10 changed files with 173 additions and 79 deletions

View File

@ -477,7 +477,7 @@ cc_test(
":cl_test",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/tasks:reshape",
"//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util",
"@com_google_googletest//:gtest_main",
],
)
@ -494,7 +494,7 @@ cc_test(
":cl_test",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/tasks:reshapex4",
"//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/tasks/reshape.h"
#include <vector>
#include <gmock/gmock.h>
@ -22,9 +20,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
using ::testing::FloatNear;
using ::testing::Pointwise;
#include "tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h"
namespace tflite {
namespace gpu {
@ -32,30 +28,8 @@ namespace cl {
namespace {
TEST_F(OpenCLOperationTest, Reshape) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 1, 3);
src_tensor.data = {half(0.5f), half(-1.1f), half(-2.2f),
half(3.1f), half(1.2f), half(2.9f)};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation = CreateReshape(op_def);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
BHWC(1, 3, 1, 2), &dst_tensor));
EXPECT_THAT(
dst_tensor.data,
Pointwise(FloatNear(0.0f), {half(0.5f), half(-1.1f), half(-2.2f),
half(3.1f), half(1.2f), half(2.9f)}));
}
}
auto status = ReshapeTest(&exec_env_);
ASSERT_TRUE(status.ok()) << status.error_message();
}
} // namespace

View File

@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/tasks/reshapex4.h"
#include <vector>
#include <gmock/gmock.h>
@ -22,9 +20,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
using ::testing::FloatNear;
using ::testing::Pointwise;
#include "tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h"
namespace tflite {
namespace gpu {
@ -32,30 +28,8 @@ namespace cl {
namespace {
TEST_F(OpenCLOperationTest, Reshapex4) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 1, 1, 8);
src_tensor.data = {half(0.5f), half(-1.1f), half(-2.2f), half(3.1f),
half(1.2f), half(2.9f), half(4.2f), half(-1.9f)};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation = CreateReshapex4(op_def);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
BHWC(1, 1, 2, 4), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(0.0f),
{half(0.5f), half(-1.1f), half(-2.2f), half(3.1f),
half(1.2f), half(2.9f), half(4.2f), half(-1.9f)}));
}
}
auto status = Reshapex4Test(&exec_env_);
ASSERT_TRUE(status.ok()) << status.error_message();
}
} // namespace

View File

@ -573,6 +573,20 @@ cc_library(
],
)
cc_library(
name = "reshape_test_util",
testonly = 1,
srcs = ["reshape_test_util.cc"],
hdrs = ["reshape_test_util.h"],
deps = [
":reshape",
":reshapex4",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/task:testing_util",
],
)
cc_library(
name = "resize",
srcs = ["resize.cc"],

View File

@ -25,27 +25,26 @@ namespace gpu {
namespace {
std::string GetReshapeCode(const OperationDef& op_def) {
std::string c;
c += "__kernel void main_function(\n";
c += "$0) {\n";
c += "MAIN_FUNCTION($0) {\n";
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int linear_id = get_global_id(0);\n";
c += " int linear_id = GLOBAL_ID_0;\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " args.dst_tensor.SetBatchRef(B);\n";
} else {
c += " int X = get_global_id(0);\n";
c += " int X = GLOBAL_ID_0;\n";
}
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " int Y = GLOBAL_ID_1;\n";
c += " int Z = GLOBAL_ID_2;\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) { \n";
c += " return; \n";
c += " } \n";
c += " FLT temps[4];\n";
c += " temps[0] = (FLT)(0.0f);\n";
c += " temps[1] = (FLT)(0.0f);\n";
c += " temps[2] = (FLT)(0.0f);\n";
c += " temps[3] = (FLT)(0.0f);\n";
c += " temps[0] = INIT_FLT(0.0f);\n";
c += " temps[1] = INIT_FLT(0.0f);\n";
c += " temps[2] = INIT_FLT(0.0f);\n";
c += " temps[3] = INIT_FLT(0.0f);\n";
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int base = B;\n";
} else {
@ -73,7 +72,11 @@ std::string GetReshapeCode(const OperationDef& op_def) {
c += " temps[i] = t_ar[src_sub_ch];\n";
c += " }\n";
c += " }\n";
c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n";
c += " FLT4 result;\n";
c += " result.x = temps[0];\n";
c += " result.y = temps[1];\n";
c += " result.z = temps[2];\n";
c += " result.w = temps[3];\n";
c += " args.dst_tensor.Write(result, X, Y, Z);\n";
c += "}\n";
return c;

View File

@ -0,0 +1,83 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h"
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/reshape.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/reshapex4.h"
namespace tflite {
namespace gpu {
absl::Status ReshapeTest(TestExecutionEnvironment* env) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 1, 3);
src_tensor.data = {half(0.5f), half(-1.1f), half(-2.2f),
half(3.1f), half(1.2f), half(2.9f)};
for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation = CreateReshape(op_def);
RETURN_IF_ERROR(env->ExecuteGPUOperation(
src_tensor, absl::make_unique<GPUOperation>(std::move(operation)),
BHWC(1, 3, 1, 2), &dst_tensor));
RETURN_IF_ERROR(PointWiseNear({half(0.5f), half(-1.1f), half(-2.2f),
half(3.1f), half(1.2f), half(2.9f)},
dst_tensor.data, 0.0));
}
}
return absl::OkStatus();
}
absl::Status Reshapex4Test(TestExecutionEnvironment* env) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 1, 1, 8);
src_tensor.data = {half(0.5f), half(-1.1f), half(-2.2f), half(3.1f),
half(1.2f), half(2.9f), half(4.2f), half(-1.9f)};
for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation = CreateReshapex4(op_def);
RETURN_IF_ERROR(env->ExecuteGPUOperation(
src_tensor, absl::make_unique<GPUOperation>(std::move(operation)),
BHWC(1, 1, 2, 4), &dst_tensor));
RETURN_IF_ERROR(
PointWiseNear({half(0.5f), half(-1.1f), half(-2.2f), half(3.1f),
half(1.2f), half(2.9f), half(4.2f), half(-1.9f)},
dst_tensor.data, 0.0f));
}
}
return absl::OkStatus();
}
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,31 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_RESHAPE_TEST_UTIL_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_RESHAPE_TEST_UTIL_H_
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
namespace tflite {
namespace gpu {
absl::Status ReshapeTest(TestExecutionEnvironment* env);
absl::Status Reshapex4Test(TestExecutionEnvironment* env);
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_RESHAPE_TEST_UTIL_H_

View File

@ -26,18 +26,17 @@ namespace {
std::string GetReshapeCode(const OperationDef& op_def) {
std::string c;
c += "__kernel void main_function(\n";
c += "$0) {\n";
c += "MAIN_FUNCTION($0) {\n";
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int linear_id = get_global_id(0);\n";
c += " int linear_id = GLOBAL_ID_0;\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " args.dst_tensor.SetBatchRef(B);\n";
} else {
c += " int X = get_global_id(0);\n";
c += " int X = GLOBAL_ID_0;\n";
}
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " int Y = GLOBAL_ID_1;\n";
c += " int Z = GLOBAL_ID_2;\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) { \n";
c += " return; \n";

View File

@ -562,6 +562,7 @@ objc_library(
deps = [
":reshape",
":test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util",
],
)
@ -883,6 +884,7 @@ objc_library(
"//tensorflow/lite/delegates/gpu/common/tasks:prelu_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:relu_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:space_to_depth_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:strided_slice_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:transpose_test_util",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@ -36,10 +37,13 @@ using ::tflite::gpu::TensorRef;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
@interface ReshapeTest : XCTestCase
@interface ReshapeMetalTest : XCTestCase
@end
@implementation ReshapeTest
@implementation ReshapeMetalTest {
tflite::gpu::metal::MetalExecutionEnvironment exec_env_;
}
- (void)setUp {
[super setUp];
}
@ -154,4 +158,14 @@ using ::tflite::gpu::metal::SingleOpModel;
@"%s", std::string(status.message()).c_str());
}
- (void)testReshape {
auto status = ReshapeTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testReshapex4 {
auto status = Reshapex4Test(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
@end