Elementwise tests logic moved to elementwise_test_util that can be reused among different backends.

Added new Elementwise tests in Metal.

PiperOrigin-RevId: 352016582
Change-Id: Idb3c40251aab8f724d8ea436fdb984e2dae177b5
This commit is contained in:
Raman Sarokin 2021-01-15 08:43:22 -08:00 committed by TensorFlower Gardener
parent 0e1f4fd484
commit 4bb669c09c
7 changed files with 1406 additions and 933 deletions

View File

@ -266,7 +266,7 @@ cc_test(
":cl_test",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/tasks:elementwise",
"//tensorflow/lite/delegates/gpu/common/tasks:elementwise_test_util",
"@com_google_googletest//:gtest_main",
],
)

File diff suppressed because it is too large Load Diff

View File

@ -295,6 +295,19 @@ cc_library(
],
)
cc_library(
name = "elementwise_test_util",
testonly = 1,
srcs = ["elementwise_test_util.cc"],
hdrs = ["elementwise_test_util.h"],
deps = [
":elementwise",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/task:testing_util",
],
)
cc_library(
name = "fully_connected",
srcs = ["fully_connected.cc"],

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,66 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_ELEMENTWISE_TEST_UTIL_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_ELEMENTWISE_TEST_UTIL_H_
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
namespace tflite {
namespace gpu {
absl::Status AbsTest(TestExecutionEnvironment* env);
absl::Status CosTest(TestExecutionEnvironment* env);
absl::Status CopyTest(TestExecutionEnvironment* env);
absl::Status EluTest(TestExecutionEnvironment* env);
absl::Status ExpTest(TestExecutionEnvironment* env);
absl::Status HardSwishTest(TestExecutionEnvironment* env);
absl::Status LogTest(TestExecutionEnvironment* env);
absl::Status NegTest(TestExecutionEnvironment* env);
absl::Status RsqrtTest(TestExecutionEnvironment* env);
absl::Status SigmoidTest(TestExecutionEnvironment* env);
absl::Status SinTest(TestExecutionEnvironment* env);
absl::Status SqrtTest(TestExecutionEnvironment* env);
absl::Status SquareTest(TestExecutionEnvironment* env);
absl::Status TanhTest(TestExecutionEnvironment* env);
absl::Status SubTest(TestExecutionEnvironment* env);
absl::Status SquaredDiffTest(TestExecutionEnvironment* env);
absl::Status DivTest(TestExecutionEnvironment* env);
absl::Status PowTest(TestExecutionEnvironment* env);
absl::Status AddTest(TestExecutionEnvironment* env);
absl::Status MaximumTest(TestExecutionEnvironment* env);
absl::Status MaximumWithScalarTest(TestExecutionEnvironment* env);
absl::Status MaximumWithConstantLinearTensorTest(TestExecutionEnvironment* env);
absl::Status MaximumWithConstantHWCTensorTest(TestExecutionEnvironment* env);
absl::Status MaximumWithConstantHWCTensorBroadcastChannelsTest(
TestExecutionEnvironment* env);
absl::Status MinimumTest(TestExecutionEnvironment* env);
absl::Status MinimumWithScalarTest(TestExecutionEnvironment* env);
absl::Status MulTest(TestExecutionEnvironment* env);
absl::Status MulBroadcastHWTest(TestExecutionEnvironment* env);
absl::Status MulBroadcastChannelsTest(TestExecutionEnvironment* env);
absl::Status SubWithScalarAtFirstPositionTest(TestExecutionEnvironment* env);
absl::Status LessTest(TestExecutionEnvironment* env);
absl::Status LessEqualTest(TestExecutionEnvironment* env);
absl::Status GreaterTest(TestExecutionEnvironment* env);
absl::Status GreaterEqualTest(TestExecutionEnvironment* env);
absl::Status EqualTest(TestExecutionEnvironment* env);
absl::Status NotEqualTest(TestExecutionEnvironment* env);
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_ELEMENTWISE_TEST_UTIL_H_

View File

@ -231,6 +231,7 @@ objc_library(
deps = [
":elementwise",
":test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:elementwise_test_util",
],
)
@ -904,6 +905,7 @@ objc_library(
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/tasks:add_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:elementwise_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:prelu_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:relu_test_util",

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
#import <XCTest/XCTest.h>
@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h"
using ::tflite::gpu::DataType;
using ::tflite::gpu::HWC;
@ -39,7 +40,10 @@ using ::tflite::gpu::metal::SingleOpModel;
@interface ElementwiseTest : XCTestCase
@end
@implementation ElementwiseTest
@implementation ElementwiseTest {
tflite::gpu::metal::MetalExecutionEnvironment exec_env_;
}
- (void)setUp {
[super setUp];
}
@ -419,4 +423,184 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testAbsUnit {
auto status = AbsTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testCosUnit {
auto status = CosTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testCopyUnit {
auto status = CopyTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testEluUnit {
auto status = EluTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testExpUnit {
auto status = ExpTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testHardSwishUnit {
auto status = HardSwishTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testLogUnit {
auto status = LogTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testNegUnit {
auto status = NegTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testRsqrtUnit {
auto status = RsqrtTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSigmoidUnit {
auto status = SigmoidTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSinUnit {
auto status = SinTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSqrtUnit {
auto status = SqrtTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSquareUnit {
auto status = SquareTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testTanhUnit {
auto status = TanhTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSubUnit {
auto status = SubTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSquaredDiffUnit {
auto status = SquaredDiffTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testDivUnit {
auto status = DivTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testPowUnit {
auto status = PowTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testAddUnit {
auto status = AddTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMaximumUnit {
auto status = MaximumTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMaximumWithScalarUnit {
auto status = MaximumWithScalarTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMaximumWithConstantLinearTensorUnit {
auto status = MaximumWithConstantLinearTensorTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMaximumWithConstantHWCTensorUnit {
auto status = MaximumWithConstantHWCTensorTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMaximumWithConstantHWCTensorBroadcastChannelsUnit {
auto status = MaximumWithConstantHWCTensorBroadcastChannelsTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMinimumUnit {
auto status = MinimumTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMinimumWithScalarUnit {
auto status = MinimumWithScalarTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMulUnit {
auto status = MulTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMulBroadcastHWUnit {
auto status = MulBroadcastHWTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testMulBroadcastChannelsUnit {
auto status = MulBroadcastChannelsTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSubWithScalarAtFirstPositionUnit {
auto status = SubWithScalarAtFirstPositionTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testLessUnit {
auto status = LessTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testLessEqualUnit {
auto status = LessEqualTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testGreaterUnit {
auto status = GreaterTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testGreaterEqualUnit {
auto status = GreaterEqualTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testEqualUnit {
auto status = EqualTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testNotEqualUnit {
auto status = NotEqualTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
@end