Softmax1x1 task modified to be Metal compatible.
Added Metal softmax1x1 unit tests. PiperOrigin-RevId: 353025248 Change-Id: I1e4e34df391faff04634181da79e81468fa7b18c
This commit is contained in:
parent
f3afd96a6a
commit
be308cc157
@ -528,7 +528,7 @@ cc_test(
|
||||
":cl_test",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:softmax1x1",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:softmax_test_util",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
@ -13,10 +13,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
@ -24,9 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Pointwise;
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/softmax_test_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
@ -34,67 +28,13 @@ namespace cl {
|
||||
namespace {
|
||||
|
||||
TEST_F(OpenCLOperationTest, Softmax1x1) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 1, 1, 4);
|
||||
src_tensor.data = {std::log(1.0f), std::log(2.0f), std::log(3.0f),
|
||||
std::log(4.0f)};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
Softmax1x1 operation = CreateSoftmax1x1(op_def);
|
||||
ASSERT_OK(ExecuteGPUOperation(
|
||||
src_tensor, creation_context_,
|
||||
absl::make_unique<Softmax1x1>(std::move(operation)), BHWC(1, 1, 1, 4),
|
||||
&dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {0.1f, 0.2f, 0.3f, 0.4f}));
|
||||
}
|
||||
}
|
||||
auto status = Softmax1x1Test(&exec_env_);
|
||||
ASSERT_TRUE(status.ok()) << status.error_message();
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, Softmax1x1BigNumber) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 1, 1, 4);
|
||||
double doubles[4] = {1.0, 2.0, 3.0, 100.0};
|
||||
// exp(100) is inf in float (32 bit) but representable in double (64 bit)
|
||||
src_tensor.data.resize(4);
|
||||
src_tensor.data[0] = doubles[0];
|
||||
src_tensor.data[1] = doubles[1];
|
||||
src_tensor.data[2] = doubles[2];
|
||||
src_tensor.data[3] = doubles[3];
|
||||
EXPECT_TRUE(std::isinf(std::exp(src_tensor.data[3])));
|
||||
EXPECT_FALSE(std::isinf(std::exp(doubles[3])));
|
||||
double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) +
|
||||
std::exp(doubles[2]) + std::exp(doubles[3]);
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
Softmax1x1 operation = CreateSoftmax1x1(op_def);
|
||||
ASSERT_OK(ExecuteGPUOperation(
|
||||
src_tensor, creation_context_,
|
||||
absl::make_unique<Softmax1x1>(std::move(operation)), BHWC(1, 1, 1, 4),
|
||||
&dst_tensor));
|
||||
EXPECT_THAT(
|
||||
dst_tensor.data,
|
||||
Pointwise(FloatNear(eps),
|
||||
{std::exp(doubles[0]) / s0, std::exp(doubles[1]) / s0,
|
||||
std::exp(doubles[2]) / s0, std::exp(doubles[3]) / s0}));
|
||||
}
|
||||
}
|
||||
auto status = Softmax1x1BigNumberTest(&exec_env_);
|
||||
ASSERT_TRUE(status.ok()) << status.error_message();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -632,6 +632,7 @@ cc_library(
|
||||
hdrs = ["softmax_test_util.h"],
|
||||
deps = [
|
||||
":softmax",
|
||||
":softmax1x1",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:testing_util",
|
||||
|
@ -47,22 +47,22 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) {
|
||||
args_.AddFloat("mask_w");
|
||||
|
||||
std::string c;
|
||||
c += "__kernel void main_function(\n";
|
||||
c += "$0) {\n";
|
||||
c += "MAIN_FUNCTION($0) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int batch_id = get_global_id(1);\n";
|
||||
c += " int batch_id = GLOBAL_ID_1;\n";
|
||||
c += " if (batch_id >= args.dst_tensor.Batch()) return;\n";
|
||||
c += " args.dst_tensor.SetBatchRef(batch_id);\n";
|
||||
c += " args.src_tensor.SetBatchRef(batch_id);\n";
|
||||
}
|
||||
c += " float4 mask = (float4)(args.mask_x, args.mask_y, args.mask_z, "
|
||||
c += " float4 mask = INIT_FLOAT4v4(args.mask_x, args.mask_y, args.mask_z, "
|
||||
"args.mask_w);\n";
|
||||
c += " float4 maxx4 = (float4)(args.src_tensor.Read<float>(0, 0, 0).x);\n";
|
||||
c += " int tid = get_local_id(0);\n";
|
||||
c +=
|
||||
" float4 maxx4 = INIT_FLOAT4(args.src_tensor.Read<float>(0, 0, 0).x);\n";
|
||||
c += " int tid = LOCAL_ID_0;\n";
|
||||
c += " for (int s = tid; s < args.src_tensor.Slices(); s += 32) {\n";
|
||||
c += " float4 mask_a = s == args.src_tensor.Slices() - 1 ? mask : "
|
||||
"(float4)(1.0f);\n";
|
||||
c += " float4 mask_b = (float4)(1.0f) - mask_a;\n";
|
||||
"INIT_FLOAT4(1.0f);\n";
|
||||
c += " float4 mask_b = INIT_FLOAT4(1.0f) - mask_a;\n";
|
||||
c += " float4 src = args.src_tensor.Read<float>(0, 0, s);\n";
|
||||
c += " src = src * mask_a + mask_b * src.x;\n";
|
||||
c += " maxx4 = max(maxx4, src);\n";
|
||||
@ -73,7 +73,7 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) {
|
||||
c += " __local float4 tmp[8];\n";
|
||||
c += " __local float* tmpx1 = (__local float*)tmp;\n";
|
||||
c += " tmpx1[tid] = maximum;\n";
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
c += " LOCAL_MEM_BARRIER;\n";
|
||||
c += " if (tid == 0) {\n";
|
||||
c += " maxx4 = max(tmp[0], tmp[1]);\n";
|
||||
c += " maxx4 = max(maxx4, tmp[2]);\n";
|
||||
@ -87,19 +87,19 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) {
|
||||
c += " maximum = max(maximum, maxx4.w);\n";
|
||||
c += " tmpx1[0] = maximum;\n";
|
||||
c += " }\n";
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
c += " LOCAL_MEM_BARRIER;\n";
|
||||
c += " maximum = tmpx1[0];\n";
|
||||
c += " float sum = 0.0f;\n";
|
||||
c += " for (int s = tid; s < args.src_tensor.Slices(); s += 32) {\n";
|
||||
c += " float4 mask_temp = s == args.src_tensor.Slices() - 1 ? mask : "
|
||||
"(float4)(1.0f);\n";
|
||||
"INIT_FLOAT4(1.0f);\n";
|
||||
c += " float4 src = args.src_tensor.Read<float>(0, 0, s) - "
|
||||
"(float4)(maximum);\n";
|
||||
"INIT_FLOAT4(maximum);\n";
|
||||
c += " sum += dot(mask_temp, exp(src));\n";
|
||||
c += " }\n";
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
c += " LOCAL_MEM_BARRIER;\n";
|
||||
c += " tmpx1[tid] = sum;\n";
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
c += " LOCAL_MEM_BARRIER;\n";
|
||||
c += " if (tid == 0) {\n";
|
||||
c += " sum = dot((float4)(1.0f), tmp[0]);\n";
|
||||
c += " sum += dot((float4)(1.0f), tmp[1]);\n";
|
||||
@ -111,13 +111,13 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) {
|
||||
c += " sum += dot((float4)(1.0f), tmp[7]);\n";
|
||||
c += " tmpx1[0] = 1.0f / sum;\n";
|
||||
c += " }\n";
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
c += " LOCAL_MEM_BARRIER;\n";
|
||||
c += " sum = tmpx1[0];\n";
|
||||
c += "\n";
|
||||
c += " int dst_s = get_global_id(0);\n";
|
||||
c += " int dst_s = GLOBAL_ID_0;\n";
|
||||
c += " if (dst_s < args.dst_tensor.Slices()) {\n";
|
||||
c += " float4 src = args.src_tensor.Read<float>(0, 0, dst_s) - "
|
||||
"(float4)(maximum);\n";
|
||||
"INIT_FLOAT4(maximum);\n";
|
||||
c += " FLT4 res = TO_FLT4(exp(src) * sum);\n";
|
||||
c += " args.dst_tensor.Write(res, 0, 0, dst_s);\n";
|
||||
c += " }\n";
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/softmax.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
@ -95,5 +96,74 @@ absl::Status SoftmaxBigNumberTest(TestExecutionEnvironment* env) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Softmax1x1Test(TestExecutionEnvironment* env) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 1, 1, 4);
|
||||
src_tensor.data = {std::log(1.0f), std::log(2.0f), std::log(3.0f),
|
||||
std::log(4.0f)};
|
||||
|
||||
for (auto storage : env->GetSupportedStorages()) {
|
||||
for (auto precision : env->GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
Softmax1x1 operation = CreateSoftmax1x1(op_def);
|
||||
RETURN_IF_ERROR(env->ExecuteGPUOperation(
|
||||
src_tensor, absl::make_unique<Softmax1x1>(std::move(operation)),
|
||||
BHWC(1, 1, 1, 4), &dst_tensor));
|
||||
RETURN_IF_ERROR(
|
||||
PointWiseNear({0.1f, 0.2f, 0.3f, 0.4f}, dst_tensor.data, eps));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Softmax1x1BigNumberTest(TestExecutionEnvironment* env) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 1, 1, 4);
|
||||
double doubles[4] = {1.0, 2.0, 3.0, 100.0};
|
||||
// exp(100) is inf in float (32 bit) but representable in double (64 bit)
|
||||
src_tensor.data.resize(4);
|
||||
src_tensor.data[0] = doubles[0];
|
||||
src_tensor.data[1] = doubles[1];
|
||||
src_tensor.data[2] = doubles[2];
|
||||
src_tensor.data[3] = doubles[3];
|
||||
if (!std::isinf(std::exp(src_tensor.data[3]))) {
|
||||
return absl::InternalError("exp(100.0f) not inf in float (32 bit)");
|
||||
}
|
||||
if (std::isinf(std::exp(doubles[3]))) {
|
||||
return absl::InternalError("exp(100.0) inf in double (64 bit)");
|
||||
}
|
||||
double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) +
|
||||
std::exp(doubles[2]) + std::exp(doubles[3]);
|
||||
|
||||
for (auto storage : env->GetSupportedStorages()) {
|
||||
for (auto precision : env->GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
Softmax1x1 operation = CreateSoftmax1x1(op_def);
|
||||
RETURN_IF_ERROR(env->ExecuteGPUOperation(
|
||||
src_tensor, absl::make_unique<Softmax1x1>(std::move(operation)),
|
||||
BHWC(1, 1, 1, 4), &dst_tensor));
|
||||
RETURN_IF_ERROR(
|
||||
PointWiseNear({static_cast<float>(std::exp(doubles[0]) / s0),
|
||||
static_cast<float>(std::exp(doubles[1]) / s0),
|
||||
static_cast<float>(std::exp(doubles[2]) / s0),
|
||||
static_cast<float>(std::exp(doubles[3]) / s0)},
|
||||
dst_tensor.data, eps));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -23,9 +23,11 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
absl::Status SoftmaxTest(TestExecutionEnvironment* env);
|
||||
|
||||
absl::Status SoftmaxBigNumberTest(TestExecutionEnvironment* env);
|
||||
|
||||
absl::Status Softmax1x1Test(TestExecutionEnvironment* env);
|
||||
absl::Status Softmax1x1BigNumberTest(TestExecutionEnvironment* env);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -111,7 +111,7 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
XCTAssertFalse(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSoftmax1x1 {
|
||||
- (void)testSoftmax1x1Op {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
@ -173,7 +173,7 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSoftmax1x1BigNumber {
|
||||
- (void)testSoftmax1x1BigNumberOp {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
@ -219,4 +219,14 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSoftmax1x1 {
|
||||
auto status = Softmax1x1Test(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSoftmax1x1BigNumber {
|
||||
auto status = Softmax1x1BigNumberTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
@end
|
||||
|
Loading…
Reference in New Issue
Block a user