[Core ML] Support FP16 in fully connected op

PiperOrigin-RevId: 315425894
Change-Id: I5340adf7b70d3e6d51a9c1edb814e6a309e99a84
This commit is contained in:
Taehee Jeong 2020-06-08 23:01:36 -07:00 committed by TensorFlower Gardener
parent 47b4145e68
commit 33014a38d9

View File

@ -51,21 +51,37 @@ CoreML::Specification::NeuralNetworkLayer* FullyConnectedOpBuilder::Build() {
void FullyConnectedOpBuilder::FillCoreMLWeights() { void FullyConnectedOpBuilder::FillCoreMLWeights() {
layer_->mutable_innerproduct()->set_inputchannels(weights_->dims->data[1]); layer_->mutable_innerproduct()->set_inputchannels(weights_->dims->data[1]);
layer_->mutable_innerproduct()->set_outputchannels(weights_->dims->data[0]); layer_->mutable_innerproduct()->set_outputchannels(weights_->dims->data[0]);
const float* weights_data = GetTensorData<float>(weights_); if (weights_->type == kTfLiteFloat32) {
std::copy(weights_data, weights_data + NumElements(weights_), const float* weights_data = GetTensorData<float>(weights_);
google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct() std::copy(weights_data, weights_data + NumElements(weights_),
->mutable_weights() google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct()
->mutable_floatvalue())); ->mutable_weights()
->mutable_floatvalue()));
} else if (weights_->type == kTfLiteFloat16) {
// float16value has type of bytes (std::string)
layer_->mutable_innerproduct()
->mutable_weights()
->mutable_float16value()
->assign(weights_->data.raw, weights_->bytes);
}
} }
void FullyConnectedOpBuilder::FillCoreMLBias() { void FullyConnectedOpBuilder::FillCoreMLBias() {
if (bias_ != nullptr) { if (bias_ != nullptr) {
layer_->mutable_innerproduct()->set_hasbias(true); layer_->mutable_innerproduct()->set_hasbias(true);
const float* bias_data = GetTensorData<float>(bias_); if (bias_->type == kTfLiteFloat32) {
std::copy(bias_data, bias_data + NumElements(bias_), const float* bias_data = GetTensorData<float>(bias_);
google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct() std::copy(bias_data, bias_data + NumElements(bias_),
->mutable_bias() google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct()
->mutable_floatvalue())); ->mutable_bias()
->mutable_floatvalue()));
} else if (bias_->type == kTfLiteFloat16) {
// float16value has type of bytes (std::string)
layer_->mutable_innerproduct()
->mutable_bias()
->mutable_float16value()
->assign(bias_->data.raw, bias_->bytes);
}
} }
} }
@ -120,6 +136,10 @@ OpBuilder* CreateFullyConnectedOpBuilder(GraphBuilder* graph_builder) {
return new FullyConnectedOpBuilder(graph_builder); return new FullyConnectedOpBuilder(graph_builder);
} }
bool IsFloatType(TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteFloat16;
}
bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration, bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
const TfLiteNode* node, const TfLiteNode* node,
TfLiteContext* context) { TfLiteContext* context) {
@ -136,10 +156,10 @@ bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
const TfLiteTensor* input = GetInput(context, node, kInput); const TfLiteTensor* input = GetInput(context, node, kInput);
const TfLiteTensor* weights = GetInput(context, node, kWeights); const TfLiteTensor* weights = GetInput(context, node, kWeights);
if (input->type != kTfLiteFloat32) { if (!IsFloatType(input->type)) {
return false; return false;
} }
if (weights->type != kTfLiteFloat32 || !IsConstantTensor(weights)) { if (!IsFloatType(weights->type) || !IsConstantTensor(weights)) {
return false; return false;
} }
// Core ML 2 only supports single-batch fully connected layer, thus dimensions // Core ML 2 only supports single-batch fully connected layer, thus dimensions
@ -150,7 +170,7 @@ bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
if (node->inputs->size > 2) { if (node->inputs->size > 2) {
const TfLiteTensor* bias = GetInput(context, node, kBias); const TfLiteTensor* bias = GetInput(context, node, kBias);
if (bias->type != kTfLiteFloat32 || !IsConstantTensor(bias)) { if (!IsFloatType(bias->type) || !IsConstantTensor(bias)) {
return false; return false;
} }
} }