FullyConnectedTexture renamed to FullyConnected so as support all storage types.
PiperOrigin-RevId: 296085184 Change-Id: I3ea56947c7ddf70370c10b4375903880fd3d83c9
This commit is contained in:
parent
3aecbb9fb1
commit
38168415ea
@ -731,9 +731,9 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "fully_connected_texture",
|
name = "fully_connected",
|
||||||
srcs = ["fully_connected_texture.cc"],
|
srcs = ["fully_connected.cc"],
|
||||||
hdrs = ["fully_connected_texture.h"],
|
hdrs = ["fully_connected.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":gpu_operation",
|
":gpu_operation",
|
||||||
":util",
|
":util",
|
||||||
@ -751,8 +751,8 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "fully_connected_texture_test",
|
name = "fully_connected_test",
|
||||||
srcs = ["fully_connected_texture_test.cc"],
|
srcs = ["fully_connected_test.cc"],
|
||||||
linkstatic = True,
|
linkstatic = True,
|
||||||
tags = tf_gpu_tests_tags() + [
|
tags = tf_gpu_tests_tags() + [
|
||||||
"linux",
|
"linux",
|
||||||
@ -760,7 +760,7 @@ cc_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":cl_test",
|
":cl_test",
|
||||||
":fully_connected_texture",
|
":fully_connected",
|
||||||
"//tensorflow/lite/delegates/gpu/cl:tensor",
|
"//tensorflow/lite/delegates/gpu/cl:tensor",
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
@ -1386,7 +1386,7 @@ test_suite(
|
|||||||
"depth_wise_conv_3x3_test",
|
"depth_wise_conv_3x3_test",
|
||||||
"depth_wise_conv_test",
|
"depth_wise_conv_test",
|
||||||
"elementwise_test",
|
"elementwise_test",
|
||||||
"fully_connected_texture_test",
|
"fully_connected_test",
|
||||||
"lstm_test",
|
"lstm_test",
|
||||||
"max_unpooling_test",
|
"max_unpooling_test",
|
||||||
"multiply_add_test",
|
"multiply_add_test",
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_texture.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -90,18 +90,17 @@ std::string GetFullyConnectedKernelCode(
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
FullyConnectedTexture::FullyConnectedTexture(const OperationDef& definition)
|
FullyConnected::FullyConnected(const OperationDef& definition)
|
||||||
: GPUOperation(definition) {}
|
: GPUOperation(definition) {}
|
||||||
|
|
||||||
FullyConnectedTexture::FullyConnectedTexture(FullyConnectedTexture&& kernel)
|
FullyConnected::FullyConnected(FullyConnected&& kernel)
|
||||||
: GPUOperation(std::move(kernel)),
|
: GPUOperation(std::move(kernel)),
|
||||||
weights_(std::move(kernel.weights_)),
|
weights_(std::move(kernel.weights_)),
|
||||||
biases_(std::move(kernel.biases_)),
|
biases_(std::move(kernel.biases_)),
|
||||||
kernel_(std::move(kernel.kernel_)),
|
kernel_(std::move(kernel.kernel_)),
|
||||||
work_group_size_(kernel.work_group_size_) {}
|
work_group_size_(kernel.work_group_size_) {}
|
||||||
|
|
||||||
FullyConnectedTexture& FullyConnectedTexture::operator=(
|
FullyConnected& FullyConnected::operator=(FullyConnected&& kernel) {
|
||||||
FullyConnectedTexture&& kernel) {
|
|
||||||
if (this != &kernel) {
|
if (this != &kernel) {
|
||||||
weights_ = std::move(kernel.weights_);
|
weights_ = std::move(kernel.weights_);
|
||||||
biases_ = std::move(kernel.biases_);
|
biases_ = std::move(kernel.biases_);
|
||||||
@ -112,7 +111,7 @@ FullyConnectedTexture& FullyConnectedTexture::operator=(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status FullyConnectedTexture::Compile(const CreationContext& creation_context) {
|
Status FullyConnected::Compile(const CreationContext& creation_context) {
|
||||||
int wg_width = 32;
|
int wg_width = 32;
|
||||||
int wg_height = 4;
|
int wg_height = 4;
|
||||||
int work_items;
|
int work_items;
|
||||||
@ -136,7 +135,7 @@ Status FullyConnectedTexture::Compile(const CreationContext& creation_context) {
|
|||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status FullyConnectedTexture::AddToQueue(CLCommandQueue* queue) {
|
Status FullyConnected::AddToQueue(CLCommandQueue* queue) {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||||
@ -150,11 +149,11 @@ Status FullyConnectedTexture::AddToQueue(CLCommandQueue* queue) {
|
|||||||
work_group_size_);
|
work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateFullyConnectedTexture(const CreationContext& creation_context,
|
Status CreateFullyConnected(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
FullyConnectedTexture* result) {
|
FullyConnected* result) {
|
||||||
*result = FullyConnectedTexture(definition);
|
*result = FullyConnected(definition);
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
result->UploadWeights(attr.weights, creation_context.context));
|
result->UploadWeights(attr.weights, creation_context.context));
|
||||||
LinearStorageCreateInfo create_info;
|
LinearStorageCreateInfo create_info;
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_TEXTURE_H_
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_H_
|
||||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_TEXTURE_H_
|
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -34,24 +34,25 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace cl {
|
namespace cl {
|
||||||
|
|
||||||
class FullyConnectedTexture : public GPUOperation {
|
class FullyConnected : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
FullyConnectedTexture() = default;
|
FullyConnected() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
FullyConnectedTexture(FullyConnectedTexture&& kernel);
|
FullyConnected(FullyConnected&& kernel);
|
||||||
FullyConnectedTexture& operator=(FullyConnectedTexture&& kernel);
|
FullyConnected& operator=(FullyConnected&& kernel);
|
||||||
FullyConnectedTexture(const FullyConnectedTexture&) = delete;
|
FullyConnected(const FullyConnected&) = delete;
|
||||||
FullyConnectedTexture& operator=(const FullyConnectedTexture&) = delete;
|
FullyConnected& operator=(const FullyConnected&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit FullyConnectedTexture(const OperationDef& definition);
|
explicit FullyConnected(const OperationDef& definition);
|
||||||
friend Status CreateFullyConnectedTexture(
|
friend Status CreateFullyConnected(const CreationContext& creation_context,
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr, FullyConnectedTexture* result);
|
const FullyConnectedAttributes& attr,
|
||||||
|
FullyConnected* result);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
@ -68,7 +69,7 @@ class FullyConnectedTexture : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status FullyConnectedTexture::UploadWeights(
|
Status FullyConnected::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||||
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||||
@ -92,7 +93,7 @@ Status FullyConnectedTexture::UploadWeights(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <DataType T, typename S>
|
template <DataType T, typename S>
|
||||||
void FullyConnectedTexture::RearrangeWeights(
|
void FullyConnected::RearrangeWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, absl::Span<S> dst) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, absl::Span<S> dst) {
|
||||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||||
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||||
@ -122,13 +123,13 @@ void FullyConnectedTexture::RearrangeWeights(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateFullyConnectedTexture(const CreationContext& creation_context,
|
Status CreateFullyConnected(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
FullyConnectedTexture* result);
|
FullyConnected* result);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_TEXTURE_H_
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_FULLY_CONNECTED_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_texture.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ namespace gpu {
|
|||||||
namespace cl {
|
namespace cl {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST_F(OpenCLOperationTest, FullyConnectedTexture) {
|
TEST_F(OpenCLOperationTest, FullyConnected) {
|
||||||
TensorFloat32 src_tensor;
|
TensorFloat32 src_tensor;
|
||||||
src_tensor.shape = BHWC(1, 1, 1, 4);
|
src_tensor.shape = BHWC(1, 1, 1, 4);
|
||||||
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||||
@ -51,9 +51,9 @@ TEST_F(OpenCLOperationTest, FullyConnectedTexture) {
|
|||||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
TensorFloat32 dst_tensor;
|
TensorFloat32 dst_tensor;
|
||||||
FullyConnectedTexture operation;
|
FullyConnected operation;
|
||||||
ASSERT_OK(CreateFullyConnectedTexture(creation_context_, op_def, attr,
|
ASSERT_OK(
|
||||||
&operation));
|
CreateFullyConnected(creation_context_, op_def, attr, &operation));
|
||||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||||
BHWC(1, 1, 1, 2), &dst_tensor));
|
BHWC(1, 1, 1, 2), &dst_tensor));
|
||||||
EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {14.5f, 37.5f}));
|
EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {14.5f, 37.5f}));
|
@ -66,7 +66,7 @@ cc_library(
|
|||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:conv_buffer_1x1",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:conv_buffer_1x1",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:conv_texture",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:conv_texture",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:fully_connected_texture",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:fully_connected",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_texture.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
@ -36,10 +36,10 @@ Status SelectFullyConnectedAdreno(const FullyConnectedAttributes& attr,
|
|||||||
RETURN_IF_ERROR(CreateConvTexture(creation_context, op_def, attr, &conv));
|
RETURN_IF_ERROR(CreateConvTexture(creation_context, op_def, attr, &conv));
|
||||||
*ptr = absl::make_unique<ConvTexture>(std::move(conv));
|
*ptr = absl::make_unique<ConvTexture>(std::move(conv));
|
||||||
} else {
|
} else {
|
||||||
FullyConnectedTexture fc;
|
FullyConnected fc;
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
CreateFullyConnectedTexture(creation_context, op_def, attr, &fc));
|
CreateFullyConnected(creation_context, op_def, attr, &fc));
|
||||||
*ptr = absl::make_unique<FullyConnectedTexture>(std::move(fc));
|
*ptr = absl::make_unique<FullyConnected>(std::move(fc));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
@ -53,10 +53,10 @@ Status SelectFullyConnectedPowerVR(const FullyConnectedAttributes& attr,
|
|||||||
RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv));
|
RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv));
|
||||||
*ptr = absl::make_unique<ConvPowerVR>(std::move(conv));
|
*ptr = absl::make_unique<ConvPowerVR>(std::move(conv));
|
||||||
} else {
|
} else {
|
||||||
FullyConnectedTexture fc;
|
FullyConnected fc;
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
CreateFullyConnectedTexture(creation_context, op_def, attr, &fc));
|
CreateFullyConnected(creation_context, op_def, attr, &fc));
|
||||||
*ptr = absl::make_unique<FullyConnectedTexture>(std::move(fc));
|
*ptr = absl::make_unique<FullyConnected>(std::move(fc));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
@ -77,10 +77,10 @@ Status SelectFullyConnectedMali(const FullyConnectedAttributes& attr,
|
|||||||
*ptr = absl::make_unique<ConvTexture>(std::move(conv));
|
*ptr = absl::make_unique<ConvTexture>(std::move(conv));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
FullyConnectedTexture fc;
|
FullyConnected fc;
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
CreateFullyConnectedTexture(creation_context, op_def, attr, &fc));
|
CreateFullyConnected(creation_context, op_def, attr, &fc));
|
||||||
*ptr = absl::make_unique<FullyConnectedTexture>(std::move(fc));
|
*ptr = absl::make_unique<FullyConnected>(std::move(fc));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user