Elementwise tests logic moved to elementwise_test_util that can be reused among different backends.
Added new Elementwise tests in Metal. PiperOrigin-RevId: 352016582 Change-Id: Idb3c40251aab8f724d8ea436fdb984e2dae177b5
This commit is contained in:
parent
0e1f4fd484
commit
4bb669c09c
@ -266,7 +266,7 @@ cc_test(
|
||||
":cl_test",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:elementwise",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:elementwise_test_util",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -295,6 +295,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "elementwise_test_util",
|
||||
testonly = 1,
|
||||
srcs = ["elementwise_test_util.cc"],
|
||||
hdrs = ["elementwise_test_util.h"],
|
||||
deps = [
|
||||
":elementwise",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:testing_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "fully_connected",
|
||||
srcs = ["fully_connected.cc"],
|
||||
|
1064
tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.cc
Normal file
1064
tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,66 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_ELEMENTWISE_TEST_UTIL_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_ELEMENTWISE_TEST_UTIL_H_
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
absl::Status AbsTest(TestExecutionEnvironment* env);
|
||||
absl::Status CosTest(TestExecutionEnvironment* env);
|
||||
absl::Status CopyTest(TestExecutionEnvironment* env);
|
||||
absl::Status EluTest(TestExecutionEnvironment* env);
|
||||
absl::Status ExpTest(TestExecutionEnvironment* env);
|
||||
absl::Status HardSwishTest(TestExecutionEnvironment* env);
|
||||
absl::Status LogTest(TestExecutionEnvironment* env);
|
||||
absl::Status NegTest(TestExecutionEnvironment* env);
|
||||
absl::Status RsqrtTest(TestExecutionEnvironment* env);
|
||||
absl::Status SigmoidTest(TestExecutionEnvironment* env);
|
||||
absl::Status SinTest(TestExecutionEnvironment* env);
|
||||
absl::Status SqrtTest(TestExecutionEnvironment* env);
|
||||
absl::Status SquareTest(TestExecutionEnvironment* env);
|
||||
absl::Status TanhTest(TestExecutionEnvironment* env);
|
||||
absl::Status SubTest(TestExecutionEnvironment* env);
|
||||
absl::Status SquaredDiffTest(TestExecutionEnvironment* env);
|
||||
absl::Status DivTest(TestExecutionEnvironment* env);
|
||||
absl::Status PowTest(TestExecutionEnvironment* env);
|
||||
absl::Status AddTest(TestExecutionEnvironment* env);
|
||||
absl::Status MaximumTest(TestExecutionEnvironment* env);
|
||||
absl::Status MaximumWithScalarTest(TestExecutionEnvironment* env);
|
||||
absl::Status MaximumWithConstantLinearTensorTest(TestExecutionEnvironment* env);
|
||||
absl::Status MaximumWithConstantHWCTensorTest(TestExecutionEnvironment* env);
|
||||
absl::Status MaximumWithConstantHWCTensorBroadcastChannelsTest(
|
||||
TestExecutionEnvironment* env);
|
||||
absl::Status MinimumTest(TestExecutionEnvironment* env);
|
||||
absl::Status MinimumWithScalarTest(TestExecutionEnvironment* env);
|
||||
absl::Status MulTest(TestExecutionEnvironment* env);
|
||||
absl::Status MulBroadcastHWTest(TestExecutionEnvironment* env);
|
||||
absl::Status MulBroadcastChannelsTest(TestExecutionEnvironment* env);
|
||||
absl::Status SubWithScalarAtFirstPositionTest(TestExecutionEnvironment* env);
|
||||
absl::Status LessTest(TestExecutionEnvironment* env);
|
||||
absl::Status LessEqualTest(TestExecutionEnvironment* env);
|
||||
absl::Status GreaterTest(TestExecutionEnvironment* env);
|
||||
absl::Status GreaterEqualTest(TestExecutionEnvironment* env);
|
||||
absl::Status EqualTest(TestExecutionEnvironment* env);
|
||||
absl::Status NotEqualTest(TestExecutionEnvironment* env);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_ELEMENTWISE_TEST_UTIL_H_
|
@ -231,6 +231,7 @@ objc_library(
|
||||
deps = [
|
||||
":elementwise",
|
||||
":test_util",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:elementwise_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -904,6 +905,7 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:add_test_util",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:elementwise_test_util",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:prelu_test_util",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize_test_util",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:relu_test_util",
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h"
|
||||
|
||||
using ::tflite::gpu::DataType;
|
||||
using ::tflite::gpu::HWC;
|
||||
@ -39,7 +40,10 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
@interface ElementwiseTest : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation ElementwiseTest
|
||||
@implementation ElementwiseTest {
|
||||
tflite::gpu::metal::MetalExecutionEnvironment exec_env_;
|
||||
}
|
||||
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
}
|
||||
@ -419,4 +423,184 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testAbsUnit {
|
||||
auto status = AbsTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testCosUnit {
|
||||
auto status = CosTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testCopyUnit {
|
||||
auto status = CopyTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testEluUnit {
|
||||
auto status = EluTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testExpUnit {
|
||||
auto status = ExpTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testHardSwishUnit {
|
||||
auto status = HardSwishTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testLogUnit {
|
||||
auto status = LogTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testNegUnit {
|
||||
auto status = NegTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testRsqrtUnit {
|
||||
auto status = RsqrtTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSigmoidUnit {
|
||||
auto status = SigmoidTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSinUnit {
|
||||
auto status = SinTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSqrtUnit {
|
||||
auto status = SqrtTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSquareUnit {
|
||||
auto status = SquareTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testTanhUnit {
|
||||
auto status = TanhTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSubUnit {
|
||||
auto status = SubTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSquaredDiffUnit {
|
||||
auto status = SquaredDiffTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testDivUnit {
|
||||
auto status = DivTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testPowUnit {
|
||||
auto status = PowTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testAddUnit {
|
||||
auto status = AddTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMaximumUnit {
|
||||
auto status = MaximumTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMaximumWithScalarUnit {
|
||||
auto status = MaximumWithScalarTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMaximumWithConstantLinearTensorUnit {
|
||||
auto status = MaximumWithConstantLinearTensorTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMaximumWithConstantHWCTensorUnit {
|
||||
auto status = MaximumWithConstantHWCTensorTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMaximumWithConstantHWCTensorBroadcastChannelsUnit {
|
||||
auto status = MaximumWithConstantHWCTensorBroadcastChannelsTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMinimumUnit {
|
||||
auto status = MinimumTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMinimumWithScalarUnit {
|
||||
auto status = MinimumWithScalarTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMulUnit {
|
||||
auto status = MulTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMulBroadcastHWUnit {
|
||||
auto status = MulBroadcastHWTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testMulBroadcastChannelsUnit {
|
||||
auto status = MulBroadcastChannelsTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSubWithScalarAtFirstPositionUnit {
|
||||
auto status = SubWithScalarAtFirstPositionTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testLessUnit {
|
||||
auto status = LessTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testLessEqualUnit {
|
||||
auto status = LessEqualTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testGreaterUnit {
|
||||
auto status = GreaterTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testGreaterEqualUnit {
|
||||
auto status = GreaterEqualTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testEqualUnit {
|
||||
auto status = EqualTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testNotEqualUnit {
|
||||
auto status = NotEqualTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
@end
|
||||
|
Loading…
Reference in New Issue
Block a user