Merged commit includes the following changes:

236041198  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Doc fix.

--
236039839  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Remove Kmeans from V2 endpoint.

--
236039121  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Internal change.

--
236036378  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Run all the save/summary ops on all workers.

--
236036371  by A. Unique TensorFlower<gardener@tensorflow.org>:

    PR #25930: Typo error fixed in resolve_multiply_by_zero.cc

    Please approve this CL. It will be submitted automatically, and its GitHub pull request will be marked as merged.

    Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/25930

    Copybara import of the project:

      - 3c955f9d499d60ce76a84265011d60928dcadc9d Typo error fixed in resolve_multiply_by_zero.cc by Albin Joy <albin.joy@huawei.com>
      - d98e6be685f04aa22ad00fcc01e59e684f3eb3ad Merge 3c955f9d499d60ce76a84265011d60928dcadc9d into aa250... by Albin Joy <albin.joy@huawei.com>

--
236035773  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fix for dynamic model training.

--
236035592  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automatednot check the tests carefully :(

Also found since semanticlift.app_android is not a valid tap project any more, the tests configure will fail and cause the tap presubmit not executed.

Automated g4 rollback of changelist 235488332.

236035568  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Updated models documentation

--

PiperOrigin-RevId: 236041198
This commit is contained in:
A. Unique TensorFlower 2019-02-27 19:38:16 -08:00 committed by TensorFlower Gardener
parent 3e21fe5fae
commit f9a5fdc6d8
26 changed files with 1220 additions and 683 deletions

View File

@ -79,40 +79,36 @@ upper_tabs:
- title: Optimizing for mobile
path: /lite/tfmobile/optimizing
# - name: Models
# contents:
# - title: Overview
# path: /lite/models/
# - title: Hosted models
# path: /lite/models/hosted
# - title: Image classification
# section:
# - title: Overview
# path: /lite/models/image_classification/overview
# - title: Android
# path: /lite/models/image_classification/android
# - title: iOS
# path: /lite/models/image_classification/ios
# - title: Object detection
# section:
# - title: Overview
# path: /lite/models/object_detection/overview
# - title: Speech recognition
# section:
# - title: Overview
# path: /lite/models/speech_recognition/overview
# - title: Pose estimation
# section:
# - title: Overview
# path: /lite/models/pose_estimation/overview
# - title: Segmentation
# section:
# - title: Overview
# path: /lite/models/segmentation/overview
# - title: Smart reply
# section:
# - title: Overview
# path: /lite/models/smart_reply/overview
- name: Models
contents:
- title: Overview
path: /lite/models/
- title: Hosted models
path: /lite/models/hosted
- title: Image classification
section:
- title: Overview
path: /lite/models/image_classification/overview
- title: Android
path: /lite/models/image_classification/android
- title: iOS
path: /lite/models/image_classification/ios
- title: Object detection
section:
- title: Overview
path: /lite/models/object_detection/overview
- title: Pose estimation
section:
- title: Overview
path: /lite/models/pose_estimation/overview
- title: Segmentation
section:
- title: Overview
path: /lite/models/segmentation/overview
- title: Smart reply
section:
- title: Overview
path: /lite/models/smart_reply/overview
- name: API
skip_translation: true

View File

@ -1,63 +1,27 @@
# Hosted models
# AutoML mobile image classification models (Float Models)
The following is an incomplete list of pre-trained models optimized to work with
TensorFlow Lite.
Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^
------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ---------------------:
MnasNet_0.50_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.5_224_09_07_2018.tgz) | 8.5 Mb | 68.03% | 87.79% | 37 ms
MnasNet_0.75_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.75_224_09_07_2018.tgz) | 12 Mb | 71.72% | 90.17% | 61 ms
MnasNet_1.0_96| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_96_09_07_2018.tgz) | 17 Mb | 62.33% | 83.98% | 23 ms
MnasNet_1.0_128| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_128_09_07_2018.tgz) | 17 Mb | 67.32% | 87.70% | 34 ms
MnasNet_1.0_160| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_160_09_07_2018.tgz) | 17 Mb | 70.63% | 89.58% | 51 ms
MnasNet_1.0_192| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_192_09_07_2018.tgz) | 17 Mb | 72.56% | 90.76% | 70 ms
MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms
MnasNet_1.3_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz) | 24 Mb | 75.24% | 92.55% | 152 ms
To get started choosing a model, visit <a href="./">Models</a>.
Note: The best model for a given application depends on your requirements. For
example, some applications might benefit from higher accuracy, while others
require a small model size. You should test your application with a variety of
models to find the optimal balance between size, performance, and accuracy.
^ Performance numbers are generated on Pixel-1 using single thread large BIG core.
## Image classification
For more information about image classification, see
<a href="image_classification/overview.md">Image classification</a>.
## Image classification (Float Models)
### Quantized models
Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance
--------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------:
DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms
SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms
NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 73.9% | 91.5% | 261 ms | 389 ms
NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.6% | 96.1% | 6697 ms | 7940 ms
ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 76.8% | 93.6% | 1880 ms | 1970 ms
Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 77.9% | 93.8% | 1433 ms | 1522 ms
Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.1% | 95.1% | 2986 ms | 3139 ms
Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.5% | 94.0% | 2731 ms | 2926 ms
Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.4% | 66.2% | 6.2 ms | 13.0 ms
Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.4% | 70.2% | 8.6 ms | 19.5 ms
Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.1% | 72.0% | 12.1 ms | 27.8 ms
Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.7% | 74.1% | 16.2 ms | 37.3 ms
Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.2% | 79.3% | 18.1 ms | 29.9 ms
Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.0% | 81.8% | 26.8 ms | 45.9 ms
Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.5% | 35.6 ms | 65.3 ms
Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.2% | 84.9% | 47.6 ms | 164.2 ms
Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.0% | 83.8% | 34.6 ms | 48.7 ms
Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.2% | 85.9% | 51.3 ms | 75.2 ms
Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.1% | 87.2% | 71.7 ms | 107.0 ms
Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.3% | 88.1% | 95.7 ms | 143.4 ms
Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.7% | 57.4 ms | 76.8 ms
Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms
Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 69.9% | 89.1% | 118.6 ms | 167.3 ms
Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.0% | 89.9% | 160.1 ms | 224.3 ms
Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.8% | 90.6% | 117 ms |
<a href="../performance/post_training_quantization.md">Quantized</a> image
classification models offer the smallest model size and fastest performance, at
the expense of accuracy.
^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph.
^^ The performance numbers are generated in the benchmark on Pixel-2 using
single thread large core.
^^ Accuracy numbers were computed using the
[TFLite accuracy tool](../tools/accuracy/ilsvrc) .
## Image classification (Quantized Models)
Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance
Model name | Paper and model | Model size | Top-1 accuracy | Top-5 accuracy | TF Lite performance
--------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms
Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 42.8% | 68.1% | 5.5 ms
@ -81,9 +45,104 @@ Inception_V2_quant | [paper](https://arxiv.org/abs/1512.00567), [tflite
Inception_V3_quant | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz) | 23 Mb | 77.5% | 93.7% | 637 ms
Inception_V4_quant | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](http://download.tensorflow.org/models/inception_v4_299_quant_20181026.tgz) | 41 Mb | 79.5% | 93.9% | 1250.8 ms
## Other models
Note: The model files include both TF Lite FlatBuffer and Tensorflow frozen
Graph.
Model | TF Lite FlatBuffer
----------------------- | :----------------:
[reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html),
[tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip)
Note: Performance numbers were benchmarked on Pixel-2 using single thread large
core. Accuracy numbers were computed using the
[TFLite accuracy tool](../tools/accuracy/ilsvrc.md).
### Floating point models
Floating point models offer the best accuracy, at the expense of model size and
performance. <a href="../performance/gpu.md">GPU acceleration</a> requires the
use of floating point models.
Model name | Paper and model | Model size | Top-1 accuracy | Top-5 accuracy | TF Lite performance | Tensorflow performance
--------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: | ---------------------:
DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms
SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms
NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 73.9% | 91.5% | 261 ms | 389 ms
NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.6% | 96.1% | 6697 ms | 7940 ms
ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 76.8% | 93.6% | 1880 ms | 1970 ms
Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 77.9% | 93.8% | 1433 ms | 1522 ms
Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.1% | 95.1% | 2986 ms | 3139 ms
Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.5% | 94.0% | 2731 ms | 2926 ms
Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.4% | 66.2% | 6.2 ms | 13.0 ms
Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.4% | 70.2% | 8.6 ms | 19.5 ms
Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.1% | 72.0% | 12.1 ms | 27.8 ms
Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.7% | 74.1% | 16.2 ms | 37.3 ms
Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.2% | 79.3% | 18.1 ms | 29.9 ms
Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.0% | 81.8% | 26.8 ms | 45.9 ms
Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.5% | 35.6 ms | 65.3 ms
Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.2% | 84.9% | 47.6 ms | 164.2 ms
Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.0% | 83.8% | 34.6 ms | 48.7 ms
Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.2% | 85.9% | 51.3 ms | 75.2 ms
Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.1% | 87.2% | 71.7 ms | 107.0 ms
Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.3% | 88.1% | 95.7 ms | 143.4 ms
Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.7% | 57.4 ms | 76.8 ms
Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms
Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 69.9% | 89.1% | 118.6 ms | 167.3 ms
Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.0% | 89.9% | 160.1 ms | 224.3 ms
Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.8% | 90.6% | 117 ms |
### AutoML mobile models
The following image classification models were created using
<a href="https://cloud.google.com/automl/">Cloud AutoML</a>.
Model Name | Paper and model | Model size | Top-1 accuracy | Top-5 accuracy | TF Lite performance
---------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
MnasNet_0.50_224 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.5_224_09_07_2018.tgz) | 8.5 Mb | 68.03% | 87.79% | 37 ms
MnasNet_0.75_224 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.75_224_09_07_2018.tgz) | 12 Mb | 71.72% | 90.17% | 61 ms
MnasNet_1.0_96 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_96_09_07_2018.tgz) | 17 Mb | 62.33% | 83.98% | 23 ms
MnasNet_1.0_128 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_128_09_07_2018.tgz) | 17 Mb | 67.32% | 87.70% | 34 ms
MnasNet_1.0_160 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_160_09_07_2018.tgz) | 17 Mb | 70.63% | 89.58% | 51 ms
MnasNet_1.0_192 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_192_09_07_2018.tgz) | 17 Mb | 72.56% | 90.76% | 70 ms
MnasNet_1.0_224 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms
MnasNet_1.3_224 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz) | 24 Mb | 75.24% | 92.55% | 152 ms
Note: Performance numbers were benchmarked on Pixel-1 using single thread large
BIG core.
## Object detection
For more information about object detection, see
<a href="object_detection/overview.md">Object detection</a>.
The object detection model we currently host is
**coco_ssd_mobilenet_v1_1.0_quant_2018_06_29**.
<a class="button button-primary" href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download
model and labels</a>
## Pose estimation
For more information about pose estimation, see
<a href="pose_estimation/overview.md">Pose estimation</a>.
The pose estimation model we currently host is
**multi_person_mobilenet_v1_075_float**.
<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/multi_person_mobilenet_v1_075_float.tflite">Download
model</a>
## Image segmentation
For more information about image segmentation, see
<a href="segmentation/overview.md">Segmentation</a>.
The image segmentation model we currently host is **deeplabv3_257_mv_gpu**.
<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/deeplabv3_257_mv_gpu.tflite">Download
model</a>
## Smart reply
For more information about smart reply, see
<a href="smart_reply/overview.md">Smart reply</a>.
The smart reply model we currently host is **smartreply_1.0_2017_11_01**.
<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip">Download
model</a>

