Merge pull request #44089 from matthias-vogt:patch-1
PiperOrigin-RevId: 338241713 Change-Id: Ia2e2d86ef732cf0f2d7e9d41daa9d4ce55a2560e
This commit is contained in:
commit
c7ddaeba9e
@ -518,7 +518,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Choose a `model_spec` that Represents a Model for Text Classifier\n",
|
"## Choose a `model_spec` that Represents a Model for Text Classifier\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Each `model_spec` object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf), averaging word embeddings and [BERT-Base]((https://arxiv.org/pdf/1810.04805.pdf) models.\n",
|
"Each `model_spec` object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf), averaging word embeddings and [BERT-Base](https://arxiv.org/pdf/1810.04805.pdf) models.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Supported Model | Name of model_spec | Model Description\n",
|
"Supported Model | Name of model_spec | Model Description\n",
|
||||||
"--- | --- | ---\n",
|
"--- | --- | ---\n",
|
||||||
@ -548,7 +548,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Load Input Data Specific to an On-device ML App\n",
|
"## Load Input Data Specific to an On-device ML App\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) is one of the tasks in the [GLUE](https://gluebenchmark.com/) benchmark . It contains 67,349 movie reviews for training and 872 movie reviews for validation. The dataset has two classes: positive and negative movie reviews.\n",
|
"The [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) is one of the tasks in the [GLUE](https://gluebenchmark.com/) benchmark. It contains 67,349 movie reviews for training and 872 movie reviews for validation. The dataset has two classes: positive and negative movie reviews.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Download the archived version of the dataset and extract it.\n"
|
"Download the archived version of the dataset and extract it.\n"
|
||||||
]
|
]
|
||||||
@ -669,9 +669,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Evaluate the Customized Model\n",
|
"## Evaluate the Customized Model\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Evaluate the result of the model and get the loss and accuracy of the model.\n",
|
"Evaluate the model with the test data and get its loss and accuracy.",
|
||||||
"\n",
|
|
||||||
"Evaluate the loss and accuracy in the test data."
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -749,7 +747,7 @@
|
|||||||
"id": "HZKYthlVrTos"
|
"id": "HZKYthlVrTos"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"You can evalute the tflite model with `evaluate_tflite` method."
|
"You can evalute the tflite model with `evaluate_tflite` method to get its accuracy."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -760,7 +758,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model.evaluate_tflite('average_word_vec/model.tflite', test_data)"
|
"accuracy = model.evaluate_tflite('average_word_vec/model.tflite', test_data)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -771,7 +769,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Advanced Usage\n",
|
"## Advanced Usage\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The `create` function is the driver function that the Model Maker library uses to create models. The `model spec` parameter defines the model specification. The `AverageWordVecModelSpec` and `BertClassifierModelSpec` classes are currently supported. The `create` function comprises of the following steps:\n",
|
"The `create` function is the driver function that the Model Maker library uses to create models. The `model_spec` parameter defines the model specification. The `AverageWordVecModelSpec` and `BertClassifierModelSpec` classes are currently supported. The `create` function comprises of the following steps:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"1. Creates the model for the text classifier according to `model_spec`.\n",
|
"1. Creates the model for the text classifier according to `model_spec`.\n",
|
||||||
"2. Trains the classifier model. The default epochs and the default batch size are set by the `default_training_epochs` and `default_batch_size` variables in the `model_spec` object.\n",
|
"2. Trains the classifier model. The default epochs and the default batch size are set by the `default_training_epochs` and `default_batch_size` variables in the `model_spec` object.\n",
|
||||||
@ -867,7 +865,7 @@
|
|||||||
"The model parameters you can adjust are:\n",
|
"The model parameters you can adjust are:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `seq_len`: Length of the sequence to feed into the model.\n",
|
"* `seq_len`: Length of the sequence to feed into the model.\n",
|
||||||
"* `initializer_range`: The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n",
|
"* `initializer_range`: The standard deviation of the `truncated_normal_initializer` for initializing all weight matrices.\n",
|
||||||
"* `trainable`: Boolean that specifies whether the pre-trained layer is trainable.\n",
|
"* `trainable`: Boolean that specifies whether the pre-trained layer is trainable.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The training pipeline parameters you can adjust are:\n",
|
"The training pipeline parameters you can adjust are:\n",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user