Update op_version for split, it should be checking for input[1] which is the input. Input[0] is the axis

PiperOrigin-RevId: 303214325
Change-Id: Id93c8a66d173fd6af460a0db8263268d50aa5f6b
This commit is contained in:
Karim Nosir 2020-03-26 16:09:46 -07:00 committed by TensorFlower Gardener
parent dce27acd28
commit 3009664be0

View File

@ -216,10 +216,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_SPLIT:
// If the op take int8 input, it is version 2, for int32 it's version 3.
if (op_sig.input_types.at(0) == TensorType_INT32) {
// The input tensor is at index 1 not 0, 0 is the axis.
if (op_sig.input_types.at(1) == TensorType_INT32) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
if (op_sig.input_types.at(1) == TensorType_INT8) {
return 2;
}
return 1;