Set sparse FullyConnected op version properly.

PiperOrigin-RevId: 311586496
Change-Id: Ieb57857388bbb25de02163b9a6594dd02666b867
This commit is contained in:
Yunlu Li 2020-05-14 12:40:01 -07:00 committed by TensorFlower Gardener
parent 215616fddc
commit f4a49c6871
4 changed files with 23 additions and 0 deletions

View File

@ -487,6 +487,7 @@ class FullyConnected
op_sig.options.fully_connected.keep_num_dims = fc_op.keep_num_dims;
op_sig.options.fully_connected.weights_format =
GetWeightFormat(fc_op.weights_format);
op_sig.options.fully_connected.sparse_weight = false;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};

View File

@ -121,6 +121,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
// | Quantized Int8 | 4 | 4 |
// +-----------------+--------------------+--------------------------+
// FullyConnected with sparse weight is supported at version 8.
if (op_sig.options.fully_connected.sparse_weight) {
return 8;
}
// Int16 fully fixed point kernel is at version 7.
if (op_sig.input_types.at(0) == TensorType_INT16 &&
op_sig.input_types.at(1) == TensorType_INT16 &&
@ -578,6 +583,11 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
op_sig.options.fully_connected.weights_format =
fully_connected_option->weights_format();
}
const Tensor* weight_tensor =
subgraph->tensors()->Get(op->inputs()->Get(1));
op_sig.options.fully_connected.sparse_weight =
(weight_tensor->sparsity() != nullptr);
} break;
case BuiltinOperator_MUL: {

View File

@ -37,6 +37,9 @@ typedef struct {
struct {
bool keep_num_dims;
FullyConnectedOptionsWeightsFormat weights_format;
// TODO(b/156530611): Make this global when more ops support sparse
// computation.
bool sparse_weight;
} fully_connected;
struct {
float input1_scale;

View File

@ -352,6 +352,15 @@ TEST(OpVersionTest, VersioningFullyConnectedTest) {
fake_op_sig.options.fully_connected = {
false, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
fake_op_sig = {
.op = BuiltinOperator_FULLY_CONNECTED,
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_INT8},
};
fake_op_sig.options.fully_connected = {
false, FullyConnectedOptionsWeightsFormat_DEFAULT, true};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8);
}
TEST(OpVersionTest, VersioningDequantizeTest) {