FullyConnectedTexture renamed to FullyConnected so as support all storage types.
PiperOrigin-RevId: 296085184 Change-Id: I3ea56947c7ddf70370c10b4375903880fd3d83c9
This commit is contained in:
parent
3aecbb9fb1
commit
38168415ea
@ -731,9 +731,9 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "fully_connected_texture",
|
||||
srcs = ["fully_connected_texture.cc"],
|
||||
hdrs = ["fully_connected_texture.h"],
|
||||
name = "fully_connected",
|
||||
srcs = ["fully_connected.cc"],
|
||||
hdrs = ["fully_connected.h"],
|
||||
deps = [
|
||||
":gpu_operation",
|
||||
":util",
|
||||
@ -751,8 +751,8 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "fully_connected_texture_test",
|
||||
srcs = ["fully_connected_texture_test.cc"],
|
||||
name = "fully_connected_test",
|
||||
srcs = ["fully_connected_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"linux",
|
||||
@ -760,7 +760,7 @@ cc_test(
|
||||
],
|
||||
deps = [
|
||||
":cl_test",
|
||||
":fully_connected_texture",
|
||||
":fully_connected",
|
||||
"//tensorflow/lite/delegates/gpu/cl:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
@ -1386,7 +1386,7 @@ test_suite(
|
||||
"depth_wise_conv_3x3_test",
|
||||
"depth_wise_conv_test",
|
||||
"elementwise_test",
|
||||
"fully_connected_texture_test",
|
||||
"fully_connected_test",
|
||||
"lstm_test",
|
||||
"max_unpooling_test",
|
||||
"multiply_add_test",
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_texture.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
@ -90,18 +90,17 @@ std::string GetFullyConnectedKernelCode(
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FullyConnectedTexture::FullyConnectedTexture(const OperationDef& definition)
|
||||
FullyConnected::FullyConnected(const OperationDef& definition)
|
||||
: GPUOperation(definition) {}
|
||||
|
||||
FullyConnectedTexture::FullyConnectedTexture(FullyConnectedTexture&& kernel)
|
||||
FullyConnected::FullyConnected(FullyConnected&& kernel)
|
||||
: GPUOperation(std::move(kernel)),
|
||||
weights_(std::move(kernel.weights_)),
|
||||
biases_(std::move(kernel.biases_)),
|
||||
kernel_(std::move(kernel.kernel_)),
|
||||
work_group_size_(kernel.work_group_size_) {}
|
||||
|
||||
FullyConnectedTexture& FullyConnectedTexture::operator=(
|
||||
FullyConnectedTexture&& kernel) {
|
||||
FullyConnected& FullyConnected::operator=(FullyConnected&& kernel) {
|
||||
if (this != &kernel) {
|
||||
weights_ = std::move(kernel.weights_);
|
||||
biases_ = std::move(kernel.biases_);
|
||||
@ -112,7 +111,7 @@ FullyConnectedTexture& FullyConnectedTexture::operator=(
|
||||
return *this;
|
||||
}
|
||||
|
||||
Status FullyConnectedTexture::Compile(const CreationContext& creation_context) {
|
||||
Status FullyConnected::Compile(const CreationContext& creation_context) {
|
||||
int wg_width = 32;
|
||||
int wg_height = 4;
|
||||
int work_items;
|
||||
@ -136,7 +135,7 @@ Status FullyConnectedTexture::Compile(const CreationContext& creation_context) {
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status FullyConnectedTexture::AddToQueue(CLCommandQueue* queue) {
|
||||
Status FullyConnected::AddToQueue(CLCommandQueue* queue) {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||
@ -150,11 +149,11 @@ Status FullyConnectedTexture::AddToQueue(CLCommandQueue* queue) {
|
||||
work_group_size_);
|
||||
}
|
||||
|
||||
Status CreateFullyConnectedTexture(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
FullyConnectedTexture* result) {
|
||||
*result = FullyConnectedTexture(definition);
|
||||
Status CreateFullyConnected(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
FullyConnected* result) {
|
||||
*result = FullyConnected(definition);
|
||||
RETURN_IF_ERROR(
|
||||
result->UploadWeights(attr.weights, creation_context.context));
|
||||
LinearStorageCreateInfo create_info;
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_TEXTURE_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_TEXTURE_H_
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -34,24 +34,25 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
class FullyConnectedTexture : public GPUOperation {
|
||||
class FullyConnected : public GPUOperation {
|
||||
public:
|
||||
FullyConnectedTexture() = default;
|
||||
FullyConnected() = default;
|
||||
Status AddToQueue(CLCommandQueue* queue) override;
|
||||
|
||||
Status Compile(const CreationContext& creation_context) override;
|
||||
|
||||
// Move only
|
||||
FullyConnectedTexture(FullyConnectedTexture&& kernel);
|
||||
FullyConnectedTexture& operator=(FullyConnectedTexture&& kernel);
|
||||
FullyConnectedTexture(const FullyConnectedTexture&) = delete;
|
||||
FullyConnectedTexture& operator=(const FullyConnectedTexture&) = delete;
|
||||
FullyConnected(FullyConnected&& kernel);
|
||||
FullyConnected& operator=(FullyConnected&& kernel);
|
||||
FullyConnected(const FullyConnected&) = delete;
|
||||
FullyConnected& operator=(const FullyConnected&) = delete;
|
||||
|
||||
private:
|
||||
explicit FullyConnectedTexture(const OperationDef& definition);
|
||||
friend Status CreateFullyConnectedTexture(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr, FullyConnectedTexture* result);
|
||||
explicit FullyConnected(const OperationDef& definition);
|
||||
friend Status CreateFullyConnected(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
FullyConnected* result);
|
||||
|
||||
template <DataType T>
|
||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
@ -68,7 +69,7 @@ class FullyConnectedTexture : public GPUOperation {
|
||||
};
|
||||
|
||||
template <DataType T>
|
||||
Status FullyConnectedTexture::UploadWeights(
|
||||
Status FullyConnected::UploadWeights(
|
||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||
@ -92,7 +93,7 @@ Status FullyConnectedTexture::UploadWeights(
|
||||
}
|
||||
|
||||
template <DataType T, typename S>
|
||||
void FullyConnectedTexture::RearrangeWeights(
|
||||
void FullyConnected::RearrangeWeights(
|
||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, absl::Span<S> dst) {
|
||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||
@ -122,13 +123,13 @@ void FullyConnectedTexture::RearrangeWeights(
|
||||
}
|
||||
}
|
||||
|
||||
Status CreateFullyConnectedTexture(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
FullyConnectedTexture* result);
|
||||
Status CreateFullyConnected(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
FullyConnected* result);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_TEXTURE_H_
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_texture.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -31,7 +31,7 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
TEST_F(OpenCLOperationTest, FullyConnectedTexture) {
|
||||
TEST_F(OpenCLOperationTest, FullyConnected) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 1, 1, 4);
|
||||
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
@ -51,9 +51,9 @@ TEST_F(OpenCLOperationTest, FullyConnectedTexture) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
FullyConnectedTexture operation;
|
||||
ASSERT_OK(CreateFullyConnectedTexture(creation_context_, op_def, attr,
|
||||
&operation));
|
||||
FullyConnected operation;
|
||||
ASSERT_OK(
|
||||
CreateFullyConnected(creation_context_, op_def, attr, &operation));
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 1, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {14.5f, 37.5f}));
|
@ -66,7 +66,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:conv_buffer_1x1",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:conv_texture",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:fully_connected_texture",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:fully_connected",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_texture.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
@ -36,10 +36,10 @@ Status SelectFullyConnectedAdreno(const FullyConnectedAttributes& attr,
|
||||
RETURN_IF_ERROR(CreateConvTexture(creation_context, op_def, attr, &conv));
|
||||
*ptr = absl::make_unique<ConvTexture>(std::move(conv));
|
||||
} else {
|
||||
FullyConnectedTexture fc;
|
||||
FullyConnected fc;
|
||||
RETURN_IF_ERROR(
|
||||
CreateFullyConnectedTexture(creation_context, op_def, attr, &fc));
|
||||
*ptr = absl::make_unique<FullyConnectedTexture>(std::move(fc));
|
||||
CreateFullyConnected(creation_context, op_def, attr, &fc));
|
||||
*ptr = absl::make_unique<FullyConnected>(std::move(fc));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
@ -53,10 +53,10 @@ Status SelectFullyConnectedPowerVR(const FullyConnectedAttributes& attr,
|
||||
RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv));
|
||||
*ptr = absl::make_unique<ConvPowerVR>(std::move(conv));
|
||||
} else {
|
||||
FullyConnectedTexture fc;
|
||||
FullyConnected fc;
|
||||
RETURN_IF_ERROR(
|
||||
CreateFullyConnectedTexture(creation_context, op_def, attr, &fc));
|
||||
*ptr = absl::make_unique<FullyConnectedTexture>(std::move(fc));
|
||||
CreateFullyConnected(creation_context, op_def, attr, &fc));
|
||||
*ptr = absl::make_unique<FullyConnected>(std::move(fc));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
@ -77,10 +77,10 @@ Status SelectFullyConnectedMali(const FullyConnectedAttributes& attr,
|
||||
*ptr = absl::make_unique<ConvTexture>(std::move(conv));
|
||||
}
|
||||
} else {
|
||||
FullyConnectedTexture fc;
|
||||
FullyConnected fc;
|
||||
RETURN_IF_ERROR(
|
||||
CreateFullyConnectedTexture(creation_context, op_def, attr, &fc));
|
||||
*ptr = absl::make_unique<FullyConnectedTexture>(std::move(fc));
|
||||
CreateFullyConnected(creation_context, op_def, attr, &fc));
|
||||
*ptr = absl::make_unique<FullyConnected>(std::move(fc));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user