Added interface for tasks testing.
Lstm test logic moved to gpu/common/tasks/lstm_test.h. Left minimal part in gpu/cl/kernels/lstm_test.cc PiperOrigin-RevId: 344147268 Change-Id: I8c80f3f06005d78fdee36ca8f9ac1f433f72e414
This commit is contained in:
parent
c1854cd182
commit
0047603768
@ -38,6 +38,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:testing_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
@ -303,7 +304,7 @@ cc_test(
|
||||
":cl_test",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:lstm",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:lstm_test_util",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
@ -22,6 +22,81 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
absl::Status ClExecutionEnvironment::Init() { return CreateEnvironment(&env_); }
|
||||
|
||||
std::vector<CalculationsPrecision>
|
||||
ClExecutionEnvironment::GetSupportedPrecisions() const {
|
||||
return env_.GetSupportedPrecisions();
|
||||
}
|
||||
|
||||
std::vector<TensorStorageType> ClExecutionEnvironment::GetSupportedStorages()
|
||||
const {
|
||||
return env_.GetSupportedStorages();
|
||||
}
|
||||
|
||||
std::vector<TensorStorageType>
|
||||
ClExecutionEnvironment::GetSupportedStoragesWithHWZeroClampSupport() const {
|
||||
return env_.GetSupportedStoragesWithHWZeroClampSupport();
|
||||
}
|
||||
|
||||
const GpuInfo& ClExecutionEnvironment::GetGpuInfo() const {
|
||||
return env_.GetDevicePtr()->GetInfo();
|
||||
}
|
||||
|
||||
absl::Status ClExecutionEnvironment::ExecuteGPUOperation(
|
||||
const std::vector<TensorFloat32>& src_cpu,
|
||||
std::unique_ptr<GPUOperation>&& operation,
|
||||
const std::vector<BHWC>& dst_sizes,
|
||||
const std::vector<TensorFloat32*>& dst_cpu) {
|
||||
CreationContext creation_context;
|
||||
creation_context.device = env_.GetDevicePtr();
|
||||
creation_context.context = &env_.context();
|
||||
creation_context.queue = env_.queue();
|
||||
creation_context.cache = env_.program_cache();
|
||||
|
||||
const OperationDef& op_def = operation->GetDefinition();
|
||||
std::vector<Tensor> src(src_cpu.size());
|
||||
for (int i = 0; i < src_cpu.size(); ++i) {
|
||||
auto src_shape = src_cpu[i].shape;
|
||||
if (src_shape.b != 1 && !op_def.IsBatchSupported()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Layout doesn't have Batch dimension, but shape.b != 1");
|
||||
}
|
||||
RETURN_IF_ERROR(CreateTensor(*creation_context.context, src_shape,
|
||||
op_def.src_tensors[0], &src[i]));
|
||||
RETURN_IF_ERROR(src[i].WriteData(creation_context.queue, src_cpu[i]));
|
||||
operation->SetSrc(&src[i], i);
|
||||
}
|
||||
|
||||
std::vector<Tensor> dst(dst_cpu.size());
|
||||
for (int i = 0; i < dst_cpu.size(); ++i) {
|
||||
auto dst_shape = dst_sizes[i];
|
||||
if (dst_shape.b != 1 && !op_def.IsBatchSupported()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Layout doesn't have Batch dimension, but shape.b != 1");
|
||||
}
|
||||
RETURN_IF_ERROR(CreateTensor(*creation_context.context, dst_shape,
|
||||
op_def.dst_tensors[0], &dst[i]));
|
||||
|
||||
operation->SetDst(&dst[i], i);
|
||||
}
|
||||
|
||||
ClOperation cl_op;
|
||||
cl_op.Init(std::move(operation));
|
||||
RETURN_IF_ERROR(cl_op.Compile(creation_context));
|
||||
RETURN_IF_ERROR(cl_op.UpdateParams());
|
||||
cl_op.GetGpuOperation().args_.ReleaseCPURepresentation();
|
||||
RETURN_IF_ERROR(cl_op.AddToQueue(creation_context.queue));
|
||||
RETURN_IF_ERROR(creation_context.queue->WaitForCompletion());
|
||||
|
||||
for (int i = 0; i < dst_cpu.size(); ++i) {
|
||||
dst_cpu[i]->shape = dst_sizes[i];
|
||||
dst_cpu[i]->data = std::vector<float>(dst_sizes[i].DimensionsProduct(), 0);
|
||||
RETURN_IF_ERROR(dst[i].ReadData(creation_context.queue, dst_cpu[i]));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
||||
const CreationContext& creation_context,
|
||||
std::unique_ptr<GPUOperation>&& operation,
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -35,6 +36,30 @@ namespace cl {
|
||||
#define ASSERT_OK(x) ASSERT_TRUE(x.ok());
|
||||
#endif
|
||||
|
||||
class ClExecutionEnvironment : public TestExecutionEnvironment {
|
||||
public:
|
||||
ClExecutionEnvironment() = default;
|
||||
~ClExecutionEnvironment() override = default;
|
||||
|
||||
absl::Status Init();
|
||||
|
||||
std::vector<CalculationsPrecision> GetSupportedPrecisions() const override;
|
||||
std::vector<TensorStorageType> GetSupportedStorages() const override;
|
||||
std::vector<TensorStorageType> GetSupportedStoragesWithHWZeroClampSupport()
|
||||
const override;
|
||||
|
||||
const GpuInfo& GetGpuInfo() const override;
|
||||
|
||||
absl::Status ExecuteGPUOperation(
|
||||
const std::vector<TensorFloat32>& src_cpu,
|
||||
std::unique_ptr<GPUOperation>&& operation,
|
||||
const std::vector<BHWC>& dst_sizes,
|
||||
const std::vector<TensorFloat32*>& dst_cpu) override;
|
||||
|
||||
private:
|
||||
Environment env_;
|
||||
};
|
||||
|
||||
class OpenCLOperationTest : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
@ -44,11 +69,15 @@ class OpenCLOperationTest : public ::testing::Test {
|
||||
creation_context_.context = &env_.context();
|
||||
creation_context_.queue = env_.queue();
|
||||
creation_context_.cache = env_.program_cache();
|
||||
|
||||
ASSERT_OK(exec_env_.Init());
|
||||
}
|
||||
|
||||
protected:
|
||||
Environment env_;
|
||||
CreationContext creation_context_;
|
||||
|
||||
ClExecutionEnvironment exec_env_;
|
||||
};
|
||||
|
||||
absl::Status ExecuteGPUOperation(const TensorFloat32& src_cpu,
|
||||
|
@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/lstm.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
@ -24,68 +22,14 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Pointwise;
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/lstm_test_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
TEST_F(OpenCLOperationTest, LSTM) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 1, 1, 16);
|
||||
src_tensor.data = {
|
||||
-std::log(2.0f), -std::log(2.0f), -std::log(2.0f), -std::log(2.0f),
|
||||
std::log(3.0f), std::log(3.0f), std::log(3.0f), std::log(3.0f),
|
||||
-std::log(4.0f), -std::log(4.0f), -std::log(4.0f), -std::log(4.0f),
|
||||
-std::log(5.0f), -std::log(5.0f), -std::log(5.0f), -std::log(5.0f)};
|
||||
// input_gate = 1.0 / (1.0 + exp(log(2.0f))) = 1.0 / 3.0;
|
||||
// new_input = tanh(log(3.0f)) = (exp(2 * log(3.0f)) - 1) / exp(2 * log(3.0f))
|
||||
// + 1 = (9 - 1) / (9 + 1) = 0.8;
|
||||
// forget_gate = 1.0 / (1.0 + exp(log(4.0f)))
|
||||
// = 1.0 / 5.0;
|
||||
// output_gate = 1.0 / (1.0 + exp(log(5.0f))) = 1.0 / 6.0;
|
||||
// new_st = input_gate * new_input + forget_gate * prev_st
|
||||
// = 1.0 / 3.0 * 0.8 + 1.0 / 5.0 * prev_st
|
||||
// = 4.0 / 15.0 + 3.0 / 15.0 = 7.0 / 15.0
|
||||
// activation = output_gate * tanh(new_st)
|
||||
TensorFloat32 prev_state;
|
||||
prev_state.shape = BHWC(1, 1, 1, 4);
|
||||
prev_state.data = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
TEST_F(OpenCLOperationTest, LSTM) { LstmTest(&exec_env_); }
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
|
||||
TensorFloat32 new_state;
|
||||
TensorFloat32 new_activ;
|
||||
GPUOperation operation = CreateLSTM(op_def, env_.GetDevicePtr()->info_);
|
||||
ASSERT_OK(ExecuteGPUOperation(
|
||||
{src_tensor, prev_state}, creation_context_,
|
||||
absl::make_unique<GPUOperation>(std::move(operation)),
|
||||
{BHWC(1, 1, 1, 4), BHWC(1, 1, 1, 4)}, {&new_state, &new_activ}));
|
||||
EXPECT_THAT(new_state.data,
|
||||
Pointwise(FloatNear(eps), {7.0 / 15.0, 10.0 / 15.0,
|
||||
13.0 / 15.0, 16.0 / 15.0}));
|
||||
EXPECT_THAT(
|
||||
new_activ.data,
|
||||
Pointwise(FloatNear(eps), {(1.0 / 6.0) * std::tanh(7.0 / 15.0),
|
||||
(1.0 / 6.0) * std::tanh(10.0 / 15.0),
|
||||
(1.0 / 6.0) * std::tanh(13.0 / 15.0),
|
||||
(1.0 / 6.0) * std::tanh(16.0 / 15.0)}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -130,6 +130,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "testing_util",
|
||||
hdrs = ["testing_util.h"],
|
||||
deps = [
|
||||
":gpu_operation",
|
||||
":tensor_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common:gpu_info",
|
||||
"//tensorflow/lite/delegates/gpu/common:precision",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "texture2d_desc",
|
||||
srcs = ["texture2d_desc.cc"],
|
||||
|
55
tensorflow/lite/delegates/gpu/common/task/testing_util.h
Normal file
55
tensorflow/lite/delegates/gpu/common/task/testing_util.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TESTING_UTIL_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TESTING_UTIL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/precision.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
class TestExecutionEnvironment {
|
||||
public:
|
||||
TestExecutionEnvironment() = default;
|
||||
virtual ~TestExecutionEnvironment() = default;
|
||||
|
||||
virtual std::vector<CalculationsPrecision> GetSupportedPrecisions() const = 0;
|
||||
virtual std::vector<TensorStorageType> GetSupportedStorages() const = 0;
|
||||
// returns storage types that support zero clamping when reading OOB in HW
|
||||
// (Height/Width) dimensions.
|
||||
virtual std::vector<TensorStorageType>
|
||||
GetSupportedStoragesWithHWZeroClampSupport() const = 0;
|
||||
|
||||
virtual const GpuInfo& GetGpuInfo() const = 0;
|
||||
|
||||
virtual absl::Status ExecuteGPUOperation(
|
||||
const std::vector<TensorFloat32>& src_cpu,
|
||||
std::unique_ptr<GPUOperation>&& operation,
|
||||
const std::vector<BHWC>& dst_sizes,
|
||||
const std::vector<TensorFloat32*>& dst_cpu) = 0;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TESTING_UTIL_H_
|
@ -315,6 +315,20 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lstm_test_util",
|
||||
testonly = 1,
|
||||
srcs = ["lstm_test_util.cc"],
|
||||
hdrs = ["lstm_test_util.h"],
|
||||
deps = [
|
||||
":lstm",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:testing_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "max_unpooling",
|
||||
srcs = ["max_unpooling.cc"],
|
||||
|
86
tensorflow/lite/delegates/gpu/common/tasks/lstm_test_util.cc
Normal file
86
tensorflow/lite/delegates/gpu/common/tasks/lstm_test_util.cc
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/lstm_test_util.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/lstm.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
void LstmTest(TestExecutionEnvironment* env) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 1, 1, 16);
|
||||
src_tensor.data = {
|
||||
-std::log(2.0f), -std::log(2.0f), -std::log(2.0f), -std::log(2.0f),
|
||||
std::log(3.0f), std::log(3.0f), std::log(3.0f), std::log(3.0f),
|
||||
-std::log(4.0f), -std::log(4.0f), -std::log(4.0f), -std::log(4.0f),
|
||||
-std::log(5.0f), -std::log(5.0f), -std::log(5.0f), -std::log(5.0f)};
|
||||
// input_gate = 1.0 / (1.0 + exp(log(2.0f))) = 1.0 / 3.0;
|
||||
// new_input = tanh(log(3.0f)) = (exp(2 * log(3.0f)) - 1) / exp(2 * log(3.0f))
|
||||
// + 1 = (9 - 1) / (9 + 1) = 0.8;
|
||||
// forget_gate = 1.0 / (1.0 + exp(log(4.0f)))
|
||||
// = 1.0 / 5.0;
|
||||
// output_gate = 1.0 / (1.0 + exp(log(5.0f))) = 1.0 / 6.0;
|
||||
// new_st = input_gate * new_input + forget_gate * prev_st
|
||||
// = 1.0 / 3.0 * 0.8 + 1.0 / 5.0 * prev_st
|
||||
// = 4.0 / 15.0 + 3.0 / 15.0 = 7.0 / 15.0
|
||||
// activation = output_gate * tanh(new_st)
|
||||
TensorFloat32 prev_state;
|
||||
prev_state.shape = BHWC(1, 1, 1, 4);
|
||||
prev_state.data = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
|
||||
for (auto storage : env->GetSupportedStorages()) {
|
||||
for (auto precision : env->GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
|
||||
TensorFloat32 new_state;
|
||||
TensorFloat32 new_activ;
|
||||
GPUOperation operation = CreateLSTM(op_def, env->GetGpuInfo());
|
||||
ASSERT_TRUE(env->ExecuteGPUOperation(
|
||||
{src_tensor, prev_state},
|
||||
absl::make_unique<GPUOperation>(std::move(operation)),
|
||||
{BHWC(1, 1, 1, 4), BHWC(1, 1, 1, 4)},
|
||||
{&new_state, &new_activ})
|
||||
.ok());
|
||||
EXPECT_THAT(new_state.data,
|
||||
testing::Pointwise(
|
||||
testing::FloatNear(eps),
|
||||
{7.0 / 15.0, 10.0 / 15.0, 13.0 / 15.0, 16.0 / 15.0}))
|
||||
<< ToString(storage) << ", " << ToString(precision);
|
||||
EXPECT_THAT(new_activ.data,
|
||||
testing::Pointwise(testing::FloatNear(eps),
|
||||
{(1.0 / 6.0) * std::tanh(7.0 / 15.0),
|
||||
(1.0 / 6.0) * std::tanh(10.0 / 15.0),
|
||||
(1.0 / 6.0) * std::tanh(13.0 / 15.0),
|
||||
(1.0 / 6.0) * std::tanh(16.0 / 15.0)}))
|
||||
<< ToString(storage) << ", " << ToString(precision);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
29
tensorflow/lite/delegates/gpu/common/tasks/lstm_test_util.h
Normal file
29
tensorflow/lite/delegates/gpu/common/tasks/lstm_test_util.h
Normal file
@ -0,0 +1,29 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_LSTM_TEST_UTIL_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_LSTM_TEST_UTIL_H_
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
void LstmTest(TestExecutionEnvironment* env);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_LSTM_TEST_UTIL_H_
|
Loading…
x
Reference in New Issue
Block a user