Merge pull request #44089 from matthias-vogt:patch-1
PiperOrigin-RevId: 338241713 Change-Id: Ia2e2d86ef732cf0f2d7e9d41daa9d4ce55a2560e
This commit is contained in:
commit
c7ddaeba9e
@ -518,7 +518,7 @@
|
||||
"source": [
|
||||
"## Choose a `model_spec` that Represents a Model for Text Classifier\n",
|
||||
"\n",
|
||||
"Each `model_spec` object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf), averaging word embeddings and [BERT-Base]((https://arxiv.org/pdf/1810.04805.pdf) models.\n",
|
||||
"Each `model_spec` object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf), averaging word embeddings and [BERT-Base](https://arxiv.org/pdf/1810.04805.pdf) models.\n",
|
||||
"\n",
|
||||
"Supported Model | Name of model_spec | Model Description\n",
|
||||
"--- | --- | ---\n",
|
||||
@ -548,7 +548,7 @@
|
||||
"source": [
|
||||
"## Load Input Data Specific to an On-device ML App\n",
|
||||
"\n",
|
||||
"The [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) is one of the tasks in the [GLUE](https://gluebenchmark.com/) benchmark . It contains 67,349 movie reviews for training and 872 movie reviews for validation. The dataset has two classes: positive and negative movie reviews.\n",
|
||||
"The [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) is one of the tasks in the [GLUE](https://gluebenchmark.com/) benchmark. It contains 67,349 movie reviews for training and 872 movie reviews for validation. The dataset has two classes: positive and negative movie reviews.\n",
|
||||
"\n",
|
||||
"Download the archived version of the dataset and extract it.\n"
|
||||
]
|
||||
@ -669,9 +669,7 @@
|
||||
"source": [
|
||||
"## Evaluate the Customized Model\n",
|
||||
"\n",
|
||||
"Evaluate the result of the model and get the loss and accuracy of the model.\n",
|
||||
"\n",
|
||||
"Evaluate the loss and accuracy in the test data."
|
||||
"Evaluate the model with the test data and get its loss and accuracy.",
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -749,7 +747,7 @@
|
||||
"id": "HZKYthlVrTos"
|
||||
},
|
||||
"source": [
|
||||
"You can evalute the tflite model with `evaluate_tflite` method."
|
||||
"You can evalute the tflite model with `evaluate_tflite` method to get its accuracy."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -760,7 +758,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.evaluate_tflite('average_word_vec/model.tflite', test_data)"
|
||||
"accuracy = model.evaluate_tflite('average_word_vec/model.tflite', test_data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -771,7 +769,7 @@
|
||||
"source": [
|
||||
"## Advanced Usage\n",
|
||||
"\n",
|
||||
"The `create` function is the driver function that the Model Maker library uses to create models. The `model spec` parameter defines the model specification. The `AverageWordVecModelSpec` and `BertClassifierModelSpec` classes are currently supported. The `create` function comprises of the following steps:\n",
|
||||
"The `create` function is the driver function that the Model Maker library uses to create models. The `model_spec` parameter defines the model specification. The `AverageWordVecModelSpec` and `BertClassifierModelSpec` classes are currently supported. The `create` function comprises of the following steps:\n",
|
||||
"\n",
|
||||
"1. Creates the model for the text classifier according to `model_spec`.\n",
|
||||
"2. Trains the classifier model. The default epochs and the default batch size are set by the `default_training_epochs` and `default_batch_size` variables in the `model_spec` object.\n",
|
||||
@ -867,7 +865,7 @@
|
||||
"The model parameters you can adjust are:\n",
|
||||
"\n",
|
||||
"* `seq_len`: Length of the sequence to feed into the model.\n",
|
||||
"* `initializer_range`: The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n",
|
||||
"* `initializer_range`: The standard deviation of the `truncated_normal_initializer` for initializing all weight matrices.\n",
|
||||
"* `trainable`: Boolean that specifies whether the pre-trained layer is trainable.\n",
|
||||
"\n",
|
||||
"The training pipeline parameters you can adjust are:\n",
|
||||
|
Loading…
Reference in New Issue
Block a user