Add one example of sparse FullyConnected kernel.

PiperOrigin-RevId: 293450452
Change-Id: I4ad38bc8f871291a6db11b80b44660f5ae9e6421
This commit is contained in:
Yunlu Li 2020-02-05 14:21:41 -08:00 committed by TensorFlower Gardener
parent 8f0dfa951d
commit 8a9c9b6af7
6 changed files with 301 additions and 0 deletions

View File

@ -23,10 +23,12 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
@ -38,11 +40,24 @@ namespace ops {
namespace builtin {
namespace fully_connected {
namespace {
bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
if (sparsity.dim_metadata[0].format == kTfLiteDimSparseCSR &&
sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) {
return true;
}
return false;
}
} // namespace
// This file has four implementations of FullyConnected
enum KernelType {
kReference,
kGenericOptimized,
kLegacyPie, // Legacy path used by the PIE team and related clients.
kSparseReference,
kSparseOptimized,
};
struct OpData {
@ -574,11 +589,41 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
FullyConnectedParams op_params;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
reference_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter),
GetTensorShape(bias), GetTensorData<float>(bias),
GetTensorShape(output), GetTensorData<float>(output));
} else if (kernel_type == kSparseReference) {
FullyConnectedParams op_params;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
TF_LITE_ENSURE(context, filter->sparsity != nullptr);
const auto& sparsity = *filter->sparsity;
reference_ops::FullyConnectedSparseWeight(
sparsity, op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter),
GetTensorShape(bias), GetTensorData<float>(bias),
GetTensorShape(output), GetTensorData<float>(output));
} else if (kernel_type == kSparseOptimized) {
FullyConnectedParams op_params;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
TF_LITE_ENSURE(context, filter->sparsity != nullptr);
const auto& sparsity = *filter->sparsity;
if (!SupportedSparsityFormat(sparsity)) {
context->ReportError(context,
"Unsupported sparse fully-connected weight format.");
return kTfLiteError;
}
optimized_ops::FullyConnectedSparseWeight(
sparsity, op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter),
GetTensorShape(bias), GetTensorData<float>(bias),
GetTensorShape(output), GetTensorData<float>(output));
} else if (kernel_type == kLegacyPie) {
return EvalPie(context, node, params, data, input, filter, bias, output);
} else {
@ -653,6 +698,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace fully_connected
// TODO(b/147449640): Clean up sparse registrations after conversion is done.
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_REF() {
static TfLiteRegistration r = {
fully_connected::Init, fully_connected::Free,
fully_connected::Prepare<fully_connected::kSparseReference>,
fully_connected::Eval<fully_connected::kSparseReference>};
return &r;
}
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_OPT() {
static TfLiteRegistration r = {
fully_connected::Init, fully_connected::Free,
fully_connected::Prepare<fully_connected::kSparseOptimized>,
fully_connected::Eval<fully_connected::kSparseOptimized>};
return &r;
}
TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
static TfLiteRegistration r = {
fully_connected::Init, fully_connected::Free,

View File

@ -29,6 +29,8 @@ namespace builtin {
TfLiteRegistration* Register_FULLY_CONNECTED_REF();
TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT();
TfLiteRegistration* Register_FULLY_CONNECTED_PIE();
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_REF();
TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_OPT();
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/fully_connected.h"
#include <initializer_list>
#include <iomanip>
#include <random>
#include <vector>
@ -356,6 +357,11 @@ const auto kKernelMapNoPie = new std::map<string, TfLiteRegistration*>({
{"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
});
const auto kKernelMapSparse = new std::map<string, TfLiteRegistration*>({
{"SparseReference", ops::builtin::Register_FULLY_CONNECTED_SPARSE_REF()},
{"SparseOptimized", ops::builtin::Register_FULLY_CONNECTED_SPARSE_OPT()},
});
class QuantizedFullyConnectedOpTest : public SingleOpTest {
protected:
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
@ -1061,5 +1067,113 @@ TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) {
}
}
template <typename T>
class SparseFullyConnectedOpModel : public SingleOpModel {
public:
SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
int batches, const TensorData& input,
std::initializer_list<int> weights_shape,
std::initializer_list<T> weights_data)
: batches_(batches), units_(units) {
int total_input_size = 1;
for (size_t i = 0; i < input.shape.size(); ++i) {
total_input_size *= input.shape[i];
}
input_size_ = total_input_size / batches_;
input_ = AddInput(input);
weights_ = AddConstSparseInput(input.type, weights_shape, weights_data);
TensorData bias{input.type, {units_}};
bias_ = AddInput(bias);
output_ = AddOutput({input.type});
SetBuiltinOp(
BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
.Union());
resolver_ = absl::make_unique<SingleOpResolver>(
BuiltinOperator_FULLY_CONNECTED, registration);
BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
}
void SetBias(const std::vector<T>& data) { PopulateTensor(bias_, data); }
void SetInput(const std::vector<T>& data) { PopulateTensor(input_, data); }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
int input_size() { return input_size_; }
int num_units() { return units_; }
int num_batches() { return batches_; }
protected:
int input_;
int weights_;
int bias_;
int output_;
int batches_;
int units_;
int input_size_;
};
class SparseFullyConnectedOpTest : public SingleOpTest {
protected:
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
return *kKernelMapSparse;
}
};
TEST_P(SparseFullyConnectedOpTest, SimpleTest) {
std::initializer_list<int> weight_shape = {3, 10};
std::initializer_list<float> weight_data = {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
};
SparseFullyConnectedOpModel<float> m(
GetRegistration(), /*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 10}}, weight_shape, weight_data);
m.SetBias({1, 2, 3});
m.SetInput({
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
TEST_P(SparseFullyConnectedOpTest, SimpleTest2) {
std::initializer_list<int> weight_shape = {1, 2};
std::initializer_list<float> weight_data = {
2, 4 // u = 0
};
SparseFullyConnectedOpModel<float> m(
GetRegistration(), /*units=*/1, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 2}}, weight_shape, weight_data);
m.SetBias({1});
m.SetInput({
1, 2, // b = 0
2, 1 // b = 1
});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
}
// TODO(b/148391360): Add tests for unsupported sparsity format.
// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
INSTANTIATE_TEST_SUITE_P(
SparseFullyConnectedOpTest, SparseFullyConnectedOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapSparse)));
} // namespace
} // namespace tflite