Binary file not shown.

After

Width:  |  Height:  |  Size: 717 KiB

View File

@ -1,34 +1,74 @@
# Image classification
<img src="../images/image.png" class="attempt-right">
Use a pre-trained and optimized model to identify hundreds of classes of objects, including people, activities, animals, plants, and places.
Use a pre-trained and optimized model to identify hundreds of classes of
objects, including people, activities, animals, plants, and places.
## Get started
If you are unfamiliar with the concept of image classification, you should start by reading <a href="#what_is_image_classification">What is image classification?</a>
If you are unfamiliar with the concept of image classification, you should start
by reading <a href="#what_is_image_classification">What is image
classification?</a>
If you understand image classification, youre new to TensorFlow Lite, and youre working with Android or iOS, we recommend following the corresponding tutorial that will walk you through our sample code.
If you understand image classification, youre new to TensorFlow Lite, and
youre working with Android or iOS, we recommend following the corresponding
tutorial that will walk you through our sample code.
<a class="button button-primary" href="android">Android</a>
<a class="button button-primary" href="ios">iOS</a>
If you are using a platform other than Android or iOS, or you are already familiar with the <a href="https://www.tensorflow.org/lite/apis">TensorFlow Lite APIs</a>, you can download our starter image classification model and the accompanying labels.
We also provide <a href="example_applications">example applications</a> you can
use to get started.
Once you have the starter model running on your target device, you can experiment with different models to find the optimal balance between performance, accuracy, and model size. For guidance, see Choose a different model.
If you are using a platform other than Android or iOS, or you are already
familiar with the <a href="../../apis">TensorFlow Lite APIs</a>, you can
download our starter image classification model and the accompanying labels.
<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip">Download
starter model and labels</a>
If you are using a platform other than Android or iOS, or you are already familiar with the <a href="../apis">TensorFlow Lite APIs</a>, you can download our starter image classification model and the accompanying labels.
Once you have the starter model running on your target device, you can
experiment with different models to find the optimal balance between
performance, accuracy, and model size. For guidance, see
<a href="#choose_a_different_model">Choose a different model</a>.
<a class="button button-primary" href="">Download starter model and labels</a>
If you are using a platform other than Android or iOS, or you are already
familiar with the <a href="../../apis.md">TensorFlow Lite APIs</a>, you can
download our starter image classification model and the accompanying labels.
<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip">Download
starter model and labels</a>
### Example applications
We have example applications for image classification for both Android and iOS.
<a class="button button-primary" href="https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android">Android
example</a>
<a class="button button-primary" href="https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/ios">iOS
example</a>
The following screenshot shows the Android image classification example:
<img src="images/android_banana.png" alt="Screenshot of Android example" width="30%">
## What is image classification?
A common use of machine learning is to identify what an image represents. For example, we might want to know what type of animal appears in the following photograph.
A common use of machine learning is to identify what an image represents. For
example, we might want to know what type of animal appears in the following
photograph.
<img src="images/dog.png" alt="dog" width="50%">
The task of predicting what an image represents is called image classification. An image classification model is trained to recognize various classes of images. For example, a model might be trained to recognize photos representing three different types of animals: rabbits, hamsters, and dogs.
The task of predicting what an image represents is called _image
classification_. An image classification model is trained to recognize various
classes of images. For example, a model might be trained to recognize photos
representing three different types of animals: rabbits, hamsters, and dogs.
When we subsequently provide a new image as input to the model, it will output the probabilities of the image representing each of the types of animal it was trained on. An example output might be as follows:
When we subsequently provide a new image as input to the model, it will output
the probabilities of the image representing each of the types of animal it was
trained on. An example output might be as follows:
<table style="width: 40%;">
<thead>
@ -53,49 +93,40 @@ When we subsequently provide a new image as input to the model, it will output t
</tbody>
</table>
Based on the output, we can see that the classification model has predicted that the image has a high probability of representing a dog.
Based on the output, we can see that the classification model has predicted that
the image has a high probability of representing a dog.
Note: Image classification can only tell you the probability that an image represents one or more of the classes that the model was trained on. It cannot tell you the position or identity of objects within the image. If you need to identify objects and their positions within images, you should use an <a href="object_detection">object detection</a> model.
Note: Image classification can only tell you the probability that an image
represents one or more of the classes that the model was trained on. It cannot
tell you the position or identity of objects within the image. If you need to
identify objects and their positions within images, you should use an
<a href="../object_detection/overview.md">object detection</a> model.
### Training, labels, and inference
During training, an image classification model is fed images and their associated labels. Each label is the name of a distinct concept, or class, that the model will learn to recognize. Here are some examples of labels and training data for our hypothetical model that classifies animal photos:
During training, an image classification model is fed images and their
associated _labels_. Each label is the name of a distinct concept, or class,
that the model will learn to recognize.
<table>
<thead>
<tr>
<th>Label</th>
<th>Training data</th>
</tr>
</thead>
<tbody>
<tr>
<td>rabbit</td>
<td>[three different images of rabbits]</td>
</tr>
<tr>
<td>hamster</td>
<td>[three different images of hamsters]</td>
</tr>
<tr>
<td>dog</td>
<td>[three different images of dogs]</td>
</tr>
</tbody>
</table>
Given sufficient training data (often hundreds or thousands of images per
label), an image classification model can learn to predict whether new images
belong to any of the classes it has been trained on. This process of prediction
is called _inference_.
Given sufficient training data (often hundreds or thousands of images per label), an image classification model can learn to predict whether new images belong to any of the classes it has been trained on. This process of prediction is called inference.
To perform inference, an image is passed as input to a model. The model will then output an array of probabilities between 0 and 1. With our example model, this process might look like the following:
To perform inference, an image is passed as input to a model. The model will
then output an array of probabilities between 0 and 1. With our example model,
this process might look like the following:
<table style="width: 60%">
<tr style="border-top: 0px;">
<td style="width: 40%"><img src="images/dog.png" alt="dog"></td>
<td style="width: 20%; font-size: 2em; vertical-align: middle;"></td>
<td style="width: 40%; vertical-align: middle;">[0.07, 0.02, 0.91]</td>
<td style="width: 20%; font-size: 2em; vertical-align: middle; text-align: center;"></td>
<td style="width: 40%; vertical-align: middle; text-align: center;">[0.07, 0.02, 0.91]</td>
</table>
Each number in the output corresponds to a label in our training data. Associating our output with the three labels the model was trained on, we can see the model has predicted a high probability that the image represents a dog.
Each number in the output corresponds to a label in our training data.
Associating our output with the three labels the model was trained on, we can
see the model has predicted a high probability that the image represents a dog.
<table style="width: 40%;">
<thead>
@ -120,11 +151,18 @@ Each number in the output corresponds to a label in our training data. Associati
</tbody>
</table>
You might notice that the sum of all the probabilities (for rabbit, hamster, and dog) is equal to 1. This is a common type of output for models with multiple classes (see <a href="https://developers.google.com/machine-learning/crash-course/multi-class-neural-networks/softmax">Softmax</a> for more information).
You might notice that the sum of all the probabilities (for rabbit, hamster, and
dog) is equal to 1. This is a common type of output for models with multiple
classes (see
<a href="https://developers.google.com/machine-learning/crash-course/multi-class-neural-networks/softmax">Softmax</a>
for more information).
### Ambiguous results
Since the probabilities will always sum to 1, if the image is not confidently recognized as belonging to any of the classes the model was trained on you may see the probability distributed throughout the labels without any one value being significantly larger.
Since the probabilities will always sum to 1, if the image is not confidently
recognized as belonging to any of the classes the model was trained on you may
see the probability distributed throughout the labels without any one value
being significantly larger.
For example, the following might indicate an ambiguous result:
@ -153,9 +191,15 @@ For example, the following might indicate an ambiguous result:
### Uses and limitations
The image classification models that we provide are useful for single-label classification, which means predicting which single label the image is most likely to represent. They are trained to recognize 1000 classes of image. For a full list of classes, see the labels file.
The image classification models that we provide are useful for single-label
classification, which means predicting which single label the image is most
likely to represent. They are trained to recognize 1000 classes of image. For a
full list of classes, see the labels file in the
<a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip">model
zip</a>.
If you want to train a model to recognize new classes, see <a href="#customize_model">Customize model</a>.
If you want to train a model to recognize new classes, see
<a href="#customize_model">Customize model</a>.
For the following use cases, you should use a different type of model:
@ -164,48 +208,78 @@ For the following use cases, you should use a different type of model:
<li>Predicting the composition of an image, for example subject versus background (see <a href="segmentation">segmentation</a>)</li>
</ul>
Once you have the starter model running on your target device, you can experiment with different models to find the optimal balance between performance, accuracy, and model size. For guidance, see <a href="#choose_a_different_model">Choose a different model</a>.
Once you have the starter model running on your target device, you can
experiment with different models to find the optimal balance between
performance, accuracy, and model size. For guidance, see
<a href="#choose_a_different_model">Choose a different model</a>.
## Choose a different model
There are a large number of image classification models available on our List of hosted models. You should aim to choose the optimal model for your application based on performance, accuracy and model size. There are trade-offs between each of them.
There are a large number of image classification models available on our
<a href="../hosted.md">List of hosted models</a>. You should aim to choose the
optimal model for your application based on performance, accuracy and model
size. There are trade-offs between each of them.
### Performance
We measure performance in terms of the amount of time it takes for a model to run inference on a given piece of hardware. The less time, the faster the model.
We measure performance in terms of the amount of time it takes for a model to
run inference on a given piece of hardware. The less time, the faster the model.
The performance you require depends on your application. Performance can be important for applications like real-time video, where it may be important to analyze each frame in the time before the next frame is drawn (e.g. inference must be faster than 33ms to perform real-time inference on a 30fps video stream).
The performance you require depends on your application. Performance can be
important for applications like real-time video, where it may be important to
analyze each frame in the time before the next frame is drawn (e.g. inference
must be faster than 33ms to perform real-time inference on a 30fps video
stream).
Our quantized Mobilenet models performance ranges from 3.7ms to 80.3 ms.
### Accuracy
We measure accuracy in terms of how often the model correctly classifies an image. For example, a model with a stated accuracy of 60% can be expected to classify an image correctly an average of 60% of the time.
Our List of hosted models provides Top-1 and Top-5 accuracy statistics. Top-1 refers to how often the correct label appears as the label with the highest probability in the models output. Top-5 refers to how often the correct label appears in the top 5 highest probabilities in the models output.
We measure accuracy in terms of how often the model correctly classifies an
image. For example, a model with a stated accuracy of 60% can be expected to
classify an image correctly an average of 60% of the time.
Our <a href="../hosted.md">List of hosted models</a> provides Top-1 and Top-5
accuracy statistics. Top-1 refers to how often the correct label appears as the
label with the highest probability in the models output. Top-5 refers to how
often the correct label appears in the top 5 highest probabilities in the
models output.
Our quantized Mobilenet models Top-5 accuracy ranges from 64.4 to 89.9%.
### Size
The size of a model on-disk varies with its performance and accuracy. Size may be important for mobile development (where it might impact app download sizes) or when working with hardware (where available storage might be limited).
The size of a model on-disk varies with its performance and accuracy. Size may
be important for mobile development (where it might impact app download sizes)
or when working with hardware (where available storage might be limited).
Our quantized Mobilenet models size ranges from 0.5 to 3.4 Mb.
### Architecture
There are several different architectures of models available on List of hosted models, indicated by the models name. For example, you can choose between Mobilenet, Inception, and others.
The architecture of a model impacts its performance, accuracy, and size. All of our hosted models are trained on the same data, meaning you can use the provided statistics to compare them and choose which is optimal for your application.
There are several different architectures of models available on
<a href="../hosted.md">List of hosted models</a>, indicated by the models name.
For example, you can choose between Mobilenet, Inception, and others.
Note: The image classification models we provide accept varying sizes of input. For some models, this is indicated in the filename. For example, the Mobilenet_V1_1.0_224 model accepts an input of 224x224 pixels. <br /><br />All of the models require three color channels per pixel (red, green, and blue). Quantized models require 1 byte per channel, and float models require 4 bytes per channel.<br /><br />Our Android and iOS code samples demonstrate how to process full-sized camera images into the required format for each model.
The architecture of a model impacts its performance, accuracy, and size. All of
our hosted models are trained on the same data, meaning you can use the provided
statistics to compare them and choose which is optimal for your application.
Note: The image classification models we provide accept varying sizes of input. For some models, this is indicated in the filename. For example, the Mobilenet_V1_1.0_224 model accepts an input of 224x224 pixels. <br /><br />All of the models require three color channels per pixel (red, green, and blue). Quantized models require 1 byte per channel, and float models require 4 bytes per channel.<br /><br />Our <a href="android.md">Android</a> and <a href="ios">iOS</a> code samples demonstrate how to process full-sized camera images into the required format for each model.
## Customize model
The pre-trained models we provide are trained to recognize 1000 classes of image. For a full list of classes, see the labels file.
You can use a technique known as transfer learning to re-train a model to recognize classes not in the original set. For example, you could re-train the model to distinguish between different species of tree, despite there being no trees in the original training data. To do this, you will need a set of training images for each of the new labels you wish to train.
The pre-trained models we provide are trained to recognize 1000 classes of
image. For a full list of classes, see the labels file in the
<a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip">model
zip</a>.
Learn how to perform transfer learning in the TensorFlow for Poets codelab.
You can use a technique known as _transfer learning_ to re-train a model to
recognize classes not in the original set. For example, you could re-train the
model to distinguish between different species of tree, despite there being no
trees in the original training data. To do this, you will need a set of training
images for each of the new labels you wish to train.
## Read more about this
<ul>
<li>Blog post:</li>
<li>Image classification GitHub:</li>
</ul>
Learn how to perform transfer learning in the
<a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/">TensorFlow
for Poets</a> codelab.

