[Core ML] Support FP16 in fully connected op
PiperOrigin-RevId: 315425894 Change-Id: I5340adf7b70d3e6d51a9c1edb814e6a309e99a84
This commit is contained in:
parent
47b4145e68
commit
33014a38d9
@ -51,21 +51,37 @@ CoreML::Specification::NeuralNetworkLayer* FullyConnectedOpBuilder::Build() {
|
|||||||
void FullyConnectedOpBuilder::FillCoreMLWeights() {
|
void FullyConnectedOpBuilder::FillCoreMLWeights() {
|
||||||
layer_->mutable_innerproduct()->set_inputchannels(weights_->dims->data[1]);
|
layer_->mutable_innerproduct()->set_inputchannels(weights_->dims->data[1]);
|
||||||
layer_->mutable_innerproduct()->set_outputchannels(weights_->dims->data[0]);
|
layer_->mutable_innerproduct()->set_outputchannels(weights_->dims->data[0]);
|
||||||
const float* weights_data = GetTensorData<float>(weights_);
|
if (weights_->type == kTfLiteFloat32) {
|
||||||
std::copy(weights_data, weights_data + NumElements(weights_),
|
const float* weights_data = GetTensorData<float>(weights_);
|
||||||
google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct()
|
std::copy(weights_data, weights_data + NumElements(weights_),
|
||||||
->mutable_weights()
|
google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct()
|
||||||
->mutable_floatvalue()));
|
->mutable_weights()
|
||||||
|
->mutable_floatvalue()));
|
||||||
|
} else if (weights_->type == kTfLiteFloat16) {
|
||||||
|
// float16value has type of bytes (std::string)
|
||||||
|
layer_->mutable_innerproduct()
|
||||||
|
->mutable_weights()
|
||||||
|
->mutable_float16value()
|
||||||
|
->assign(weights_->data.raw, weights_->bytes);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FullyConnectedOpBuilder::FillCoreMLBias() {
|
void FullyConnectedOpBuilder::FillCoreMLBias() {
|
||||||
if (bias_ != nullptr) {
|
if (bias_ != nullptr) {
|
||||||
layer_->mutable_innerproduct()->set_hasbias(true);
|
layer_->mutable_innerproduct()->set_hasbias(true);
|
||||||
const float* bias_data = GetTensorData<float>(bias_);
|
if (bias_->type == kTfLiteFloat32) {
|
||||||
std::copy(bias_data, bias_data + NumElements(bias_),
|
const float* bias_data = GetTensorData<float>(bias_);
|
||||||
google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct()
|
std::copy(bias_data, bias_data + NumElements(bias_),
|
||||||
->mutable_bias()
|
google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct()
|
||||||
->mutable_floatvalue()));
|
->mutable_bias()
|
||||||
|
->mutable_floatvalue()));
|
||||||
|
} else if (bias_->type == kTfLiteFloat16) {
|
||||||
|
// float16value has type of bytes (std::string)
|
||||||
|
layer_->mutable_innerproduct()
|
||||||
|
->mutable_bias()
|
||||||
|
->mutable_float16value()
|
||||||
|
->assign(bias_->data.raw, bias_->bytes);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,6 +136,10 @@ OpBuilder* CreateFullyConnectedOpBuilder(GraphBuilder* graph_builder) {
|
|||||||
return new FullyConnectedOpBuilder(graph_builder);
|
return new FullyConnectedOpBuilder(graph_builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsFloatType(TfLiteType type) {
|
||||||
|
return type == kTfLiteFloat32 || type == kTfLiteFloat16;
|
||||||
|
}
|
||||||
|
|
||||||
bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
|
bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
|
||||||
const TfLiteNode* node,
|
const TfLiteNode* node,
|
||||||
TfLiteContext* context) {
|
TfLiteContext* context) {
|
||||||
@ -136,10 +156,10 @@ bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
|
|||||||
const TfLiteTensor* input = GetInput(context, node, kInput);
|
const TfLiteTensor* input = GetInput(context, node, kInput);
|
||||||
const TfLiteTensor* weights = GetInput(context, node, kWeights);
|
const TfLiteTensor* weights = GetInput(context, node, kWeights);
|
||||||
|
|
||||||
if (input->type != kTfLiteFloat32) {
|
if (!IsFloatType(input->type)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (weights->type != kTfLiteFloat32 || !IsConstantTensor(weights)) {
|
if (!IsFloatType(weights->type) || !IsConstantTensor(weights)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Core ML 2 only supports single-batch fully connected layer, thus dimensions
|
// Core ML 2 only supports single-batch fully connected layer, thus dimensions
|
||||||
@ -150,7 +170,7 @@ bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
|
|||||||
|
|
||||||
if (node->inputs->size > 2) {
|
if (node->inputs->size > 2) {
|
||||||
const TfLiteTensor* bias = GetInput(context, node, kBias);
|
const TfLiteTensor* bias = GetInput(context, node, kBias);
|
||||||
if (bias->type != kTfLiteFloat32 || !IsConstantTensor(bias)) {
|
if (!IsFloatType(bias->type) || !IsConstantTensor(bias)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user