diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index 31a694604c4..b7a5ad8ea70 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -477,7 +477,7 @@ cc_test( ":cl_test", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common/tasks:reshape", + "//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util", "@com_google_googletest//:gtest_main", ], ) @@ -494,7 +494,7 @@ cc_test( ":cl_test", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common/tasks:reshapex4", + "//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshape_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reshape_test.cc index 93321b8eb0a..4949a18c2a3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshape_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshape_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/common/tasks/reshape.h" - #include #include @@ -22,9 +20,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" - -using ::testing::FloatNear; -using ::testing::Pointwise; +#include "tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h" namespace tflite { namespace gpu { @@ -32,30 +28,8 @@ namespace cl { namespace { TEST_F(OpenCLOperationTest, Reshape) { - TensorFloat32 src_tensor; - src_tensor.shape = BHWC(1, 2, 1, 3); - src_tensor.data = {half(0.5f), half(-1.1f), half(-2.2f), - half(3.1f), half(1.2f), half(2.9f)}; - - for (auto storage : env_.GetSupportedStorages()) { - for (auto precision : env_.GetSupportedPrecisions()) { - OperationDef op_def; - op_def.precision = precision; - auto data_type = DeduceDataTypeFromPrecision(precision); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); - TensorFloat32 dst_tensor; - GPUOperation operation = CreateReshape(op_def); - ASSERT_OK(ExecuteGPUOperation( - src_tensor, creation_context_, - absl::make_unique(std::move(operation)), - BHWC(1, 3, 1, 2), &dst_tensor)); - EXPECT_THAT( - dst_tensor.data, - Pointwise(FloatNear(0.0f), {half(0.5f), half(-1.1f), half(-2.2f), - half(3.1f), half(1.2f), half(2.9f)})); - } - } + auto status = ReshapeTest(&exec_env_); + ASSERT_TRUE(status.ok()) << status.error_message(); } } // namespace diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4_test.cc index 2e51fc2a281..bb1db97e660 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/common/tasks/reshapex4.h" - #include #include @@ -22,9 +20,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" - -using ::testing::FloatNear; -using ::testing::Pointwise; +#include "tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h" namespace tflite { namespace gpu { @@ -32,30 +28,8 @@ namespace cl { namespace { TEST_F(OpenCLOperationTest, Reshapex4) { - TensorFloat32 src_tensor; - src_tensor.shape = BHWC(1, 1, 1, 8); - src_tensor.data = {half(0.5f), half(-1.1f), half(-2.2f), half(3.1f), - half(1.2f), half(2.9f), half(4.2f), half(-1.9f)}; - - for (auto storage : env_.GetSupportedStorages()) { - for (auto precision : env_.GetSupportedPrecisions()) { - OperationDef op_def; - op_def.precision = precision; - auto data_type = DeduceDataTypeFromPrecision(precision); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); - TensorFloat32 dst_tensor; - GPUOperation operation = CreateReshapex4(op_def); - ASSERT_OK(ExecuteGPUOperation( - src_tensor, creation_context_, - absl::make_unique(std::move(operation)), - BHWC(1, 1, 2, 4), &dst_tensor)); - EXPECT_THAT(dst_tensor.data, - Pointwise(FloatNear(0.0f), - {half(0.5f), half(-1.1f), half(-2.2f), half(3.1f), - half(1.2f), half(2.9f), half(4.2f), half(-1.9f)})); - } - } + auto status = Reshapex4Test(&exec_env_); + ASSERT_TRUE(status.ok()) << status.error_message(); } } // namespace diff --git a/tensorflow/lite/delegates/gpu/common/tasks/BUILD b/tensorflow/lite/delegates/gpu/common/tasks/BUILD index 08c6177fdc2..1e3d40081cc 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/BUILD +++ b/tensorflow/lite/delegates/gpu/common/tasks/BUILD @@ -573,6 +573,20 @@ cc_library( ], ) +cc_library( + name = "reshape_test_util", + testonly = 1, + srcs = ["reshape_test_util.cc"], + hdrs = ["reshape_test_util.h"], + deps = [ + ":reshape", + ":reshapex4", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/task:testing_util", + ], +) + cc_library( name = "resize", srcs = ["resize.cc"], diff --git a/tensorflow/lite/delegates/gpu/common/tasks/reshape.cc b/tensorflow/lite/delegates/gpu/common/tasks/reshape.cc index 10097720b96..d5bc3ef7bc8 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/reshape.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/reshape.cc @@ -25,27 +25,26 @@ namespace gpu { namespace { std::string GetReshapeCode(const OperationDef& op_def) { std::string c; - c += "__kernel void main_function(\n"; - c += "$0) {\n"; + c += "MAIN_FUNCTION($0) {\n"; if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { - c += " int linear_id = get_global_id(0);\n"; + c += " int linear_id = GLOBAL_ID_0;\n"; c += " int X = linear_id / args.dst_tensor.Batch();\n"; c += " int B = linear_id % args.dst_tensor.Batch();\n"; c += " args.dst_tensor.SetBatchRef(B);\n"; } else { - c += " int X = get_global_id(0);\n"; + c += " int X = GLOBAL_ID_0;\n"; } - c += " int Y = get_global_id(1);\n"; - c += " int Z = get_global_id(2);\n"; + c += " int Y = GLOBAL_ID_1;\n"; + c += " int Z = GLOBAL_ID_2;\n"; c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || " "Z >= args.dst_tensor.Slices()) { \n"; c += " return; \n"; c += " } \n"; c += " FLT temps[4];\n"; - c += " temps[0] = (FLT)(0.0f);\n"; - c += " temps[1] = (FLT)(0.0f);\n"; - c += " temps[2] = (FLT)(0.0f);\n"; - c += " temps[3] = (FLT)(0.0f);\n"; + c += " temps[0] = INIT_FLT(0.0f);\n"; + c += " temps[1] = INIT_FLT(0.0f);\n"; + c += " temps[2] = INIT_FLT(0.0f);\n"; + c += " temps[3] = INIT_FLT(0.0f);\n"; if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { c += " int base = B;\n"; } else { @@ -73,7 +72,11 @@ std::string GetReshapeCode(const OperationDef& op_def) { c += " temps[i] = t_ar[src_sub_ch];\n"; c += " }\n"; c += " }\n"; - c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n"; + c += " FLT4 result;\n"; + c += " result.x = temps[0];\n"; + c += " result.y = temps[1];\n"; + c += " result.z = temps[2];\n"; + c += " result.w = temps[3];\n"; c += " args.dst_tensor.Write(result, X, Y, Z);\n"; c += "}\n"; return c; diff --git a/tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.cc b/tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.cc new file mode 100644 index 00000000000..13799d56171 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.cc @@ -0,0 +1,83 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h" + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h" +#include "tensorflow/lite/delegates/gpu/common/tasks/reshape.h" +#include "tensorflow/lite/delegates/gpu/common/tasks/reshapex4.h" + +namespace tflite { +namespace gpu { + +absl::Status ReshapeTest(TestExecutionEnvironment* env) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 3); + src_tensor.data = {half(0.5f), half(-1.1f), half(-2.2f), + half(3.1f), half(1.2f), half(2.9f)}; + + for (auto storage : env->GetSupportedStorages()) { + for (auto precision : env->GetSupportedPrecisions()) { + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = CreateReshape(op_def); + RETURN_IF_ERROR(env->ExecuteGPUOperation( + src_tensor, absl::make_unique(std::move(operation)), + BHWC(1, 3, 1, 2), &dst_tensor)); + RETURN_IF_ERROR(PointWiseNear({half(0.5f), half(-1.1f), half(-2.2f), + half(3.1f), half(1.2f), half(2.9f)}, + dst_tensor.data, 0.0)); + } + } + return absl::OkStatus(); +} + +absl::Status Reshapex4Test(TestExecutionEnvironment* env) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 1, 1, 8); + src_tensor.data = {half(0.5f), half(-1.1f), half(-2.2f), half(3.1f), + half(1.2f), half(2.9f), half(4.2f), half(-1.9f)}; + + for (auto storage : env->GetSupportedStorages()) { + for (auto precision : env->GetSupportedPrecisions()) { + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = CreateReshapex4(op_def); + RETURN_IF_ERROR(env->ExecuteGPUOperation( + src_tensor, absl::make_unique(std::move(operation)), + BHWC(1, 1, 2, 4), &dst_tensor)); + RETURN_IF_ERROR( + PointWiseNear({half(0.5f), half(-1.1f), half(-2.2f), half(3.1f), + half(1.2f), half(2.9f), half(4.2f), half(-1.9f)}, + dst_tensor.data, 0.0f)); + } + } + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h b/tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h new file mode 100644 index 00000000000..6b2652b687a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_RESHAPE_TEST_UTIL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_RESHAPE_TEST_UTIL_H_ + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h" + +namespace tflite { +namespace gpu { + +absl::Status ReshapeTest(TestExecutionEnvironment* env); +absl::Status Reshapex4Test(TestExecutionEnvironment* env); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_RESHAPE_TEST_UTIL_H_ diff --git a/tensorflow/lite/delegates/gpu/common/tasks/reshapex4.cc b/tensorflow/lite/delegates/gpu/common/tasks/reshapex4.cc index f5c481996c8..4ee2e50e492 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/reshapex4.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/reshapex4.cc @@ -26,18 +26,17 @@ namespace { std::string GetReshapeCode(const OperationDef& op_def) { std::string c; - c += "__kernel void main_function(\n"; - c += "$0) {\n"; + c += "MAIN_FUNCTION($0) {\n"; if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { - c += " int linear_id = get_global_id(0);\n"; + c += " int linear_id = GLOBAL_ID_0;\n"; c += " int X = linear_id / args.dst_tensor.Batch();\n"; c += " int B = linear_id % args.dst_tensor.Batch();\n"; c += " args.dst_tensor.SetBatchRef(B);\n"; } else { - c += " int X = get_global_id(0);\n"; + c += " int X = GLOBAL_ID_0;\n"; } - c += " int Y = get_global_id(1);\n"; - c += " int Z = get_global_id(2);\n"; + c += " int Y = GLOBAL_ID_1;\n"; + c += " int Z = GLOBAL_ID_2;\n"; c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || " "Z >= args.dst_tensor.Slices()) { \n"; c += " return; \n"; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD index acfa58b9db2..af59535e6ad 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -562,6 +562,7 @@ objc_library( deps = [ ":reshape", ":test_util", + "//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util", ], ) @@ -883,6 +884,7 @@ objc_library( "//tensorflow/lite/delegates/gpu/common/tasks:prelu_test_util", "//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize_test_util", "//tensorflow/lite/delegates/gpu/common/tasks:relu_test_util", + "//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util", "//tensorflow/lite/delegates/gpu/common/tasks:space_to_depth_test_util", "//tensorflow/lite/delegates/gpu/common/tasks:strided_slice_test_util", "//tensorflow/lite/delegates/gpu/common/tasks:transpose_test_util", diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm index 1b36988f5b0..a1979279fac 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tasks/reshape_test_util.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" @@ -36,10 +37,13 @@ using ::tflite::gpu::TensorRef; using ::tflite::gpu::metal::CompareVectors; using ::tflite::gpu::metal::SingleOpModel; -@interface ReshapeTest : XCTestCase +@interface ReshapeMetalTest : XCTestCase @end -@implementation ReshapeTest +@implementation ReshapeMetalTest { + tflite::gpu::metal::MetalExecutionEnvironment exec_env_; +} + - (void)setUp { [super setUp]; } @@ -154,4 +158,14 @@ using ::tflite::gpu::metal::SingleOpModel; @"%s", std::string(status.message()).c_str()); } +- (void)testReshape { + auto status = ReshapeTest(&exec_env_); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); +} + +- (void)testReshapex4 { + auto status = Reshapex4Test(&exec_env_); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); +} + @end