Add Bool support in transpose op.

PiperOrigin-RevId: 258304540
This commit is contained in:
Haoliang Zhang 2019-07-15 22:57:43 -07:00 committed by TensorFlower Gardener
parent c1f7f64b03
commit b51a1b258b
4 changed files with 12 additions and 2 deletions

View File

@ -254,7 +254,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version */ 2);
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
/* min_version */ 1,
/* max_version */ 2);
/* max_version */ 3);
AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(),
/* min_version */ 1,
/* max_version */ 2);

View File

@ -132,6 +132,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_TRANSPOSE(reference_ops, int64_t);
}
break;
case kTfLiteBool:
if (kernel_type == kReference) {
TF_LITE_TRANSPOSE(reference_ops, bool);
}
break;
default:
context->ReportError(context,
"Type %d is currently not supported by Transpose.",

View File

@ -90,6 +90,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kSpaceToDepth, 2}, "1.14.0"},
{{OperatorType::kTranspose, 1}, "1.6.0"},
{{OperatorType::kTranspose, 2}, "1.14.0"},
{{OperatorType::kTranspose, 3}, kPendingReleaseOpVersion},
{{OperatorType::kLstmCell, 1}, "1.7.0"},
{{OperatorType::kLstmCell, 2}, "1.10.0"},
{{OperatorType::kLstmCell, 3}, "1.14.0"},

View File

@ -952,7 +952,11 @@ class Transpose
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// If the op take int8 input, it is version 2.
// If the op takes bool input, it is version 3.
if (input_array.data_type == ArrayDataType::kBool) {
return 3;
}
// If the op takes int8 input, it is version 2.
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}