Binary file not shown.

After

Width:  |  Height:  |  Size: 724 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 675 KiB

View File

@ -1,30 +1,59 @@
# Object detection
<img src="../images/detection.png" class="attempt-right">
Detect multiple objects with bounding boxes. Yes, dogs and cats too.
Detect multiple objects within an image, with bounding boxes. Recognize 80
different classes of objects.
<a class="button button-primary" href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download starter model and labels</a>
## Get started
## Tutorials (coming soon)
<a class="button button-primary" href="">iOS</a>
<a class="button button-primary" href="">Android</a>
If you are new to TensorFlow Lite and are working with Android or iOS, we
recommend exploring the following example applications that can help you get
started.
<a class="button button-primary" href="https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android">Android
example</a>
<a class="button button-primary" href="https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/ios">iOS
example</a>
If you are using a platform other than Android or iOS, or you are already
familiar with the <a href="../apis.md">TensorFlow Lite APIs</a>, you can
download our starter object detection model and the accompanying labels.
<a class="button button-primary" href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download
starter model and labels</a>
For more information about the starter model, see
<a href="#starter_model">Starter model</a>.
## What is object detection?
Given an image or a video stream, an object detection model can identify which of a known set of objects might be present and provide information about their positions within the image.
<!-- TODO -->
For example, this screenshot of our <a href="">object detection sample app</a> shows how several objects have been recognized and their positions annotated:
Given an image or a video stream, an object detection model can identify which
of a known set of objects might be present and provide information about their
positions within the image.
For example, this screenshot of our <a href="#get_started">example
application</a> shows how two objects have been recognized and their positions
annotated:
<!-- TODO -->
TODO: Insert image
<img src="images/android_apple_banana.png" alt="Screenshot of Android example" width="30%">
An object detection model is trained to detect the presence and location of multiple classes of objects. For example, a model might be trained with images that contain various pieces of computer hardware, along with a label that specifies the class of hardware they represent (e.g. a laptop, a keyboard, or a monitor), and data specifying where each object appears in the image.
An object detection model is trained to detect the presence and location of
multiple classes of objects. For example, a model might be trained with images
that contain various pieces of fruit, along with a _label_ that specifies the
class of fruit they represent (e.g. an apple, a banana, or a strawberry), and
data specifying where each object appears in the image.
When we subsequently provide an image to the model, it will output a list of the objects it detects, the location of a bounding box that contains each object, and a score that indicates the confidence that detection was correct.
When we subsequently provide an image to the model, it will output a list of the
objects it detects, the location of a bounding box that contains each object,
and a score that indicates the confidence that detection was correct.
### Model output
Imagine a model has been trained to detect apples, bananas, and strawberries.
When we pass it an image, it will output a set number of detection results - in
this example, 5.
<table style="width: 60%;">
<thead>
<tr>
@ -35,27 +64,27 @@ When we subsequently provide an image to the model, it will output a list of the
</thead>
<tbody>
<tr>
<td>Laptop</td>
<td>Apple</td>
<td>0.92</td>
<td>[18, 21, 57, 63]</td>
</tr>
<tr>
<td>Keyboard</td>
<td>Banana</td>
<td>0.88</td>
<td>[100, 30, 180, 150]</td>
</tr>
<tr>
<td>Monitor</td>
<td>Strawberry</td>
<td>0.87</td>
<td>[7, 82, 89, 163] </td>
</tr>
<tr>
<td>Keyboard</td>
<td>Banana</td>
<td>0.23</td>
<td>[42, 66, 57, 83]</td>
</tr>
<tr>
<td>Monitor</td>
<td>Apple</td>
<td>0.11</td>
<td>[6, 42, 31, 58]</td>
</tr>
@ -64,9 +93,16 @@ When we subsequently provide an image to the model, it will output a list of the
### Confidence score
To interpret these results, we can look at the score and the location for each detected object. The score is a number between 0 and 1 that indicates confidence that the object was genuinely detected. The closer the number is to 1, the more confident the model is.
To interpret these results, we can look at the score and the location for each
detected object. The score is a number between 0 and 1 that indicates confidence
that the object was genuinely detected. The closer the number is to 1, the more
confident the model is.
Depending on your application, you can decide a cut-off threshold below which you will discard detection results. For our example, we might decide a sensible cut-off is a score of 0.5 (meaning a 50% probability that the detection is valid). In that case, we would ignore the last two objects in the array, because those confidence scores are below 0.5:
Depending on your application, you can decide a cut-off threshold below which
you will discard detection results. For our example, we might decide a sensible
cut-off is a score of 0.5 (meaning a 50% probability that the detection is
valid). In that case, we would ignore the last two objects in the array, because
those confidence scores are below 0.5:
<table style="width: 60%;">
<thead>
@ -78,41 +114,51 @@ Depending on your application, you can decide a cut-off threshold below which yo
</thead>
<tbody>
<tr>
<td>Laptop</td>
<td>Apple</td>
<td>0.92</td>
<td>[18, 21, 57, 63]</td>
</tr>
<tr>
<td>Keyboard</td>
<td>Banana</td>
<td>0.88</td>
<td>[100, 30, 180, 150]</td>
</tr>
<tr>
<td>Monitor</td>
<td>Strawberry</td>
<td>0.87</td>
<td>[7, 82, 89, 163] </td>
</tr>
<tr>
<td style="background-color: #e9cecc; text-decoration-line: line-through;">Keyboard</td>
<td style="background-color: #e9cecc; text-decoration-line: line-through;">Banana</td>
<td style="background-color: #e9cecc; text-decoration-line: line-through;">0.23</td>
<td style="background-color: #e9cecc; text-decoration-line: line-through;">[42, 66, 57, 83]</td>
</tr>
<tr>
<td style="background-color: #e9cecc; text-decoration-line: line-through;">Monitor</td>
<td style="background-color: #e9cecc; text-decoration-line: line-through;">Apple</td>
<td style="background-color: #e9cecc; text-decoration-line: line-through;">0.11</td>
<td style="background-color: #e9cecc; text-decoration-line: line-through;">[6, 42, 31, 58]</td>
</tr>
</tbody>
</table>
The cut-off you use should be based on whether you are more comfortable with false positives (objects that are wrongly identified, or areas of the image that are erroneously identified as objects when they are not), or false negatives (genuine objects that are missed because their confidence was low).
The cut-off you use should be based on whether you are more comfortable with
false positives (objects that are wrongly identified, or areas of the image that
are erroneously identified as objects when they are not), or false negatives
(genuine objects that are missed because their confidence was low).
<!-- TODO -->
TODO: Insert screenshot showing both
For example, in the following image, a pear (which is not an object that the
model was trained to detect) was misidentified as a "person". This is an example
of a false positive that could be ignored by selecting an appropriate cut-off.
In this case, a cut-off of 0.6 (or 60%) would comfortably exclude the false
positive.
<img src="images/false_positive.png" alt="Screenshot of Android example showing a false positive" width="30%">
### Location
For each detected object, the model will return an array of four numbers representing a bounding rectangle that surrounds its position. The numbers are ordered as follows:
For each detected object, the model will return an array of four numbers
representing a bounding rectangle that surrounds its position. For the starter
model we provide, the numbers are ordered as follows:
<table style="width: 50%; margin: 0 auto;">
<tbody>
@ -127,49 +173,52 @@ For each detected object, the model will return an array of four numbers represe
</tbody>
</table>
The top value represents the distance of the rectangles top edge from the top of the image, in pixels. The left value represents the left edges distance from the left of the input image. The other values represent the bottom and right edges in a similar manner.
The top value represents the distance of the rectangles top edge from the top
of the image, in pixels. The left value represents the left edges distance from
the left of the input image. The other values represent the bottom and right
edges in a similar manner.
<!-- TODO -->
Note: Object detection models accept input images of a specific size. This is likely to be different from the size of the raw image captured by your devices camera, and you will have to write code to crop and scale your raw image to fit the models input size (there are examples of this in our <a href="">sample code</a>).<br /><br />The pixel values output by the model refer to the position in the cropped and scaled image, so you must scale them to fit the raw image in order to interpret them correctly.
Note: Object detection models accept input images of a specific size. This is likely to be different from the size of the raw image captured by your devices camera, and you will have to write code to crop and scale your raw image to fit the models input size (there are examples of this in our <a href="#get_started">example applications</a>).<br /><br />The pixel values output by the model refer to the position in the cropped and scaled image, so you must scale them to fit the raw image in order to interpret them correctly.
## Starter model
We recommend starting with this pre-trained quantized COCO SSD MobileNet v1
model.
<a class="button button-primary" href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download
starter model and labels</a>
### Uses and limitations
<!-- TODO -->
The object detection model we provide can identify and locate up to 10 objects in an image. It is trained to recognize 80 classes of object. For a full list of classes, see the labels file in the <a href="">model zip</a>.
The object detection model we provide can identify and locate up to 10 objects
in an image. It is trained to recognize 80 classes of object. For a full list of
classes, see the labels file in the
<a href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">model
zip</a>.
If you want to train a model to recognize new classes, see <a href="#customize_model">Customize model</a>.
If you want to train a model to recognize new classes, see
<a href="#customize_model">Customize model</a>.
For the following use cases, you should use a different type of model:
<ul>
<li>Predicting which single label the image most likely represents (see <a href="image_classification">image classification</a>)</li>
<li>Predicting the composition of an image, for example subject versus background (see <a href="segmentation">segmentation</a>)</li>
<li>Predicting which single label the image most likely represents (see <a href="../image_classification/overview.md">image classification</a>)</li>
<li>Predicting the composition of an image, for example subject versus background (see <a href="../segmentation/overview.md">segmentation</a>)</li>
</ul>
Get started
If you are new to TensorFlow Lite and are working with Android or iOS, we recommend following the corresponding tutorial that will walk you through our sample code.
<!-- TODO -->
<a class="button button-primary" href="">iOS</a>
<a class="button button-primary" href="">Android</a>
If you are using a platform other than Android or iOS, or you are already familiar with the <a href="../apis">TensorFlow Lite APIs</a>, you can download our starter object detection model and the accompanying labels.
<a href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download starter model and labels</a>
The model will return 10 detection results...
## Starter model
We recommend starting to implement object detection using the quantized COCO SSD MobileNet v1 model, available with labels from this download link:
<a href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download starter model and labels</a>
### Input
The model takes an image as input. The expected image is 300x300 pixels, with three channels (red, blue, and green) per pixel. This should be fed to the model as a flattened buffer of 270,000 byte values (300x300x3). Since the model is <a href="">quantized</a>, each value should be a single byte representing a value between 0 and 255.
The model takes an image as input. The expected image is 300x300 pixels, with
three channels (red, blue, and green) per pixel. This should be fed to the model
as a flattened buffer of 270,000 byte values (300x300x3). Since the model is
<a href="../../performance/post_training_quantization.md">quantized</a>, each
value should be a single byte representing a value between 0 and 255.
### Output
The model outputs four arrays, mapped to the indices 0-4. Arrays 0, 1, and 2 describe 10 detected objects, with one element in each array corresponding to each object. There will always be 10 objects detected.
The model outputs four arrays, mapped to the indices 0-4. Arrays 0, 1, and 2
describe 10 detected objects, with one element in each array corresponding to
each object. There will always be 10 objects detected.
<table>
<thead>
@ -205,16 +254,17 @@ The model outputs four arrays, mapped to the indices 0-4. Arrays 0, 1, and 2 des
## Customize model
<!-- TODO -->
The pre-trained models we provide are trained to detect 80 classes of object. For a full list of classes, see the labels file in the <a href="">model zip</a>.
The pre-trained models we provide are trained to detect 80 classes of object.
For a full list of classes, see the labels file in the
<a href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">model
zip</a>.
You can use a technique known as transfer learning to re-train a model to recognize classes not in the original set. For example, you could re-train the model to detect multiple types of vegetable, despite there only being one vegetable in the original training data. To do this, you will need a set of training images for each of the new labels you wish to train.
You can use a technique known as transfer learning to re-train a model to
recognize classes not in the original set. For example, you could re-train the
model to detect multiple types of vegetable, despite there only being one
vegetable in the original training data. To do this, you will need a set of
training images for each of the new labels you wish to train.
Learn how to perform transfer learning in the <a href="https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193">Training and serving a real-time mobile object detector in 30 minutes</a> blog post.
<!-- TODO -->
Read more about this
<ul>
<li>Blog post:</li>
<li>Object detection GitHub:</li>
</ul>
Learn how to perform transfer learning in
<a href="https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193">Training
and serving a real-time mobile object detector in 30 minutes</a>.

