From f4a49c6871a36444a0a28e9b127ab052efc1f9ca Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Thu, 14 May 2020 12:40:01 -0700 Subject: [PATCH] Set sparse FullyConnected op version properly. PiperOrigin-RevId: 311586496 Change-Id: Ieb57857388bbb25de02163b9a6594dd02666b867 --- tensorflow/lite/toco/tflite/operator.cc | 1 + tensorflow/lite/tools/versioning/op_version.cc | 10 ++++++++++ tensorflow/lite/tools/versioning/op_version.h | 3 +++ tensorflow/lite/tools/versioning/op_version_test.cc | 9 +++++++++ 4 files changed, 23 insertions(+) diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 917fd24c952..fee10a19787 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -487,6 +487,7 @@ class FullyConnected op_sig.options.fully_connected.keep_num_dims = fc_op.keep_num_dims; op_sig.options.fully_connected.weights_format = GetWeightFormat(fc_op.weights_format); + op_sig.options.fully_connected.sparse_weight = false; return ::tflite::GetBuiltinOperatorVersion(op_sig); } }; diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 9022afca629..118e2d420f8 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -121,6 +121,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { // | Quantized Int8 | 4 | 4 | // +-----------------+--------------------+--------------------------+ + // FullyConnected with sparse weight is supported at version 8. + if (op_sig.options.fully_connected.sparse_weight) { + return 8; + } + // Int16 fully fixed point kernel is at version 7. if (op_sig.input_types.at(0) == TensorType_INT16 && op_sig.input_types.at(1) == TensorType_INT16 && @@ -578,6 +583,11 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, op_sig.options.fully_connected.weights_format = fully_connected_option->weights_format(); } + + const Tensor* weight_tensor = + subgraph->tensors()->Get(op->inputs()->Get(1)); + op_sig.options.fully_connected.sparse_weight = + (weight_tensor->sparsity() != nullptr); } break; case BuiltinOperator_MUL: { diff --git a/tensorflow/lite/tools/versioning/op_version.h b/tensorflow/lite/tools/versioning/op_version.h index 4b0fe8836e2..df74ffaf6dd 100644 --- a/tensorflow/lite/tools/versioning/op_version.h +++ b/tensorflow/lite/tools/versioning/op_version.h @@ -37,6 +37,9 @@ typedef struct { struct { bool keep_num_dims; FullyConnectedOptionsWeightsFormat weights_format; + // TODO(b/156530611): Make this global when more ops support sparse + // computation. + bool sparse_weight; } fully_connected; struct { float input1_scale; diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index f0d8259d764..4017fc3bff0 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -352,6 +352,15 @@ TEST(OpVersionTest, VersioningFullyConnectedTest) { fake_op_sig.options.fully_connected = { false, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8}; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); + + fake_op_sig = { + .op = BuiltinOperator_FULLY_CONNECTED, + .input_types = std::vector{TensorType_INT8, TensorType_INT8}, + .output_types = std::vector{TensorType_INT8}, + }; + fake_op_sig.options.fully_connected = { + false, FullyConnectedOptionsWeightsFormat_DEFAULT, true}; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8); } TEST(OpVersionTest, VersioningDequantizeTest) {