Add Bool support in transpose op.
PiperOrigin-RevId: 258304540
This commit is contained in:
parent
c1f7f64b03
commit
b51a1b258b
@ -254,7 +254,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
/* max_version */ 2);
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
|
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 2);
|
/* max_version */ 3);
|
||||||
AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(),
|
AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 2);
|
/* max_version */ 2);
|
||||||
|
@ -132,6 +132,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_TRANSPOSE(reference_ops, int64_t);
|
TF_LITE_TRANSPOSE(reference_ops, int64_t);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteBool:
|
||||||
|
if (kernel_type == kReference) {
|
||||||
|
TF_LITE_TRANSPOSE(reference_ops, bool);
|
||||||
|
}
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
context->ReportError(context,
|
context->ReportError(context,
|
||||||
"Type %d is currently not supported by Transpose.",
|
"Type %d is currently not supported by Transpose.",
|
||||||
|
@ -90,6 +90,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
|||||||
{{OperatorType::kSpaceToDepth, 2}, "1.14.0"},
|
{{OperatorType::kSpaceToDepth, 2}, "1.14.0"},
|
||||||
{{OperatorType::kTranspose, 1}, "1.6.0"},
|
{{OperatorType::kTranspose, 1}, "1.6.0"},
|
||||||
{{OperatorType::kTranspose, 2}, "1.14.0"},
|
{{OperatorType::kTranspose, 2}, "1.14.0"},
|
||||||
|
{{OperatorType::kTranspose, 3}, kPendingReleaseOpVersion},
|
||||||
{{OperatorType::kLstmCell, 1}, "1.7.0"},
|
{{OperatorType::kLstmCell, 1}, "1.7.0"},
|
||||||
{{OperatorType::kLstmCell, 2}, "1.10.0"},
|
{{OperatorType::kLstmCell, 2}, "1.10.0"},
|
||||||
{{OperatorType::kLstmCell, 3}, "1.14.0"},
|
{{OperatorType::kLstmCell, 3}, "1.14.0"},
|
||||||
|
@ -952,7 +952,11 @@ class Transpose
|
|||||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
const string& input_name = op_signature.op->inputs[0];
|
const string& input_name = op_signature.op->inputs[0];
|
||||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||||
// If the op take int8 input, it is version 2.
|
// If the op takes bool input, it is version 3.
|
||||||
|
if (input_array.data_type == ArrayDataType::kBool) {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
// If the op takes int8 input, it is version 2.
|
||||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user