View File

@ -1,20 +1,31 @@
# Pose estimation
<img src="../images/pose.png" class="attempt-right" />
<i>PoseNet</i> is a vision model that can be used to estimate the pose of a person in an image/video by estimating where key body joints are.
## Get started
<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/multi_person_mobilenet_v1_075_float.tflite">Download starter model</a>
_PoseNet_ is a vision model that can be used to estimate the pose of a person in
an image or video by estimating where key body joints are.
## Tutorials (coming soon)
<a class="button button-primary" href="">iOS</a>
<a class="button button-primary" href="">Android</a>
<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/multi_person_mobilenet_v1_075_float.tflite">Download
starter model</a>
Android and iOS end-to-end tutorials are coming soon. In the meantime, if you
want to experiment this on a web browser, check out the
<a href="https://github.com/tensorflow/tfjs-models/tree/master/posenet">TensorFlow.js
GitHub repository</a>.
## How it works
Pose estimation refers to computer vision techniques that detect human figures in images and videos, so that one could determine, for example, where someones elbow shows up in an image.
To be clear, this technology is not recognizing who is in an imagethere is no personal identifiable information associated to pose detection. The algorithm is simply estimating where key body joints are.
Pose estimation refers to computer vision techniques that detect human figures
in images and videos, so that one could determine, for example, where someones
elbow shows up in an image.
The key points detected are indexed by part id with a confidence score between 0.0 and 1.0; 1.0 being the highest.
To be clear, this technology is not recognizing who is in an image. The
algorithm is simply estimating where key body joints are.
The key points detected are indexed by "Part ID", with a confidence score
between 0.0 and 1.0, 1.0 being the highest.
<table style="width: 30%;">
<thead>
@ -96,33 +107,47 @@ The key points detected are indexed by part id with a confidence score between 0
</table>
## Example output
<img src="https://www.tensorflow.org/images/models/pose_estimation.gif" />
## Get started
Android and iOS end-to-end tutorials are coming soon. In the meantime, if you want to experiment this on a web browser, check out the TensorFlow.js <a href="https://github.com/tensorflow/tfjs-models/tree/master/posenet">GitHub repository</a>.
<img alt="Animation showing pose estimation" src="https://www.tensorflow.org/images/models/pose_estimation.gif" />
## How it performs
Performance varies based on your device and output stride (heatmaps and offset vectors). The PoseNet model is image size invariant, which means it can predict pose positions in the same scale as the original image regardless of whether the image is downscaled. This means PoseNet can be configured to have a higher accuracy at the expense of performance.
The output stride determines how much were scaling down the output relative to the input image size. It affects the size of the layers and the model outputs. The higher the output stride, the smaller the resolution of layers in the network and the outputs, and correspondingly their accuracy. In this implementation, the output stride can have values of 8, 16, or 32. In other words, an output stride of 32 will result in the fastest performance but lowest accuracy, while 8 will result in the highest accuracy but slowest performance. We recommend starting with 16.
Performance varies based on your device and output stride (heatmaps and offset
vectors). The PoseNet model is image size invariant, which means it can predict
pose positions in the same scale as the original image regardless of whether the
image is downscaled. This means PoseNet can be configured to have a higher
accuracy at the expense of performance.
<img src="../images/models/output_stride.png" >
<span style="font-size: 0.8em">The output stride determines how much were scaling down the output relative to the input image size. A higher output stride is faster but results in lower accuracy.</span>
The output stride determines how much were scaling down the output relative to
the input image size. It affects the size of the layers and the model outputs.
The higher the output stride, the smaller the resolution of layers in the
network and the outputs, and correspondingly their accuracy. In this
implementation, the output stride can have values of 8, 16, or 32. In other
words, an output stride of 32 will result in the fastest performance but lowest
accuracy, while 8 will result in the highest accuracy but slowest performance.
We recommend starting with 16.
The following image shows how the output stride determines how much were
scaling down the output relative to the input image size. A higher output stride
is faster but results in lower accuracy.
<img alt="Output stride and heatmap resolution" src="../images/output_stride.png" >
## Read more about pose estimation
## Read more about this
<ul>
<li><a href="">Blog post: Real-time Human Pose Estimation in the Browser with TensorFlow.js</a></li>
<li><a href="">TF.js GitHub: Pose Detection in the Browser: PoseNet Model</a></li>
<li><a href="https://medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-tensorflow-js-7dd0bc881cd5">Blog post: Real-time Human Pose Estimation in the Browser with TensorFlow.js</a></li>
<li><a href="https://github.com/tensorflow/tfjs-models/tree/master/posenet">TF.js GitHub: Pose Detection in the Browser: PoseNet Model</a></li>
</ul>
## Users
### Use cases
<ul>
<li><a href="">PomPom Mirror</a></li>
<li><a href="">Amazing Art Installation Turns You Into A Bird | Chris Milk "The Treachery of Sanctuary"</a></li>
<li><a href="">Puppet Parade - Interactive Kinect Puppets</a></li>
<li><a href="">Messa di Voce (Performance), Excerpts</a></li>
<li><a href="">Augmented reality</a></li>
<li><a href="">Interactive animation</a></li>
<li><a href="">Gait analysis</a></li>
<li><a href="https://vimeo.com/128375543">PomPom Mirror</a></li>
<li><a href="https://youtu.be/I5__9hq-yas">Amazing Art Installation Turns You Into A Bird | Chris Milk "The Treachery of Sanctuary"</a></li>
<li><a href="https://vimeo.com/34824490">Puppet Parade - Interactive Kinect Puppets</a></li>
<li><a href="https://vimeo.com/2892576">Messa di Voce (Performance), Excerpts</a></li>
<li><a href="https://www.instagram.com/p/BbkKLiegrTR/">Augmented reality</a></li>
<li><a href="https://www.instagram.com/p/Bg1EgOihgyh/">Interactive animation</a></li>
<li><a href="https://www.runnersneed.com/expert-advice/gear-guides/gait-analysis.html">Gait analysis</a></li>
</ul>

