diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 36dab796a28..fe39971f303 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -23,10 +23,12 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -38,11 +40,24 @@ namespace ops { namespace builtin { namespace fully_connected { +namespace { +bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) { + if (sparsity.dim_metadata[0].format == kTfLiteDimSparseCSR && + sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) { + return true; + } + + return false; +} +} // namespace + // This file has four implementations of FullyConnected enum KernelType { kReference, kGenericOptimized, kLegacyPie, // Legacy path used by the PIE team and related clients. + kSparseReference, + kSparseOptimized, }; struct OpData { @@ -574,11 +589,41 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, FullyConnectedParams op_params; op_params.float_activation_min = output_activation_min; op_params.float_activation_max = output_activation_max; + reference_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), GetTensorData(filter), GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), GetTensorData(output)); + } else if (kernel_type == kSparseReference) { + FullyConnectedParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + TF_LITE_ENSURE(context, filter->sparsity != nullptr); + + const auto& sparsity = *filter->sparsity; + reference_ops::FullyConnectedSparseWeight( + sparsity, op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); + } else if (kernel_type == kSparseOptimized) { + FullyConnectedParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + TF_LITE_ENSURE(context, filter->sparsity != nullptr); + + const auto& sparsity = *filter->sparsity; + if (!SupportedSparsityFormat(sparsity)) { + context->ReportError(context, + "Unsupported sparse fully-connected weight format."); + return kTfLiteError; + } + optimized_ops::FullyConnectedSparseWeight( + sparsity, op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); } else if (kernel_type == kLegacyPie) { return EvalPie(context, node, params, data, input, filter, bias, output); } else { @@ -653,6 +698,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace fully_connected +// TODO(b/147449640): Clean up sparse registrations after conversion is done. +TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_REF() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, + fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_OPT() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, + fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + TfLiteRegistration* Register_FULLY_CONNECTED_REF() { static TfLiteRegistration r = { fully_connected::Init, fully_connected::Free, diff --git a/tensorflow/lite/kernels/fully_connected.h b/tensorflow/lite/kernels/fully_connected.h index a9a6cc70a80..badc9e7a91c 100644 --- a/tensorflow/lite/kernels/fully_connected.h +++ b/tensorflow/lite/kernels/fully_connected.h @@ -29,6 +29,8 @@ namespace builtin { TfLiteRegistration* Register_FULLY_CONNECTED_REF(); TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT(); TfLiteRegistration* Register_FULLY_CONNECTED_PIE(); +TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_REF(); +TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_OPT(); } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index a4b49c59efb..21273cd7a3a 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/kernels/fully_connected.h" +#include #include #include #include @@ -356,6 +357,11 @@ const auto kKernelMapNoPie = new std::map({ {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()}, }); +const auto kKernelMapSparse = new std::map({ + {"SparseReference", ops::builtin::Register_FULLY_CONNECTED_SPARSE_REF()}, + {"SparseOptimized", ops::builtin::Register_FULLY_CONNECTED_SPARSE_OPT()}, +}); + class QuantizedFullyConnectedOpTest : public SingleOpTest { protected: const std::map& GetKernelMap() override { @@ -1061,5 +1067,113 @@ TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) { } } +template +class SparseFullyConnectedOpModel : public SingleOpModel { + public: + SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units, + int batches, const TensorData& input, + std::initializer_list weights_shape, + std::initializer_list weights_data) + : batches_(batches), units_(units) { + int total_input_size = 1; + for (size_t i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = AddConstSparseInput(input.type, weights_shape, weights_data); + + TensorData bias{input.type, {units_}}; + bias_ = AddInput(bias); + + output_ = AddOutput({input.type}); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + resolver_ = absl::make_unique( + BuiltinOperator_FULLY_CONNECTED, registration); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + void SetBias(const std::vector& data) { PopulateTensor(bias_, data); } + void SetInput(const std::vector& data) { PopulateTensor(input_, data); } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +class SparseFullyConnectedOpTest : public SingleOpTest { + protected: + const std::map& GetKernelMap() override { + return *kKernelMapSparse; + } +}; + +TEST_P(SparseFullyConnectedOpTest, SimpleTest) { + std::initializer_list weight_shape = {3, 10}; + std::initializer_list weight_data = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }; + SparseFullyConnectedOpModel m( + GetRegistration(), /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}, weight_shape, weight_data); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); +} + +TEST_P(SparseFullyConnectedOpTest, SimpleTest2) { + std::initializer_list weight_shape = {1, 2}; + std::initializer_list weight_data = { + 2, 4 // u = 0 + }; + SparseFullyConnectedOpModel m( + GetRegistration(), /*units=*/1, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 2}}, weight_shape, weight_data); + m.SetBias({1}); + + m.SetInput({ + 1, 2, // b = 0 + 2, 1 // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1)); + EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9)); +} + +// TODO(b/148391360): Add tests for unsupported sparsity format. +// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat) + +INSTANTIATE_TEST_SUITE_P( + SparseFullyConnectedOpTest, SparseFullyConnectedOpTest, + ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapSparse))); + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 0c063ed6338..03bdfd1cf36 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -232,6 +232,7 @@ cc_library( "optimized/integer_ops/softmax.h", "optimized/integer_ops/transpose_conv.h", "optimized/optimized_ops.h", + "optimized/sparse_ops/fully_connected.h", ], copts = tflite_copts(), deps = [ @@ -456,6 +457,7 @@ cc_library( "reference/requantize.h", "reference/round.h", "reference/softmax.h", + "reference/sparse_ops/fully_connected.h", "reference/strided_slice.h", "reference/svdf.h", ], diff --git a/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h new file mode 100644 index 00000000000..9a05b61c006 --- /dev/null +++ b/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h @@ -0,0 +1,75 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +inline void FullyConnectedSparseWeight( + const TfLiteSparsity& sparsity, const FullyConnectedParams& params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& weights_shape, const float* weights_data, + const RuntimeShape& bias_shape, const float* bias_data, + const RuntimeShape& output_shape, float* output_data) { + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; + + const int output_elements = output_shape.FlatSize(); + const int output_dims_count = output_shape.DimensionsCount(); + const int weights_dims_count = weights_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1); + const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2, + output_shape, output_dims_count - 1); + const int accum_depth = weights_shape.Dims(weights_dims_count - 1); + const int* w0_segments = sparsity.dim_metadata[0].array_segments->data; + const int* w0_indices = sparsity.dim_metadata[0].array_indices->data; + const int* w1_segments = sparsity.dim_metadata[1].array_segments->data; + const int* w1_indices = sparsity.dim_metadata[1].array_indices->data; + + for (int i = 0; i < output_elements; ++i) { + output_data[i] = 0.f; + } + + for (int b = 0; b < batches; ++b) { + for (int pw0 = w0_segments[0]; pw0 < w0_segments[1]; ++pw0) { + int idx_0 = w0_indices[pw0]; + for (int pw1 = w1_segments[pw0]; pw1 < w1_segments[pw0 + 1]; ++pw1) { + int idx_1 = w1_indices[pw1]; + output_data[b * output_depth + idx_0] += + weights_data[pw1] * input_data[b * accum_depth + idx_1]; + } + } + } + + for (int b = 0; b < batches; ++b) { + for (int i = 0; i < output_depth; ++i) { + float total = output_data[b * output_depth + i]; + float bias_value = bias_data[i]; + output_data[b * output_depth + i] = ActivationFunctionWithMinMax( + total + bias_value, output_activation_min, output_activation_max); + } + } +} + +} // namespace optimized_ops +} // namespace tflite +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_ diff --git a/tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h b/tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h new file mode 100644 index 00000000000..0f8a248d61c --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_ + +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h" + +namespace tflite { +namespace reference_ops { + +// Convert weights to dense format and run dense fully connected. +inline void FullyConnectedSparseWeight( + const TfLiteSparsity& sparsity, const FullyConnectedParams& params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& weights_shape, const float* weights_data, + const RuntimeShape& bias_shape, const float* bias_data, + const RuntimeShape& output_shape, float* output_data) { + std::vector weights_shape_vector(weights_shape.DimensionsCount()); + for (int i = 0; i < weights_shape.DimensionsCount(); i++) { + weights_shape_vector[i] = weights_shape.Dims(i); + } + tflite::optimize::sparsity::FormatConverter converter( + weights_shape_vector, sparsity); + converter.SparseToDense(weights_data); + const std::vector dense_weights_data = converter.GetData(); + FullyConnected(params, input_shape, input_data, weights_shape, + dense_weights_data.data(), bias_shape, bias_data, output_shape, + output_data); +} + +} // namespace reference_ops +} // namespace tflite +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_