Merge pull request from matthias-vogt:patch-1

PiperOrigin-RevId: 338241713
Change-Id: Ia2e2d86ef732cf0f2d7e9d41daa9d4ce55a2560e
This commit is contained in:
TensorFlower Gardener 2020-10-21 05:02:36 -07:00
commit c7ddaeba9e

View File

@ -518,7 +518,7 @@
"source": [
"## Choose a `model_spec` that Represents a Model for Text Classifier\n",
"\n",
"Each `model_spec` object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf), averaging word embeddings and [BERT-Base]((https://arxiv.org/pdf/1810.04805.pdf) models.\n",
"Each `model_spec` object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf), averaging word embeddings and [BERT-Base](https://arxiv.org/pdf/1810.04805.pdf) models.\n",
"\n",
"Supported Model | Name of model_spec | Model Description\n",
"--- | --- | ---\n",
@ -548,7 +548,7 @@
"source": [
"## Load Input Data Specific to an On-device ML App\n",
"\n",
"The [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) is one of the tasks in the [GLUE](https://gluebenchmark.com/) benchmark . It contains 67,349 movie reviews for training and 872 movie reviews for validation. The dataset has two classes: positive and negative movie reviews.\n",
"The [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) is one of the tasks in the [GLUE](https://gluebenchmark.com/) benchmark. It contains 67,349 movie reviews for training and 872 movie reviews for validation. The dataset has two classes: positive and negative movie reviews.\n",
"\n",
"Download the archived version of the dataset and extract it.\n"
]
@ -669,9 +669,7 @@
"source": [
"## Evaluate the Customized Model\n",
"\n",
"Evaluate the result of the model and get the loss and accuracy of the model.\n",
"\n",
"Evaluate the loss and accuracy in the test data."
"Evaluate the model with the test data and get its loss and accuracy.",
]
},
{
@ -749,7 +747,7 @@
"id": "HZKYthlVrTos"
},
"source": [
"You can evalute the tflite model with `evaluate_tflite` method."
"You can evalute the tflite model with `evaluate_tflite` method to get its accuracy."
]
},
{
@ -760,7 +758,7 @@
},
"outputs": [],
"source": [
"model.evaluate_tflite('average_word_vec/model.tflite', test_data)"
"accuracy = model.evaluate_tflite('average_word_vec/model.tflite', test_data)"
]
},
{
@ -771,7 +769,7 @@
"source": [
"## Advanced Usage\n",
"\n",
"The `create` function is the driver function that the Model Maker library uses to create models. The `model spec` parameter defines the model specification. The `AverageWordVecModelSpec` and `BertClassifierModelSpec` classes are currently supported. The `create` function comprises of the following steps:\n",
"The `create` function is the driver function that the Model Maker library uses to create models. The `model_spec` parameter defines the model specification. The `AverageWordVecModelSpec` and `BertClassifierModelSpec` classes are currently supported. The `create` function comprises of the following steps:\n",
"\n",
"1. Creates the model for the text classifier according to `model_spec`.\n",
"2. Trains the classifier model. The default epochs and the default batch size are set by the `default_training_epochs` and `default_batch_size` variables in the `model_spec` object.\n",
@ -867,7 +865,7 @@
"The model parameters you can adjust are:\n",
"\n",
"* `seq_len`: Length of the sequence to feed into the model.\n",
"* `initializer_range`: The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n",
"* `initializer_range`: The standard deviation of the `truncated_normal_initializer` for initializing all weight matrices.\n",
"* `trainable`: Boolean that specifies whether the pre-trained layer is trainable.\n",
"\n",
"The training pipeline parameters you can adjust are:\n",