Using common FullyConnected task and selector for Metal.
PiperOrigin-RevId: 355410583 Change-Id: Ieb7b9a337dc45b290df6dfb4a2176dced07ff75c
This commit is contained in:
parent
54bcf2919f
commit
8130665856
@ -65,6 +65,7 @@ std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
|
||||
result += "#define GROUP_SIZE_0 get_local_size(0)\n";
|
||||
result += "#define GROUP_SIZE_1 get_local_size(1)\n";
|
||||
result += "#define GROUP_SIZE_2 get_local_size(2)\n";
|
||||
result += "#define SIMD_LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
|
||||
result += "#define LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
|
||||
result += "#define MAIN_FUNCTION __kernel void main_function\n";
|
||||
result += "#define INIT_FLOAT(value) (float)(value)\n";
|
||||
|
@ -85,8 +85,8 @@ std::unique_ptr<GPUOperation> SelectFullyConnected(
|
||||
const OperationDef& op_def, int batch_size) {
|
||||
if (gpu_info.IsAdreno()) {
|
||||
return SelectFullyConnectedAdreno(attr, gpu_info, op_def, batch_size);
|
||||
} else if (gpu_info.IsPowerVR() || gpu_info.IsAMD() ||
|
||||
gpu_info.IsNvidia() || gpu_info.IsIntel()) {
|
||||
} else if (gpu_info.IsPowerVR() || gpu_info.IsAMD() || gpu_info.IsNvidia() ||
|
||||
gpu_info.IsIntel() || gpu_info.IsApple()) {
|
||||
return SelectFullyConnectedPowerVR(attr, gpu_info, op_def, batch_size);
|
||||
} else if (gpu_info.IsMali()) {
|
||||
return SelectFullyConnectedMali(attr, gpu_info, op_def, batch_size);
|
||||
|
@ -374,6 +374,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:storage_type_util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:tensor_linear_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
@ -31,7 +32,8 @@ namespace gpu {
|
||||
|
||||
namespace {
|
||||
bool UseBufferForWeights(const GpuInfo& gpu_info) {
|
||||
return gpu_info.IsAdreno() || gpu_info.IsAMD() || gpu_info.IsMali();
|
||||
return gpu_info.IsAdreno() || gpu_info.IsAMD() || gpu_info.IsMali() ||
|
||||
gpu_info.IsApple();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -46,11 +48,8 @@ FullyConnected::FullyConnected(const OperationDef& definition,
|
||||
} else {
|
||||
work_group_size_ = int3(32, 4, 1);
|
||||
}
|
||||
} else if (gpu_info.IsIntel()) {
|
||||
work_group_size_ = int3(8, 4, 1);
|
||||
} else if (gpu_info.IsNvidia()) {
|
||||
work_group_size_ = int3(8, 4, 1);
|
||||
} else if (gpu_info.IsPowerVR()) {
|
||||
} else if (gpu_info.IsIntel() || gpu_info.IsNvidia() ||
|
||||
gpu_info.IsPowerVR() || gpu_info.IsApple()) {
|
||||
work_group_size_ = int3(8, 4, 1);
|
||||
} else {
|
||||
work_group_size_ = int3(16, 4, 1);
|
||||
@ -76,6 +75,11 @@ FullyConnected& FullyConnected::operator=(FullyConnected&& kernel) {
|
||||
|
||||
std::string FullyConnected::GetFullyConnectedKernelCode(
|
||||
const OperationDef& op_def, const GpuInfo& gpu_info) {
|
||||
const int wg_total_size = work_group_size_.x * work_group_size_.y;
|
||||
const std::string barrier =
|
||||
wg_total_size == 32 && gpu_info.IsWaveSizeEqualTo32()
|
||||
? "SIMD_LOCAL_MEM_BARRIER"
|
||||
: "LOCAL_MEM_BARRIER";
|
||||
AddSrcTensor("src_tensor", op_def.src_tensors[0]);
|
||||
AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
|
||||
|
||||
@ -127,7 +131,9 @@ std::string FullyConnected::GetFullyConnectedKernelCode(
|
||||
}
|
||||
__local ACCUM_FLT4 temp[WG_X][WG_Y];
|
||||
temp[tid.x][tid.y] = s;
|
||||
LOCAL_MEM_BARRIER;
|
||||
)";
|
||||
c += " " + barrier + ";\n";
|
||||
c += R"(
|
||||
if (gid >= args.dst_tensor.Slices()) {
|
||||
return;
|
||||
}
|
||||
@ -157,6 +163,10 @@ FullyConnected CreateFullyConnected(const GpuInfo& gpu_info,
|
||||
TensorLinearDescriptor desc;
|
||||
desc.storage_type = gpu_info.SupportsImages() ? LinearStorageType::TEXTURE_2D
|
||||
: LinearStorageType::BUFFER;
|
||||
if (gpu_info.IsApple()) {
|
||||
desc.storage_type =
|
||||
DeduceLinearStorageType(definition.GetPrimaryStorageType());
|
||||
}
|
||||
desc.element_type = definition.GetDataType();
|
||||
desc.UploadLinearData(attr.bias);
|
||||
result.args_.AddObject(
|
||||
|
@ -136,6 +136,7 @@ absl::Status ComputeTask::CompileProgram(MetalDevice* device,
|
||||
@"TO_ACCUM_TYPE" : toAccumulatorType4,
|
||||
@"TO_FLT4" : [NSString stringWithFormat:@"%@4", storageType],
|
||||
@"SIMDGROUP_BARRIER" : barrier,
|
||||
@"SIMD_LOCAL_MEM_BARRIER" : barrier,
|
||||
@"MAIN_FUNCTION" : @"\"kernel void ComputeFunction\"",
|
||||
@"GLOBAL_ID_0" : @"static_cast<int>(reserved_gid.x)",
|
||||
@"GLOBAL_ID_1" : @"static_cast<int>(reserved_gid.y)",
|
||||
|
@ -20,7 +20,6 @@ cc_library(
|
||||
deps = [
|
||||
":conv",
|
||||
":depthwise_conv",
|
||||
":fully_connected",
|
||||
":transpose_conv",
|
||||
":winograd",
|
||||
],
|
||||
@ -180,30 +179,12 @@ ios_unit_test(
|
||||
deps = [":elementwise_test_lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "fully_connected",
|
||||
srcs = ["fully_connected.cc"],
|
||||
hdrs = ["fully_connected.h"],
|
||||
deps = [
|
||||
":util",
|
||||
"//tensorflow/lite/delegates/gpu/common:gpu_info",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "fully_connected_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["fully_connected_test.mm"],
|
||||
sdk_frameworks = ["XCTest"],
|
||||
deps = [
|
||||
":fully_connected",
|
||||
":test_util",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:fully_connected_test_util",
|
||||
],
|
||||
|
@ -1,182 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
|
||||
std::string GetFullyConnectedCode(const GpuInfo& gpu_info, int src_channels,
|
||||
int dst_channels) {
|
||||
bool shared_memory = gpu_info.IsApple() &&
|
||||
gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal();
|
||||
const std::string barrier = gpu_info.IsWaveSizeEqualTo32()
|
||||
? "SIMDGROUP_BARRIER"
|
||||
: "threadgroup_barrier";
|
||||
const int src_depth = DivideRoundUp(src_channels, 4);
|
||||
std::stringstream code;
|
||||
code << R"(
|
||||
kernel void ComputeFunction(
|
||||
$$0
|
||||
uint3 tid[[thread_position_in_threadgroup]],
|
||||
uint tid_index[[thread_index_in_threadgroup]],
|
||||
uint3 ugid[[thread_position_in_grid]]) {
|
||||
|
||||
)";
|
||||
if (shared_memory) {
|
||||
code << R"(
|
||||
float summa = 0.0f;
|
||||
threadgroup FLT4 local_vector[32];
|
||||
for (int j = 0; j < args.src_depth_sub_groups; ++j) {
|
||||
local_vector[tid_index] = j * 32 + tid_index >= args.src_tensor.Slices() ?
|
||||
FLT4(0.0f) : args.src_tensor.Read(0, 0, j * 32 + tid_index);
|
||||
$0(mem_flags::mem_threadgroup);
|
||||
for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) {
|
||||
summa += dot(local_vector[tid.y * 8 + i], args.weights.Read(counter * args.dst_channels_alignedx8 + ugid.x));
|
||||
}
|
||||
$0(mem_flags::mem_none);
|
||||
}
|
||||
)";
|
||||
} else {
|
||||
code << R"(
|
||||
float summa = 0.0f;
|
||||
int counter = int(ugid.y) * args.src_depth_sub_groups;
|
||||
for (int i = 0; i < args.src_depth_sub_groups; ++i, ++counter) {
|
||||
)";
|
||||
if (src_depth % 4 != 0) {
|
||||
code << " if (counter >= args.src_tensor.Slices()) continue;"
|
||||
<< std::endl;
|
||||
}
|
||||
code << " summa += dot(args.src_tensor.Read(0, 0, counter), "
|
||||
"args.weights.Read(counter * "
|
||||
"args.dst_channels_alignedx8 + ugid.x));"
|
||||
<< std::endl;
|
||||
code << " }" << std::endl;
|
||||
}
|
||||
code << R"(
|
||||
|
||||
threadgroup float temp[8][4];
|
||||
temp[tid.x][tid.y] = summa;
|
||||
$0(mem_flags::mem_threadgroup);
|
||||
if (tid.y == 0) {
|
||||
summa += temp[tid.x][1];
|
||||
summa += temp[tid.x][2];
|
||||
summa += temp[tid.x][3];
|
||||
temp[tid.x][0] = summa;
|
||||
}
|
||||
$0(mem_flags::mem_threadgroup);
|
||||
int dst_s = ugid.x / 4;
|
||||
if (tid.y == 0 && tid.x % 4 == 0 && dst_s < args.dst_tensor.Slices()) {
|
||||
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
|
||||
args.bias.Read(dst_s);
|
||||
args.dst_tensor.Write(value, 0, 0, dst_s);
|
||||
}
|
||||
}
|
||||
)";
|
||||
return absl::Substitute(code.str(), barrier);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int3 FullyConnected::GetGridSize() const {
|
||||
const int dst_channels_aligned = AlignByN(dst_[0]->Channels(), 8);
|
||||
return int3(dst_channels_aligned, 1, 1);
|
||||
}
|
||||
|
||||
FullyConnected CreateFullyConnected(const GpuInfo& gpu_info,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr) {
|
||||
FullyConnected desc(definition);
|
||||
desc.code_ = GetFullyConnectedCode(gpu_info, attr.weights.shape.i,
|
||||
attr.weights.shape.o);
|
||||
|
||||
bool shared_memory = gpu_info.IsApple() &&
|
||||
gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal();
|
||||
const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
|
||||
const int src_depth_sub_groups = shared_memory ? DivideRoundUp(src_depth, 32)
|
||||
: DivideRoundUp(src_depth, 4);
|
||||
desc.args_.AddInt("dst_channels_alignedx8",
|
||||
AlignByN(attr.weights.shape.o, 8));
|
||||
desc.args_.AddInt("src_depth_sub_groups", src_depth_sub_groups);
|
||||
|
||||
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||
|
||||
const int src_depth_aligned = AlignByN(src_depth, shared_memory ? 32 : 4);
|
||||
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);
|
||||
|
||||
int counter = 0;
|
||||
std::vector<float> filters_reordered(dst_channels_aligned *
|
||||
src_depth_aligned * 4);
|
||||
for (int j = 0; j < src_depth_aligned; ++j) {
|
||||
for (int i = 0; i < dst_channels_aligned; ++i) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
if (j * 4 + k >= attr.weights.shape.i || i >= attr.weights.shape.o) {
|
||||
filters_reordered[counter++] = 0.0f;
|
||||
} else {
|
||||
const int f_index =
|
||||
attr.weights.shape.LinearIndex({i, 0, 0, j * 4 + k});
|
||||
filters_reordered[counter++] = attr.weights.data[f_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
|
||||
BufferDescriptor weights_desc;
|
||||
weights_desc.element_type = data_type;
|
||||
weights_desc.element_size = 4;
|
||||
weights_desc.data = GetByteBufferConverted(filters_reordered, data_type);
|
||||
weights_desc.size = weights_desc.data.size();
|
||||
|
||||
desc.args_.AddObject(
|
||||
"weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
|
||||
|
||||
BufferDescriptor bias_desc;
|
||||
bias_desc.element_type = data_type;
|
||||
bias_desc.element_size = 4;
|
||||
bias_desc.data = GetByteBufferConvertedResized(attr.bias.data, data_type,
|
||||
dst_channels_aligned);
|
||||
bias_desc.size = bias_desc.data.size();
|
||||
|
||||
desc.args_.AddObject(
|
||||
"bias", absl::make_unique<BufferDescriptor>(std::move(bias_desc)));
|
||||
|
||||
desc.work_group_size_ = int3(8, 4, 1);
|
||||
|
||||
return desc;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -1,63 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_FULLY_CONNECTED_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_FULLY_CONNECTED_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
class FullyConnected : public GPUOperation {
|
||||
public:
|
||||
FullyConnected() = default;
|
||||
void GetPossibleKernelWorkGroups(
|
||||
TuningType tuning_type, const GpuInfo& gpu_info,
|
||||
const KernelInfo& kernel_info,
|
||||
std::vector<int3>* work_groups) const override {
|
||||
work_groups->push_back(work_group_size_);
|
||||
}
|
||||
int3 GetGridSize() const override;
|
||||
|
||||
// Move only
|
||||
FullyConnected(FullyConnected&& kernel) = default;
|
||||
FullyConnected& operator=(FullyConnected&& kernel) = default;
|
||||
FullyConnected(const FullyConnected&) = delete;
|
||||
FullyConnected& operator=(const FullyConnected&) = delete;
|
||||
|
||||
private:
|
||||
explicit FullyConnected(const OperationDef& definition)
|
||||
: GPUOperation(definition) {}
|
||||
friend FullyConnected CreateFullyConnected(
|
||||
const GpuInfo& gpu_info, const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr);
|
||||
};
|
||||
|
||||
FullyConnected CreateFullyConnected(const GpuInfo& gpu_info,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_FULLY_CONNECTED_H_
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/fully_connected_test_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
|
||||
|
||||
@interface FullyConnectedMetalTest : XCTestCase
|
||||
@ -34,50 +33,6 @@ limitations under the License.
|
||||
tflite::gpu::metal::MetalExecutionEnvironment exec_env_;
|
||||
}
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
absl::Status FullyConnectedTest(TestExecutionEnvironment* env) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 1, 1, 2);
|
||||
src_tensor.data = {1, 2};
|
||||
|
||||
FullyConnectedAttributes attr;
|
||||
attr.weights.shape = OHWI(4, 1, 1, 2);
|
||||
attr.weights.data = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
attr.bias.shape = Linear(4);
|
||||
attr.bias.data = {1, 2, 3, 4};
|
||||
|
||||
for (auto storage : env->GetSupportedStorages()) {
|
||||
for (auto precision : env->GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
FullyConnected operation = CreateFullyConnected(env->GetGpuInfo(), op_def, attr);
|
||||
RETURN_IF_ERROR(env->ExecuteGPUOperation(
|
||||
src_tensor, absl::make_unique<FullyConnected>(std::move(operation)), BHWC(1, 1, 1, 4),
|
||||
&dst_tensor));
|
||||
RETURN_IF_ERROR(PointWiseNear({6, 13, 20, 27}, dst_tensor.data, eps))
|
||||
<< "Failed using precision " << ToString(precision);
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
- (void)testFullyConnectedMetal {
|
||||
auto status = tflite::gpu::metal::FullyConnectedTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testFullyConnected {
|
||||
auto status = tflite::gpu::FullyConnectedTest(&exec_env_);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
|
@ -18,6 +18,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common:winograd_util",
|
||||
"//tensorflow/lite/delegates/gpu/common/selectors:default_selector",
|
||||
"//tensorflow/lite/delegates/gpu/common/selectors:fully_connected_selector",
|
||||
"//tensorflow/lite/delegates/gpu/common/selectors:subgraph",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:add",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:concat_xy",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_hints.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/selectors/default_selector.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
@ -53,7 +54,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h"
|
||||
|
||||
@ -399,11 +399,11 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
|
||||
gpu_info);
|
||||
break;
|
||||
case OperationType::FULLY_CONNECTED: {
|
||||
FullyConnected conv_op = CreateFullyConnected(
|
||||
gpu_info, op_def,
|
||||
absl::any_cast<FullyConnectedAttributes>(node.operation.attributes));
|
||||
*gpu_op = absl::make_unique<FullyConnected>(std::move(conv_op));
|
||||
break;
|
||||
auto attr =
|
||||
absl::any_cast<FullyConnectedAttributes>(node.operation.attributes);
|
||||
*gpu_op = SelectFullyConnected(attr, gpu_info, op_def,
|
||||
inputs[0]->tensor.shape.b);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::LSTM: {
|
||||
*gpu_op = SelectLSTM(op_def, gpu_info);
|
||||
|
Loading…
Reference in New Issue
Block a user