View File

@ -1,18 +1,26 @@
# Segmentation (GPU)
# Segmentation
<img src="../images/segmentation.png" class="attempt-right" />
<i>DeepLab</i> is a state-of-art deep learning model for semantic image segmentation, where the goal is to assign semantic labels (e.g., person, dog, cat and so on) to every pixel in the input image.
## Get started
<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/deeplabv3_257_mv_gpu.tflite">Download starter model</a>
_DeepLab_ is a state-of-art deep learning model for semantic image segmentation,
where the goal is to assign semantic labels (e.g. person, dog, cat) to every
pixel in the input image.
## Tutorials (coming soon)
<a class="button button-primary" href="">iOS</a>
<a class="button button-primary" href="">Android</a>
<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/deeplabv3_257_mv_gpu.tflite">Download
starter model</a>
## How it works
It all started with classification where the model predicts an entire input. With advances in data, hardware, and software, object detection can infer objects with spatial location. Semantic segmentation offers the highest level of granularity with labels at a pixel level.
Current implementation includes the following features:
Semantic image segmentation predicts whether each pixel of an image is
associated with a certain class. This is in contrast to
<a href="../object_detection/overview.md">object detection</a>, which detects
objects in rectangular regions, and
<a href="../image_classification/overview.md">image classification</a>, which
classifies the overall image.
The current implementation includes the following features:
<ol>
<li>DeepLabv1: We use atrous convolution to explicitly control the resolution at which feature responses are computed within Deep Convolutional Neural Networks.</li>
<li>DeepLabv2: We use atrous spatial pyramid pooling (ASPP) to robustly segment objects at multiple scales with filters at multiple sampling rates and effective fields-of-views.</li>
@ -21,12 +29,15 @@ Current implementation includes the following features:
</ol>
## Example output
The model will create a mask over the target objects with high accuracy.
<img src="images/segmentation.gif" />
## Read more about this
The model will create a mask over the target objects with high accuracy.
<img alt="Animation showing image segmentation" src="images/segmentation.gif" />
## Read more about segmentation
<ul>
<li>Blog post: <a href="https://ai.googleblog.com/2018/03/semantic-image-segmentation-with.html">Semantic Image Segmentation with DeepLab in TensorFlow</a></li>
<li><a href="https://medium.com/tensorflow/tensorflow-lite-now-faster-with-mobile-gpus-developer-preview-e15797e6dee7">Blog post: TensorFlow Lite Now Faster with Mobile GPUs (Developer Preview)</a></li>
<li><a href="https://github.com/tensorflow/models/tree/master/research/deeplab">DeepLab GitHub: DeepLab: Deep Labelling for Semantic Image Segmentation</a></li>
<li><a href="https://ai.googleblog.com/2018/03/semantic-image-segmentation-with.html">Semantic Image Segmentation with DeepLab in TensorFlow</a></li>
<li><a href="https://medium.com/tensorflow/tensorflow-lite-now-faster-with-mobile-gpus-developer-preview-e15797e6dee7">TensorFlow Lite Now Faster with Mobile GPUs (Developer Preview)</a></li>
<li><a href="https://github.com/tensorflow/models/tree/master/research/deeplab">DeepLab: Deep Labelling for Semantic Image Segmentation</a></li>
</ul>

View File

@ -1,37 +1,49 @@
# Smart reply
<img src="../images/smart_reply.png" class="attempt-right" />
Smart replies are contextually relevant, one-touch responses that help the user to reply to an incoming text message (or email) efficiently and effortlessly.
## Get started
<a class="button button-primary" href="http://download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip">Download starter model and labels</a>
Our smart reply model generates reply suggestions based on chat messages. The
suggestions are intended to be contextually relevant, one-touch responses that
help the user to easily reply to an incoming message.
## Tutorials (coming soon)
<a class="button button-primary" href="">iOS</a>
<a class="button button-primary" href="">Android</a>
<a class="button button-primary" href="http://download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip">Download
starter model and labels</a>
### Sample application
We have provided a pre-built APK that demonstrates the smart reply model on
Android.
Go to the
<a href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/smartreply/g3doc">GitHub
page</a> for instructions and list of supported ops and functionalities.
## How it works
The model generates reply suggestions to input conversational chat messages with an efficient inference that can be easily be plugged in to your chat application to power on-device conversational intelligence.
The model generates reply suggestions to conversational chat messages.
The on-device model comes with several benefits. It is:
<ul>
<li>Faster: The model resides on the device and does not require internet connectivity. Thus, the inference is very fast and has an average latency of only a few milliseconds.</li>
<li>Fast: The model resides on the device and does not require internet connectivity. Thus, inference is very fast and has an average latency of only a few milliseconds.</li>
<li>Resource efficient: The model has a small memory footprint on the device.</li>
<li>Privacy-friendly: The user data never leaves the device and this eliminates any privacy restrictions.</li>
<li>Privacy-friendly: User data never leaves the device.</li>
</ul>
## Example output
<img src="images/smart_reply.gif" />
## How to use this model?
We have provided a pre-built demo APK that you can download, install, and test on your phone. Go to the <a href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/smartreply/g3doc">GitHub page</a> for instructions and list of support ops and functionalities.
<img alt="Animation showing smart reply" src="images/smart_reply.gif" />
## Read more about this
<ul>
<li><a href="https://arxiv.org/pdf/1708.00630.pdf">Research paper</a></li>
<li><a href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/smartreply/">Source code</a></li>
</ul>
## Users
<ul>
<li><a href="https://www.blog.google/products/gmail/save-time-with-smart-reply-in-gmail/">Gmail</a></li>
<li><a href="https://www.blog.google/products/gmail/computer-respond-to-this-email/">Inbox</a></li>

View File

