Add one example of sparse FullyConnected kernel.
PiperOrigin-RevId: 293450452 Change-Id: I4ad38bc8f871291a6db11b80b44660f5ae9e6421
This commit is contained in:
parent
8f0dfa951d
commit
8a9c9b6af7
@ -23,10 +23,12 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
@ -38,11 +40,24 @@ namespace ops {
|
||||
namespace builtin {
|
||||
namespace fully_connected {
|
||||
|
||||
namespace {
|
||||
bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
|
||||
if (sparsity.dim_metadata[0].format == kTfLiteDimSparseCSR &&
|
||||
sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// This file has four implementations of FullyConnected
|
||||
enum KernelType {
|
||||
kReference,
|
||||
kGenericOptimized,
|
||||
kLegacyPie, // Legacy path used by the PIE team and related clients.
|
||||
kSparseReference,
|
||||
kSparseOptimized,
|
||||
};
|
||||
|
||||
struct OpData {
|
||||
@ -574,11 +589,41 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
FullyConnectedParams op_params;
|
||||
op_params.float_activation_min = output_activation_min;
|
||||
op_params.float_activation_max = output_activation_max;
|
||||
|
||||
reference_ops::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||
GetTensorShape(bias), GetTensorData<float>(bias),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
} else if (kernel_type == kSparseReference) {
|
||||
FullyConnectedParams op_params;
|
||||
op_params.float_activation_min = output_activation_min;
|
||||
op_params.float_activation_max = output_activation_max;
|
||||
TF_LITE_ENSURE(context, filter->sparsity != nullptr);
|
||||
|
||||
const auto& sparsity = *filter->sparsity;
|
||||
reference_ops::FullyConnectedSparseWeight(
|
||||
sparsity, op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||
GetTensorShape(bias), GetTensorData<float>(bias),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
} else if (kernel_type == kSparseOptimized) {
|
||||
FullyConnectedParams op_params;
|
||||
op_params.float_activation_min = output_activation_min;
|
||||
op_params.float_activation_max = output_activation_max;
|
||||
TF_LITE_ENSURE(context, filter->sparsity != nullptr);
|
||||
|
||||
const auto& sparsity = *filter->sparsity;
|
||||
if (!SupportedSparsityFormat(sparsity)) {
|
||||
context->ReportError(context,
|
||||
"Unsupported sparse fully-connected weight format.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
optimized_ops::FullyConnectedSparseWeight(
|
||||
sparsity, op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||
GetTensorShape(bias), GetTensorData<float>(bias),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
} else if (kernel_type == kLegacyPie) {
|
||||
return EvalPie(context, node, params, data, input, filter, bias, output);
|
||||
} else {
|
||||
@ -653,6 +698,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
} // namespace fully_connected
|
||||
|
||||
// TODO(b/147449640): Clean up sparse registrations after conversion is done.
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_REF() {
|
||||
static TfLiteRegistration r = {
|
||||
fully_connected::Init, fully_connected::Free,
|
||||
fully_connected::Prepare<fully_connected::kSparseReference>,
|
||||
fully_connected::Eval<fully_connected::kSparseReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_OPT() {
|
||||
static TfLiteRegistration r = {
|
||||
fully_connected::Init, fully_connected::Free,
|
||||
fully_connected::Prepare<fully_connected::kSparseOptimized>,
|
||||
fully_connected::Eval<fully_connected::kSparseOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
|
||||
static TfLiteRegistration r = {
|
||||
fully_connected::Init, fully_connected::Free,
|
||||
|
@ -29,6 +29,8 @@ namespace builtin {
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED_REF();
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT();
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED_PIE();
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_REF();
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_OPT();
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/kernels/fully_connected.h"
|
||||
|
||||
#include <initializer_list>
|
||||
#include <iomanip>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
@ -356,6 +357,11 @@ const auto kKernelMapNoPie = new std::map<string, TfLiteRegistration*>({
|
||||
{"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
|
||||
});
|
||||
|
||||
const auto kKernelMapSparse = new std::map<string, TfLiteRegistration*>({
|
||||
{"SparseReference", ops::builtin::Register_FULLY_CONNECTED_SPARSE_REF()},
|
||||
{"SparseOptimized", ops::builtin::Register_FULLY_CONNECTED_SPARSE_OPT()},
|
||||
});
|
||||
|
||||
class QuantizedFullyConnectedOpTest : public SingleOpTest {
|
||||
protected:
|
||||
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
||||
@ -1061,5 +1067,113 @@ TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class SparseFullyConnectedOpModel : public SingleOpModel {
|
||||
public:
|
||||
SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
|
||||
int batches, const TensorData& input,
|
||||
std::initializer_list<int> weights_shape,
|
||||
std::initializer_list<T> weights_data)
|
||||
: batches_(batches), units_(units) {
|
||||
int total_input_size = 1;
|
||||
for (size_t i = 0; i < input.shape.size(); ++i) {
|
||||
total_input_size *= input.shape[i];
|
||||
}
|
||||
input_size_ = total_input_size / batches_;
|
||||
|
||||
input_ = AddInput(input);
|
||||
weights_ = AddConstSparseInput(input.type, weights_shape, weights_data);
|
||||
|
||||
TensorData bias{input.type, {units_}};
|
||||
bias_ = AddInput(bias);
|
||||
|
||||
output_ = AddOutput({input.type});
|
||||
|
||||
SetBuiltinOp(
|
||||
BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
|
||||
CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
|
||||
.Union());
|
||||
resolver_ = absl::make_unique<SingleOpResolver>(
|
||||
BuiltinOperator_FULLY_CONNECTED, registration);
|
||||
BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
|
||||
}
|
||||
void SetBias(const std::vector<T>& data) { PopulateTensor(bias_, data); }
|
||||
void SetInput(const std::vector<T>& data) { PopulateTensor(input_, data); }
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
int input_size() { return input_size_; }
|
||||
int num_units() { return units_; }
|
||||
int num_batches() { return batches_; }
|
||||
|
||||
protected:
|
||||
int input_;
|
||||
int weights_;
|
||||
int bias_;
|
||||
int output_;
|
||||
|
||||
int batches_;
|
||||
int units_;
|
||||
int input_size_;
|
||||
};
|
||||
|
||||
class SparseFullyConnectedOpTest : public SingleOpTest {
|
||||
protected:
|
||||
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
||||
return *kKernelMapSparse;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(SparseFullyConnectedOpTest, SimpleTest) {
|
||||
std::initializer_list<int> weight_shape = {3, 10};
|
||||
std::initializer_list<float> weight_data = {
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
|
||||
};
|
||||
SparseFullyConnectedOpModel<float> m(
|
||||
GetRegistration(), /*units=*/3, /*batches=*/2,
|
||||
/*input=*/{TensorType_FLOAT32, {2, 10}}, weight_shape, weight_data);
|
||||
m.SetBias({1, 2, 3});
|
||||
|
||||
m.SetInput({
|
||||
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
|
||||
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
|
||||
});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
|
||||
}
|
||||
|
||||
TEST_P(SparseFullyConnectedOpTest, SimpleTest2) {
|
||||
std::initializer_list<int> weight_shape = {1, 2};
|
||||
std::initializer_list<float> weight_data = {
|
||||
2, 4 // u = 0
|
||||
};
|
||||
SparseFullyConnectedOpModel<float> m(
|
||||
GetRegistration(), /*units=*/1, /*batches=*/2,
|
||||
/*input=*/{TensorType_FLOAT32, {2, 2}}, weight_shape, weight_data);
|
||||
m.SetBias({1});
|
||||
|
||||
m.SetInput({
|
||||
1, 2, // b = 0
|
||||
2, 1 // b = 1
|
||||
});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
|
||||
}
|
||||
|
||||
// TODO(b/148391360): Add tests for unsupported sparsity format.
|
||||
// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SparseFullyConnectedOpTest, SparseFullyConnectedOpTest,
|
||||
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapSparse)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -232,6 +232,7 @@ cc_library(
|
||||
"optimized/integer_ops/softmax.h",
|
||||
"optimized/integer_ops/transpose_conv.h",
|
||||
"optimized/optimized_ops.h",
|
||||
"optimized/sparse_ops/fully_connected.h",
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
@ -456,6 +457,7 @@ cc_library(
|
||||
"reference/requantize.h",
|
||||
"reference/round.h",
|
||||
"reference/softmax.h",
|
||||
"reference/sparse_ops/fully_connected.h",
|
||||
"reference/strided_slice.h",
|
||||
"reference/svdf.h",
|
||||
],
|
||||
|
@ -0,0 +1,75 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/round.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimized_ops {
|
||||
|
||||
inline void FullyConnectedSparseWeight(
|
||||
const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
|
||||
const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& weights_shape, const float* weights_data,
|
||||
const RuntimeShape& bias_shape, const float* bias_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const float output_activation_min = params.float_activation_min;
|
||||
const float output_activation_max = params.float_activation_max;
|
||||
|
||||
const int output_elements = output_shape.FlatSize();
|
||||
const int output_dims_count = output_shape.DimensionsCount();
|
||||
const int weights_dims_count = weights_shape.DimensionsCount();
|
||||
const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
|
||||
const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
|
||||
output_shape, output_dims_count - 1);
|
||||
const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
|
||||
const int* w0_segments = sparsity.dim_metadata[0].array_segments->data;
|
||||
const int* w0_indices = sparsity.dim_metadata[0].array_indices->data;
|
||||
const int* w1_segments = sparsity.dim_metadata[1].array_segments->data;
|
||||
const int* w1_indices = sparsity.dim_metadata[1].array_indices->data;
|
||||
|
||||
for (int i = 0; i < output_elements; ++i) {
|
||||
output_data[i] = 0.f;
|
||||
}
|
||||
|
||||
for (int b = 0; b < batches; ++b) {
|
||||
for (int pw0 = w0_segments[0]; pw0 < w0_segments[1]; ++pw0) {
|
||||
int idx_0 = w0_indices[pw0];
|
||||
for (int pw1 = w1_segments[pw0]; pw1 < w1_segments[pw0 + 1]; ++pw1) {
|
||||
int idx_1 = w1_indices[pw1];
|
||||
output_data[b * output_depth + idx_0] +=
|
||||
weights_data[pw1] * input_data[b * accum_depth + idx_1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int b = 0; b < batches; ++b) {
|
||||
for (int i = 0; i < output_depth; ++i) {
|
||||
float total = output_data[b * output_depth + i];
|
||||
float bias_value = bias_data[i];
|
||||
output_data[b * output_depth + i] = ActivationFunctionWithMinMax(
|
||||
total + bias_value, output_activation_min, output_activation_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
|
@ -0,0 +1,46 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_ops {
|
||||
|
||||
// Convert weights to dense format and run dense fully connected.
|
||||
inline void FullyConnectedSparseWeight(
|
||||
const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
|
||||
const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& weights_shape, const float* weights_data,
|
||||
const RuntimeShape& bias_shape, const float* bias_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
std::vector<int> weights_shape_vector(weights_shape.DimensionsCount());
|
||||
for (int i = 0; i < weights_shape.DimensionsCount(); i++) {
|
||||
weights_shape_vector[i] = weights_shape.Dims(i);
|
||||
}
|
||||
tflite::optimize::sparsity::FormatConverter<float> converter(
|
||||
weights_shape_vector, sparsity);
|
||||
converter.SparseToDense(weights_data);
|
||||
const std::vector<float> dense_weights_data = converter.GetData();
|
||||
FullyConnected(params, input_shape, input_data, weights_shape,
|
||||
dense_weights_data.data(), bias_shape, bias_data, output_shape,
|
||||
output_data);
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_
|
Loading…
Reference in New Issue
Block a user