Set sparse FullyConnected op version properly.
PiperOrigin-RevId: 311586496 Change-Id: Ieb57857388bbb25de02163b9a6594dd02666b867
This commit is contained in:
parent
215616fddc
commit
f4a49c6871
@ -487,6 +487,7 @@ class FullyConnected
|
||||
op_sig.options.fully_connected.keep_num_dims = fc_op.keep_num_dims;
|
||||
op_sig.options.fully_connected.weights_format =
|
||||
GetWeightFormat(fc_op.weights_format);
|
||||
op_sig.options.fully_connected.sparse_weight = false;
|
||||
return ::tflite::GetBuiltinOperatorVersion(op_sig);
|
||||
}
|
||||
};
|
||||
|
@ -121,6 +121,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
// | Quantized Int8 | 4 | 4 |
|
||||
// +-----------------+--------------------+--------------------------+
|
||||
|
||||
// FullyConnected with sparse weight is supported at version 8.
|
||||
if (op_sig.options.fully_connected.sparse_weight) {
|
||||
return 8;
|
||||
}
|
||||
|
||||
// Int16 fully fixed point kernel is at version 7.
|
||||
if (op_sig.input_types.at(0) == TensorType_INT16 &&
|
||||
op_sig.input_types.at(1) == TensorType_INT16 &&
|
||||
@ -578,6 +583,11 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
||||
op_sig.options.fully_connected.weights_format =
|
||||
fully_connected_option->weights_format();
|
||||
}
|
||||
|
||||
const Tensor* weight_tensor =
|
||||
subgraph->tensors()->Get(op->inputs()->Get(1));
|
||||
op_sig.options.fully_connected.sparse_weight =
|
||||
(weight_tensor->sparsity() != nullptr);
|
||||
} break;
|
||||
|
||||
case BuiltinOperator_MUL: {
|
||||
|
@ -37,6 +37,9 @@ typedef struct {
|
||||
struct {
|
||||
bool keep_num_dims;
|
||||
FullyConnectedOptionsWeightsFormat weights_format;
|
||||
// TODO(b/156530611): Make this global when more ops support sparse
|
||||
// computation.
|
||||
bool sparse_weight;
|
||||
} fully_connected;
|
||||
struct {
|
||||
float input1_scale;
|
||||
|
@ -352,6 +352,15 @@ TEST(OpVersionTest, VersioningFullyConnectedTest) {
|
||||
fake_op_sig.options.fully_connected = {
|
||||
false, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_FULLY_CONNECTED,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
fake_op_sig.options.fully_connected = {
|
||||
false, FullyConnectedOptionsWeightsFormat_DEFAULT, true};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningDequantizeTest) {
|
||||
|
Loading…
Reference in New Issue
Block a user