@ -1,14 +0,0 @@
# Speech recognition
<img src="../images/audio.png" class="attempt-right">
Recognize audio keywords!
<a class="button button-primary" href="">Download starter model</a>
## Tutorials (coming soon)
<a class="button button-primary" href="">iOS</a>
<a class="button button-primary" href="">Android</a>
## What is speech recognition?
Coming soon.

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <sys/types.h>
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <limits>
#include <memory>
#include <tuple>
@ -1743,6 +1744,221 @@ inline void ShuffledFullyConnected(
gemm_context->workers_pool()->Execute(tasks);
}
inline void MeanImpl(const tflite::MeanParams& op_params,
const RuntimeShape& input_shape, const uint8_t* input_data,
int32 input_zero_point, float input_scale,
const RuntimeShape& output_shape, uint8_t* output_data,
int32 output_zero_point, float output_scale,
int start_depth, int end_depth) {
gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8/MeanImpl");
// Current implementation only supports dimension equals 4 and simultaneous
// reduction over width and height.
const int output_batch = output_shape.Dims(0);
const int output_height = output_shape.Dims(2);
const int output_width = output_shape.Dims(2);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const float num_elements_in_axis = input_width * input_height;
TFLITE_DCHECK_EQ(op_params.axis_count, 2);
TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
TFLITE_DCHECK_EQ(output_height, 1);
TFLITE_DCHECK_EQ(output_width, 1);
const bool ordinary_mean =
(input_zero_point == output_zero_point && input_scale == output_scale);
float scale, bias;
if (!ordinary_mean) {
scale = input_scale / output_scale;
bias = -input_zero_point * scale + 0.5;
}
#ifdef USE_NEON
const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis);
// This is only an approximation as NEON does not offer division instruction.
const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup);
const float32x4_t kRounding = vdupq_n_f32(0.5);
float32x4_t bias_dup;
float32x4_t output_zero_point_dup;
if (!ordinary_mean) {
bias_dup = vdupq_n_f32(bias);
output_zero_point_dup = vdupq_n_f32(output_zero_point);
}
#endif
for (int out_b = 0; out_b < output_batch; ++out_b) {
int out_d = start_depth;
#ifdef USE_NEON
for (; out_d < end_depth - 8; out_d += 8) {
float32x4_t temp_sum_1 = vdupq_n_f32(0);
float32x4_t temp_sum_2 = vdupq_n_f32(0);
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
const uint8_t* input_data_ptr =
input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
uint8x8_t input_data_val = vld1_u8(input_data_ptr);
int16x8_t input_data_val_shift =
vreinterpretq_s16_u16(vmovl_u8(input_data_val));
float32x4_t input_float_1 =
vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_data_val_shift)));
float32x4_t input_float_2 =
vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_data_val_shift)));
temp_sum_1 = vaddq_f32(temp_sum_1, input_float_1);
temp_sum_2 = vaddq_f32(temp_sum_2, input_float_2);
}
}
float32x4_t mean_1 = vmulq_f32(temp_sum_1, num_elements_reverse);
float32x4_t mean_2 = vmulq_f32(temp_sum_2, num_elements_reverse);
if (!ordinary_mean) {
// maq is not supported, break down into two ops.
mean_1 = vmulq_n_f32(mean_1, scale);
mean_1 = vaddq_f32(mean_1, bias_dup);
mean_2 = vmulq_n_f32(mean_2, scale);
mean_2 = vaddq_f32(mean_2, bias_dup);
}
if (!ordinary_mean) {
mean_1 = vaddq_f32(mean_1, output_zero_point_dup);
mean_2 = vaddq_f32(mean_2, output_zero_point_dup);
}
// Rounding.
mean_1 = vaddq_f32(mean_1, kRounding);
mean_2 = vaddq_f32(mean_2, kRounding);
uint32x4_t casted_mean_1 = vcvtq_u32_f32(mean_1);
uint16x4_t narrow_range_mean_1 = vmovn_u32(casted_mean_1);
uint32x4_t casted_mean_2 = vcvtq_u32_f32(mean_2);
uint16x4_t narrow_range_mean_2 = vmovn_u32(casted_mean_2);
uint16x8_t combined_mean =
vcombine_u16(narrow_range_mean_2, narrow_range_mean_1);
uint8x8_t narrowed_combined_mean = vmovn_u16(combined_mean);
uint8_t* output_data_ptr =
output_data + Offset(output_shape, out_b, 0, 0, out_d);
vst1_u8(output_data_ptr, narrowed_combined_mean);
}
#endif
for (; out_d < end_depth; ++out_d) {
float temp_value = 0;
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
temp_value +=
input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
}
}
temp_value = temp_value / num_elements_in_axis;
if (ordinary_mean) {
output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
static_cast<uint8_t>(round(temp_value));
} else {
output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
static_cast<uint8_t>(round(temp_value * scale + bias)) +
output_zero_point;
}
}
}
}
struct MeanWorkerTask : public gemmlowp::Task {
MeanWorkerTask(const tflite::MeanParams& op_params,
const RuntimeShape& input_shape, const uint8_t* input_data,
int32 input_zero_point, float input_scale,
const RuntimeShape& output_shape, uint8_t* output_data,
int32 output_zero_point, float output_scale, int start_height,
int end_height)
: op_params_(op_params),
input_shape_(input_shape),
input_data_(input_data),
input_zero_point_(input_zero_point),
input_scale_(input_scale),
output_shape_(output_shape),
output_data_(output_data),
output_zero_point_(output_zero_point),
output_scale_(output_scale),
start_height_(start_height),
end_height_(end_height) {}
void Run() override {
MeanImpl(op_params_, input_shape_, input_data_, input_zero_point_,
input_scale_, output_shape_, output_data_, output_zero_point_,
output_scale_, start_height_, end_height_);
}
private:
const tflite::MeanParams& op_params_;
const RuntimeShape& input_shape_;
const uint8_t* input_data_;
int32 input_zero_point_;
float input_scale_;
const RuntimeShape& output_shape_;
uint8_t* output_data_;
int32 output_zero_point_;
float output_scale_;
int start_height_;
int end_height_;
gemmlowp::GemmContext* gemm_context_;
};
inline void Mean(const tflite::MeanParams& op_params,
const RuntimeShape& unextended_input_shape,
const uint8_t* input_data, int32 input_zero_point,
float input_scale, const RuntimeShape& unextended_output_shape,
uint8_t* output_data, int32 output_zero_point,
float output_scale, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8");
// Current implementation only supports dimension equals 4 and simultaneous
// reduction over width and height.
TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_depth = output_shape.Dims(3);
TFLITE_DCHECK_EQ(op_params.axis_count, 2);
TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
TFLITE_DCHECK_EQ(output_height, 1);
TFLITE_DCHECK_EQ(output_width, 1);
constexpr int kMinDepthPerThread = 8;
int thread_count = output_depth / kMinDepthPerThread;
thread_count = thread_count > 0 ? thread_count : 1;
const int capped_thread_count =
std::min(thread_count, gemm_context->max_num_threads());
if (thread_count == 1) {
MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
output_shape, output_data, output_zero_point, output_scale, 0,
output_depth);
} else {
// Instead parrallel for batch, we loop for the output_depth since batch
// is typical 1.
std::vector<gemmlowp::Task*> tasks(capped_thread_count);
int depth_start = 0;
for (int i = 0; i < capped_thread_count; ++i) {
// Try to distribute the tasks as even as possible.
int depth_end = (output_depth - depth_start) / (capped_thread_count - i);
tasks[i] = new MeanWorkerTask(op_params, input_shape, input_data,
input_zero_point, input_scale, output_shape,
output_data, output_zero_point,
output_scale, depth_start, depth_end);
depth_start = depth_end;
}
gemm_context->workers_pool()->Execute(tasks);
}
}
template <typename T>
inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
int h, int b, int kheight, int kwidth,

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/gemm_support.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
@ -49,6 +51,7 @@ struct OpContext {
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
gemm_support::IncrementUsageCounter(context);
// Creates two temp tensors to store index and axis for internal
// implementation only.
auto* scratch_tensor_index = new int;
@ -57,6 +60,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
}
void Free(TfLiteContext* context, void* buffer) {
gemm_support::DecrementUsageCounter(context);
delete reinterpret_cast<int*>(buffer);
}
@ -248,6 +252,7 @@ void ResolveAxis(const int* axis_data, int axis_count,
template <KernelType kernel_type>
TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
int num_axis = static_cast<int>(NumElements(op_context.axis));
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
@ -272,13 +277,15 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
if (op_context.input->type == kTfLiteUInt8) {
reference_ops::Mean(
gemmlowp::GemmContext* gemm_context =
gemm_support::GetFromContext(context);
optimized_ops::Mean(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
op_context.input->params.zero_point, op_context.input->params.scale,
GetTensorShape(op_context.output),
GetTensorData<uint8_t>(op_context.output),
op_context.output->params.zero_point,
op_context.output->params.scale);
op_context.output->params.scale, gemm_context);
} else {
reference_ops::Mean(op_params, GetTensorShape(input),
GetTensorData<float>(input),

View File

@ -259,7 +259,7 @@ TEST(ConstFloatMeanOpTest, KeepDims) {
// Uses a set of reduction conditions that trigger the specialized 4D version
// of Mean.
TEST(ConstFloatMeanOpTest, KeepDims_4DMean) {
TEST(ConstFloatMeanOpTest, KeepDims4DMean) {
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
@ -272,7 +272,7 @@ TEST(ConstFloatMeanOpTest, KeepDims_4DMean) {
ElementsAreArray(ArrayFloatNear({6, 7, 18, 19})));
}
TEST(ConstFloatMeanOpTest, KeepDims_4DMean_UInt8) {
TEST(ConstFloatMeanOpTest, KeepDims4DMeanUInt8) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.1, 0.2,
0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
@ -286,7 +286,24 @@ TEST(ConstFloatMeanOpTest, KeepDims_4DMean_UInt8) {
kQuantizedTolerance)));
}
TEST(ConstFloatMeanOpTest, KeepDims_4DMean_Quantized) {
TEST(ConstFloatMeanOpTest, KeepDims4DMeanLargeDepthUInt8) {
float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1,
0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9,
0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3,
0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
MeanOpConstModel m({TensorType_UINT8, {1, 2, 2, 9}, -1.0, 1.0},
{TensorType_UINT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true);
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9}));
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425},
kQuantizedTolerance)));
}
TEST(ConstFloatMeanOpTest, KeepDims4DMeanQuantized) {
float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.1, 0.2,
0.3, 0.4, 0.1, 0.2, 0.3, 0.4};

View File

