diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/fully_connected_op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/fully_connected_op_builder.cc index 2efc767d703..376830922c9 100644 --- a/tensorflow/lite/experimental/delegates/coreml/builders/fully_connected_op_builder.cc +++ b/tensorflow/lite/experimental/delegates/coreml/builders/fully_connected_op_builder.cc @@ -51,21 +51,37 @@ CoreML::Specification::NeuralNetworkLayer* FullyConnectedOpBuilder::Build() { void FullyConnectedOpBuilder::FillCoreMLWeights() { layer_->mutable_innerproduct()->set_inputchannels(weights_->dims->data[1]); layer_->mutable_innerproduct()->set_outputchannels(weights_->dims->data[0]); - const float* weights_data = GetTensorData(weights_); - std::copy(weights_data, weights_data + NumElements(weights_), - google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct() - ->mutable_weights() - ->mutable_floatvalue())); + if (weights_->type == kTfLiteFloat32) { + const float* weights_data = GetTensorData(weights_); + std::copy(weights_data, weights_data + NumElements(weights_), + google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct() + ->mutable_weights() + ->mutable_floatvalue())); + } else if (weights_->type == kTfLiteFloat16) { + // float16value has type of bytes (std::string) + layer_->mutable_innerproduct() + ->mutable_weights() + ->mutable_float16value() + ->assign(weights_->data.raw, weights_->bytes); + } } void FullyConnectedOpBuilder::FillCoreMLBias() { if (bias_ != nullptr) { layer_->mutable_innerproduct()->set_hasbias(true); - const float* bias_data = GetTensorData(bias_); - std::copy(bias_data, bias_data + NumElements(bias_), - google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct() - ->mutable_bias() - ->mutable_floatvalue())); + if (bias_->type == kTfLiteFloat32) { + const float* bias_data = GetTensorData(bias_); + std::copy(bias_data, bias_data + NumElements(bias_), + google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct() + ->mutable_bias() + ->mutable_floatvalue())); + } else if (bias_->type == kTfLiteFloat16) { + // float16value has type of bytes (std::string) + layer_->mutable_innerproduct() + ->mutable_bias() + ->mutable_float16value() + ->assign(bias_->data.raw, bias_->bytes); + } } } @@ -120,6 +136,10 @@ OpBuilder* CreateFullyConnectedOpBuilder(GraphBuilder* graph_builder) { return new FullyConnectedOpBuilder(graph_builder); } +bool IsFloatType(TfLiteType type) { + return type == kTfLiteFloat32 || type == kTfLiteFloat16; +} + bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration, const TfLiteNode* node, TfLiteContext* context) { @@ -136,10 +156,10 @@ bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration, const TfLiteTensor* input = GetInput(context, node, kInput); const TfLiteTensor* weights = GetInput(context, node, kWeights); - if (input->type != kTfLiteFloat32) { + if (!IsFloatType(input->type)) { return false; } - if (weights->type != kTfLiteFloat32 || !IsConstantTensor(weights)) { + if (!IsFloatType(weights->type) || !IsConstantTensor(weights)) { return false; } // Core ML 2 only supports single-batch fully connected layer, thus dimensions @@ -150,7 +170,7 @@ bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration, if (node->inputs->size > 2) { const TfLiteTensor* bias = GetInput(context, node, kBias); - if (bias->type != kTfLiteFloat32 || !IsConstantTensor(bias)) { + if (!IsFloatType(bias->type) || !IsConstantTensor(bias)) { return false; } }