View File

@ -232,6 +232,7 @@ cc_library(
"optimized/integer_ops/softmax.h",
"optimized/integer_ops/transpose_conv.h",
"optimized/optimized_ops.h",
"optimized/sparse_ops/fully_connected.h",
],
copts = tflite_copts(),
deps = [
@ -456,6 +457,7 @@ cc_library(
"reference/requantize.h",
"reference/round.h",
"reference/softmax.h",
"reference/sparse_ops/fully_connected.h",
"reference/strided_slice.h",
"reference/svdf.h",
],

View File

@ -0,0 +1,75 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/round.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace optimized_ops {
inline void FullyConnectedSparseWeight(
const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& weights_shape, const float* weights_data,
const RuntimeShape& bias_shape, const float* bias_data,
const RuntimeShape& output_shape, float* output_data) {
const float output_activation_min = params.float_activation_min;
const float output_activation_max = params.float_activation_max;
const int output_elements = output_shape.FlatSize();
const int output_dims_count = output_shape.DimensionsCount();
const int weights_dims_count = weights_shape.DimensionsCount();
const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
output_shape, output_dims_count - 1);
const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
const int* w0_segments = sparsity.dim_metadata[0].array_segments->data;
const int* w0_indices = sparsity.dim_metadata[0].array_indices->data;
const int* w1_segments = sparsity.dim_metadata[1].array_segments->data;
const int* w1_indices = sparsity.dim_metadata[1].array_indices->data;
for (int i = 0; i < output_elements; ++i) {
output_data[i] = 0.f;
}
for (int b = 0; b < batches; ++b) {
for (int pw0 = w0_segments[0]; pw0 < w0_segments[1]; ++pw0) {
int idx_0 = w0_indices[pw0];
for (int pw1 = w1_segments[pw0]; pw1 < w1_segments[pw0 + 1]; ++pw1) {
int idx_1 = w1_indices[pw1];
output_data[b * output_depth + idx_0] +=
weights_data[pw1] * input_data[b * accum_depth + idx_1];
}
}
}
for (int b = 0; b < batches; ++b) {
for (int i = 0; i < output_depth; ++i) {
float total = output_data[b * output_depth + i];
float bias_value = bias_data[i];
output_data[b * output_depth + i] = ActivationFunctionWithMinMax(
total + bias_value, output_activation_min, output_activation_max);
}
}
}
} // namespace optimized_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
namespace tflite {
namespace reference_ops {
// Convert weights to dense format and run dense fully connected.
inline void FullyConnectedSparseWeight(
const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& weights_shape, const float* weights_data,
const RuntimeShape& bias_shape, const float* bias_data,
const RuntimeShape& output_shape, float* output_data) {
std::vector<int> weights_shape_vector(weights_shape.DimensionsCount());
for (int i = 0; i < weights_shape.DimensionsCount(); i++) {
weights_shape_vector[i] = weights_shape.Dims(i);
}
tflite::optimize::sparsity::FormatConverter<float> converter(
weights_shape_vector, sparsity);
converter.SparseToDense(weights_data);
const std::vector<float> dense_weights_data = converter.GetData();
FullyConnected(params, input_shape, input_data, weights_shape,
dense_weights_data.data(), bias_shape, bias_data, output_shape,
output_data);
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_