@ -242,17 +242,18 @@ class TFLiteConverterV2(object):
Input shape is not specified.
None value for dimension in input_tensor.
"""
graph_def = _convert_to_constants.convert_variables_to_constants_v2(
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
self._func)
input_tensors = [
tensor for tensor in self._func.inputs
tensor for tensor in frozen_func.inputs
if tensor.dtype != _dtypes.resource
]
output_tensors = self._func.outputs
output_tensors = frozen_func.outputs
# Run a Grappler pass.
graph_def = _run_graph_optimizations(graph_def, input_tensors,
output_tensors, self._func.graph)
graph_def = _run_graph_optimizations(frozen_func.graph.as_graph_def(),
input_tensors, output_tensors,
frozen_func.graph)
# Checks dimensions in input tensor.
for tensor in input_tensors:

View File

@ -49,8 +49,8 @@ void FillArrayWithZeros(Array* array) {
} // namespace
// Removes a multiplication by array of constant zeros by making the output
// array an array of constant zeros and removing the input arrays if they are no
// longer needed.
// array to an array of constant zeros and removing the input arrays if they
// are no longer needed.
::tensorflow::Status ResolveMultiplyByZero::Run(Model* model,
std::size_t op_index,
bool* modified) {

View File

@ -1,245 +1,245 @@
tensorflow/contrib/tpu/profiler/pip_package/BUILD
tensorflow/contrib/tpu/profiler/pip_package/setup.py
tensorflow/contrib/tpu/profiler/pip_package/README
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
tensorflow/contrib/mpi/BUILD
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/lib_package/README.md
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
tensorflow/third_party/toolchains/cpus/arm/BUILD
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/git/BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/gpus/crosstool/LICENSE
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl
tensorflow/third_party/gpus/crosstool/BUILD.tpl
tensorflow/third_party/gpus/crosstool/BUILD
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda/LICENSE
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/gast.BUILD
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/swig.BUILD
tensorflow/third_party/astor.BUILD
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/grpc/BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/astor.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/png.BUILD
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/six.BUILD
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/zlib.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/gast.BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/eigen.BUILD
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/BUILD
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/swig.BUILD
tensorflow/compat_template.__init__.py
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/README.md
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/api_template.__init__.py
tensorflow/contrib/tpu/profiler/pip_package/BUILD
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/contrib/tpu/profiler/pip_package/README
tensorflow/contrib/tpu/profiler/pip_package/setup.py
tensorflow/contrib/mpi/BUILD
tensorflow/__init__.py
tensorflow/stream_executor/build_defs.bzl
tensorflow/api_template_v1.__init__.py
tensorflow/compat_template_v1.__init__.py
tensorflow/compat_template.__init__.py
tensorflow/api_template.__init__.py
tensorflow/__init__.py
tensorflow/compat_template_v1.__init__.py

View File

@ -5248,6 +5248,7 @@ tf_py_test(
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/distribute:collective_all_reduce_strategy",
"//tensorflow/python/distribute:distribute_coordinator",
],
tags = [

View File

@ -21,6 +21,10 @@ from __future__ import print_function
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.eager import function
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import tf_logging as logging
@ -55,6 +59,64 @@ def _run_inline_graph_optimization(func):
return tf_optimizer.OptimizeGraph(config, meta_graph)
def _get_tensors_from_graph(graph, tensors):
"""Gets the Tensors in `graph` with the name of the tensors in `tensors`.
Args:
graph: TensorFlow Graph.
tensors: List of Tensors.
Returns:
List of Tensors.
"""
new_tensors = []
for orig_tensor in tensors:
new_tensor = graph.get_tensor_by_name(orig_tensor.name)
if new_tensor.shape.rank is None:
new_tensor.set_shape(orig_tensor.shape)
new_tensors.append(new_tensor)
return new_tensors
def _construct_concrete_function(input_func, graph_def):
"""Creates a ConcreteFunction from the input function and frozen graph.
Args:
input_func: ConcreteFunction.
graph_def: TensorFlow GraphDef.
Returns:
ConcreteFunction containing the graph_def.
"""
output_graph = func_graph.FuncGraph(input_func.graph.name)
with output_graph.as_default():
importer.import_graph_def(graph_def, name="")
output_graph.inputs = _get_tensors_from_graph(output_graph,
input_func.inputs)
output_graph.outputs = _get_tensors_from_graph(output_graph,
input_func.outputs)
output_graph.structured_outputs = input_func.graph.structured_outputs
output_graph.structured_input_signature = (
input_func.graph.structured_input_signature)
# Create the ConcreteFunction and add it to the global context.
output_func = function.ConcreteFunction(output_graph)
output_func.add_to_graph()
# Inject the captured inputs into the ConcreteFunction.
output_func._captured_inputs = input_func.captured_inputs # pylint: disable=protected-access
output_func.graph.variables = input_func.graph.variables
output_func._arg_keywords = input_func._arg_keywords # pylint: disable=protected-access
output_func._num_position_args = input_func._num_positional_args # pylint: disable=protected-access
# Register the gradients in the current root context.
with ops.init_scope():
output_func._register_gradient() # pylint: disable=protected-access
return output_func
def convert_variables_to_constants_v2(func):
"""Replaces all the variables in a graph with constants of the same values.
@ -71,7 +133,7 @@ def convert_variables_to_constants_v2(func):
func: ConcreteFunction.
Returns:
GraphDef containing a simplified version of the original.
ConcreteFunction containing a simplified version of the original.
"""
# TODO(nupurgarg): Replace ResourceGather with Gather.
# TODO(nupurgarg): Change attr for Variables in control flow and functions.
@ -145,4 +207,5 @@ def convert_variables_to_constants_v2(func):
output_node.CopyFrom(input_node)
logging.info("Converted %d variables to const ops.", how_many_converted)
return output_graph_def
# TODO(b/126613403): Use wrap_function.function_from_graph_def.
return _construct_concrete_function(func, output_graph_def)

View File

@ -35,6 +35,7 @@ from tensorflow.python.saved_model.save import save
from tensorflow.python.training.tracking import tracking
# TODO(nupurgarg): Simplify the test cases to use the ConcreteFunction.
class VariablesToConstantsTest(test.TestCase):
def _hasStatefulPartitionedCallOp(self, graph_def):
@ -77,20 +78,21 @@ class VariablesToConstantsTest(test.TestCase):
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save(root, save_dir, to_save)
saved_model = load(save_dir)
concrete_func = saved_model.signatures["serving_default"]
input_func = saved_model.signatures["serving_default"]
variable_graph_def = concrete_func.graph.as_graph_def()
variable_graph_def = input_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(variable_graph_def))
self.assertTrue(variable_graph_def.library.function)
constant_graph_def = convert_to_constants.convert_variables_to_constants_v2(
concrete_func)
output_func = convert_to_constants.convert_variables_to_constants_v2(
input_func)
constant_graph_def = output_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(constant_graph_def))
self.assertFalse(constant_graph_def.library.function)
# Check value.
expected_value = root.f(input_data)
actual_value = self._evaluateGraphDef(constant_graph_def, concrete_func,
actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
[input_data.numpy()])
self.assertEqual(expected_value.numpy(), actual_value)
@ -102,19 +104,20 @@ class VariablesToConstantsTest(test.TestCase):
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
concrete_func = root.f.get_concrete_function(input_data)
input_func = root.f.get_concrete_function(input_data)
variable_graph_def = concrete_func.graph.as_graph_def()
variable_graph_def = input_func.graph.as_graph_def()
self.assertEqual(2, self._getNumVariables(variable_graph_def))
constant_graph_def = convert_to_constants.convert_variables_to_constants_v2(
concrete_func)
output_func = convert_to_constants.convert_variables_to_constants_v2(
input_func)
constant_graph_def = output_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(constant_graph_def))
self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
# Check value.
expected_value = root.f(input_data)
actual_value = self._evaluateGraphDef(constant_graph_def, concrete_func,
actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
[input_data.numpy()])
self.assertEqual(expected_value.numpy(), actual_value)
@ -131,19 +134,20 @@ class VariablesToConstantsTest(test.TestCase):
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save(root, save_dir, to_save)
saved_model = load(save_dir)
concrete_func = saved_model.signatures["serving_default"]
input_func = saved_model.signatures["serving_default"]
variable_graph_def = concrete_func.graph.as_graph_def()
variable_graph_def = input_func.graph.as_graph_def()
self.assertTrue(self._hasStatefulPartitionedCallOp(variable_graph_def))
constant_graph_def = convert_to_constants.convert_variables_to_constants_v2(
concrete_func)
output_func = convert_to_constants.convert_variables_to_constants_v2(
input_func)
constant_graph_def = output_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(constant_graph_def))
self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
# Check value.
expected_value = root.f(input_data)
actual_value = self._evaluateGraphDef(constant_graph_def, concrete_func,
actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
[input_data.numpy()])
self.assertEqual(expected_value.numpy(), actual_value)
@ -171,19 +175,48 @@ class VariablesToConstantsTest(test.TestCase):
input_data = constant_op.constant(1., shape=[1])
root = BasicModel()
concrete_func = root.add.get_concrete_function(input_data)
input_func = root.add.get_concrete_function(input_data)
variable_graph_def = concrete_func.graph.as_graph_def()
variable_graph_def = input_func.graph.as_graph_def()
self.assertEqual(1, self._getNumVariables(variable_graph_def))
constant_graph_def = convert_to_constants.convert_variables_to_constants_v2(
concrete_func)
output_func = convert_to_constants.convert_variables_to_constants_v2(
input_func)
constant_graph_def = output_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(constant_graph_def))
self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
# Check value.
expected_value = root.add(input_data)
actual_value = self._evaluateGraphDef(constant_graph_def, concrete_func,
actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
[input_data.numpy()])
self.assertEqual(expected_value.numpy(), actual_value)
@test_util.run_v2_only
def testConstructConcreteFunction(self):
input_data = constant_op.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
func = root.f.get_concrete_function(input_data)
input_func = convert_to_constants._construct_concrete_function(
func, func.graph.as_graph_def())
# Test if model has enough metadata to be frozen afterwards.
variable_graph_def = input_func.graph.as_graph_def()
self.assertEqual(2, self._getNumVariables(variable_graph_def))
output_func = convert_to_constants.convert_variables_to_constants_v2(
input_func)
constant_graph_def = output_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(constant_graph_def))
self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
# Check value.
expected_value = root.f(input_data)
actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
[input_data.numpy()])
self.assertEqual(expected_value.numpy(), actual_value)
@ -205,19 +238,20 @@ class VariablesToConstantsTest(test.TestCase):
def to_save(x):
return model(x)
concrete_func = to_save.get_concrete_function(input_data)
input_func = to_save.get_concrete_function(input_data)
variable_graph_def = concrete_func.graph.as_graph_def()
variable_graph_def = input_func.graph.as_graph_def()
self.assertEqual(2, self._getNumVariables(variable_graph_def))
constant_graph_def = convert_to_constants.convert_variables_to_constants_v2(
concrete_func)
output_func = convert_to_constants.convert_variables_to_constants_v2(
input_func)
constant_graph_def = output_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(constant_graph_def))
self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
# Check value.
expected_value = to_save(input_data)
actual_value = self._evaluateGraphDef(constant_graph_def, concrete_func,
actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
[input_data.numpy()])
self.assertEqual(expected_value.numpy(), actual_value)

View File

@ -237,14 +237,14 @@ def _process_single_batch(model,
raise ValueError('The model cannot be run '
'because it has no loss to optimize.')
if training:
if not model._collected_trainable_weights:
if not model.trainable_weights:
logging.warning('The list of trainable weights is empty. Make sure that'
' you are not setting model.trainable to False before '
'compiling the model.')
else:
grads = tape.gradient(total_loss, model._collected_trainable_weights)
grads = tape.gradient(total_loss, model.trainable_weights)
model.optimizer.apply_gradients(zip(grads,
model._collected_trainable_weights))
model.trainable_weights))
return outs, total_loss, output_losses, aggregated_output_losses, masks

View File

@ -24,6 +24,7 @@ from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import testing_utils
@ -34,6 +35,32 @@ from tensorflow.python.platform import test
class TrainingTest(keras_parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes()
def test_dynamic_model_has_trainable_weights(self):
if not context.executing_eagerly():
# Only test Eager modes, as Graph mode is not relevant for dynamic models.
return
class DynamicModel(keras.Model):
def __init__(self):
super(DynamicModel, self).__init__(dynamic=True)
self.dense = keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='ones')
def call(self, inputs):
return self.dense(inputs)
model = DynamicModel()
model.compile('rmsprop', 'mae')
hist = model.fit(np.zeros((1, 1)), np.zeros((1, 1)))
self.assertEqual(hist.history['loss'][-1], 1)
self.assertEqual(len(model.trainable_weights), 2)
loss = model.train_on_batch(np.zeros((1, 1)), np.zeros((1, 1)))
# The loss must have been updated if the trainable weights are taken into
# account during tracking.
self.assertLess(loss, 1)
@keras_parameterized.run_with_all_model_types(exclude_models='sequential')
@keras_parameterized.run_all_keras_modes
def test_model_methods_with_eager_tensors_multi_io(self):

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import abc
import os
import sys
import six
@ -327,32 +328,81 @@ def _create_monitored_session_with_worker_context(worker_context, # pylint: dis
if chief_only_hooks and worker_context.is_chief:
all_hooks.extend(chief_only_hooks)
# We need to call save or summary ops on all workers since these ops may
# contain collective ops, only running save ops on some workers would make
# collective ops hang. Therefore on those workers that don't need to actually
# write checkpoints or summaries, we let them write to a temp directory.
# pylint: disable=protected-access
if type(worker_context._strategy).__name__ in ('CollectiveAllReduceStrategy',
'MultiWorkerMirroredStrategy'):
if worker_context.task_type:
tmpdir = 'tmp_%s_%d' % (worker_context.task_type, worker_context.task_id)
else:
tmpdir = 'tmp'
if save_checkpoint_secs:
logging.warning('Collective ops may deadlock with '
'`save_checkpoints_secs` please use '
'`save_checkpoint_steps` instead. Clearing '
'`save_checkpoint_secs` and setting '
'`save_checkpoint_steps` to 1000 now.')
save_checkpoint_secs = None
save_checkpoint_steps = 1000
if save_summaries_secs:
logging.warning('Collective ops may run out of sync with'
'`save_summaries_secs`, please use '
'`save_summaries_steps` instead.')
else:
tmpdir = None
summary_dir = summary_dir or checkpoint_dir
if summary_dir and worker_context.should_save_summary:
if log_step_count_steps and log_step_count_steps > 0:
if summary_dir and log_step_count_steps and log_step_count_steps > 0:
if worker_context.should_save_summary:
all_hooks.append(
basic_session_run_hooks.StepCounterHook(
output_dir=summary_dir, every_n_steps=log_step_count_steps))
elif tmpdir:
all_hooks.append(
basic_session_run_hooks.StepCounterHook(
output_dir=os.path.join(summary_dir, tmpdir),
every_n_steps=log_step_count_steps))
if (save_summaries_steps and save_summaries_steps > 0) or (
save_summaries_secs and save_summaries_secs > 0):
if (((save_summaries_steps and save_summaries_steps > 0) or
(save_summaries_secs and save_summaries_secs > 0)) and summary_dir):
if worker_context.should_save_summary:
all_hooks.append(
basic_session_run_hooks.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
save_secs=save_summaries_secs,
output_dir=summary_dir))
if checkpoint_dir and worker_context.should_checkpoint:
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
save_checkpoint_steps and save_checkpoint_steps > 0):
elif tmpdir:
all_hooks.append(
basic_session_run_hooks.CheckpointSaverHook(
checkpoint_dir,
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold))
basic_session_run_hooks.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
save_secs=save_summaries_secs,
output_dir=os.path.join(summary_dir, tmpdir)))
if (((save_checkpoint_secs and save_checkpoint_secs > 0) or
(save_checkpoint_steps and save_checkpoint_steps > 0)) and
checkpoint_dir):
if worker_context.should_checkpoint:
all_hooks.append(
basic_session_run_hooks.CheckpointSaverHook(
checkpoint_dir,
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold))
elif tmpdir:
all_hooks.append(
basic_session_run_hooks.CheckpointSaverHook(
os.path.join(checkpoint_dir, tmpdir),
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold))
logging.info('all_hooks %r', all_hooks)
session_creator = worker_context.session_creator(
scaffold,
config=config,

View File

@ -32,6 +32,7 @@ from tensorflow.contrib.testing.python.framework import util_test
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import debug_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_coordinator
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -487,6 +488,32 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
checkpoint = checkpoint_management.latest_checkpoint(logdir)
self.assertIsNone(checkpoint)
def test_checkpoint_hook_enable_on_non_chief_with_collective_ops(self):
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
strategy.extended._is_chief = False
context = distribute_coordinator._WorkerContext(strategy, None, 'worker', 1)
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_checkpoint_steps=100,
log_step_count_steps=10) as session:
for _ in range(100):
session.run(new_gstep)
# No checkpoint is saved.
checkpoint = checkpoint_management.latest_checkpoint(logdir)
self.assertIsNone(checkpoint)
# But saved to a temporary directory.
checkpoint = checkpoint_management.latest_checkpoint(
os.path.join(logdir, 'tmp_worker_1'))
self.assertIsNotNone(checkpoint)
class StopAtNSession(monitored_session._WrappedSession):
"""A wrapped session that stops at the N-th call to _check_stop."""

View File

@ -1,115 +0,0 @@
path: "tensorflow.estimator.experimental.KMeans"
tf_class {
is_instance: "<class \'tensorflow_estimator.python.estimator.canned.kmeans.KMeansClustering\'>"
is_instance: "<class \'tensorflow_estimator.python.estimator.estimator.Estimator\'>"
is_instance: "<class \'tensorflow_estimator.python.estimator.estimator.EstimatorV2\'>"
is_instance: "<type \'object\'>"
member {
name: "ALL_DISTANCES"
mtype: "<type \'str\'>"
}
member {
name: "CLUSTER_CENTERS_VAR_NAME"
mtype: "<type \'str\'>"
}
member {
name: "CLUSTER_INDEX"
mtype: "<type \'str\'>"
}
member {
name: "COSINE_DISTANCE"
mtype: "<type \'str\'>"
}
member {
name: "KMEANS_PLUS_PLUS_INIT"
mtype: "<type \'str\'>"
}
member {
name: "RANDOM_INIT"
mtype: "<type \'str\'>"
}
member {
name: "SCORE"
mtype: "<type \'str\'>"
}
member {
name: "SQUARED_EUCLIDEAN_DISTANCE"
mtype: "<type \'str\'>"
}
member {
name: "config"
mtype: "<type \'property\'>"
}
member {
name: "model_dir"
mtype: "<type \'property\'>"
}
member {
name: "model_fn"
mtype: "<type \'property\'>"
}
member {
name: "params"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_clusters\', \'model_dir\', \'initial_clusters\', \'distance_metric\', \'seed\', \'use_mini_batch\', \'mini_batch_steps_per_iteration\', \'kmeans_plus_plus_num_retries\', \'relative_tolerance\', \'config\', \'feature_columns\'], varargs=None, keywords=None, defaults=[\'None\', \'random\', \'squared_euclidean\', \'None\', \'True\', \'1\', \'2\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "cluster_centers"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "eval_dir"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_export_all_saved_models"
argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "get_variable_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_variable_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "latest_checkpoint"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "predict"
argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "predict_cluster_index"
argspec: "args=[\'self\', \'input_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "score"
argspec: "args=[\'self\', \'input_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "train"
argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "transform"
argspec: "args=[\'self\', \'input_fn\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -4,10 +4,6 @@ tf_module {
name: "InMemoryEvaluatorHook"
mtype: "<type \'type\'>"
}
member {
name: "KMeans"
mtype: "<type \'type\'>"
}
member {
name: "LinearSDCA"
mtype: "<type \'type\'>"