diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 95c027b8775..700b321c7a9 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -254,7 +254,7 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version */ 2); AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(), /* min_version */ 1, - /* max_version */ 2); + /* max_version */ 3); AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(), /* min_version */ 1, /* max_version */ 2); diff --git a/tensorflow/lite/kernels/transpose.cc b/tensorflow/lite/kernels/transpose.cc index 0ef4972d1a8..33dee1ff838 100644 --- a/tensorflow/lite/kernels/transpose.cc +++ b/tensorflow/lite/kernels/transpose.cc @@ -132,6 +132,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_TRANSPOSE(reference_ops, int64_t); } break; + case kTfLiteBool: + if (kernel_type == kReference) { + TF_LITE_TRANSPOSE(reference_ops, bool); + } + break; default: context->ReportError(context, "Type %d is currently not supported by Transpose.", diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index 71deb5aac17..ddd5f598eec 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -90,6 +90,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kSpaceToDepth, 2}, "1.14.0"}, {{OperatorType::kTranspose, 1}, "1.6.0"}, {{OperatorType::kTranspose, 2}, "1.14.0"}, + {{OperatorType::kTranspose, 3}, kPendingReleaseOpVersion}, {{OperatorType::kLstmCell, 1}, "1.7.0"}, {{OperatorType::kLstmCell, 2}, "1.10.0"}, {{OperatorType::kLstmCell, 3}, "1.14.0"}, diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 3c65eed17a9..03313aaca91 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -952,7 +952,11 @@ class Transpose int GetVersion(const OperatorSignature& op_signature) const override { const string& input_name = op_signature.op->inputs[0]; const Array& input_array = op_signature.model->GetArray(input_name); - // If the op take int8 input, it is version 2. + // If the op takes bool input, it is version 3. + if (input_array.data_type == ArrayDataType::kBool) { + return 3; + } + // If the op takes int8 input, it is version 2. if (input_array.data_type == ArrayDataType::kInt8) { return 2; }