Add one example of sparse FullyConnected kernel.
PiperOrigin-RevId: 293450452 Change-Id: I4ad38bc8f871291a6db11b80b44660f5ae9e6421
This commit is contained in:
parent
8f0dfa951d
commit
8a9c9b6af7
tensorflow/lite/kernels
@ -23,10 +23,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||||
@ -38,11 +40,24 @@ namespace ops {
|
|||||||
namespace builtin {
|
namespace builtin {
|
||||||
namespace fully_connected {
|
namespace fully_connected {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
|
||||||
|
if (sparsity.dim_metadata[0].format == kTfLiteDimSparseCSR &&
|
||||||
|
sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// This file has four implementations of FullyConnected
|
// This file has four implementations of FullyConnected
|
||||||
enum KernelType {
|
enum KernelType {
|
||||||
kReference,
|
kReference,
|
||||||
kGenericOptimized,
|
kGenericOptimized,
|
||||||
kLegacyPie, // Legacy path used by the PIE team and related clients.
|
kLegacyPie, // Legacy path used by the PIE team and related clients.
|
||||||
|
kSparseReference,
|
||||||
|
kSparseOptimized,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OpData {
|
struct OpData {
|
||||||
@ -574,11 +589,41 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
|||||||
FullyConnectedParams op_params;
|
FullyConnectedParams op_params;
|
||||||
op_params.float_activation_min = output_activation_min;
|
op_params.float_activation_min = output_activation_min;
|
||||||
op_params.float_activation_max = output_activation_max;
|
op_params.float_activation_max = output_activation_max;
|
||||||
|
|
||||||
reference_ops::FullyConnected(
|
reference_ops::FullyConnected(
|
||||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||||
GetTensorShape(filter), GetTensorData<float>(filter),
|
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||||
GetTensorShape(bias), GetTensorData<float>(bias),
|
GetTensorShape(bias), GetTensorData<float>(bias),
|
||||||
GetTensorShape(output), GetTensorData<float>(output));
|
GetTensorShape(output), GetTensorData<float>(output));
|
||||||
|
} else if (kernel_type == kSparseReference) {
|
||||||
|
FullyConnectedParams op_params;
|
||||||
|
op_params.float_activation_min = output_activation_min;
|
||||||
|
op_params.float_activation_max = output_activation_max;
|
||||||
|
TF_LITE_ENSURE(context, filter->sparsity != nullptr);
|
||||||
|
|
||||||
|
const auto& sparsity = *filter->sparsity;
|
||||||
|
reference_ops::FullyConnectedSparseWeight(
|
||||||
|
sparsity, op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||||
|
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||||
|
GetTensorShape(bias), GetTensorData<float>(bias),
|
||||||
|
GetTensorShape(output), GetTensorData<float>(output));
|
||||||
|
} else if (kernel_type == kSparseOptimized) {
|
||||||
|
FullyConnectedParams op_params;
|
||||||
|
op_params.float_activation_min = output_activation_min;
|
||||||
|
op_params.float_activation_max = output_activation_max;
|
||||||
|
TF_LITE_ENSURE(context, filter->sparsity != nullptr);
|
||||||
|
|
||||||
|
const auto& sparsity = *filter->sparsity;
|
||||||
|
if (!SupportedSparsityFormat(sparsity)) {
|
||||||
|
context->ReportError(context,
|
||||||
|
"Unsupported sparse fully-connected weight format.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
optimized_ops::FullyConnectedSparseWeight(
|
||||||
|
sparsity, op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||||
|
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||||
|
GetTensorShape(bias), GetTensorData<float>(bias),
|
||||||
|
GetTensorShape(output), GetTensorData<float>(output));
|
||||||
} else if (kernel_type == kLegacyPie) {
|
} else if (kernel_type == kLegacyPie) {
|
||||||
return EvalPie(context, node, params, data, input, filter, bias, output);
|
return EvalPie(context, node, params, data, input, filter, bias, output);
|
||||||
} else {
|
} else {
|
||||||
@ -653,6 +698,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
} // namespace fully_connected
|
} // namespace fully_connected
|
||||||
|
|
||||||
|
// TODO(b/147449640): Clean up sparse registrations after conversion is done.
|
||||||
|
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_REF() {
|
||||||
|
static TfLiteRegistration r = {
|
||||||
|
fully_connected::Init, fully_connected::Free,
|
||||||
|
fully_connected::Prepare<fully_connected::kSparseReference>,
|
||||||
|
fully_connected::Eval<fully_connected::kSparseReference>};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_OPT() {
|
||||||
|
static TfLiteRegistration r = {
|
||||||
|
fully_connected::Init, fully_connected::Free,
|
||||||
|
fully_connected::Prepare<fully_connected::kSparseOptimized>,
|
||||||
|
fully_connected::Eval<fully_connected::kSparseOptimized>};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
|
TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
|
||||||
static TfLiteRegistration r = {
|
static TfLiteRegistration r = {
|
||||||
fully_connected::Init, fully_connected::Free,
|
fully_connected::Init, fully_connected::Free,
|
||||||
|
@ -29,6 +29,8 @@ namespace builtin {
|
|||||||
TfLiteRegistration* Register_FULLY_CONNECTED_REF();
|
TfLiteRegistration* Register_FULLY_CONNECTED_REF();
|
||||||
TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT();
|
TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT();
|
||||||
TfLiteRegistration* Register_FULLY_CONNECTED_PIE();
|
TfLiteRegistration* Register_FULLY_CONNECTED_PIE();
|
||||||
|
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_REF();
|
||||||
|
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_OPT();
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/kernels/fully_connected.h"
|
#include "tensorflow/lite/kernels/fully_connected.h"
|
||||||
|
|
||||||
|
#include <initializer_list>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -356,6 +357,11 @@ const auto kKernelMapNoPie = new std::map<string, TfLiteRegistration*>({
|
|||||||
{"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
|
{"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const auto kKernelMapSparse = new std::map<string, TfLiteRegistration*>({
|
||||||
|
{"SparseReference", ops::builtin::Register_FULLY_CONNECTED_SPARSE_REF()},
|
||||||
|
{"SparseOptimized", ops::builtin::Register_FULLY_CONNECTED_SPARSE_OPT()},
|
||||||
|
});
|
||||||
|
|
||||||
class QuantizedFullyConnectedOpTest : public SingleOpTest {
|
class QuantizedFullyConnectedOpTest : public SingleOpTest {
|
||||||
protected:
|
protected:
|
||||||
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
||||||
@ -1061,5 +1067,113 @@ TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class SparseFullyConnectedOpModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
|
||||||
|
int batches, const TensorData& input,
|
||||||
|
std::initializer_list<int> weights_shape,
|
||||||
|
std::initializer_list<T> weights_data)
|
||||||
|
: batches_(batches), units_(units) {
|
||||||
|
int total_input_size = 1;
|
||||||
|
for (size_t i = 0; i < input.shape.size(); ++i) {
|
||||||
|
total_input_size *= input.shape[i];
|
||||||
|
}
|
||||||
|
input_size_ = total_input_size / batches_;
|
||||||
|
|
||||||
|
input_ = AddInput(input);
|
||||||
|
weights_ = AddConstSparseInput(input.type, weights_shape, weights_data);
|
||||||
|
|
||||||
|
TensorData bias{input.type, {units_}};
|
||||||
|
bias_ = AddInput(bias);
|
||||||
|
|
||||||
|
output_ = AddOutput({input.type});
|
||||||
|
|
||||||
|
SetBuiltinOp(
|
||||||
|
BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
|
||||||
|
CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
|
||||||
|
.Union());
|
||||||
|
resolver_ = absl::make_unique<SingleOpResolver>(
|
||||||
|
BuiltinOperator_FULLY_CONNECTED, registration);
|
||||||
|
BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
|
||||||
|
}
|
||||||
|
void SetBias(const std::vector<T>& data) { PopulateTensor(bias_, data); }
|
||||||
|
void SetInput(const std::vector<T>& data) { PopulateTensor(input_, data); }
|
||||||
|
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||||
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
|
|
||||||
|
int input_size() { return input_size_; }
|
||||||
|
int num_units() { return units_; }
|
||||||
|
int num_batches() { return batches_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int input_;
|
||||||
|
int weights_;
|
||||||
|
int bias_;
|
||||||
|
int output_;
|
||||||
|
|
||||||
|
int batches_;
|
||||||
|
int units_;
|
||||||
|
int input_size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SparseFullyConnectedOpTest : public SingleOpTest {
|
||||||
|
protected:
|
||||||
|
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
||||||
|
return *kKernelMapSparse;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(SparseFullyConnectedOpTest, SimpleTest) {
|
||||||
|
std::initializer_list<int> weight_shape = {3, 10};
|
||||||
|
std::initializer_list<float> weight_data = {
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
|
||||||
|
};
|
||||||
|
SparseFullyConnectedOpModel<float> m(
|
||||||
|
GetRegistration(), /*units=*/3, /*batches=*/2,
|
||||||
|
/*input=*/{TensorType_FLOAT32, {2, 10}}, weight_shape, weight_data);
|
||||||
|
m.SetBias({1, 2, 3});
|
||||||
|
|
||||||
|
m.SetInput({
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
|
||||||
|
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
|
||||||
|
});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(SparseFullyConnectedOpTest, SimpleTest2) {
|
||||||
|
std::initializer_list<int> weight_shape = {1, 2};
|
||||||
|
std::initializer_list<float> weight_data = {
|
||||||
|
2, 4 // u = 0
|
||||||
|
};
|
||||||
|
SparseFullyConnectedOpModel<float> m(
|
||||||
|
GetRegistration(), /*units=*/1, /*batches=*/2,
|
||||||
|
/*input=*/{TensorType_FLOAT32, {2, 2}}, weight_shape, weight_data);
|
||||||
|
m.SetBias({1});
|
||||||
|
|
||||||
|
m.SetInput({
|
||||||
|
1, 2, // b = 0
|
||||||
|
2, 1 // b = 1
|
||||||
|
});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(b/148391360): Add tests for unsupported sparsity format.
|
||||||
|
// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
SparseFullyConnectedOpTest, SparseFullyConnectedOpTest,
|
||||||
|
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapSparse)));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -232,6 +232,7 @@ cc_library(
|
|||||||
"optimized/integer_ops/softmax.h",
|
"optimized/integer_ops/softmax.h",
|
||||||
"optimized/integer_ops/transpose_conv.h",
|
"optimized/integer_ops/transpose_conv.h",
|
||||||
"optimized/optimized_ops.h",
|
"optimized/optimized_ops.h",
|
||||||
|
"optimized/sparse_ops/fully_connected.h",
|
||||||
],
|
],
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
@ -456,6 +457,7 @@ cc_library(
|
|||||||
"reference/requantize.h",
|
"reference/requantize.h",
|
||||||
"reference/round.h",
|
"reference/round.h",
|
||||||
"reference/softmax.h",
|
"reference/softmax.h",
|
||||||
|
"reference/sparse_ops/fully_connected.h",
|
||||||
"reference/strided_slice.h",
|
"reference/strided_slice.h",
|
||||||
"reference/svdf.h",
|
"reference/svdf.h",
|
||||||
],
|
],
|
||||||
|
@ -0,0 +1,75 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/round.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace optimized_ops {
|
||||||
|
|
||||||
|
inline void FullyConnectedSparseWeight(
|
||||||
|
const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
|
||||||
|
const RuntimeShape& input_shape, const float* input_data,
|
||||||
|
const RuntimeShape& weights_shape, const float* weights_data,
|
||||||
|
const RuntimeShape& bias_shape, const float* bias_data,
|
||||||
|
const RuntimeShape& output_shape, float* output_data) {
|
||||||
|
const float output_activation_min = params.float_activation_min;
|
||||||
|
const float output_activation_max = params.float_activation_max;
|
||||||
|
|
||||||
|
const int output_elements = output_shape.FlatSize();
|
||||||
|
const int output_dims_count = output_shape.DimensionsCount();
|
||||||
|
const int weights_dims_count = weights_shape.DimensionsCount();
|
||||||
|
const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
|
||||||
|
const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
|
||||||
|
output_shape, output_dims_count - 1);
|
||||||
|
const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
|
||||||
|
const int* w0_segments = sparsity.dim_metadata[0].array_segments->data;
|
||||||
|
const int* w0_indices = sparsity.dim_metadata[0].array_indices->data;
|
||||||
|
const int* w1_segments = sparsity.dim_metadata[1].array_segments->data;
|
||||||
|
const int* w1_indices = sparsity.dim_metadata[1].array_indices->data;
|
||||||
|
|
||||||
|
for (int i = 0; i < output_elements; ++i) {
|
||||||
|
output_data[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < batches; ++b) {
|
||||||
|
for (int pw0 = w0_segments[0]; pw0 < w0_segments[1]; ++pw0) {
|
||||||
|
int idx_0 = w0_indices[pw0];
|
||||||
|
for (int pw1 = w1_segments[pw0]; pw1 < w1_segments[pw0 + 1]; ++pw1) {
|
||||||
|
int idx_1 = w1_indices[pw1];
|
||||||
|
output_data[b * output_depth + idx_0] +=
|
||||||
|
weights_data[pw1] * input_data[b * accum_depth + idx_1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < batches; ++b) {
|
||||||
|
for (int i = 0; i < output_depth; ++i) {
|
||||||
|
float total = output_data[b * output_depth + i];
|
||||||
|
float bias_value = bias_data[i];
|
||||||
|
output_data[b * output_depth + i] = ActivationFunctionWithMinMax(
|
||||||
|
total + bias_value, output_activation_min, output_activation_max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace optimized_ops
|
||||||
|
} // namespace tflite
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
|
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||||
|
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
// Convert weights to dense format and run dense fully connected.
|
||||||
|
inline void FullyConnectedSparseWeight(
|
||||||
|
const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
|
||||||
|
const RuntimeShape& input_shape, const float* input_data,
|
||||||
|
const RuntimeShape& weights_shape, const float* weights_data,
|
||||||
|
const RuntimeShape& bias_shape, const float* bias_data,
|
||||||
|
const RuntimeShape& output_shape, float* output_data) {
|
||||||
|
std::vector<int> weights_shape_vector(weights_shape.DimensionsCount());
|
||||||
|
for (int i = 0; i < weights_shape.DimensionsCount(); i++) {
|
||||||
|
weights_shape_vector[i] = weights_shape.Dims(i);
|
||||||
|
}
|
||||||
|
tflite::optimize::sparsity::FormatConverter<float> converter(
|
||||||
|
weights_shape_vector, sparsity);
|
||||||
|
converter.SparseToDense(weights_data);
|
||||||
|
const std::vector<float> dense_weights_data = converter.GetData();
|
||||||
|
FullyConnected(params, input_shape, input_data, weights_shape,
|
||||||
|
dense_weights_data.data(), bias_shape, bias_data, output_shape,
|
||||||
|
output_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_
|
Loading…
Reference in New Issue
Block a user