Added interface for tasks testing.

Lstm test logic moved to gpu/common/tasks/lstm_test.h.
Left minimal part in gpu/cl/kernels/lstm_test.cc

PiperOrigin-RevId: 344147268
Change-Id: I8c80f3f06005d78fdee36ca8f9ac1f433f72e414
This commit is contained in:
Raman Sarokin 2020-11-24 15:56:04 -08:00 committed by TensorFlower Gardener
parent c1854cd182
commit 0047603768
9 changed files with 305 additions and 59 deletions

View File

@ -38,6 +38,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:tensor",
"//tensorflow/lite/delegates/gpu/common/task:testing_util",
"@com_google_googletest//:gtest",
],
)
@ -303,7 +304,7 @@ cc_test(
":cl_test",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/tasks:lstm",
"//tensorflow/lite/delegates/gpu/common/tasks:lstm_test_util",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -22,6 +22,81 @@ namespace tflite {
namespace gpu {
namespace cl {
absl::Status ClExecutionEnvironment::Init() { return CreateEnvironment(&env_); }
std::vector<CalculationsPrecision>
ClExecutionEnvironment::GetSupportedPrecisions() const {
return env_.GetSupportedPrecisions();
}
std::vector<TensorStorageType> ClExecutionEnvironment::GetSupportedStorages()
const {
return env_.GetSupportedStorages();
}
std::vector<TensorStorageType>
ClExecutionEnvironment::GetSupportedStoragesWithHWZeroClampSupport() const {
return env_.GetSupportedStoragesWithHWZeroClampSupport();
}
const GpuInfo& ClExecutionEnvironment::GetGpuInfo() const {
return env_.GetDevicePtr()->GetInfo();
}
absl::Status ClExecutionEnvironment::ExecuteGPUOperation(
const std::vector<TensorFloat32>& src_cpu,
std::unique_ptr<GPUOperation>&& operation,
const std::vector<BHWC>& dst_sizes,
const std::vector<TensorFloat32*>& dst_cpu) {
CreationContext creation_context;
creation_context.device = env_.GetDevicePtr();
creation_context.context = &env_.context();
creation_context.queue = env_.queue();
creation_context.cache = env_.program_cache();
const OperationDef& op_def = operation->GetDefinition();
std::vector<Tensor> src(src_cpu.size());
for (int i = 0; i < src_cpu.size(); ++i) {
auto src_shape = src_cpu[i].shape;
if (src_shape.b != 1 && !op_def.IsBatchSupported()) {
return absl::InvalidArgumentError(
"Layout doesn't have Batch dimension, but shape.b != 1");
}
RETURN_IF_ERROR(CreateTensor(*creation_context.context, src_shape,
op_def.src_tensors[0], &src[i]));
RETURN_IF_ERROR(src[i].WriteData(creation_context.queue, src_cpu[i]));
operation->SetSrc(&src[i], i);
}
std::vector<Tensor> dst(dst_cpu.size());
for (int i = 0; i < dst_cpu.size(); ++i) {
auto dst_shape = dst_sizes[i];
if (dst_shape.b != 1 && !op_def.IsBatchSupported()) {
return absl::InvalidArgumentError(
"Layout doesn't have Batch dimension, but shape.b != 1");
}
RETURN_IF_ERROR(CreateTensor(*creation_context.context, dst_shape,
op_def.dst_tensors[0], &dst[i]));
operation->SetDst(&dst[i], i);
}
ClOperation cl_op;
cl_op.Init(std::move(operation));
RETURN_IF_ERROR(cl_op.Compile(creation_context));
RETURN_IF_ERROR(cl_op.UpdateParams());
cl_op.GetGpuOperation().args_.ReleaseCPURepresentation();
RETURN_IF_ERROR(cl_op.AddToQueue(creation_context.queue));
RETURN_IF_ERROR(creation_context.queue->WaitForCompletion());
for (int i = 0; i < dst_cpu.size(); ++i) {
dst_cpu[i]->shape = dst_sizes[i];
dst_cpu[i]->data = std::vector<float>(dst_sizes[i].DimensionsProduct(), 0);
RETURN_IF_ERROR(dst[i].ReadData(creation_context.queue, dst_cpu[i]));
}
return absl::OkStatus();
}
absl::Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
const CreationContext& creation_context,
std::unique_ptr<GPUOperation>&& operation,

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
namespace tflite {
@ -35,6 +36,30 @@ namespace cl {
#define ASSERT_OK(x) ASSERT_TRUE(x.ok());
#endif
class ClExecutionEnvironment : public TestExecutionEnvironment {
public:
ClExecutionEnvironment() = default;
~ClExecutionEnvironment() override = default;
absl::Status Init();
std::vector<CalculationsPrecision> GetSupportedPrecisions() const override;
std::vector<TensorStorageType> GetSupportedStorages() const override;
std::vector<TensorStorageType> GetSupportedStoragesWithHWZeroClampSupport()
const override;
const GpuInfo& GetGpuInfo() const override;
absl::Status ExecuteGPUOperation(
const std::vector<TensorFloat32>& src_cpu,
std::unique_ptr<GPUOperation>&& operation,
const std::vector<BHWC>& dst_sizes,
const std::vector<TensorFloat32*>& dst_cpu) override;
private:
Environment env_;
};
class OpenCLOperationTest : public ::testing::Test {
public:
void SetUp() override {
@ -44,11 +69,15 @@ class OpenCLOperationTest : public ::testing::Test {
creation_context_.context = &env_.context();
creation_context_.queue = env_.queue();
creation_context_.cache = env_.program_cache();
ASSERT_OK(exec_env_.Init());
}
protected:
Environment env_;
CreationContext creation_context_;
ClExecutionEnvironment exec_env_;
};
absl::Status ExecuteGPUOperation(const TensorFloat32& src_cpu,

View File

@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/tasks/lstm.h"
#include <cmath>
#include <cstdlib>
#include <vector>
@ -24,68 +22,14 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
using ::testing::FloatNear;
using ::testing::Pointwise;
#include "tensorflow/lite/delegates/gpu/common/tasks/lstm_test_util.h"
namespace tflite {
namespace gpu {
namespace cl {
namespace {
TEST_F(OpenCLOperationTest, LSTM) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 1, 1, 16);
src_tensor.data = {
-std::log(2.0f), -std::log(2.0f), -std::log(2.0f), -std::log(2.0f),
std::log(3.0f), std::log(3.0f), std::log(3.0f), std::log(3.0f),
-std::log(4.0f), -std::log(4.0f), -std::log(4.0f), -std::log(4.0f),
-std::log(5.0f), -std::log(5.0f), -std::log(5.0f), -std::log(5.0f)};
// input_gate = 1.0 / (1.0 + exp(log(2.0f))) = 1.0 / 3.0;
// new_input = tanh(log(3.0f)) = (exp(2 * log(3.0f)) - 1) / exp(2 * log(3.0f))
// + 1 = (9 - 1) / (9 + 1) = 0.8;
// forget_gate = 1.0 / (1.0 + exp(log(4.0f)))
// = 1.0 / 5.0;
// output_gate = 1.0 / (1.0 + exp(log(5.0f))) = 1.0 / 6.0;
// new_st = input_gate * new_input + forget_gate * prev_st
// = 1.0 / 3.0 * 0.8 + 1.0 / 5.0 * prev_st
// = 4.0 / 15.0 + 3.0 / 15.0 = 7.0 / 15.0
// activation = output_gate * tanh(new_st)
TensorFloat32 prev_state;
prev_state.shape = BHWC(1, 1, 1, 4);
prev_state.data = {1.0f, 2.0f, 3.0f, 4.0f};
TEST_F(OpenCLOperationTest, LSTM) { LstmTest(&exec_env_); }
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
TensorFloat32 new_state;
TensorFloat32 new_activ;
GPUOperation operation = CreateLSTM(op_def, env_.GetDevicePtr()->info_);
ASSERT_OK(ExecuteGPUOperation(
{src_tensor, prev_state}, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
{BHWC(1, 1, 1, 4), BHWC(1, 1, 1, 4)}, {&new_state, &new_activ}));
EXPECT_THAT(new_state.data,
Pointwise(FloatNear(eps), {7.0 / 15.0, 10.0 / 15.0,
13.0 / 15.0, 16.0 / 15.0}));
EXPECT_THAT(
new_activ.data,
Pointwise(FloatNear(eps), {(1.0 / 6.0) * std::tanh(7.0 / 15.0),
(1.0 / 6.0) * std::tanh(10.0 / 15.0),
(1.0 / 6.0) * std::tanh(13.0 / 15.0),
(1.0 / 6.0) * std::tanh(16.0 / 15.0)}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -130,6 +130,19 @@ cc_library(
],
)
cc_library(
name = "testing_util",
hdrs = ["testing_util.h"],
deps = [
":gpu_operation",
":tensor_desc",
"//tensorflow/lite/delegates/gpu/common:gpu_info",
"//tensorflow/lite/delegates/gpu/common:precision",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:tensor",
],
)
cc_library(
name = "texture2d_desc",
srcs = ["texture2d_desc.cc"],

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TESTING_UTIL_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TESTING_UTIL_H_
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
#include "tensorflow/lite/delegates/gpu/common/precision.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
namespace tflite {
namespace gpu {
class TestExecutionEnvironment {
public:
TestExecutionEnvironment() = default;
virtual ~TestExecutionEnvironment() = default;
virtual std::vector<CalculationsPrecision> GetSupportedPrecisions() const = 0;
virtual std::vector<TensorStorageType> GetSupportedStorages() const = 0;
// returns storage types that support zero clamping when reading OOB in HW
// (Height/Width) dimensions.
virtual std::vector<TensorStorageType>
GetSupportedStoragesWithHWZeroClampSupport() const = 0;
virtual const GpuInfo& GetGpuInfo() const = 0;
virtual absl::Status ExecuteGPUOperation(
const std::vector<TensorFloat32>& src_cpu,
std::unique_ptr<GPUOperation>&& operation,
const std::vector<BHWC>& dst_sizes,
const std::vector<TensorFloat32*>& dst_cpu) = 0;
};
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TESTING_UTIL_H_

View File

@ -315,6 +315,20 @@ cc_library(
],
)
cc_library(
name = "lstm_test_util",
testonly = 1,
srcs = ["lstm_test_util.cc"],
hdrs = ["lstm_test_util.h"],
deps = [
":lstm",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/task:testing_util",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "max_unpooling",
srcs = ["max_unpooling.cc"],

View File

@ -0,0 +1,86 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/tasks/lstm_test_util.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/lstm.h"
namespace tflite {
namespace gpu {
void LstmTest(TestExecutionEnvironment* env) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 1, 1, 16);
src_tensor.data = {
-std::log(2.0f), -std::log(2.0f), -std::log(2.0f), -std::log(2.0f),
std::log(3.0f), std::log(3.0f), std::log(3.0f), std::log(3.0f),
-std::log(4.0f), -std::log(4.0f), -std::log(4.0f), -std::log(4.0f),
-std::log(5.0f), -std::log(5.0f), -std::log(5.0f), -std::log(5.0f)};
// input_gate = 1.0 / (1.0 + exp(log(2.0f))) = 1.0 / 3.0;
// new_input = tanh(log(3.0f)) = (exp(2 * log(3.0f)) - 1) / exp(2 * log(3.0f))
// + 1 = (9 - 1) / (9 + 1) = 0.8;
// forget_gate = 1.0 / (1.0 + exp(log(4.0f)))
// = 1.0 / 5.0;
// output_gate = 1.0 / (1.0 + exp(log(5.0f))) = 1.0 / 6.0;
// new_st = input_gate * new_input + forget_gate * prev_st
// = 1.0 / 3.0 * 0.8 + 1.0 / 5.0 * prev_st
// = 4.0 / 15.0 + 3.0 / 15.0 = 7.0 / 15.0
// activation = output_gate * tanh(new_st)
TensorFloat32 prev_state;
prev_state.shape = BHWC(1, 1, 1, 4);
prev_state.data = {1.0f, 2.0f, 3.0f, 4.0f};
for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
TensorFloat32 new_state;
TensorFloat32 new_activ;
GPUOperation operation = CreateLSTM(op_def, env->GetGpuInfo());
ASSERT_TRUE(env->ExecuteGPUOperation(
{src_tensor, prev_state},
absl::make_unique<GPUOperation>(std::move(operation)),
{BHWC(1, 1, 1, 4), BHWC(1, 1, 1, 4)},
{&new_state, &new_activ})
.ok());
EXPECT_THAT(new_state.data,
testing::Pointwise(
testing::FloatNear(eps),
{7.0 / 15.0, 10.0 / 15.0, 13.0 / 15.0, 16.0 / 15.0}))
<< ToString(storage) << ", " << ToString(precision);
EXPECT_THAT(new_activ.data,
testing::Pointwise(testing::FloatNear(eps),
{(1.0 / 6.0) * std::tanh(7.0 / 15.0),
(1.0 / 6.0) * std::tanh(10.0 / 15.0),
(1.0 / 6.0) * std::tanh(13.0 / 15.0),
(1.0 / 6.0) * std::tanh(16.0 / 15.0)}))
<< ToString(storage) << ", " << ToString(precision);
}
}
}
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_LSTM_TEST_UTIL_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_LSTM_TEST_UTIL_H_
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
namespace tflite {
namespace gpu {
void LstmTest(TestExecutionEnvironment* env);
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_LSTM_TEST_UTIL_H_