Merge changes from github.
Change: 146918929
This commit is contained in:
parent
15ff7b7027
commit
639b4e71f5
README.mdRELEASE.mdconfigure
tensorflow
BUILDworkspace.bzl
contrib
BUILD__init__.py
cmake
factorization/python/ops
learn
BUILD
python/learn
losses/python/losses
nccl
rnn/python/kernel_tests
slim
sparsemax
core
BUILD
graph
kernels
BUILDconv_grad_input_ops.ccconv_ops.ccsegment_reduction_ops.ccsegment_reduction_ops.hsegment_reduction_ops_gpu.cu.ccsparse_matmul_op.ccxsmm_conv2d.ccxsmm_conv2d_test.cc
ops
platform/default
public
examples
how_tos/reading_data
image_retraining
multibox_detector
udacity
g3doc
api_docs/python
get_started
how_tos
resources
go/genop
java
python
framework
kernel_tests
ops
training
tensorboard
tensorflow.bzltools
benchmark
ci_build
compatibility
docker
docs
lib_package
pip_package
tfprof
@ -33,10 +33,11 @@ and discussion.**
|
|||||||
|
|
||||||
People who are a little more adventurous can also try our nightly binaries:
|
People who are a little more adventurous can also try our nightly binaries:
|
||||||
|
|
||||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||||
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||||
|
* [Android](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
|
||||||
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
|
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
|
||||||
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
|
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
|
||||||
|
|
||||||
|
112
RELEASE.md
112
RELEASE.md
@ -1,7 +1,20 @@
|
|||||||
# Changes since the last release
|
# Release 1.0.0
|
||||||
|
|
||||||
|
## Major Features and Improvements
|
||||||
|
* XLA (experimental): initial release of [XLA](https://www.tensorflow.org/versions/master/experimental/xla/), a domain-specific compiler for TensorFlow graphs, that targets CPUs and GPUs.
|
||||||
|
* TensorFlow Debugger (tfdbg): command-line interface and API.
|
||||||
|
* New python 3 docker images added.
|
||||||
|
* Made pip packages pypi compliant. TensorFlow can now be installed by `pip
|
||||||
|
install tensorflow` command.
|
||||||
|
* Several python API calls have been changed to resemble NumPy more closely.
|
||||||
|
* Android: person detection + tracking demo implementing Scalable Object
|
||||||
|
Detection using Deep Neural Networks.
|
||||||
|
* New (experimental) [Java API](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java).
|
||||||
|
* Add new Android image stylization demo based on "A Learned Representation For Artistic Style", and add YOLO object detector support.
|
||||||
|
|
||||||
## Breaking Changes to the API
|
## Breaking Changes to the API
|
||||||
|
To help you upgrade your existing TensorFlow Python code to match the API changes below, we have prepared a [conversion script](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/compatibility).
|
||||||
|
* TensorFlow/models have been moved to a separate github repository.
|
||||||
* Division and modulus operators (/, //, %) now match Python (flooring)
|
* Division and modulus operators (/, //, %) now match Python (flooring)
|
||||||
semantics. This applies to `tf.div` and `tf.mod` as well. To obtain forced
|
semantics. This applies to `tf.div` and `tf.mod` as well. To obtain forced
|
||||||
integer truncation based behaviors you can use `tf.truncatediv`
|
integer truncation based behaviors you can use `tf.truncatediv`
|
||||||
@ -51,16 +64,93 @@
|
|||||||
keywords. In particular we now match NumPy order as
|
keywords. In particular we now match NumPy order as
|
||||||
`tf.sparse_split(sp_input, num_split, axis)`. NOTE: we have temporarily
|
`tf.sparse_split(sp_input, num_split, axis)`. NOTE: we have temporarily
|
||||||
made `tf.sparse_split` require keyword arguments.
|
made `tf.sparse_split` require keyword arguments.
|
||||||
* Deprecated `tf.concat` operator. Please switch to use `tf.concat_v2` for now.
|
* `tf.concat` now takes arguments in reversed order and with different keywords. In particular we now match NumPy order as `tf.concat(values, axis, name)`.
|
||||||
In the Beta release, we will update `tf.concat` to match argument order of
|
* `tf.image.decode_jpeg` by default uses the faster DCT method, sacrificing
|
||||||
`tf.concat_v2.
|
|
||||||
* tf.image.decode_jpeg by default uses the faster DCT method, sacrificing
|
|
||||||
a little fidelity for improved speed. One can revert to the old
|
a little fidelity for improved speed. One can revert to the old
|
||||||
behavior by specifying the attribute dct_method='INTEGER_ACCURATE'.
|
behavior by specifying the attribute `dct_method='INTEGER_ACCURATE'`.
|
||||||
* `tf.complex_abs` has been removed from the Python interface. `tf.abs`
|
* `tf.complex_abs` has been removed from the Python interface. `tf.abs`
|
||||||
supports complex tensors and should be used instead.
|
supports complex tensors and should be used instead.
|
||||||
* In the C++ API (in tensorflow/cc), Input, Output, etc. have moved
|
* In the C++ API (in tensorflow/cc), Input, Output, etc. have moved
|
||||||
from the tensorflow::ops namespace to tensorflow.
|
from the tensorflow::ops namespace to tensorflow.
|
||||||
|
* Template.`var_scope` property renamed to `.variable_scope`
|
||||||
|
* SyncReplicasOptimizer is removed and SyncReplicasOptimizerV2 renamed to SyncReplicasOptimizer.
|
||||||
|
* `tf.zeros_initializer()` and `tf.ones_initializer()` now return a callable
|
||||||
|
that must be called with initializer arguments, in your code replace
|
||||||
|
`tf.zeros_initializer` with `tf.zeros_initializer()`.
|
||||||
|
* `SparseTensor.shape` has been renamed to `SparseTensor.dense_shape`. Same for
|
||||||
|
`SparseTensorValue.shape`.
|
||||||
|
* Replace tf.scalar_summary, tf.histogram_summary, tf.audio_summary, tf.image_summary with tf.summary.scalar, tf.summary.histogram, tf.summary.audio, tf.summary.image, respectively. The new summary ops take name rather than tag as their first argument, meaning summary ops now respect TensorFlow name scopes.
|
||||||
|
* Replace tf.train.SummaryWriter and tf.train.SummaryWriterCache with tf.summary.FileWriter and tf.summary.FileWriterCache.
|
||||||
|
* Removes RegisterShape from public API. Use C++ shape function registration
|
||||||
|
instead.
|
||||||
|
* Deprecated `_ref` dtypes from the python API.
|
||||||
|
* In the C++ API (in tensorflow/cc), Input, Output, etc. have moved
|
||||||
|
from the tensorflow::ops namespace to tensorflow.
|
||||||
|
* Change arg order for `{softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits` to be (labels, predictions), and force use of named args.
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
* New op: `parallel_stack`.
|
||||||
|
* Introducing common tf io compression options constants for
|
||||||
|
RecordReader/RecordWriter.
|
||||||
|
* Add `sparse_column_with_vocabulary_file`, to specify a feature column that
|
||||||
|
transform string features to IDs, where the mapping is defined by a vocabulary
|
||||||
|
file.
|
||||||
|
* Added `index_to_string_table` which returns a lookup table that maps indices to
|
||||||
|
strings.
|
||||||
|
* Add `string_to_index_table`, which returns a lookup table that matches strings
|
||||||
|
to indices.
|
||||||
|
* Add a `ParallelForWithWorkerId` function.
|
||||||
|
* Add `string_to_index_table`, which returns a lookup table that matches strings
|
||||||
|
to indices.
|
||||||
|
* Support restore session from checkpoint files in v2 in `contrib/session_bundle`.
|
||||||
|
* Added a tf.contrib.image.rotate function for arbitrary angles.
|
||||||
|
* Added `tf.contrib.framework.filter_variables` as a convenience function to
|
||||||
|
filter lists of variables based on regular expressions.
|
||||||
|
* `make_template()` takes an optional `custom_getter_ param`.
|
||||||
|
* Added comment about how existing directories are handled by
|
||||||
|
`recursive_create_dir`.
|
||||||
|
* Added an op for QR factorizations.
|
||||||
|
* Divides and mods in Python API now use flooring (Python) semantics.
|
||||||
|
* Android: pre-built libs are now built nightly.
|
||||||
|
* Android: cmake/gradle build for TensorFlow Inference library under
|
||||||
|
`contrib/android/cmake`
|
||||||
|
* Android: Much more robust Session initialization code.
|
||||||
|
* Android: TF stats now exposed directly in demo and log when debug mode is
|
||||||
|
active
|
||||||
|
* Android: new/better README.md documentation
|
||||||
|
* saved_model is available as `tf.saved_model`.
|
||||||
|
* Empty op is now stateful.
|
||||||
|
* Improve speed of scatter_update on the cpu for ASSIGN operations.
|
||||||
|
* Change `reduce_join` to treat `reduction_indices` in the same way as other `reduce_` ops.
|
||||||
|
* Move `TensorForestEstimator` to `contrib/tensor_forest`.
|
||||||
|
* Enable compiler optimizations by default and allow configuration in configure.
|
||||||
|
* `tf.divide` now honors the name field.
|
||||||
|
* Make metrics weight broadcasting more strict.
|
||||||
|
* Add new queue-like `StagingArea` and new ops: `stage` and `unstage`.
|
||||||
|
|
||||||
|
## Thanks to our Contributors
|
||||||
|
|
||||||
|
This release contains contributions from many people at Google, as well as:
|
||||||
|
|
||||||
|
Aaron Hu, Abhishek Aggarwal, Adam Michael, Adriano Carmezim, @AfirSraftGarrier,
|
||||||
|
Alexander Novikov, Alexander Rosenberg Johansen, Andrew Gibiansky, Andrew Hundt,
|
||||||
|
Anish Shah, Anton Loss, @b0noI, @BoyuanJiang, Carl Thomé, Chad Kennedy, Comic
|
||||||
|
Chang, Connor Braa, Daniel N. Lang, Daniel Trebbien,
|
||||||
|
@danielgordon10, Darcy Liu, Darren Garvey, Dmitri Lapin, Eron Wright, Evan
|
||||||
|
Cofer, Fabrizio Milo, Finbarr Timbers, Franck Dernoncourt, Garrett Smith,
|
||||||
|
@guschmue, Hao Wei, Henrik Holst, Huazuo Gao, @Ian, @Issac, Jacob Israel,
|
||||||
|
Jangsoo Park, Jin Kim, Jingtian Peng, John Pope, Kye Bostelmann, Liangliang He,
|
||||||
|
Ling Zhang, Luheng He, Luke Iwanski, @lvli, Michael Basilyan, Mihir Patel,
|
||||||
|
Mikalai Drabovich, Morten Just, @newge, Nick Butlin, Nishant Shukla,
|
||||||
|
Pengfei Ni, Przemyslaw Tredak, @rasbt, @Ronny, Rudolf Rosa, @RustingSword,
|
||||||
|
Sam Abrahams, Sam Putnam, @SeongAhJo, Shi Jiaxin, @skavulya, Steffen MüLler,
|
||||||
|
@TheUSER123, @tiriplicamihai, @vhasanov, Victor Costan, Vit Stepanovs,
|
||||||
|
Wangda Tan, Wenjian Huang, Xingdong Zuo, Yaroslav Bulatov, Yota Toyama,
|
||||||
|
Yuan (Terry) Tang, Yuxin Wu
|
||||||
|
|
||||||
|
We are also grateful to all who filed issues or helped resolve them, asked and
|
||||||
|
answered questions, and were part of inspiring discussions.
|
||||||
|
|
||||||
|
|
||||||
# Release 0.12.0
|
# Release 0.12.0
|
||||||
|
|
||||||
@ -100,15 +190,15 @@
|
|||||||
## Breaking Changes to the API
|
## Breaking Changes to the API
|
||||||
|
|
||||||
* `BusAdjacency` enum replaced with a protocol buffer `DeviceLocality`. PCI bus
|
* `BusAdjacency` enum replaced with a protocol buffer `DeviceLocality`. PCI bus
|
||||||
indexing now starts from 1 instead of 0, and bus_id==0 is used where
|
indexing now starts from 1 instead of 0, and `bus_id==0` is used where
|
||||||
previously BUS_ANY was used.
|
previously `BUS_ANY` was used.
|
||||||
* `Env::FileExists` and `FileSystem::FileExists` now return a tensorflow::Status
|
* `Env::FileExists` and `FileSystem::FileExists` now return a tensorflow::Status
|
||||||
intead of a bool. Any callers to this function can be converted to a bool
|
intead of a bool. Any callers to this function can be converted to a bool
|
||||||
by adding .ok() to the call.
|
by adding .ok() to the call.
|
||||||
* The C API type `TF_SessionWithGraph` has been renamed to `TF_Session`,
|
* The C API type `TF_SessionWithGraph` has been renamed to `TF_Session`,
|
||||||
indicating its preferred use in language bindings for TensorFlow.
|
indicating its preferred use in language bindings for TensorFlow.
|
||||||
What was previously `TF_Session` has been renamed to `TF_DeprecatedSession`.
|
What was previously `TF_Session` has been renamed to `TF_DeprecatedSession`.
|
||||||
* Renamed TF_Port to TF_Output in the C API.
|
* Renamed `TF_Port` to `TF_Output` in the C API.
|
||||||
* Removes RegisterShape from public API. Use C++ shape function registration instead.
|
* Removes RegisterShape from public API. Use C++ shape function registration instead.
|
||||||
indexing now starts from 1 instead of 0, and `bus_id==0` is used where
|
indexing now starts from 1 instead of 0, and `bus_id==0` is used where
|
||||||
previously `BUS_ANY` was used.
|
previously `BUS_ANY` was used.
|
||||||
@ -143,7 +233,7 @@
|
|||||||
`tf.global_variables_initializer` respectively.
|
`tf.global_variables_initializer` respectively.
|
||||||
* `tf.zeros_initializer()` and `tf.ones_initializer()` now return a callable
|
* `tf.zeros_initializer()` and `tf.ones_initializer()` now return a callable
|
||||||
that must be called with initializer arguments, in your code replace
|
that must be called with initializer arguments, in your code replace
|
||||||
tf.zeros_initializer with tf.zeros_initializer()
|
`tf.zeros_initializer` with `tf.zeros_initializer()`
|
||||||
|
|
||||||
## Bug Fixes and Other Changes
|
## Bug Fixes and Other Changes
|
||||||
|
|
||||||
|
116
configure
vendored
116
configure
vendored
@ -41,7 +41,8 @@ function bazel_clean_and_fetch() {
|
|||||||
if ! is_windows; then
|
if ! is_windows; then
|
||||||
bazel clean --expunge
|
bazel clean --expunge
|
||||||
fi
|
fi
|
||||||
bazel fetch "//tensorflow/... -//tensorflow/examples/android/..."
|
bazel fetch "//tensorflow/... -//tensorflow/contrib/nccl/... \
|
||||||
|
-//tensorflow/examples/android/..."
|
||||||
}
|
}
|
||||||
|
|
||||||
# Delete any leftover BUILD files from the Makefile build, which would interfere
|
# Delete any leftover BUILD files from the Makefile build, which would interfere
|
||||||
@ -73,10 +74,77 @@ while true; do
|
|||||||
# Retry
|
# Retry
|
||||||
done
|
done
|
||||||
|
|
||||||
|
## Set up MKL related environment settings
|
||||||
|
if false; then # Disable building with MKL for now
|
||||||
|
while [ "$TF_NEED_MKL" == "" ]; do
|
||||||
|
fromuser=""
|
||||||
|
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
|
||||||
|
fromuser="1"
|
||||||
|
case $INPUT in
|
||||||
|
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
|
||||||
|
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||||
|
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||||
|
* ) echo "Invalid selection: " $INPUT;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
OSNAME=`uname -s`
|
||||||
|
|
||||||
|
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
|
||||||
|
DST=`dirname $0`
|
||||||
|
ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170110.tgz
|
||||||
|
GITHUB_RELEASE_TAG=v0.3
|
||||||
|
MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME"
|
||||||
|
if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then
|
||||||
|
wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL
|
||||||
|
fi
|
||||||
|
tar -xzf $DST/third_party/mkl/$ARCHIVE_BASENAME -C $DST/third_party/mkl/
|
||||||
|
extracted_dir_name="${ARCHIVE_BASENAME%.*}"
|
||||||
|
MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name
|
||||||
|
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
|
||||||
|
|
||||||
|
if [ "$OSNAME" == "Linux" ]; then
|
||||||
|
# Full MKL configuration
|
||||||
|
MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version?
|
||||||
|
MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION?
|
||||||
|
|
||||||
|
# MKL-ML configuration
|
||||||
|
MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version?
|
||||||
|
MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION?
|
||||||
|
elif [ "$OSNAME" == "Darwin" ]; then
|
||||||
|
echo "Darwin is unsupported yet";
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then
|
||||||
|
ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/
|
||||||
|
ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/
|
||||||
|
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
|
||||||
|
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
|
||||||
|
else
|
||||||
|
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist";
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$fromuser" ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat > third_party/mkl/mkl.config <<EOF
|
||||||
|
# MKL_INSTALL_PATH refers to the location of MKL root folder. The MKL header and library
|
||||||
|
# files can be either in this directory, or under include/ and lib64/
|
||||||
|
MKL_INSTALL_PATH=$MKL_INSTALL_PATH
|
||||||
|
EOF
|
||||||
|
|
||||||
|
fi # TF_NEED_MKL
|
||||||
|
################## MKL
|
||||||
|
fi # Disable building with MKL for now
|
||||||
|
|
||||||
## Set up architecture-dependent optimization flags.
|
## Set up architecture-dependent optimization flags.
|
||||||
if [ -z "$CC_OPT_FLAGS" ]; then
|
if [ -z "$CC_OPT_FLAGS" ]; then
|
||||||
default_cc_opt_flags="-march=native"
|
default_cc_opt_flags="-march=native"
|
||||||
read -p "Please specify optimization flags to use during compilation [Default is $default_cc_opt_flags]: " CC_OPT_FLAGS
|
read -p "Please specify optimization flags to use during compilation when bazel option "\
|
||||||
|
"\"--config=opt\" is specified [Default is $default_cc_opt_flags]: " CC_OPT_FLAGS
|
||||||
if [ -z "$CC_OPT_FLAGS" ]; then
|
if [ -z "$CC_OPT_FLAGS" ]; then
|
||||||
CC_OPT_FLAGS=$default_cc_opt_flags
|
CC_OPT_FLAGS=$default_cc_opt_flags
|
||||||
fi
|
fi
|
||||||
@ -328,46 +396,8 @@ while true; do
|
|||||||
|
|
||||||
if [[ -z "$TF_CUDNN_VERSION" ]]; then
|
if [[ -z "$TF_CUDNN_VERSION" ]]; then
|
||||||
TF_CUDNN_EXT=""
|
TF_CUDNN_EXT=""
|
||||||
cudnn_lib_path=""
|
|
||||||
cudnn_alt_lib_path=""
|
|
||||||
if is_windows; then
|
|
||||||
cudnn_lib_path="${CUDNN_INSTALL_PATH}/lib/x64/cudnn.lib"
|
|
||||||
cudnn_alt_lib_path="${CUDNN_INSTALL_PATH}/lib/x64/cudnn.lib"
|
|
||||||
elif is_linux; then
|
|
||||||
cudnn_lib_path="${CUDNN_INSTALL_PATH}/lib64/libcudnn.so"
|
|
||||||
cudnn_alt_lib_path="${CUDNN_INSTALL_PATH}/libcudnn.so"
|
|
||||||
elif is_macos; then
|
|
||||||
cudnn_lib_path="${CUDNN_INSTALL_PATH}/lib/libcudnn.dylib"
|
|
||||||
cudnn_alt_lib_path="${CUDNN_INSTALL_PATH}/libcudnn.dylib"
|
|
||||||
fi
|
|
||||||
# Resolve to the SONAME of the symlink. Use readlink without -f
|
|
||||||
# to resolve exactly once to the SONAME. E.g, libcudnn.so ->
|
|
||||||
# libcudnn.so.4.
|
|
||||||
# If the path is not a symlink, readlink will exit with an error code, so
|
|
||||||
# in that case, we return the path itself.
|
|
||||||
if [ -f "$cudnn_lib_path" ]; then
|
|
||||||
REALVAL=`readlink "${cudnn_lib_path}" || echo "${cudnn_lib_path}"`
|
|
||||||
else
|
|
||||||
REALVAL=`readlink "${cudnn_alt_lib_path}" || echo "${cudnn_alt_lib_path}"`
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Extract the version of the SONAME, if it was indeed symlinked to
|
|
||||||
# the SONAME version of the file.
|
|
||||||
if [[ "$REALVAL" =~ .so[.]+([0-9]*) ]]; then
|
|
||||||
TF_CUDNN_EXT="."${BASH_REMATCH[1]}
|
|
||||||
TF_CUDNN_VERSION=${BASH_REMATCH[1]}
|
|
||||||
echo "libcudnn.so resolves to libcudnn${TF_CUDNN_EXT}"
|
|
||||||
elif [[ "$REALVAL" =~ ([0-9]*).dylib ]]; then
|
|
||||||
TF_CUDNN_EXT=${BASH_REMATCH[1]}".dylib"
|
|
||||||
TF_CUDNN_VERSION=${BASH_REMATCH[1]}
|
|
||||||
echo "libcudnn.dylib resolves to libcudnn${TF_CUDNN_EXT}"
|
|
||||||
fi
|
|
||||||
else
|
else
|
||||||
if is_macos; then
|
TF_CUDNN_EXT=".$TF_CUDNN_VERSION"
|
||||||
TF_CUDNN_EXT=".${TF_CUDNN_VERSION}.dylib"
|
|
||||||
else
|
|
||||||
TF_CUDNN_EXT=".$TF_CUDNN_VERSION"
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if is_windows; then
|
if is_windows; then
|
||||||
@ -377,8 +407,8 @@ while true; do
|
|||||||
CUDA_DNN_LIB_PATH="lib64/libcudnn.so${TF_CUDNN_EXT}"
|
CUDA_DNN_LIB_PATH="lib64/libcudnn.so${TF_CUDNN_EXT}"
|
||||||
CUDA_DNN_LIB_ALT_PATH="libcudnn.so${TF_CUDNN_EXT}"
|
CUDA_DNN_LIB_ALT_PATH="libcudnn.so${TF_CUDNN_EXT}"
|
||||||
elif is_macos; then
|
elif is_macos; then
|
||||||
CUDA_DNN_LIB_PATH="lib/libcudnn${TF_CUDNN_EXT}"
|
CUDA_DNN_LIB_PATH="lib/libcudnn${TF_CUDNN_EXT}.dylib"
|
||||||
CUDA_DNN_LIB_ALT_PATH="libcudnn${TF_CUDNN_EXT}"
|
CUDA_DNN_LIB_ALT_PATH="libcudnn${TF_CUDNN_EXT}.dylib"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" -o -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then
|
if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" -o -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then
|
||||||
|
@ -171,6 +171,7 @@ filegroup(
|
|||||||
"//tensorflow/contrib/slim/python/slim/data:all_files",
|
"//tensorflow/contrib/slim/python/slim/data:all_files",
|
||||||
"//tensorflow/contrib/slim/python/slim/nets:all_files",
|
"//tensorflow/contrib/slim/python/slim/nets:all_files",
|
||||||
"//tensorflow/contrib/solvers:all_files",
|
"//tensorflow/contrib/solvers:all_files",
|
||||||
|
"//tensorflow/contrib/sparsemax:all_files",
|
||||||
"//tensorflow/contrib/specs:all_files",
|
"//tensorflow/contrib/specs:all_files",
|
||||||
"//tensorflow/contrib/stat_summarizer:all_files",
|
"//tensorflow/contrib/stat_summarizer:all_files",
|
||||||
"//tensorflow/contrib/tensor_forest:all_files",
|
"//tensorflow/contrib/tensor_forest:all_files",
|
||||||
@ -246,6 +247,20 @@ filegroup(
|
|||||||
visibility = [":__subpackages__"],
|
visibility = [":__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#load(
|
||||||
|
# "//third_party/mkl:build_defs.bzl",
|
||||||
|
# "if_mkl",
|
||||||
|
#)
|
||||||
|
|
||||||
|
#filegroup(
|
||||||
|
# name = "intel_binary_blob",
|
||||||
|
# data = if_mkl(
|
||||||
|
# [
|
||||||
|
# "//third_party/mkl:intel_binary_blob",
|
||||||
|
# ],
|
||||||
|
# ),
|
||||||
|
#)
|
||||||
|
|
||||||
# -------------------------------------------
|
# -------------------------------------------
|
||||||
# New rules should be added above this target.
|
# New rules should be added above this target.
|
||||||
# -------------------------------------------
|
# -------------------------------------------
|
||||||
|
@ -47,6 +47,7 @@ py_library(
|
|||||||
"//tensorflow/contrib/slim",
|
"//tensorflow/contrib/slim",
|
||||||
"//tensorflow/contrib/slim:nets",
|
"//tensorflow/contrib/slim:nets",
|
||||||
"//tensorflow/contrib/solvers:solvers_py",
|
"//tensorflow/contrib/solvers:solvers_py",
|
||||||
|
"//tensorflow/contrib/sparsemax:sparsemax_py",
|
||||||
"//tensorflow/contrib/specs",
|
"//tensorflow/contrib/specs",
|
||||||
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
|
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
|
||||||
"//tensorflow/contrib/tensor_forest:init_py",
|
"//tensorflow/contrib/tensor_forest:init_py",
|
||||||
|
@ -49,6 +49,7 @@ from tensorflow.contrib import rnn
|
|||||||
from tensorflow.contrib import seq2seq
|
from tensorflow.contrib import seq2seq
|
||||||
from tensorflow.contrib import slim
|
from tensorflow.contrib import slim
|
||||||
from tensorflow.contrib import solvers
|
from tensorflow.contrib import solvers
|
||||||
|
from tensorflow.contrib import sparsemax
|
||||||
from tensorflow.contrib import stat_summarizer
|
from tensorflow.contrib import stat_summarizer
|
||||||
from tensorflow.contrib import tensor_forest
|
from tensorflow.contrib import tensor_forest
|
||||||
from tensorflow.contrib import tensorboard
|
from tensorflow.contrib import tensorboard
|
||||||
|
@ -170,7 +170,8 @@ if (tensorflow_ENABLE_GPU)
|
|||||||
|
|
||||||
# add cudnn
|
# add cudnn
|
||||||
include_directories(${CUDNN_HOME})
|
include_directories(${CUDNN_HOME})
|
||||||
set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDNN_HOME}/lib/x64/cudnn.lib)
|
set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES}
|
||||||
|
${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib)
|
||||||
|
|
||||||
# create cuda_config.h
|
# create cuda_config.h
|
||||||
FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h
|
FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h
|
||||||
@ -179,6 +180,7 @@ if (tensorflow_ENABLE_GPU)
|
|||||||
"#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n"
|
"#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n"
|
||||||
"#define TF_CUDA_VERSION \"64_80\"\n"
|
"#define TF_CUDA_VERSION \"64_80\"\n"
|
||||||
"#define TF_CUDNN_VERSION \"64_5\"\n"
|
"#define TF_CUDNN_VERSION \"64_5\"\n"
|
||||||
|
"#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n"
|
||||||
"#endif // CUDA_CUDA_CONFIG_H_\n"
|
"#endif // CUDA_CUDA_CONFIG_H_\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names})
|
|||||||
COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal}
|
COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal}
|
||||||
DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir
|
DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h)
|
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h)
|
||||||
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc)
|
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc)
|
||||||
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h)
|
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h)
|
||||||
@ -79,6 +79,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names})
|
|||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
# tf_cc_ops library
|
# tf_cc_ops library
|
||||||
########################################################
|
########################################################
|
||||||
|
@ -372,6 +372,9 @@ add_python_module("tensorflow/contrib/slim/python/slim/nets")
|
|||||||
add_python_module("tensorflow/contrib/solvers")
|
add_python_module("tensorflow/contrib/solvers")
|
||||||
add_python_module("tensorflow/contrib/solvers/python")
|
add_python_module("tensorflow/contrib/solvers/python")
|
||||||
add_python_module("tensorflow/contrib/solvers/python/ops")
|
add_python_module("tensorflow/contrib/solvers/python/ops")
|
||||||
|
add_python_module("tensorflow/contrib/sparsemax")
|
||||||
|
add_python_module("tensorflow/contrib/sparsemax/python")
|
||||||
|
add_python_module("tensorflow/contrib/sparsemax/python/ops")
|
||||||
add_python_module("tensorflow/contrib/specs")
|
add_python_module("tensorflow/contrib/specs")
|
||||||
add_python_module("tensorflow/contrib/specs/python")
|
add_python_module("tensorflow/contrib/specs/python")
|
||||||
add_python_module("tensorflow/contrib/stat_summarizer")
|
add_python_module("tensorflow/contrib/stat_summarizer")
|
||||||
|
@ -102,7 +102,12 @@ class GMM(estimator.Estimator):
|
|||||||
results = self.evaluate(input_fn=input_fn, batch_size=batch_size,
|
results = self.evaluate(input_fn=input_fn, batch_size=batch_size,
|
||||||
steps=steps)
|
steps=steps)
|
||||||
return np.sum(results[GMM.SCORES])
|
return np.sum(results[GMM.SCORES])
|
||||||
|
|
||||||
|
def weights(self):
|
||||||
|
"""Returns the cluster weights."""
|
||||||
|
return checkpoint_utils.load_variable(
|
||||||
|
self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
|
||||||
|
|
||||||
def clusters(self):
|
def clusters(self):
|
||||||
"""Returns cluster centers."""
|
"""Returns cluster centers."""
|
||||||
clusters = checkpoint_utils.load_variable(
|
clusters = checkpoint_utils.load_variable(
|
||||||
|
@ -92,6 +92,7 @@ def _init_clusters_random(data, num_clusters, random_seed):
|
|||||||
|
|
||||||
class GmmAlgorithm(object):
|
class GmmAlgorithm(object):
|
||||||
"""Tensorflow Gaussian mixture model clustering class."""
|
"""Tensorflow Gaussian mixture model clustering class."""
|
||||||
|
CLUSTERS_WEIGHT = 'alphas'
|
||||||
CLUSTERS_VARIABLE = 'clusters'
|
CLUSTERS_VARIABLE = 'clusters'
|
||||||
CLUSTERS_COVS_VARIABLE = 'clusters_covs'
|
CLUSTERS_COVS_VARIABLE = 'clusters_covs'
|
||||||
|
|
||||||
@ -187,11 +188,13 @@ class GmmAlgorithm(object):
|
|||||||
array_ops.expand_dims(array_ops.diag_part(cov), 0),
|
array_ops.expand_dims(array_ops.diag_part(cov), 0),
|
||||||
[self._num_classes, 1])
|
[self._num_classes, 1])
|
||||||
self._covs = variables.Variable(
|
self._covs = variables.Variable(
|
||||||
covs, name='clusters_covs', validate_shape=False)
|
covs, name=self.CLUSTERS_COVS_VARIABLE, validate_shape=False)
|
||||||
# Mixture weights, representing the probability that a randomly
|
# Mixture weights, representing the probability that a randomly
|
||||||
# selected unobservable data (in EM terms) was generated by component k.
|
# selected unobservable data (in EM terms) was generated by component k.
|
||||||
self._alpha = variables.Variable(
|
self._alpha = variables.Variable(
|
||||||
array_ops.tile([1.0 / self._num_classes], [self._num_classes]))
|
array_ops.tile([1.0 / self._num_classes], [self._num_classes]),
|
||||||
|
name=self.CLUSTERS_WEIGHT,
|
||||||
|
validate_shape=False)
|
||||||
|
|
||||||
def training_ops(self):
|
def training_ops(self):
|
||||||
"""Returns the training operation."""
|
"""Returns the training operation."""
|
||||||
|
@ -109,6 +109,16 @@ class GMMTest(test.TestCase):
|
|||||||
np.linalg.inv(covs[assignments[r]])), points[r, :] -
|
np.linalg.inv(covs[assignments[r]])), points[r, :] -
|
||||||
means[assignments[r]])))
|
means[assignments[r]])))
|
||||||
return (points, assignments, scores)
|
return (points, assignments, scores)
|
||||||
|
|
||||||
|
def test_weights(self):
|
||||||
|
"""Tests the shape of the weights."""
|
||||||
|
gmm = gmm_lib.GMM(self.num_centers,
|
||||||
|
initial_clusters=self.initial_means,
|
||||||
|
random_seed=4,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(input_fn=self.input_fn(), steps=0)
|
||||||
|
weights = gmm.weights()
|
||||||
|
self.assertAllEqual(list(weights.shape), [self.num_centers])
|
||||||
|
|
||||||
def test_clusters(self):
|
def test_clusters(self):
|
||||||
"""Tests the shape of the clusters."""
|
"""Tests the shape of the clusters."""
|
||||||
|
@ -480,6 +480,7 @@ py_test(
|
|||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["python/learn/estimators/estimator_test.py"],
|
srcs = ["python/learn/estimators/estimator_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["manual"],
|
||||||
deps = [
|
deps = [
|
||||||
":learn",
|
":learn",
|
||||||
"//tensorflow/contrib/framework:framework_py",
|
"//tensorflow/contrib/framework:framework_py",
|
||||||
|
@ -191,6 +191,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
|
|||||||
if not dnn_feature_columns:
|
if not dnn_feature_columns:
|
||||||
dnn_logits = None
|
dnn_logits = None
|
||||||
else:
|
else:
|
||||||
|
if not dnn_hidden_units:
|
||||||
|
raise ValueError(
|
||||||
|
"dnn_hidden_units must be defined when dnn_feature_columns is specified.")
|
||||||
dnn_partitioner = (
|
dnn_partitioner = (
|
||||||
partitioned_variables.min_max_variable_partitioner(
|
partitioned_variables.min_max_variable_partitioner(
|
||||||
max_partitions=num_ps_replicas))
|
max_partitions=num_ps_replicas))
|
||||||
|
@ -241,6 +241,26 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
|
|||||||
dnn_feature_columns=None,
|
dnn_feature_columns=None,
|
||||||
dnn_hidden_units=[3, 3])
|
dnn_hidden_units=[3, 3])
|
||||||
|
|
||||||
|
def testNoDnnHiddenUnits(self):
|
||||||
|
def _input_fn():
|
||||||
|
return {
|
||||||
|
'age':
|
||||||
|
constant_op.constant([1]),
|
||||||
|
'language':
|
||||||
|
sparse_tensor.SparseTensor(
|
||||||
|
values=['english'], indices=[[0, 0]], dense_shape=[1, 1])
|
||||||
|
}, constant_op.constant([[1]])
|
||||||
|
|
||||||
|
language = feature_column.sparse_column_with_hash_bucket('language', 100)
|
||||||
|
age = feature_column.real_valued_column('age')
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
'dnn_hidden_units must be defined when dnn_feature_columns is specified'):
|
||||||
|
classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
|
||||||
|
dnn_feature_columns=[age, language])
|
||||||
|
classifier.fit(input_fn=_input_fn, steps=2)
|
||||||
|
|
||||||
def testEmbeddingMultiplier(self):
|
def testEmbeddingMultiplier(self):
|
||||||
embedding_language = feature_column.embedding_column(
|
embedding_language = feature_column.embedding_column(
|
||||||
feature_column.sparse_column_with_hash_bucket('language', 10),
|
feature_column.sparse_column_with_hash_bucket('language', 10),
|
||||||
|
@ -274,10 +274,10 @@ def bidirectional_rnn(cell_fw,
|
|||||||
output_bw = _reverse_seq(tmp, sequence_length)
|
output_bw = _reverse_seq(tmp, sequence_length)
|
||||||
# Concat each of the forward/backward outputs
|
# Concat each of the forward/backward outputs
|
||||||
outputs = [
|
outputs = [
|
||||||
array_ops_.concat_v2([fw, bw], 1) for fw, bw in zip(output_fw, output_bw)
|
array_ops_.concat([fw, bw], 1) for fw, bw in zip(output_fw, output_bw)
|
||||||
]
|
]
|
||||||
|
|
||||||
return outputs, array_ops_.concat_v2([state_fw, state_bw], 1)
|
return outputs, array_ops_.concat([state_fw, state_bw], 1)
|
||||||
|
|
||||||
|
|
||||||
# End of TensorFlow 0.7
|
# End of TensorFlow 0.7
|
||||||
|
@ -59,7 +59,7 @@ def embedding_lookup(params, ids, name='embedding_lookup'):
|
|||||||
ids_flat = array_ops_.reshape(
|
ids_flat = array_ops_.reshape(
|
||||||
ids, math_ops.reduce_prod(shape, keep_dims=True))
|
ids, math_ops.reduce_prod(shape, keep_dims=True))
|
||||||
embeds_flat = nn.embedding_lookup(params, ids_flat, name)
|
embeds_flat = nn.embedding_lookup(params, ids_flat, name)
|
||||||
embed_shape = array_ops_.concat_v2([shape, [-1]], 0)
|
embed_shape = array_ops_.concat([shape, [-1]], 0)
|
||||||
embeds = array_ops_.reshape(embeds_flat, embed_shape)
|
embeds = array_ops_.reshape(embeds_flat, embed_shape)
|
||||||
embeds.set_shape(ids.get_shape().concatenate(params.get_shape()[1:]))
|
embeds.set_shape(ids.get_shape().concatenate(params.get_shape()[1:]))
|
||||||
return embeds
|
return embeds
|
||||||
|
@ -427,7 +427,6 @@ def sparse_softmax_cross_entropy(logits, labels, weights=1.0, scope=None):
|
|||||||
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
|
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
|
||||||
[logits, labels, weights]) as scope:
|
[logits, labels, weights]) as scope:
|
||||||
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
|
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
|
||||||
weights = array_ops.squeeze(weights)
|
|
||||||
|
|
||||||
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
|
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
|
@ -243,6 +243,34 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
|
|||||||
expected_value = 400.0 * label_smoothing / 3.0
|
expected_value = 400.0 * label_smoothing / 3.0
|
||||||
self.assertAlmostEqual(loss.eval(), expected_value, 3)
|
self.assertAlmostEqual(loss.eval(), expected_value, 3)
|
||||||
|
|
||||||
|
def testLossWithDynamicallyShapedWeights1D(self):
|
||||||
|
logits = constant_op.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = constant_op.constant([[0, 0, 1],
|
||||||
|
[1, 0, 0],
|
||||||
|
[0, 1, 0]])
|
||||||
|
weights = [2.3, 2.4, 2.5]
|
||||||
|
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
|
||||||
|
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
loss = sess.run(loss, {weights_placeholder: weights})
|
||||||
|
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
|
||||||
|
|
||||||
|
def testLossWithDynamicallyShapedWeights2D(self):
|
||||||
|
logits = constant_op.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = constant_op.constant([[0, 0, 1],
|
||||||
|
[1, 0, 0],
|
||||||
|
[0, 1, 0]])
|
||||||
|
weights = [[2.3], [2.4], [2.5]]
|
||||||
|
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None])
|
||||||
|
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
loss = sess.run(loss, {weights_placeholder: weights})
|
||||||
|
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
|
||||||
|
|
||||||
|
|
||||||
class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||||
|
|
||||||
@ -445,6 +473,30 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
|||||||
loss_ops.sparse_softmax_cross_entropy(
|
loss_ops.sparse_softmax_cross_entropy(
|
||||||
logits, labels, weights=weights).eval()
|
logits, labels, weights=weights).eval()
|
||||||
|
|
||||||
|
def testLossWithDynamicallyShapedWeights1D(self):
|
||||||
|
logits = constant_op.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = constant_op.constant([2, 0, 1])
|
||||||
|
weights = [2.3, 2.4, 2.5]
|
||||||
|
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
|
||||||
|
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
loss = sess.run(loss, {weights_placeholder: weights})
|
||||||
|
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
|
||||||
|
|
||||||
|
def testLossWithDynamicallyShapedWeights2D(self):
|
||||||
|
logits = constant_op.constant([[10.0, 0.0, 0.0],
|
||||||
|
[0.0, 10.0, 0.0],
|
||||||
|
[0.0, 0.0, 10.0]])
|
||||||
|
labels = constant_op.constant([2, 0, 1])
|
||||||
|
weights = [[2.3], [2.4], [2.5]]
|
||||||
|
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None])
|
||||||
|
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
loss = sess.run(loss, {weights_placeholder: weights})
|
||||||
|
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
|
||||||
|
|
||||||
|
|
||||||
class SigmoidCrossEntropyLossTest(test.TestCase):
|
class SigmoidCrossEntropyLossTest(test.TestCase):
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ cuda_py_test(
|
|||||||
|
|
||||||
tf_cuda_cc_test(
|
tf_cuda_cc_test(
|
||||||
name = "nccl_manager_test",
|
name = "nccl_manager_test",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = if_cuda(
|
srcs = if_cuda(
|
||||||
[
|
[
|
||||||
"kernels/nccl_manager.cc",
|
"kernels/nccl_manager.cc",
|
||||||
|
@ -95,7 +95,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
input_size = 4
|
input_size = 4
|
||||||
feature_size = 2
|
feature_size = 2
|
||||||
frequency_skip = 1
|
frequency_skip = 1
|
||||||
num_shifts = (input_size - feature_size) / frequency_skip + 1
|
num_shifts = (input_size - feature_size) // frequency_skip + 1
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([batch_size, input_size])
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
|
@ -880,7 +880,7 @@ names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
|
|||||||
|
|
||||||
# Create the summary ops such that they also print out to std output:
|
# Create the summary ops such that they also print out to std output:
|
||||||
summary_ops = []
|
summary_ops = []
|
||||||
for metric_name, metric_value in metrics_to_values.iteritems():
|
for metric_name, metric_value in names_to_values.iteritems():
|
||||||
op = tf.summary.scalar(metric_name, metric_value)
|
op = tf.summary.scalar(metric_name, metric_value)
|
||||||
op = tf.Print(op, [metric_value], metric_name)
|
op = tf.Print(op, [metric_value], metric_name)
|
||||||
summary_ops.append(op)
|
summary_ops.append(op)
|
||||||
|
76
tensorflow/contrib/sparsemax/BUILD
Normal file
76
tensorflow/contrib/sparsemax/BUILD
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Description:
|
||||||
|
# Contains ops to train linear models on top of TensorFlow.
|
||||||
|
# APIs here are meant to evolve over time.
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_custom_op_library",
|
||||||
|
"tf_py_test",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"//tensorflow/core:platform/default/build_config.bzl",
|
||||||
|
"tf_kernel_tests_linkstatic",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "sparsemax_py",
|
||||||
|
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/contrib/util:util_py",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:nn",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_tests(
|
||||||
|
name = "sparsemax_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/kernel_tests/sparsemax_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":sparsemax_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:gradients",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_tests(
|
||||||
|
name = "sparsemax_loss_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["python/kernel_tests/sparsemax_loss_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":sparsemax_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:gradients",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
30
tensorflow/contrib/sparsemax/__init__.py
Normal file
30
tensorflow/contrib/sparsemax/__init__.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Module that implements sparsemax and sparsemax loss, see [1].
|
||||||
|
|
||||||
|
[1] https://arxiv.org/abs/1602.02068
|
||||||
|
|
||||||
|
## Sparsemax
|
||||||
|
|
||||||
|
@@sparsemax
|
||||||
|
@@sparsemax_loss
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.sparsemax.python.ops.sparsemax import sparsemax
|
||||||
|
from tensorflow.contrib.sparsemax.python.ops.sparsemax_loss \
|
||||||
|
import sparsemax_loss
|
@ -0,0 +1,224 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for SparsemaxLossOp."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.sparsemax import sparsemax, sparsemax_loss
|
||||||
|
from tensorflow.python.ops import gradient_checker
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
test_obs = 10
|
||||||
|
|
||||||
|
|
||||||
|
class SparsemaxLossTest(test.TestCase):
|
||||||
|
|
||||||
|
def _np_sparsemax(self, z):
|
||||||
|
z = z - np.mean(z, axis=1)[:, np.newaxis]
|
||||||
|
|
||||||
|
# sort z
|
||||||
|
z_sorted = np.sort(z, axis=1)[:, ::-1]
|
||||||
|
|
||||||
|
# calculate k(z)
|
||||||
|
z_cumsum = np.cumsum(z_sorted, axis=1)
|
||||||
|
k = np.arange(1, z.shape[1] + 1)
|
||||||
|
z_check = 1 + k * z_sorted > z_cumsum
|
||||||
|
# use argmax to get the index by row as .nonzero() doesn't
|
||||||
|
# take an axis argument. np.argmax return the first index, but the last
|
||||||
|
# index is required here, use np.flip to get the last index and
|
||||||
|
# `z.shape[axis]` to compensate for np.flip afterwards.
|
||||||
|
k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1)
|
||||||
|
|
||||||
|
# calculate tau(z)
|
||||||
|
tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1]
|
||||||
|
tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1)
|
||||||
|
|
||||||
|
# calculate p
|
||||||
|
return np.maximum(0, z - tau_z)
|
||||||
|
|
||||||
|
def _np_sparsemax_loss(self, z, q):
|
||||||
|
z = z - np.mean(z, axis=1)[:, np.newaxis]
|
||||||
|
|
||||||
|
# Calculate q^T * z
|
||||||
|
z_k = np.sum(q * z, axis=1)
|
||||||
|
|
||||||
|
# calculate sum over S(z)
|
||||||
|
p = self._np_sparsemax(z)
|
||||||
|
s = p > 0
|
||||||
|
# z_i^2 - tau(z)^2 = p_i (2 * z_i - p_i) for i \in S(z)
|
||||||
|
S_sum = np.sum(s * p * (2 * z - p), axis=1)
|
||||||
|
|
||||||
|
# because q is binary, sum([q_1^2, q_2^2, ...]) is just sum(q)
|
||||||
|
q_norm = np.sum(q, axis=1)
|
||||||
|
|
||||||
|
return -z_k + 0.5 * S_sum + 0.5 * q_norm
|
||||||
|
|
||||||
|
def _np_sparsemax_loss_grad(self, z, q):
|
||||||
|
# chain rule
|
||||||
|
grad = 1
|
||||||
|
|
||||||
|
return grad * (-q + self._np_sparsemax(z))
|
||||||
|
|
||||||
|
def _tf_sparsemax(self, z, dtype, use_gpu):
|
||||||
|
with self.test_session(use_gpu=use_gpu):
|
||||||
|
tf_sparsemax_op = sparsemax(z.astype(dtype))
|
||||||
|
tf_sparsemax_out = tf_sparsemax_op.eval()
|
||||||
|
|
||||||
|
return tf_sparsemax_op, tf_sparsemax_out
|
||||||
|
|
||||||
|
def _tf_sparsemax_loss(self, z, q, dtype, use_gpu):
|
||||||
|
z = z.astype(dtype)
|
||||||
|
q = q.astype(dtype)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=use_gpu):
|
||||||
|
tf_sparsemax_op = sparsemax(z)
|
||||||
|
tf_loss_op = sparsemax_loss(z, tf_sparsemax_op, q)
|
||||||
|
tf_loss_out = tf_loss_op.eval()
|
||||||
|
|
||||||
|
return tf_loss_op, tf_loss_out
|
||||||
|
|
||||||
|
def _test_sparsemax_loss_against_numpy(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax-loss kernel against numpy"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||||
|
q = np.zeros((test_obs, 10))
|
||||||
|
q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1
|
||||||
|
|
||||||
|
tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
|
||||||
|
np_loss = self._np_sparsemax_loss(z, q).astype(dtype)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(np_loss, tf_loss_out,
|
||||||
|
half_atol=1e-2, half_rtol=5e-3)
|
||||||
|
self.assertShapeEqual(np_loss, tf_loss_op)
|
||||||
|
|
||||||
|
def _test_constant_add(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax-loss proposition 3"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||||
|
c = random.uniform(low=-3, high=3, size=(test_obs, 1))
|
||||||
|
q = np.zeros((test_obs, 10))
|
||||||
|
q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
|
||||||
|
|
||||||
|
_, tf_loss_zpc = self._tf_sparsemax_loss(
|
||||||
|
z + c, q, dtype, use_gpu
|
||||||
|
)
|
||||||
|
|
||||||
|
_, tf_loss_z = self._tf_sparsemax_loss(
|
||||||
|
z, q, dtype, use_gpu
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(tf_loss_zpc, tf_loss_z,
|
||||||
|
float_atol=5e-6, float_rtol=5e-6,
|
||||||
|
half_atol=1e-2, half_rtol=1e-2)
|
||||||
|
|
||||||
|
def _test_sparsemax_loss_positive(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax-loss proposition 4"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||||
|
q = np.zeros((test_obs, 10))
|
||||||
|
q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1
|
||||||
|
|
||||||
|
tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(np.abs(tf_loss_out), tf_loss_out)
|
||||||
|
self.assertShapeEqual(np.zeros(test_obs), tf_loss_op)
|
||||||
|
|
||||||
|
def _test_sparsemax_loss_zero(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax-loss proposition 5"""
|
||||||
|
# construct z and q, such that z_k >= 1 + max_{j!=k} z_k holds for
|
||||||
|
# delta_0 = 1.
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||||
|
z[:, 0] = np.max(z, axis=1) + 1.05
|
||||||
|
|
||||||
|
q = np.zeros((test_obs, 10))
|
||||||
|
q[:, 0] = 1
|
||||||
|
|
||||||
|
tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
|
||||||
|
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(np.zeros(test_obs), tf_loss_out)
|
||||||
|
self.assertShapeEqual(np.zeros(test_obs), tf_loss_op)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(q, tf_sparsemax_out)
|
||||||
|
self.assertShapeEqual(q, tf_sparsemax_op)
|
||||||
|
|
||||||
|
def _test_gradient_against_estimate(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax-loss Rop, aginst estimated-loss Rop"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
|
||||||
|
q = np.zeros((test_obs, 10)).astype(dtype)
|
||||||
|
q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
|
||||||
|
|
||||||
|
logits = array_ops.placeholder(dtype, name='z')
|
||||||
|
sparsemax_op = sparsemax(logits)
|
||||||
|
loss_op = sparsemax_loss(logits, sparsemax_op, q)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=use_gpu):
|
||||||
|
err = gradient_checker.compute_gradient_error(
|
||||||
|
logits, z.shape,
|
||||||
|
loss_op, (test_obs, ),
|
||||||
|
x_init_value=z, delta=1e-9
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
|
def _test_gradient_against_numpy(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax-loss Rop, aginst numpy Rop"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||||
|
q = np.zeros((test_obs, 10))
|
||||||
|
q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
|
||||||
|
|
||||||
|
logits = constant_op.constant(z.astype(dtype), name='z')
|
||||||
|
sparsemax_op = sparsemax(logits)
|
||||||
|
loss_op = sparsemax_loss(logits, sparsemax_op, q.astype(dtype))
|
||||||
|
loss_grad_op = gradients_impl.gradients(loss_op, [logits])[0]
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=use_gpu):
|
||||||
|
tf_grad = loss_grad_op.eval()
|
||||||
|
np_grad = self._np_sparsemax_loss_grad(z, q).astype(dtype)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(np_grad, tf_grad,
|
||||||
|
half_atol=1e-2, half_rtol=5e-3)
|
||||||
|
self.assertShapeEqual(np_grad, loss_grad_op)
|
||||||
|
|
||||||
|
def _test_dtype(self, dtype):
|
||||||
|
random = np.random.RandomState(1)
|
||||||
|
|
||||||
|
self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_constant_add(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_sparsemax_loss_positive(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_sparsemax_loss_zero(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
# sparsemax is not a smooth function so gradient estimation is only
|
||||||
|
# possibol for float64.
|
||||||
|
if dtype == 'float64':
|
||||||
|
self._test_gradient_against_estimate(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_gradient_against_numpy(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
def testFloat(self):
|
||||||
|
self._test_dtype('float32')
|
||||||
|
|
||||||
|
def testDouble(self):
|
||||||
|
self._test_dtype('float64')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -0,0 +1,252 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for SparsemaxOp."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.sparsemax import sparsemax
|
||||||
|
from tensorflow.python.ops import gradient_checker
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
test_obs = 10
|
||||||
|
|
||||||
|
|
||||||
|
class SparsemaxTest(test.TestCase):
|
||||||
|
|
||||||
|
def _np_sparsemax(self, z):
|
||||||
|
z = z - np.mean(z, axis=1)[:, np.newaxis]
|
||||||
|
|
||||||
|
# sort z
|
||||||
|
z_sorted = np.sort(z, axis=1)[:, ::-1]
|
||||||
|
|
||||||
|
# calculate k(z)
|
||||||
|
z_cumsum = np.cumsum(z_sorted, axis=1)
|
||||||
|
k = np.arange(1, z.shape[1] + 1)
|
||||||
|
z_check = 1 + k * z_sorted > z_cumsum
|
||||||
|
# use argmax to get the index by row as .nonzero() doesn't
|
||||||
|
# take an axis argument. np.argmax return the first index, but the last
|
||||||
|
# index is required here, use np.flip to get the last index and
|
||||||
|
# `z.shape[axis]` to compensate for np.flip afterwards.
|
||||||
|
k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1)
|
||||||
|
|
||||||
|
# calculate tau(z)
|
||||||
|
tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1]
|
||||||
|
tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1)
|
||||||
|
|
||||||
|
# calculate p
|
||||||
|
return np.maximum(0, z - tau_z)
|
||||||
|
|
||||||
|
def _np_sparsemax_grad(self, z):
|
||||||
|
# chain rule
|
||||||
|
grad = np.ones_like(z)
|
||||||
|
|
||||||
|
# Construct S(z)
|
||||||
|
probability = self._np_sparsemax(z)
|
||||||
|
support = probability > 0
|
||||||
|
|
||||||
|
# Calculate \hat{v}, which will be a vector (scalar for each z)
|
||||||
|
v_hat = np.sum(grad * support, axis=1) / np.sum(support, axis=1)
|
||||||
|
|
||||||
|
# Calculates J(z) * v
|
||||||
|
return support * (grad - v_hat[:, np.newaxis])
|
||||||
|
|
||||||
|
def _tf_sparsemax(self, z, dtype, use_gpu):
|
||||||
|
with self.test_session(use_gpu=use_gpu):
|
||||||
|
tf_sparsemax_op = sparsemax(z.astype(dtype))
|
||||||
|
tf_sparsemax_out = tf_sparsemax_op.eval()
|
||||||
|
|
||||||
|
return tf_sparsemax_op, tf_sparsemax_out
|
||||||
|
|
||||||
|
def _test_sparsemax_against_numpy(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax kernel against numpy"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||||
|
|
||||||
|
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
|
||||||
|
p_sparemax = self._np_sparsemax(z).astype(dtype)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out,
|
||||||
|
half_atol=5e-3)
|
||||||
|
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
|
||||||
|
|
||||||
|
def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax proposition 1, part 1"""
|
||||||
|
z = np.zeros((1, 10))
|
||||||
|
|
||||||
|
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
|
||||||
|
p_sparemax = np.ones_like(z, dtype=dtype) / z.size
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out)
|
||||||
|
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
|
||||||
|
|
||||||
|
def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax proposition 1, part 2"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||||
|
|
||||||
|
# assume |A(z)| = 1, as z is continues random
|
||||||
|
z_sort_arg = np.argsort(z, axis=1)[:, ::-1]
|
||||||
|
z_sort = np.sort(z, axis=-1)[:, ::-1]
|
||||||
|
gamma_z = z_sort[:, 0] - z_sort[:, 1]
|
||||||
|
epsilon = (0.99 * gamma_z * 1).reshape(-1, 1)
|
||||||
|
|
||||||
|
# construct the expected 1_A(z) array
|
||||||
|
p_expected = np.zeros((test_obs, 10), dtype=dtype)
|
||||||
|
p_expected[np.arange(0, test_obs), z_sort_arg[:, 0]] = 1
|
||||||
|
|
||||||
|
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
|
||||||
|
(1 / epsilon) * z, dtype, use_gpu
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out)
|
||||||
|
self.assertShapeEqual(p_expected, tf_sparsemax_op)
|
||||||
|
|
||||||
|
def _test_constant_add(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax proposition 2"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
|
||||||
|
c = random.uniform(low=-3, high=3, size=(test_obs, 1)).astype(dtype)
|
||||||
|
|
||||||
|
_, tf_sparsemax_zpc = self._tf_sparsemax(
|
||||||
|
z + c, dtype, use_gpu
|
||||||
|
)
|
||||||
|
|
||||||
|
_, tf_sparsemax_z = self._tf_sparsemax(
|
||||||
|
z, dtype, use_gpu
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(tf_sparsemax_zpc, tf_sparsemax_z,
|
||||||
|
half_atol=5e-3)
|
||||||
|
|
||||||
|
def _test_permutation(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax proposition 3"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||||
|
_, p = self._tf_sparsemax(z, dtype, use_gpu)
|
||||||
|
|
||||||
|
for i in range(test_obs):
|
||||||
|
per = random.permutation(10)
|
||||||
|
|
||||||
|
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
|
||||||
|
z[i, per].reshape(1, -1), dtype, use_gpu
|
||||||
|
)
|
||||||
|
p_expected = p[i, per].reshape(1, -1)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out,
|
||||||
|
half_atol=5e-3)
|
||||||
|
self.assertShapeEqual(p_expected, tf_sparsemax_op)
|
||||||
|
|
||||||
|
def _test_diffrence(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax proposition 4"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||||
|
_, p = self._tf_sparsemax(z, dtype, use_gpu)
|
||||||
|
|
||||||
|
etol = {'float16': 1e-2, 'float32': 1e-6, 'float64': 1e-9}[dtype]
|
||||||
|
|
||||||
|
for val in range(0, test_obs):
|
||||||
|
for i in range(0, 10):
|
||||||
|
for j in range(0, 10):
|
||||||
|
# check condition, the obesite pair will be checked anyway
|
||||||
|
if z[val, i] > z[val, j]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol,
|
||||||
|
"0 <= %.10f <= %.10f" % (
|
||||||
|
p[val, j] - p[val, i], z[val, j] - z[val, i] + etol
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_two_dimentional(self, dtype, random, use_gpu):
|
||||||
|
"""check two dimentation sparsemax case"""
|
||||||
|
t = np.linspace(-2, 2, test_obs, dtype=dtype)
|
||||||
|
z = np.vstack([
|
||||||
|
t, np.zeros(test_obs, dtype=dtype)
|
||||||
|
]).T
|
||||||
|
|
||||||
|
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
|
||||||
|
|
||||||
|
p0_expected = np.select([t < -1, t <= 1, t > 1], [0, (t + 1) / 2, 1])
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(p0_expected, tf_sparsemax_out[:, 0])
|
||||||
|
self.assertAllCloseAccordingToType(1 - p0_expected, tf_sparsemax_out[:, 1])
|
||||||
|
self.assertShapeEqual(z, tf_sparsemax_op)
|
||||||
|
|
||||||
|
def _test_gradient_against_estimate(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax Rop, aginst estimated Rop"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
|
||||||
|
|
||||||
|
logits = array_ops.placeholder(dtype, name='z')
|
||||||
|
sparsemax_op = sparsemax(logits)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=use_gpu):
|
||||||
|
err = gradient_checker.compute_gradient_error(
|
||||||
|
logits, z.shape,
|
||||||
|
sparsemax_op, z.shape,
|
||||||
|
x_init_value=z, delta=1e-9
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
|
def _test_gradient_against_numpy(self, dtype, random, use_gpu):
|
||||||
|
"""check sparsemax Rop, aginst numpy Rop"""
|
||||||
|
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
|
||||||
|
|
||||||
|
logits = constant_op.constant(z, name='z')
|
||||||
|
sparsemax_op = sparsemax(logits)
|
||||||
|
sparsemax_grad_op = gradients_impl.gradients(sparsemax_op, [logits])[0]
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=use_gpu):
|
||||||
|
tf_grad = sparsemax_grad_op.eval()
|
||||||
|
np_grad = self._np_sparsemax_grad(z)
|
||||||
|
|
||||||
|
self.assertAllCloseAccordingToType(np_grad, tf_grad)
|
||||||
|
self.assertShapeEqual(np_grad, sparsemax_grad_op)
|
||||||
|
|
||||||
|
def _test_dtype(self, dtype):
|
||||||
|
random = np.random.RandomState(1)
|
||||||
|
|
||||||
|
self._test_sparsemax_against_numpy(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_sparsemax_of_inf(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_constant_add(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_permutation(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_diffrence(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_two_dimentional(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
# sparsemax is not a smooth function so gradient estimation is only
|
||||||
|
# possibol for float64.
|
||||||
|
if dtype == 'float64':
|
||||||
|
self._test_gradient_against_estimate(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
self._test_gradient_against_numpy(dtype, random, use_gpu=False)
|
||||||
|
|
||||||
|
def testFloat(self):
|
||||||
|
self._test_dtype('float32')
|
||||||
|
|
||||||
|
def testDouble(self):
|
||||||
|
self._test_dtype('float64')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
74
tensorflow/contrib/sparsemax/python/ops/sparsemax.py
Normal file
74
tensorflow/contrib/sparsemax/python/ops/sparsemax.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Sparsemax op."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.util import loader
|
||||||
|
from tensorflow.python.platform import resource_loader
|
||||||
|
from tensorflow.python.framework import ops, dtypes
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import nn
|
||||||
|
|
||||||
|
|
||||||
|
def sparsemax(logits, name=None):
|
||||||
|
"""Computes sparsemax activations [1].
|
||||||
|
|
||||||
|
For each batch `i` and class `j` we have
|
||||||
|
sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)
|
||||||
|
|
||||||
|
[1]: https://arxiv.org/abs/1602.02068
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: A `Tensor`. Must be one of the following types: `half`, `float32`,
|
||||||
|
`float64`.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor`. Has the same type as `logits`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with ops.name_scope(name, "sparsemax", [logits]) as name:
|
||||||
|
logits = ops.convert_to_tensor(logits, name="logits")
|
||||||
|
obs = array_ops.shape(logits)[0]
|
||||||
|
dims = array_ops.shape(logits)[1]
|
||||||
|
|
||||||
|
z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
|
||||||
|
|
||||||
|
# sort z
|
||||||
|
z_sorted, _ = nn.top_k(z, k=dims)
|
||||||
|
|
||||||
|
# calculate k(z)
|
||||||
|
z_cumsum = math_ops.cumsum(z_sorted, axis=1)
|
||||||
|
k = math_ops.range(
|
||||||
|
1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype
|
||||||
|
)
|
||||||
|
z_check = 1 + k * z_sorted > z_cumsum
|
||||||
|
# because the z_check vector is always [1,1,...1,0,0,...0] finding the
|
||||||
|
# (index + 1) of the last `1` is the same as just summing the number of 1.
|
||||||
|
k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1)
|
||||||
|
|
||||||
|
# calculate tau(z)
|
||||||
|
indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1)
|
||||||
|
tau_sum = array_ops.gather_nd(z_cumsum, indices)
|
||||||
|
tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype)
|
||||||
|
|
||||||
|
# calculate p
|
||||||
|
return math_ops.maximum(
|
||||||
|
math_ops.cast(0, logits.dtype),
|
||||||
|
z - tau_z[:, array_ops.newaxis]
|
||||||
|
)
|
59
tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
Normal file
59
tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Sparsemax Loss op."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.util import loader
|
||||||
|
from tensorflow.python.platform import resource_loader
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
|
def sparsemax_loss(logits, sparsemax, labels, name=None):
|
||||||
|
"""Computes sparsemax loss function [1].
|
||||||
|
|
||||||
|
[1]: https://arxiv.org/abs/1602.02068
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: A `Tensor`. Must be one of the following types: `half`, `float32`,
|
||||||
|
`float64`.
|
||||||
|
sparsemax: A `Tensor`. Must have the same type as `logits`.
|
||||||
|
labels: A `Tensor`. Must have the same type as `logits`.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor`. Has the same type as `logits`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with ops.name_scope(name, "sparsemax_loss",
|
||||||
|
[logits, sparsemax, labels]) as name:
|
||||||
|
logits = ops.convert_to_tensor(logits, name="logits")
|
||||||
|
sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax")
|
||||||
|
labels = ops.convert_to_tensor(labels, name="labels")
|
||||||
|
|
||||||
|
shifted_logits = logits - \
|
||||||
|
math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
|
||||||
|
|
||||||
|
# sum over support
|
||||||
|
support = math_ops.cast(sparsemax > 0, sparsemax.dtype)
|
||||||
|
sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax)
|
||||||
|
|
||||||
|
# - z_k + ||q||^2
|
||||||
|
q_part = labels * (0.5 * labels - shifted_logits)
|
||||||
|
|
||||||
|
return math_ops.reduce_sum(sum_s + q_part, axis=1)
|
@ -77,6 +77,8 @@ load(
|
|||||||
"tf_opts_nortti_if_android",
|
"tf_opts_nortti_if_android",
|
||||||
"cc_header_only_library",
|
"cc_header_only_library",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
|
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
|
||||||
@ -111,7 +113,10 @@ load(
|
|||||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||||
"tf_cuda_tests_tags",
|
"tf_cuda_tests_tags",
|
||||||
)
|
)
|
||||||
|
#load(
|
||||||
|
# "//third_party/mkl:build_defs.bzl",
|
||||||
|
# "if_mkl",
|
||||||
|
#)
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Public targets
|
# Public targets
|
||||||
|
|
||||||
@ -1863,6 +1868,35 @@ tf_cc_tests(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#if_mkl(
|
||||||
|
# tf_cc_test_mkl(
|
||||||
|
# name = "mkl_related_tests",
|
||||||
|
# size = "small",
|
||||||
|
# srcs = ["graph/mkl_optimizer_merge_test.cc"],
|
||||||
|
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
|
# deps = [
|
||||||
|
# ":core",
|
||||||
|
# ":core_cpu",
|
||||||
|
# ":core_cpu_internal",
|
||||||
|
# ":direct_session_internal",
|
||||||
|
# ":framework",
|
||||||
|
# ":framework_internal",
|
||||||
|
# ":lib",
|
||||||
|
# ":lib_internal",
|
||||||
|
# ":ops",
|
||||||
|
# ":protos_all_cc",
|
||||||
|
# ":test",
|
||||||
|
# ":test_main",
|
||||||
|
# ":testlib",
|
||||||
|
# "//third_party/eigen3",
|
||||||
|
# "//tensorflow/cc:cc_ops",
|
||||||
|
# "//tensorflow/cc:scope",
|
||||||
|
# "//tensorflow/cc:sendrecv_ops",
|
||||||
|
# "//tensorflow/core/kernels:ops_util",
|
||||||
|
# ],
|
||||||
|
# ),
|
||||||
|
#)
|
||||||
|
|
||||||
tf_cc_tests_gpu(
|
tf_cc_tests_gpu(
|
||||||
name = "gpu_related_tests",
|
name = "gpu_related_tests",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
596
tensorflow/core/graph/mkl_optimizer_merge.cc
Normal file
596
tensorflow/core/graph/mkl_optimizer_merge.cc
Normal file
@ -0,0 +1,596 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
// This module implements node merging optimization on the graph.
|
||||||
|
// We process the nodes in the graph in reverse postorder
|
||||||
|
// (i.e. inputs before their downstream dependencies).
|
||||||
|
//
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include <queue>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/core/graph/mkl_optimizer_merge.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// How many hops do we search for matching node in the backward dataflow graph?
|
||||||
|
// We use maxhop of 10 based on empirical observations. Also, these are
|
||||||
|
// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D)
|
||||||
|
// directly goes to backward nodes, we do not expect the hop-distance
|
||||||
|
// would be more than few nodes.
|
||||||
|
static size_t kNodeMergeContextMaxDepth = 10;
|
||||||
|
|
||||||
|
// This optimization pass performs two tasks: merge
|
||||||
|
// nodes in the forward pass, and rewrite the gradient ops
|
||||||
|
// corresponding to merged forward ops.
|
||||||
|
//
|
||||||
|
// Merging nodes in the graph: Currently, it merges Conv2D+AddBias together.
|
||||||
|
//
|
||||||
|
// Rewriting nodes in the graph: This is neded in order to optimize
|
||||||
|
// gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and
|
||||||
|
// MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into
|
||||||
|
// Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad.
|
||||||
|
// This is context-specific optimization, where the context is the
|
||||||
|
// forward operator that the BiasAddGrad corresponds to.
|
||||||
|
class NodeMergeRewritePass : public GraphOptimizationPass {
|
||||||
|
public:
|
||||||
|
NodeMergeRewritePass() {
|
||||||
|
csinfo_.conv2d = "Conv2D";
|
||||||
|
csinfo_.conv2dwithbias = "Conv2DWithBias";
|
||||||
|
csinfo_.conv2dwithbiasbackpropbias = "Conv2DWithBiasBackpropBias";
|
||||||
|
csinfo_.biasadd = "BiasAdd";
|
||||||
|
csinfo_.matmul = "MatMul";
|
||||||
|
csinfo_.biasaddgrad = "BiasAddGrad";
|
||||||
|
|
||||||
|
minfo_.push_back({csinfo_.conv2d, csinfo_.biasadd, 0,
|
||||||
|
csinfo_.conv2dwithbias});
|
||||||
|
|
||||||
|
// We use maxhop of 10 based on emperical observations. Also, these are
|
||||||
|
// maxhops in backward data-flow graph. Since input of forward nodes
|
||||||
|
// (Conv2D) directly goes to backward nodes, we do not expect the
|
||||||
|
// hop-distance would be more than few nodes.
|
||||||
|
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
|
||||||
|
{csinfo_.conv2dwithbias, kNodeMergeContextMaxDepth}});
|
||||||
|
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
|
||||||
|
{csinfo_.conv2d, kNodeMergeContextMaxDepth}});
|
||||||
|
// For now, we are rewriting BiasAddGrad to BiasAddGrad for MatMul. This is
|
||||||
|
// because we do not have a separate Op for MatMulwithBias.
|
||||||
|
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.biasaddgrad,
|
||||||
|
{csinfo_.matmul, kNodeMergeContextMaxDepth}});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard interface to run optimization pass
|
||||||
|
Status Run(const GraphOptimizationPassOptions& options);
|
||||||
|
|
||||||
|
// Helper function which does most of heavy lifting for node merge
|
||||||
|
//
|
||||||
|
// Extracts common functionality between Run public interface and
|
||||||
|
// test interface.
|
||||||
|
//
|
||||||
|
// @return true, if and only if graph is mutated; false otherwise.
|
||||||
|
bool RunPass(std::unique_ptr<Graph>* g);
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Structure to specify information used in node merge
|
||||||
|
typedef struct {
|
||||||
|
string pred; // Predecessor node string
|
||||||
|
string succ; // Successor node string
|
||||||
|
int op; // What operand no the predecessor node corresponds
|
||||||
|
// to successor node?
|
||||||
|
string newnode; // Name of the node after merge
|
||||||
|
} MergeInfo;
|
||||||
|
|
||||||
|
/// Structure to specify information used in node rewrite
|
||||||
|
typedef struct {
|
||||||
|
string node; // Name of the node to be rewritten
|
||||||
|
string rewrite; // New name of the node after rewrite
|
||||||
|
typedef struct {
|
||||||
|
string fwd; // Node name in forward pass that this node
|
||||||
|
// corresponds to
|
||||||
|
size_t maxhop; // Maximum number of hops the mfwd_ is located
|
||||||
|
// from this node. If mfwd_ is farther than mmaxhop_
|
||||||
|
// then we do not rewrite the node.
|
||||||
|
} ContextInfo;
|
||||||
|
ContextInfo cinfo; // Context for rewrite
|
||||||
|
} RewriteInfo;
|
||||||
|
|
||||||
|
/// Structure to store all constant strings
|
||||||
|
typedef struct {
|
||||||
|
string conv2d;
|
||||||
|
string conv2dwithbias;
|
||||||
|
string conv2dwithbiasbackpropbias;
|
||||||
|
string biasadd;
|
||||||
|
string matmul;
|
||||||
|
string biasaddgrad;
|
||||||
|
} ConstStringInfo;
|
||||||
|
|
||||||
|
ConstStringInfo csinfo_;
|
||||||
|
std::vector<MergeInfo> minfo_;
|
||||||
|
std::vector<RewriteInfo> rinfo_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Return a node that can be merged with input node
|
||||||
|
//
|
||||||
|
// @return pointer to the node if we can find such a
|
||||||
|
// node. Otherwise, it returns nullptr.
|
||||||
|
Node* FindNodeForMerge(const Node* a) const;
|
||||||
|
|
||||||
|
// Merge predecessor node with its successor.
|
||||||
|
// Currently, we merge Conv2D with AddBias only.
|
||||||
|
//
|
||||||
|
// Input nodes succ and pred may be deleted if the call to
|
||||||
|
// this function is successful. Attempt to use the pointers
|
||||||
|
// after the call to function may result is undefined behaviors.
|
||||||
|
//
|
||||||
|
// @input g - input graph, succ - successor node, pred - predecessor node
|
||||||
|
// @return Status::OK(), if merging is successful and supported.
|
||||||
|
// Returns appropriate Status error code otherwise.
|
||||||
|
// Graph is updated in case nodes are merged. Otherwise, it is
|
||||||
|
// not updated.
|
||||||
|
Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred);
|
||||||
|
|
||||||
|
// Is input node (n) a candidate for rewrite?
|
||||||
|
//
|
||||||
|
// @return true, if it can be rewritten; false, otherwise.
|
||||||
|
bool IsApplicableRewriteNode(const Node* n) const;
|
||||||
|
|
||||||
|
// Rewrites input node to a new node specified by its matching rewrite info.
|
||||||
|
//
|
||||||
|
// Method first searches matching rewrite info for input node and then
|
||||||
|
// uses that info to rewrite.
|
||||||
|
//
|
||||||
|
// Input node may be deleted in case of rewrite. Attempt to use the node
|
||||||
|
// after the call can result in undefined behaviors.
|
||||||
|
//
|
||||||
|
// @input g - input graph, n - Node to be rewritten
|
||||||
|
// @return Status::OK(), if the input node is rewritten;
|
||||||
|
// Returns appropriate Status error code otherwise.
|
||||||
|
// Graph is updated in case the input node is rewritten.
|
||||||
|
// Otherwise, it is not updated.
|
||||||
|
Status RewriteNode(std::unique_ptr<Graph>* g, Node* n);
|
||||||
|
|
||||||
|
// Helper function that searches the matching rewriteinfo for the node.
|
||||||
|
// Implements depth-first search in the data dependence graph for the
|
||||||
|
// gradient op in backward direction.
|
||||||
|
//
|
||||||
|
// @input n - Node (gradient op) whose rewriteinfo is to be searched,
|
||||||
|
// fwdn - pointer to node from the forward pass that this node
|
||||||
|
// belongs to
|
||||||
|
// @return Matching rewriteinfo in case a match is found; null otherwise.
|
||||||
|
const RewriteInfo* FindMatchingRewriteInfo(const Node* n,
|
||||||
|
const Node** fwdn) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// We register merge optimizer for phase 1 and MKLToTF insertion for phase 2.
|
||||||
|
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1,
|
||||||
|
NodeMergeRewritePass);
|
||||||
|
|
||||||
|
static void FillInputs(const Node* n,
|
||||||
|
gtl::InlinedVector<Node*, 4>* control_edges,
|
||||||
|
gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
|
||||||
|
DCHECK_EQ(in->size(), n->num_inputs());
|
||||||
|
control_edges->clear();
|
||||||
|
for (const Edge* e : n->in_edges()) {
|
||||||
|
if (e->IsControlEdge()) {
|
||||||
|
control_edges->push_back(e->src());
|
||||||
|
} else {
|
||||||
|
(*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::sort(control_edges->begin(), control_edges->end());
|
||||||
|
if (n->op_def().is_commutative()) {
|
||||||
|
// For commutative inputs, we sort the input by the input Node*
|
||||||
|
// to get a canonical ordering (so that add(a,b) and add(b, a) will
|
||||||
|
// hash to the same value if is_commutative is true for 'add').
|
||||||
|
std::sort(in->begin(), in->end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const {
|
||||||
|
// Search for all matching mergeinfo.
|
||||||
|
// We allow more than one match for extensibility.
|
||||||
|
std::vector<const MergeInfo*> matching_mi;
|
||||||
|
for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
|
||||||
|
if (a->type_string() == mi->succ) {
|
||||||
|
matching_mi.push_back(&*mi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "FindNodeForMerge: " << a->type_string();
|
||||||
|
|
||||||
|
for (const MergeInfo* mi : matching_mi) {
|
||||||
|
const int N_in = a->num_inputs();
|
||||||
|
if (mi->op >= N_in) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the control edges and input of node
|
||||||
|
gtl::InlinedVector<Node*, 4> a_control_edges;
|
||||||
|
gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
|
||||||
|
FillInputs(a, &a_control_edges, &a_in);
|
||||||
|
|
||||||
|
// Get operand op of the operator
|
||||||
|
Node *b = nullptr;
|
||||||
|
b = a_in[mi->op].first;
|
||||||
|
if (b == nullptr || (b->type_string() != mi->pred)) {
|
||||||
|
// NOTE: Should the first check be assert?
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(1) << " FindNode: " << b->type_string();
|
||||||
|
|
||||||
|
gtl::InlinedVector<Node*, 4> b_control_edges;
|
||||||
|
gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
|
||||||
|
FillInputs(b, &b_control_edges, &b_in);
|
||||||
|
|
||||||
|
// Shouldn't merge if a and b have different control edges.
|
||||||
|
if (a_control_edges != b_control_edges) {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
// We found a match.
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g,
|
||||||
|
Node* succ, Node* pred) {
|
||||||
|
CHECK_NOTNULL(succ);
|
||||||
|
CHECK_NOTNULL(pred);
|
||||||
|
|
||||||
|
if (succ->type_string() == csinfo_.biasadd &&
|
||||||
|
pred->type_string() == csinfo_.conv2d) {
|
||||||
|
// 1. Get all attributes from input nodes.
|
||||||
|
DataType T_pred, T_succ;
|
||||||
|
string padding;
|
||||||
|
std::vector<int32> strides;
|
||||||
|
string data_format_pred, data_format_succ;
|
||||||
|
bool use_cudnn_on_gnu;
|
||||||
|
int groups = 1;
|
||||||
|
TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
|
||||||
|
TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
|
||||||
|
TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
|
||||||
|
TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
|
||||||
|
TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
|
||||||
|
TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
|
||||||
|
TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu",
|
||||||
|
&use_cudnn_on_gnu));
|
||||||
|
// Groups attribute may not be there on the input node. So we do not
|
||||||
|
// check for error in GetNodeAttr call.
|
||||||
|
GetNodeAttr(pred->def(), "groups", &groups);
|
||||||
|
// We check to ensure that data formats of both succ and pred are same.
|
||||||
|
// We expect them to be same, so we can enforce this as assert.
|
||||||
|
// But assert can be too strict, so we enforce this as a check.
|
||||||
|
// If the check fails, then we do not merge two nodes.
|
||||||
|
if (data_format_pred != data_format_succ ||
|
||||||
|
T_pred != T_succ) {
|
||||||
|
return Status(error::Code::INVALID_ARGUMENT,
|
||||||
|
"data_format or T attribute of Conv2D and BiasAdd"
|
||||||
|
"do not match. Will skip node merge optimization");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Get inputs from both the nodes.
|
||||||
|
// Find the 2 inputs from the conv and the bias from the add Bias.
|
||||||
|
Node* oper1 = nullptr;
|
||||||
|
Node* oper2 = nullptr;
|
||||||
|
Node* oper3 = nullptr;
|
||||||
|
|
||||||
|
const int succ_num = succ->num_inputs();
|
||||||
|
gtl::InlinedVector<Node*, 4> succ_control_edges;
|
||||||
|
gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
|
||||||
|
FillInputs(succ, &succ_control_edges, &succ_in);
|
||||||
|
|
||||||
|
const int pred_num = pred->num_inputs();
|
||||||
|
gtl::InlinedVector<Node*, 4> pred_control_edges;
|
||||||
|
gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
|
||||||
|
FillInputs(pred, &pred_control_edges, &pred_in);
|
||||||
|
|
||||||
|
// We need to ensure that there is only 1 edge between Conv2D and AddBias.
|
||||||
|
// Otherwise, merging is semantically incorrect.
|
||||||
|
if (pred->out_edges().size() != 1) {
|
||||||
|
return Status(error::Code::INVALID_ARGUMENT,
|
||||||
|
"Conv2D has multiple outputs."
|
||||||
|
"Will skip node merge optimization");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const Edge *e : pred->out_edges()) {
|
||||||
|
if (e->dst() != succ) {
|
||||||
|
return Status(error::Code::INVALID_ARGUMENT,
|
||||||
|
"Conv2D does not feed to BiasAdd."
|
||||||
|
"Will skip node merge optimization");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get operand 0, 1 of conv2D
|
||||||
|
oper1 = pred_in[0].first;
|
||||||
|
oper2 = pred_in[1].first;
|
||||||
|
// Get operand 1 of add_bias
|
||||||
|
oper3 = succ_in[1].first;
|
||||||
|
|
||||||
|
Node* ret;
|
||||||
|
// We will use the node name of BiasAdd as the name of new node
|
||||||
|
TF_CHECK_OK(NodeBuilder(succ->name(), csinfo_.conv2dwithbias)
|
||||||
|
.Input(oper1)
|
||||||
|
.Input(oper2)
|
||||||
|
.Input(oper3)
|
||||||
|
.Attr("T", T_pred)
|
||||||
|
.Attr("strides", strides)
|
||||||
|
.Attr("padding", padding)
|
||||||
|
.Attr("data_format", data_format_pred)
|
||||||
|
.Attr("use_cudnn_on_gpu", use_cudnn_on_gnu)
|
||||||
|
.Attr("groups", groups)
|
||||||
|
.Finalize(&**g, &ret));
|
||||||
|
CHECK_NOTNULL(ret);
|
||||||
|
|
||||||
|
// Incoming edges are fixed, we will fix the outgoing edges now.
|
||||||
|
for (const Edge* e : succ->out_edges()) {
|
||||||
|
(*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
|
||||||
|
}
|
||||||
|
|
||||||
|
(*g)->RemoveNode(succ);
|
||||||
|
(*g)->RemoveNode(pred);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status(error::Code::UNIMPLEMENTED,
|
||||||
|
"Unimplemented case for node merge optimization.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node *n) {
|
||||||
|
CHECK_NOTNULL(n);
|
||||||
|
|
||||||
|
// Get the matching rewriteinfo for the node
|
||||||
|
const Node* fwdn = nullptr;
|
||||||
|
const RewriteInfo* ri = FindMatchingRewriteInfo(n, &fwdn);
|
||||||
|
if (ri == nullptr || fwdn == nullptr) {
|
||||||
|
VLOG(1) << "Rewriteinfo not found for: " << n->type_string();
|
||||||
|
return Status(error::Code::INVALID_ARGUMENT,
|
||||||
|
"Rewrite info not found for the node."
|
||||||
|
"Will skip node rewrite optimization");
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "Rewrite called for: " << n->type_string();
|
||||||
|
|
||||||
|
if (n->type_string() == csinfo_.biasaddgrad &&
|
||||||
|
ri->node == csinfo_.biasaddgrad &&
|
||||||
|
(ri->rewrite == csinfo_.conv2dwithbiasbackpropbias ||
|
||||||
|
ri->rewrite == csinfo_.biasaddgrad)) {
|
||||||
|
DataType T; string data_format;
|
||||||
|
TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
|
||||||
|
TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format));
|
||||||
|
|
||||||
|
int n_num = n->num_inputs(); // this must be 1.
|
||||||
|
CHECK_EQ(n_num, 1);
|
||||||
|
|
||||||
|
gtl::InlinedVector<Node*, 4> n_control_edges;
|
||||||
|
gtl::InlinedVector<std::pair<Node*, int>, 4> n_in(n_num);
|
||||||
|
FillInputs(n, &n_control_edges, &n_in);
|
||||||
|
|
||||||
|
Node *ret = nullptr, *op = n_in[0].first;
|
||||||
|
|
||||||
|
if (ri->rewrite == csinfo_.conv2dwithbiasbackpropbias) {
|
||||||
|
// Get strides info from Conv2D (node in the forward pass that this
|
||||||
|
// node corresponds to).
|
||||||
|
std::vector<int32> strides;
|
||||||
|
TF_CHECK_OK(GetNodeAttr(fwdn->def(), "strides", &strides));
|
||||||
|
|
||||||
|
// We use same name as original node name as there may be fetchoutputs
|
||||||
|
// associated with it.
|
||||||
|
TF_CHECK_OK(NodeBuilder(n->name(), ri->rewrite)
|
||||||
|
.Input(op)
|
||||||
|
.Attr("T", T)
|
||||||
|
.Attr("data_format", data_format)
|
||||||
|
.Attr("strides", strides)
|
||||||
|
.Finalize(&**g, &ret));
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(ri->rewrite, csinfo_.biasaddgrad);
|
||||||
|
TF_CHECK_OK(NodeBuilder(n->name(), ri->rewrite)
|
||||||
|
.Input(op)
|
||||||
|
.Attr("T", T)
|
||||||
|
.Attr("data_format", data_format)
|
||||||
|
.Finalize(&**g, &ret));
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_NOTNULL(ret);
|
||||||
|
|
||||||
|
// Incoming edges are fixed, we will fix the outgoing edges now.
|
||||||
|
for (const Edge* e : n->out_edges()) {
|
||||||
|
(*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "Rewrite node: " << n->type_string() << " successful";
|
||||||
|
(*g)->RemoveNode(n);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status(error::Code::UNIMPLEMENTED,
|
||||||
|
"Unimplemented case for node rewrite optimization.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const NodeMergeRewritePass::RewriteInfo*
|
||||||
|
NodeMergeRewritePass::FindMatchingRewriteInfo(const Node* n,
|
||||||
|
const Node** fwdn) const {
|
||||||
|
CHECK_NOTNULL(n);
|
||||||
|
CHECK_NOTNULL(fwdn);
|
||||||
|
*fwdn = nullptr;
|
||||||
|
|
||||||
|
// Search for matching rewriteinfo based on node name.
|
||||||
|
// There could be more than one matching rewriteinfos.
|
||||||
|
std::vector<const RewriteInfo*> matching_ri;
|
||||||
|
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
|
||||||
|
if (n->type_string() == ri->node) {
|
||||||
|
matching_ri.push_back(&*ri);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "Searching graph for: " << n->type_string() << " in backwards.";
|
||||||
|
|
||||||
|
// Now we will check for forward op name for rewrite info in data
|
||||||
|
// flow graph. Get the max hops we should search for the fwd node
|
||||||
|
// We are now going to search (breadth-first) backwards in data
|
||||||
|
// dependence graph (for up to max hops) from n for the node
|
||||||
|
// specified in fwd.
|
||||||
|
// queue to maintain nodes to be visited and depth info for
|
||||||
|
// breadth-first search
|
||||||
|
std::queue<std::pair<const Node*, int>> nqueue;
|
||||||
|
const Node* curr_node = n;
|
||||||
|
size_t curr_depth = 0;
|
||||||
|
nqueue.push(std::make_pair(curr_node, curr_depth));
|
||||||
|
|
||||||
|
while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) {
|
||||||
|
std::pair<const Node*, int> curr_pair = nqueue.front();
|
||||||
|
nqueue.pop();
|
||||||
|
|
||||||
|
std::set<const Node*> visited_nodes;
|
||||||
|
curr_node = curr_pair.first;
|
||||||
|
curr_depth = curr_pair.second;
|
||||||
|
CHECK_NOTNULL(curr_node);
|
||||||
|
|
||||||
|
VLOG(1) << "Visiting node: " << curr_node->type_string()
|
||||||
|
<< " at depth: " << curr_depth
|
||||||
|
<< " for node: " << n->type_string();
|
||||||
|
|
||||||
|
// If we find a match, we return immediately with the matching rewrite
|
||||||
|
// info.
|
||||||
|
for (const RewriteInfo* ri : matching_ri) {
|
||||||
|
if (curr_node->type_string() == ri->cinfo.fwd) {
|
||||||
|
*fwdn = curr_node;
|
||||||
|
return ri;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Else we explore backward edges from current node.
|
||||||
|
// Add the source nodes of all incoming edges of the node to the queue.
|
||||||
|
for (const Edge* e : curr_node->in_edges()) {
|
||||||
|
// We do not visit already visited node.
|
||||||
|
if (visited_nodes.find(e->src()) == visited_nodes.end()) {
|
||||||
|
// Depth of these nodes is 1 more than the depth of current node.
|
||||||
|
nqueue.push(std::make_pair(e->src(), curr_depth+1));
|
||||||
|
visited_nodes.insert(e->src());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} /* while */
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NodeMergeRewritePass::IsApplicableRewriteNode(const Node *n) const {
|
||||||
|
CHECK_NOTNULL(n);
|
||||||
|
|
||||||
|
// Search for matching rewriteinfo
|
||||||
|
// Even if we find one match, we return true.
|
||||||
|
bool match_found = false;
|
||||||
|
for (const RewriteInfo &ri : rinfo_) {
|
||||||
|
if (n->type_string() == ri.node) {
|
||||||
|
match_found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return match_found;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) {
|
||||||
|
bool result = false;
|
||||||
|
CHECK_NOTNULL(g);
|
||||||
|
|
||||||
|
DumpGraph("Before OptimizeMerge", &**g);
|
||||||
|
|
||||||
|
std::vector<Node*> order;
|
||||||
|
GetReversePostOrder(**g, &order);
|
||||||
|
std::vector<std::pair<Node*, Node*>> nodes_to_be_merged;
|
||||||
|
std::vector<Node*> nodes_to_be_rewritten;
|
||||||
|
|
||||||
|
VLOG(1) << "Running NodeMerge Optimization";
|
||||||
|
|
||||||
|
for (Node* n : order) {
|
||||||
|
if (!n->IsOp()) continue;
|
||||||
|
Node* n1 = nullptr;
|
||||||
|
if ((n1 = FindNodeForMerge(n)) != nullptr) {
|
||||||
|
VLOG(1) << "Scheduled nodes " << n->name() << " and "
|
||||||
|
<< n1->name() << " for merging";
|
||||||
|
nodes_to_be_merged.push_back(std::make_pair(n, n1));
|
||||||
|
} else if (IsApplicableRewriteNode(n)) {
|
||||||
|
VLOG(1) << "Scheduled node " << n->name() << " for rewrite";
|
||||||
|
nodes_to_be_rewritten.push_back(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (std::pair < Node*, Node* > i : nodes_to_be_merged) {
|
||||||
|
// Even if MergeNode merges single pair of nodes, we
|
||||||
|
// need to return true.
|
||||||
|
string n1_name = i.first->name();
|
||||||
|
string n2_name = i.second->name();
|
||||||
|
if (MergeNode(g, i.first, i.second) == Status::OK()) {
|
||||||
|
VLOG(1) << "Merged nodes " << n1_name << " and " << n2_name;
|
||||||
|
result = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DumpGraph("After OptimizeMerge(nodemerge)", &**g);
|
||||||
|
|
||||||
|
for (Node* i : nodes_to_be_rewritten) {
|
||||||
|
string name = i->name();
|
||||||
|
if (RewriteNode(g, i) == Status::OK()) {
|
||||||
|
VLOG(1) << "Rewrite node: " << name << " successful.";
|
||||||
|
result = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DumpGraph("After OptimizeMerge(noderewrite)", &**g);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OptimizeNodeMerge(std::unique_ptr<Graph>* g) {
|
||||||
|
return NodeMergeRewritePass().RunPass(g);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodeMergeRewritePass::Run(const GraphOptimizationPassOptions& options) {
|
||||||
|
// Currently checking only for two cases - Conv2D+Bias and Matmul+Bias.
|
||||||
|
// It is possible to extend it to other operators in future.
|
||||||
|
if (options.graph == nullptr) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the ownership of graph
|
||||||
|
std::unique_ptr<Graph>* g = std::move(options.graph);
|
||||||
|
|
||||||
|
RunPass(g);
|
||||||
|
|
||||||
|
// Return the ownership of graph back
|
||||||
|
options.graph->reset(g->release());
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif
|
42
tensorflow/core/graph/mkl_optimizer_merge.h
Normal file
42
tensorflow/core/graph/mkl_optimizer_merge.h
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// An optimization pass that performs node merging and rewrite on graph nodes
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
|
||||||
|
#define TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
|
||||||
|
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
|
||||||
|
#include <sys/types.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Interface to invoke the pass for unit test
|
||||||
|
//
|
||||||
|
// Returns true if and only if 'g' is mutated.
|
||||||
|
extern bool OptimizeNodeMerge(std::unique_ptr<Graph>* g);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // INTEL_MKL
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
|
397
tensorflow/core/graph/mkl_optimizer_merge_test.cc
Normal file
397
tensorflow/core/graph/mkl_optimizer_merge_test.cc
Normal file
@ -0,0 +1,397 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
|
||||||
|
#include "tensorflow/core/graph/mkl_optimizer_merge.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/testlib.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class OptimizerMergeTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
OptimizerMergeTest() : graph_(OpRegistry::Global()) {}
|
||||||
|
|
||||||
|
static void InitGraph(const string& s, Graph* graph) {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
auto parser = protobuf::TextFormat::Parser();
|
||||||
|
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
||||||
|
GraphConstructorOptions opts;
|
||||||
|
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitGraph(const string& s) {
|
||||||
|
InitGraph(s, &graph_);
|
||||||
|
original_ = CanonicalGraphString(&graph_);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IncludeNode(const Node* n) { return n->IsOp(); }
|
||||||
|
|
||||||
|
static string EdgeId(const Node* n, int index) {
|
||||||
|
if (index == 0) {
|
||||||
|
return n->name();
|
||||||
|
} else if (index == Graph::kControlSlot) {
|
||||||
|
return strings::StrCat(n->name(), ":control");
|
||||||
|
} else {
|
||||||
|
return strings::StrCat(n->name(), ":", index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
string CanonicalGraphString(Graph* g) {
|
||||||
|
std::vector<string> nodes;
|
||||||
|
std::vector<string> edges;
|
||||||
|
for (const Node* n : g->nodes()) {
|
||||||
|
if (IncludeNode(n)) {
|
||||||
|
nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const Edge* e : g->edges()) {
|
||||||
|
if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
|
||||||
|
edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
|
||||||
|
EdgeId(e->dst(), e->dst_input())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Canonicalize
|
||||||
|
std::sort(nodes.begin(), nodes.end());
|
||||||
|
std::sort(edges.begin(), edges.end());
|
||||||
|
return strings::StrCat(str_util::Join(nodes, ";"), "|",
|
||||||
|
str_util::Join(edges, ";"));
|
||||||
|
}
|
||||||
|
|
||||||
|
string DoNodeMerge() {
|
||||||
|
string before = CanonicalGraphString(&graph_);
|
||||||
|
LOG(ERROR) << "Before node merge optimize: " << before;
|
||||||
|
|
||||||
|
std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
|
||||||
|
OptimizeNodeMerge(ug);
|
||||||
|
|
||||||
|
string result = CanonicalGraphString(&graph_);
|
||||||
|
LOG(ERROR) << "After node merge optimize: " << result;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
const string& OriginalGraph() const { return original_; }
|
||||||
|
|
||||||
|
Graph graph_;
|
||||||
|
string original_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_OP("Input").Output("o: float").SetIsStateful();
|
||||||
|
|
||||||
|
TEST_F(OptimizerMergeTest, Basic) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }"
|
||||||
|
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);C(Mul);D(Mul)|"
|
||||||
|
"A->C;A->D;B->C:1;B->D:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test set 1: Conv2D + AddBias
|
||||||
|
|
||||||
|
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y)
|
||||||
|
TEST_F(OptimizerMergeTest, Conv2DWithBias_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Conv2D'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B']}"
|
||||||
|
"node { name: 'D' op: 'Input'}"
|
||||||
|
"node { name: 'E' op: 'BiasAdd'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['C', 'D'] }"
|
||||||
|
"node { name: 'Y' op: 'Input'}"
|
||||||
|
"node { name: 'Z' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['E', 'Y']}");
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);D(Input);E(Conv2DWithBias);Y(Input);Z(Sub)|"
|
||||||
|
"A->E;B->E:1;D->E:2;E->Z;Y->Z:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graph contains only Conv2D, no AddBias.
|
||||||
|
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_NoAddBias) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Conv2D'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B']}");
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);C(Conv2D)|"
|
||||||
|
"A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conv2D output does not go to BiasAdd.
|
||||||
|
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_Dataflow1) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Conv2D'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B']}"
|
||||||
|
"node { name: 'D' op: 'Input'}"
|
||||||
|
"node { name: 'E' op: 'Input'}"
|
||||||
|
"node { name: 'F' op: 'BiasAdd'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['D', 'E'] }"); // Output of Conv2D does not go to BiasAdd.
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);C(Conv2D);D(Input);E(Input);F(BiasAdd)|"
|
||||||
|
"A->C;B->C:1;D->F;E->F:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
|
||||||
|
// Merge should not be done in such case.
|
||||||
|
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_Dataflow2) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Conv2D'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B']}"
|
||||||
|
"node { name: 'D' op: 'Input'}"
|
||||||
|
"node { name: 'E' op: 'Input'}"
|
||||||
|
"node { name: 'F' op: 'BiasAdd'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['D', 'E'] }" // Conv2D has two outputs.
|
||||||
|
// No merge should happen.
|
||||||
|
"node { name: 'G' op: 'Add'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['C', 'E'] }");
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);C(Conv2D);D(Input);E(Input);F(BiasAdd);G(Add)|"
|
||||||
|
"A->C;B->C:1;C->G;D->F;E->F:1;E->G:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// data_format attribute value mismatch. Merge should not be done
|
||||||
|
// in such case.
|
||||||
|
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_AttrMismatch) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Conv2D'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B']}"
|
||||||
|
"node { name: 'D' op: 'Input'}"
|
||||||
|
"node { name: 'E' op: 'BiasAdd'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NHCW' } }"
|
||||||
|
" input: ['C', 'D'] }");
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);C(Conv2D);D(Input);E(BiasAdd)|"
|
||||||
|
"A->C;B->C:1;C->E;D->E:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test set 2: Conv2D..BiasAddGrad -> Conv2DWithBiasBackpropBias rewrite tests
|
||||||
|
|
||||||
|
// C=Conv2D(A,B); D=Sub(C,A); F=BiasAddGrad(D)
|
||||||
|
TEST_F(OptimizerMergeTest, Conv2DBackprop_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Conv2D'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B']}"
|
||||||
|
"node { name: 'D' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['C', 'A']}"
|
||||||
|
"node { name: 'E' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['D'] }");
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);C(Conv2D);D(Sub);E(Conv2DWithBiasBackpropBias)|"
|
||||||
|
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||||
|
}
|
||||||
|
|
||||||
|
// No Conv2D in the context for BiasAddGrad. No rewrite should happen.
|
||||||
|
// C=Add(A,B); D=Sub(C,A); F=BiasAddGrad(D,E)
|
||||||
|
TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoConv2D) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Add'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B']}"
|
||||||
|
"node { name: 'D' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['C', 'A']}"
|
||||||
|
"node { name: 'E' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['D'] }");
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
|
||||||
|
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||||
|
}
|
||||||
|
|
||||||
|
// No Conv2D in the context for BiasAddGrad, but MatMul in context.
|
||||||
|
// Rewrite should happen, but name of BiasAddGrad does not change.
|
||||||
|
// C=MatMul(A,B); D=Sub(C,A); F=BiasAddGrad(D,E)
|
||||||
|
TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoConv2D_MatMul) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'MatMul'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'transpose_a' value { b: false } }"
|
||||||
|
" attr { key: 'transpose_b' value { b: false } }"
|
||||||
|
" input: ['A', 'B']}"
|
||||||
|
"node { name: 'D' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['C', 'A']}"
|
||||||
|
"node { name: 'E' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['D'] }");
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
|
||||||
|
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests
|
||||||
|
// C=MatMul(A,B); D=Sub(C,A); F=BiasAddGrad(D,E)
|
||||||
|
TEST_F(OptimizerMergeTest, MatMulBiasAddGrad_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'MatMul'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'transpose_a' value { b: false } }"
|
||||||
|
" attr { key: 'transpose_b' value { b: false } }"
|
||||||
|
" input: ['A', 'B']}"
|
||||||
|
"node { name: 'D' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['C', 'A']}"
|
||||||
|
"node { name: 'E' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['D'] }");
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
|
||||||
|
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||||
|
}
|
||||||
|
|
||||||
|
// No MatMul in the context for BiasAddGrad. No rewrite should happen.
|
||||||
|
// C=Add(A,B); D=Sub(C,A); F=BiasAddGrad(D,E)
|
||||||
|
TEST_F(OptimizerMergeTest, MatMulBiasAddGrad_Negative_NoMatMul) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Add'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B']}"
|
||||||
|
"node { name: 'D' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['C', 'A']}"
|
||||||
|
"node { name: 'E' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['D'] }");
|
||||||
|
EXPECT_EQ(DoNodeMerge(),
|
||||||
|
"A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
|
||||||
|
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void BM_NodeMerge(int iters, int op_nodes) {
|
||||||
|
testing::StopTiming();
|
||||||
|
string s;
|
||||||
|
for (int in = 0; in < 10; in++) {
|
||||||
|
s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
|
||||||
|
}
|
||||||
|
random::PhiloxRandom philox(301, 17);
|
||||||
|
random::SimplePhilox rnd(&philox);
|
||||||
|
for (int op = 0; op < op_nodes; op++) {
|
||||||
|
s += strings::Printf(
|
||||||
|
"node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
|
||||||
|
"type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
|
||||||
|
op, rnd.Uniform(10), rnd.Uniform(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool first = true;
|
||||||
|
while (iters > 0) {
|
||||||
|
Graph* graph = new Graph(OpRegistry::Global());
|
||||||
|
OptimizerMergeTest::InitGraph(s, graph);
|
||||||
|
int N = graph->num_node_ids();
|
||||||
|
if (first) {
|
||||||
|
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
testing::StartTiming();
|
||||||
|
std::unique_ptr<Graph> ug(graph);
|
||||||
|
OptimizeNodeMerge(&ug);
|
||||||
|
testing::StopTiming();
|
||||||
|
}
|
||||||
|
iters -= N; // Our benchmark units are individual graph nodes,
|
||||||
|
// not whole graphs
|
||||||
|
// delete graph;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BENCHMARK(BM_NodeMerge)->Arg(1000)->Arg(10000);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif /* INTEL_MKL */
|
@ -7,9 +7,9 @@
|
|||||||
# append "_gpu" to the test name to invoke the GPU benchmarks. Example:
|
# append "_gpu" to the test name to invoke the GPU benchmarks. Example:
|
||||||
#
|
#
|
||||||
# # for CPU tests
|
# # for CPU tests
|
||||||
# $ bazel test -c opt --copt=-mavx //third_party/tensorflow/core/kernels:my_op_test
|
# $ bazel test --config opt //third_party/tensorflow/core/kernels:my_op_test
|
||||||
# # for GPU benchmarks
|
# # for GPU benchmarks
|
||||||
# $ bazel run -c opt --copt=-mavx --config=cuda //third_party/tensorflow/core/kernels:my_op_test_gpu -- --benchmarks=..
|
# $ bazel run --config opt --config=cuda //third_party/tensorflow/core/kernels:my_op_test_gpu -- --benchmarks=..
|
||||||
#
|
#
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
@ -170,12 +170,11 @@ struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
|
|||||||
desc.pad_w_out = 0;
|
desc.pad_w_out = 0;
|
||||||
desc.threads = num_threads;
|
desc.threads = num_threads;
|
||||||
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
|
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
|
||||||
desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
|
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
|
||||||
desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_LIBXSMM;//LIBXSMM_DNN_CONV_FORMAT_RSCK;
|
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||||
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
|
||||||
|
|
||||||
auto input_ptr = input_backward.data();
|
auto input_ptr = input_backward.data();
|
||||||
auto filter_ptr = kernel.data();
|
auto filter_ptr = kernel.data();
|
||||||
|
@ -219,12 +219,11 @@ class LaunchXsmmConvOp<CPUDevice, float> {
|
|||||||
desc.pad_w_out = 0;
|
desc.pad_w_out = 0;
|
||||||
desc.threads = num_threads;
|
desc.threads = num_threads;
|
||||||
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
|
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
|
||||||
desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
|
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
|
||||||
desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_LIBXSMM;
|
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
|
||||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||||
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
|
||||||
|
|
||||||
if (!CanUseXsmmConv2D(desc, data_format)) {
|
if (!CanUseXsmmConv2D(desc, data_format)) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -220,13 +220,15 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
|
|||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
// UnsortedSegmentSumFunctor implementation for CPUDevice.
|
// UnsortedSegmentSumFunctor implementation for CPUDevice.
|
||||||
|
// todo: Remove duplicate code in UnsortedSegmentSumFunctor and UnsortedSegmentMaxFunctor.
|
||||||
template <typename T, typename Index>
|
template <typename T, typename Index>
|
||||||
struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> {
|
struct UnsortedSegmentSumFunctor<CPUDevice, T, Index>
|
||||||
|
: UnsortedSegmentBaseFunctor<CPUDevice, T, Index> {
|
||||||
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
||||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||||
typename TTypes<Index>::ConstFlat segment_ids,
|
typename TTypes<Index>::ConstFlat segment_ids,
|
||||||
const Index data_size, const T* data,
|
const Index data_size, const T* data,
|
||||||
typename TTypes<T, 2>::Tensor output) {
|
typename TTypes<T, 2>::Tensor output) override {
|
||||||
output.setZero();
|
output.setZero();
|
||||||
if (data_size == 0) {
|
if (data_size == 0) {
|
||||||
return;
|
return;
|
||||||
@ -243,16 +245,44 @@ struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
// UnsortedSegmentMaxFunctor implementation for CPUDevice.
|
||||||
|
template <typename T, typename Index>
|
||||||
|
struct UnsortedSegmentMaxFunctor<CPUDevice, T, Index>
|
||||||
|
: UnsortedSegmentBaseFunctor<CPUDevice, T, Index> {
|
||||||
|
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
||||||
|
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||||
|
typename TTypes<Index>::ConstFlat segment_ids,
|
||||||
|
const Index data_size, const T* data,
|
||||||
|
typename TTypes<T, 2>::Tensor output) override {
|
||||||
|
output.setConstant(std::numeric_limits<T>::min());
|
||||||
|
if (data_size == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int64 N = segment_ids.dimension(0);
|
||||||
|
auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N);
|
||||||
|
for (int64 i = 0; i < N; ++i) {
|
||||||
|
Index j = internal::SubtleMustCopy(segment_ids(i));
|
||||||
|
OP_REQUIRES(ctx, FastBoundsCheck(j, output_rows),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"segment_ids", SliceDebugString(segment_ids_shape, i),
|
||||||
|
" = ", j, " is out of range [0, ", output_rows, ")"));
|
||||||
|
output.template chip<0>(j) =
|
||||||
|
data_flat.template chip<0>(i).cwiseMax(output.template chip<0>(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
// Similar to SegmentReductionOp but can handle unsorted segment definitions and
|
// Base class for SegmentReductionOps that can handle unsorted segment
|
||||||
// specifying size of output.
|
// definitions
|
||||||
|
// and specifying the size of the output in addition to a reduction function
|
||||||
template <typename Device, class T, class Index>
|
template <typename Device, class T, class Index>
|
||||||
class UnsortedSegmentSumOp : public OpKernel {
|
class UnsortedSegmentBaseOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit UnsortedSegmentSumOp(OpKernelConstruction* context)
|
explicit UnsortedSegmentBaseOp(
|
||||||
: OpKernel(context) {}
|
OpKernelConstruction* context,
|
||||||
|
functor::UnsortedSegmentBaseFunctor<Device, T, Index>& functor)
|
||||||
|
: OpKernel(context), reduction_functor_(functor) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const Tensor& data = context->input(0);
|
const Tensor& data = context->input(0);
|
||||||
@ -288,27 +318,70 @@ class UnsortedSegmentSumOp : public OpKernel {
|
|||||||
auto output_flat = output->flat_outer_dims<T>();
|
auto output_flat = output->flat_outer_dims<T>();
|
||||||
|
|
||||||
auto data_ptr = data.template flat<T>().data();
|
auto data_ptr = data.template flat<T>().data();
|
||||||
functor::UnsortedSegmentSumFunctor<Device, T, Index>()(
|
reduction_functor_(context, context->template eigen_device<Device>(),
|
||||||
context, context->template eigen_device<Device>(), output_rows,
|
output_rows, segment_ids.shape(), segment_flat,
|
||||||
segment_ids.shape(), segment_flat, data.NumElements(), data_ptr,
|
data.NumElements(), data_ptr, output_flat);
|
||||||
output_flat);
|
|
||||||
}
|
}
|
||||||
|
private:
|
||||||
|
functor::UnsortedSegmentBaseFunctor<Device, T, Index>& reduction_functor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_CPU_UNSORTED_KERNELS(type, index_type) \
|
template <typename Device, class T, class Index>
|
||||||
|
class UnsortedSegmentSumOp : public UnsortedSegmentBaseOp<Device, T, Index> {
|
||||||
|
public:
|
||||||
|
explicit UnsortedSegmentSumOp(OpKernelConstruction* context)
|
||||||
|
: UnsortedSegmentBaseOp<Device, T, Index>(
|
||||||
|
context,
|
||||||
|
sum_functor_) {}
|
||||||
|
private:
|
||||||
|
functor::UnsortedSegmentSumFunctor<Device, T, Index> sum_functor_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, class T, class Index>
|
||||||
|
class UnsortedSegmentMaxOp : public UnsortedSegmentBaseOp<Device, T, Index> {
|
||||||
|
public:
|
||||||
|
explicit UnsortedSegmentMaxOp(OpKernelConstruction* context)
|
||||||
|
: UnsortedSegmentBaseOp<Device, T, Index>(
|
||||||
|
context,
|
||||||
|
max_functor_) {}
|
||||||
|
private:
|
||||||
|
functor::UnsortedSegmentMaxFunctor<Device, T, Index> max_functor_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<type>("T") \
|
||||||
|
.TypeConstraint<index_type>("Tindices"), \
|
||||||
|
UnsortedSegmentSumOp<CPUDevice, type, index_type>); \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentMax") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<type>("T") \
|
||||||
|
.TypeConstraint<index_type>("Tindices"), \
|
||||||
|
UnsortedSegmentMaxOp<CPUDevice, type, index_type>);
|
||||||
|
|
||||||
|
#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \
|
REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<type>("T") \
|
||||||
.TypeConstraint<index_type>("Tindices"), \
|
.TypeConstraint<index_type>("Tindices"), \
|
||||||
UnsortedSegmentSumOp<CPUDevice, type, index_type>);
|
UnsortedSegmentSumOp<CPUDevice, type, index_type>);
|
||||||
|
|
||||||
#define REGISTER_CPU_UNSORTED_KERNELS_ALL(type) \
|
#define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \
|
||||||
REGISTER_CPU_UNSORTED_KERNELS(type, int32); \
|
REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \
|
||||||
REGISTER_CPU_UNSORTED_KERNELS(type, int64);
|
REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64)
|
||||||
|
|
||||||
TF_CALL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL);
|
#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \
|
||||||
#undef REGISTER_CPU_UNSORTED_KERNELS
|
REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32); \
|
||||||
#undef REGISTER_CPU_UNSORTED_KERNELS_ALL
|
REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64)
|
||||||
|
|
||||||
|
TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL);
|
||||||
|
REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64);
|
||||||
|
REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
|
||||||
|
#undef REGISTER_REAL_CPU_UNSORTED_KERNELS
|
||||||
|
#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS
|
||||||
|
#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
|
||||||
|
#undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#define REGISTER_GPU_UNSORTED_KERNELS(type, index_type) \
|
#define REGISTER_GPU_UNSORTED_KERNELS(type, index_type) \
|
||||||
|
@ -26,6 +26,17 @@ namespace tensorflow {
|
|||||||
class OpKernelContext;
|
class OpKernelContext;
|
||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
// BaseFunctor for definition of UnsorteSegmentReductionOp
|
||||||
|
// for usage without templates.
|
||||||
|
template <typename Device, typename T, typename Index>
|
||||||
|
struct UnsortedSegmentBaseFunctor{
|
||||||
|
virtual ~UnsortedSegmentBaseFunctor(){}
|
||||||
|
virtual void operator()(OpKernelContext* ctx, const Device& d,
|
||||||
|
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||||
|
typename TTypes<Index>::ConstFlat segment_ids,
|
||||||
|
const Index data_size, const T* data,
|
||||||
|
typename TTypes<T, 2>::Tensor output){};
|
||||||
|
};
|
||||||
|
|
||||||
// Functor for UnsortedSegmentSumOp.
|
// Functor for UnsortedSegmentSumOp.
|
||||||
// 'output_rows': the number of output segments (unique segment ids in
|
// 'output_rows': the number of output segments (unique segment ids in
|
||||||
@ -37,7 +48,7 @@ namespace functor {
|
|||||||
// 'data': input data tensor.
|
// 'data': input data tensor.
|
||||||
// 'output': output reshaped to {output_rows, output.size/output_rows}
|
// 'output': output reshaped to {output_rows, output.size/output_rows}
|
||||||
template <typename Device, typename T, typename Index>
|
template <typename Device, typename T, typename Index>
|
||||||
struct UnsortedSegmentSumFunctor {
|
struct UnsortedSegmentSumFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> {
|
||||||
void operator()(OpKernelContext* ctx, const Device& d,
|
void operator()(OpKernelContext* ctx, const Device& d,
|
||||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||||
typename TTypes<Index>::ConstFlat segment_ids,
|
typename TTypes<Index>::ConstFlat segment_ids,
|
||||||
@ -45,6 +56,23 @@ struct UnsortedSegmentSumFunctor {
|
|||||||
typename TTypes<T, 2>::Tensor output);
|
typename TTypes<T, 2>::Tensor output);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Functor for UnsortedSegmentMaxOp.
|
||||||
|
// 'output_rows': the number of output segments (unique segment ids in
|
||||||
|
// 'segment_ids').
|
||||||
|
// 'segment_ids_shape': shape of 'segment_ids' tensor.
|
||||||
|
// 'segment_ids': unsorted map from input to output segment ids at which to
|
||||||
|
// perform segment sum operation.
|
||||||
|
// 'data_size': size of input data tensor.
|
||||||
|
// 'data': input data tensor.
|
||||||
|
// 'output': output reshaped to {output_rows, output.size/output_rows}
|
||||||
|
template <typename Device, typename T, typename Index>
|
||||||
|
struct UnsortedSegmentMaxFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> {
|
||||||
|
void operator()(OpKernelContext* ctx, const Device& d,
|
||||||
|
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||||
|
typename TTypes<Index>::ConstFlat segment_ids,
|
||||||
|
const Index data_size, const T* data,
|
||||||
|
typename TTypes<T, 2>::Tensor output);
|
||||||
|
};
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -56,12 +56,12 @@ namespace functor {
|
|||||||
|
|
||||||
// UnsortedSegmentSumFunctor implementation for GPUDevice.
|
// UnsortedSegmentSumFunctor implementation for GPUDevice.
|
||||||
template <typename T, typename Index>
|
template <typename T, typename Index>
|
||||||
struct UnsortedSegmentSumFunctor<GPUDevice, T, Index> {
|
struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>: UnsortedSegmentBaseFunctor<GPUDevice, T, Index> {
|
||||||
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
||||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||||
typename TTypes<Index>::ConstFlat segment_ids,
|
typename TTypes<Index>::ConstFlat segment_ids,
|
||||||
const Index data_size, const T* data,
|
const Index data_size, const T* data,
|
||||||
typename TTypes<T, 2>::Tensor output) {
|
typename TTypes<T, 2>::Tensor output) override {
|
||||||
if (output.size() == 0) {
|
if (output.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#ifdef TENSORFLOW_USE_LIBXSMM
|
#ifdef TENSORFLOW_USE_LIBXSMM
|
||||||
#include "include/libxsmm_intrinsics_x86.h"
|
#include "include/libxsmm_intrinsics_x86.h"
|
||||||
|
#include "include/libxsmm_malloc.h"
|
||||||
#include "include/libxsmm_spmdm.h"
|
#include "include/libxsmm_spmdm.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -896,6 +897,8 @@ class LibxsmmSparseMatMul {
|
|||||||
} else {
|
} else {
|
||||||
std::unique_ptr<TensorInfoCacheEntry> e{
|
std::unique_ptr<TensorInfoCacheEntry> e{
|
||||||
new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
|
new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
|
||||||
|
// setup scoped allocator, which uses cpu_allocator() for this scope
|
||||||
|
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
|
||||||
libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
|
libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,7 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(void);
|
|||||||
|
|
||||||
#include "include/libxsmm_cpuid.h"
|
#include "include/libxsmm_cpuid.h"
|
||||||
#include "libxsmm_dnn_handle.h"
|
#include "libxsmm_dnn_handle.h"
|
||||||
|
#include "libxsmm_malloc.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -143,26 +144,28 @@ struct HashFunction{
|
|||||||
S << w.d.S; u << w.d.u;
|
S << w.d.S; u << w.d.u;
|
||||||
v << w.d.v; padh << w.d.pad_h_in;
|
v << w.d.v; padh << w.d.pad_h_in;
|
||||||
padw << w.d.pad_w_in;
|
padw << w.d.pad_w_in;
|
||||||
|
|
||||||
|
|
||||||
std::string out_ = N.str() + C.str()\
|
std::string out_ = N.str() + C.str()\
|
||||||
+ H.str() + W.str()\
|
+ H.str() + W.str()\
|
||||||
+ K.str() + R.str()\
|
+ K.str() + R.str()\
|
||||||
+ S.str() + u.str()\
|
+ S.str() + u.str()\
|
||||||
+ v.str() + padh.str()\
|
+ v.str() + padh.str()\
|
||||||
+ padw.str();
|
+ padw.str();
|
||||||
|
|
||||||
return ( std::hash<std::string>()(out_));
|
return ( std::hash<std::string>()(out_));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class handles{
|
class handles{
|
||||||
public:
|
public:
|
||||||
libxsmm_dnn_conv_handle* find( const libxsmm_dnn_conv_desc_wrap &w) {
|
libxsmm_dnn_layer* find( const libxsmm_dnn_conv_desc_wrap &w) {
|
||||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i = libxsmm_handles.find(w);
|
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*,
|
||||||
|
HashFunction>::iterator i = libxsmm_handles.find(w);
|
||||||
if (i == libxsmm_handles.end()){
|
if (i == libxsmm_handles.end()){
|
||||||
libxsmm_dnn_err_t status;
|
libxsmm_dnn_err_t status;
|
||||||
libxsmm_dnn_conv_handle* libxsmm_handle = libxsmm_dnn_create_conv_handle_check(w.d, &status);
|
libxsmm_dnn_layer* libxsmm_handle =
|
||||||
|
libxsmm_dnn_create_conv_layer(w.d, &status);
|
||||||
chk_libxsmm_err(status, "Create handle");
|
chk_libxsmm_err(status, "Create handle");
|
||||||
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
|
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
|
||||||
return libxsmm_handle;
|
return libxsmm_handle;
|
||||||
@ -171,15 +174,14 @@ class handles{
|
|||||||
return i->second;
|
return i->second;
|
||||||
}
|
}
|
||||||
~handles(){
|
~handles(){
|
||||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i;
|
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*,
|
||||||
|
HashFunction>::iterator i;
|
||||||
for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
|
for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
|
||||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(i->second),
|
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
|
||||||
"Destroy handle");
|
"Destroy handle");
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
|
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction> libxsmm_handles;
|
||||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction> libxsmm_handles;
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static handles libxsmm_handles;
|
static handles libxsmm_handles;
|
||||||
@ -187,22 +189,25 @@ static handles libxsmm_handles;
|
|||||||
template <typename InputPtr, typename FilterPtr, typename OutputPtr>
|
template <typename InputPtr, typename FilterPtr, typename OutputPtr>
|
||||||
static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||||
const libxsmm_dnn_conv_desc& desc,
|
const libxsmm_dnn_conv_desc& desc,
|
||||||
libxsmm_dnn_conv_kind kind, InputPtr input,
|
libxsmm_dnn_compute_kind kind, InputPtr input,
|
||||||
FilterPtr filter, OutputPtr output) {
|
FilterPtr filter, OutputPtr output) {
|
||||||
|
// setup scoped allocator, which adopts the allocator from the context
|
||||||
|
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx);
|
||||||
libxsmm_dnn_err_t status;
|
libxsmm_dnn_err_t status;
|
||||||
libxsmm_dnn_conv_handle* libxsmm_handle;
|
libxsmm_dnn_layer* libxsmm_handle;
|
||||||
libxsmm_dnn_conv_desc_wrap w(desc);
|
libxsmm_dnn_conv_desc_wrap w(desc);
|
||||||
|
void* scratch;
|
||||||
|
|
||||||
if(kind == LIBXSMM_DNN_CONV_KIND_FWD)
|
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
||||||
libxsmm_handle = libxsmm_handles.find(w);
|
libxsmm_handle = libxsmm_handles.find(w);
|
||||||
else{
|
else {
|
||||||
libxsmm_handle = libxsmm_dnn_create_conv_handle_check(desc, &status);
|
libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
|
||||||
chk_libxsmm_err(status, "Create handle");
|
chk_libxsmm_err(status, "Create handle");
|
||||||
}
|
}
|
||||||
|
|
||||||
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
|
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
|
||||||
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
|
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
|
||||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
|
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
|
||||||
"Destroy handle");
|
"Destroy handle");
|
||||||
return false; // Use non-libxsmm code
|
return false; // Use non-libxsmm code
|
||||||
}
|
}
|
||||||
@ -211,23 +216,23 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
|||||||
libxsmm_dnn_buffer* libxsmm_input;
|
libxsmm_dnn_buffer* libxsmm_input;
|
||||||
libxsmm_dnn_buffer* libxsmm_output;
|
libxsmm_dnn_buffer* libxsmm_output;
|
||||||
libxsmm_dnn_filter* libxsmm_filter;
|
libxsmm_dnn_filter* libxsmm_filter;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
const DeviceBase::CpuWorkerThreads* worker_threads =
|
const DeviceBase::CpuWorkerThreads* worker_threads =
|
||||||
ctx->device()->tensorflow_cpu_worker_threads();
|
ctx->device()->tensorflow_cpu_worker_threads();
|
||||||
|
|
||||||
int num_threads = worker_threads->num_threads;
|
int num_threads = worker_threads->num_threads;
|
||||||
*/
|
*/
|
||||||
|
|
||||||
int ifmblock = (libxsmm_handle->ifmblock);
|
int ifmblock = (libxsmm_handle->ifmblock);
|
||||||
int ofmblock = (libxsmm_handle->ofmblock);
|
int ofmblock = (libxsmm_handle->ofmblock);
|
||||||
|
|
||||||
int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
|
int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
|
||||||
int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1;
|
int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1;
|
||||||
float *native_filter = (float*)libxsmm_aligned_malloc( blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), 2097152);
|
float *native_filter = (float*)libxsmm_aligned_scratch(
|
||||||
|
blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float),
|
||||||
|
2097152);
|
||||||
|
|
||||||
|
|
||||||
const DeviceBase::CpuWorkerThreads* worker_threads =
|
const DeviceBase::CpuWorkerThreads* worker_threads =
|
||||||
ctx->device()->tensorflow_cpu_worker_threads();
|
ctx->device()->tensorflow_cpu_worker_threads();
|
||||||
|
|
||||||
@ -264,50 +269,78 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
|||||||
count.Wait();
|
count.Wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
libxsmm_input = libxsmm_dnn_link_input_buffer_check(
|
libxsmm_input = libxsmm_dnn_link_buffer(
|
||||||
libxsmm_handle, input, LIBXSMM_DNN_CONV_FORMAT_NHWC_PTR, &status);
|
libxsmm_handle, LIBXSMM_DNN_INPUT, input, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
||||||
chk_libxsmm_err(status, "Link input buffer");
|
chk_libxsmm_err(status, "Link input buffer");
|
||||||
libxsmm_output = libxsmm_dnn_link_output_buffer_check(
|
libxsmm_output = libxsmm_dnn_link_buffer(
|
||||||
libxsmm_handle, output, LIBXSMM_DNN_CONV_FORMAT_NHWC_PTR, &status);
|
libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
||||||
chk_libxsmm_err(status, "Link output buffer");
|
chk_libxsmm_err(status, "Link output buffer");
|
||||||
libxsmm_filter = libxsmm_dnn_link_filter_check(
|
libxsmm_filter = libxsmm_dnn_link_filter(
|
||||||
libxsmm_handle, native_filter, LIBXSMM_DNN_CONV_FORMAT_LIBXSMM_PTR, &status);
|
libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
|
||||||
chk_libxsmm_err(status, "Link filter");
|
chk_libxsmm_err(status, "Link filter");
|
||||||
|
|
||||||
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output");
|
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output");
|
||||||
|
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_input_buffer(libxsmm_handle, libxsmm_input),
|
|
||||||
"Bind input");
|
|
||||||
chk_libxsmm_err(
|
|
||||||
libxsmm_dnn_bind_output_buffer(libxsmm_handle, libxsmm_output),
|
|
||||||
"Bind output");
|
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter),
|
|
||||||
"Bind filter");
|
|
||||||
|
|
||||||
if (kind == LIBXSMM_DNN_CONV_KIND_BWD) {
|
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
||||||
libxsmm_dnn_transpose_filter(libxsmm_handle);
|
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT),
|
||||||
|
"Bind input forward");
|
||||||
|
chk_libxsmm_err(
|
||||||
|
libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||||
|
"Bind output forward");
|
||||||
|
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
|
||||||
|
"Bind filter forward");
|
||||||
|
} else {
|
||||||
|
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_GRADIENT_INPUT),
|
||||||
|
"Bind input backward");
|
||||||
|
chk_libxsmm_err(
|
||||||
|
libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||||
|
"Bind output backward");
|
||||||
|
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
|
||||||
|
"Bind filter backward");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* bind scratch */
|
||||||
|
scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, kind, &status ), 2097152);
|
||||||
|
chk_libxsmm_err( status, "scratch allocation" );
|
||||||
|
chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, kind, scratch ), "binding scratch" );
|
||||||
|
|
||||||
|
if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||||
|
libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER);
|
||||||
}
|
}
|
||||||
|
|
||||||
BlockingCounter counter(num_threads);
|
BlockingCounter counter(num_threads);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < num_threads; ++i) {
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
worker_threads->workers->Schedule([=, &counter]() {
|
worker_threads->workers->Schedule([=, &counter]() {
|
||||||
chk_libxsmm_err(libxsmm_dnn_convolve_st(libxsmm_handle, kind, 0, i),
|
chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i),
|
||||||
"Worker");
|
"Worker");
|
||||||
counter.DecrementCount();
|
counter.DecrementCount();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
counter.Wait();
|
counter.Wait();
|
||||||
|
|
||||||
|
/* clean up */
|
||||||
|
chk_libxsmm_err( libxsmm_dnn_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ), "release scratch" );
|
||||||
|
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
||||||
|
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ), "release input" );
|
||||||
|
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT ), "release output" );
|
||||||
|
chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" );
|
||||||
|
} else {
|
||||||
|
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT ), "release input" );
|
||||||
|
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ), "release output" );
|
||||||
|
chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" );
|
||||||
|
}
|
||||||
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
|
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
|
||||||
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
|
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
|
||||||
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
|
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
|
||||||
|
|
||||||
if(kind != LIBXSMM_DNN_CONV_KIND_FWD)
|
if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
||||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
|
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
|
||||||
"Destroy handle");
|
"Destroy handle");
|
||||||
|
|
||||||
libxsmm_free(native_filter);
|
libxsmm_free(native_filter);
|
||||||
|
libxsmm_free(scratch);
|
||||||
return true; // Succeeded
|
return true; // Succeeded
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,7 +348,7 @@ template <typename T>
|
|||||||
struct XsmmFwdConv2D<CPUDevice, T> {
|
struct XsmmFwdConv2D<CPUDevice, T> {
|
||||||
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
||||||
const T* input, const T* filter, T* output) {
|
const T* input, const T* filter, T* output) {
|
||||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_FWD, input,
|
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD, input,
|
||||||
filter, output);
|
filter, output);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -324,7 +357,7 @@ template <typename T>
|
|||||||
struct XsmmBkwInputConv2D<CPUDevice, T> {
|
struct XsmmBkwInputConv2D<CPUDevice, T> {
|
||||||
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
||||||
T* input, const T* filter, const T* output) {
|
T* input, const T* filter, const T* output) {
|
||||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_BWD, input,
|
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD, input,
|
||||||
filter, output);
|
filter, output);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -333,7 +366,7 @@ template <typename T>
|
|||||||
struct XsmmBkwFilterConv2D<CPUDevice, T> {
|
struct XsmmBkwFilterConv2D<CPUDevice, T> {
|
||||||
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
||||||
const T* input, T* filter, const T* output) {
|
const T* input, T* filter, const T* output) {
|
||||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_UPD, input,
|
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD, input,
|
||||||
filter, output);
|
filter, output);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -188,6 +188,8 @@ class XsmmConv2DTest : public OpsTestBase {
|
|||||||
TEST_F(XsmmConv2DTest, Basic) {
|
TEST_F(XsmmConv2DTest, Basic) {
|
||||||
MakeOp(1);
|
MakeOp(1);
|
||||||
|
|
||||||
|
// setup scoped allocator, which uses cpu_allocator() for this scope
|
||||||
|
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
|
||||||
|
|
||||||
int ifw = 14; /* input width, "W" */
|
int ifw = 14; /* input width, "W" */
|
||||||
int ifh = 14; /* input height, "H" */
|
int ifh = 14; /* input height, "H" */
|
||||||
@ -223,9 +225,9 @@ TEST_F(XsmmConv2DTest, Basic) {
|
|||||||
//Initialization of Filter and Image
|
//Initialization of Filter and Image
|
||||||
|
|
||||||
/* allocate data */
|
/* allocate data */
|
||||||
float *naive_input = (float*)libxsmm_aligned_malloc( nImg*nIfm*ifhp*ifwp*sizeof(float), 2097152);
|
float *naive_input = (float*)libxsmm_aligned_scratch( nImg*nIfm*ifhp*ifwp*sizeof(float), 2097152);
|
||||||
float *naive_output = (float*)libxsmm_aligned_malloc( nImg*nOfm*ofhp*ofwp*sizeof(float), 2097152);
|
float *naive_output = (float*)libxsmm_aligned_scratch( nImg*nOfm*ofhp*ofwp*sizeof(float), 2097152);
|
||||||
float *naive_filter = (float*)libxsmm_aligned_malloc( nOfm*nIfm*kh*kw* sizeof(float), 2097152);
|
float *naive_filter = (float*)libxsmm_aligned_scratch( nOfm*nIfm*kh*kw* sizeof(float), 2097152);
|
||||||
/* initialize data */
|
/* initialize data */
|
||||||
init_buf(naive_input, nImg*nIfm*ifhp*ifwp, 0, 0);
|
init_buf(naive_input, nImg*nIfm*ifhp*ifwp, 0, 0);
|
||||||
zero_buf(naive_output, nImg*nOfm*ofhp*ofwp);
|
zero_buf(naive_output, nImg*nOfm*ofhp*ofwp);
|
||||||
@ -322,12 +324,11 @@ TEST(XsmmConv2DTest, Basic) {
|
|||||||
desc.pad_w_out = 0;
|
desc.pad_w_out = 0;
|
||||||
desc.threads = num_threads;
|
desc.threads = num_threads;
|
||||||
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
|
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
|
||||||
desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
|
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
|
||||||
desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_LIBXSMM;//LIBXSMM_DNN_CONV_FORMAT_RSCK;
|
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||||
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
|
||||||
|
|
||||||
if (!CanUseXsmmConv2D(desc, data_format)) {
|
if (!CanUseXsmmConv2D(desc, data_format)) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -588,6 +588,7 @@ REGISTER_OP_GRADIENT("Mean", MeanGrad);
|
|||||||
// REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
|
// REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
|
||||||
// REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
|
// REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
|
||||||
// REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);
|
// REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);
|
||||||
|
// REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad);
|
||||||
|
|
||||||
Status MinMaxGradHelper(const string& op, const AttrSlice& attrs,
|
Status MinMaxGradHelper(const string& op, const AttrSlice& attrs,
|
||||||
FunctionDef* g) {
|
FunctionDef* g) {
|
||||||
|
@ -1342,6 +1342,36 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
|
||||||
|
ShapeHandle s_data = c->input(0);
|
||||||
|
ShapeHandle s_segment_ids = c->input(1);
|
||||||
|
ShapeHandle s_num_segments = c->input(2);
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
|
||||||
|
|
||||||
|
ShapeHandle out;
|
||||||
|
|
||||||
|
// Leading dimensions of data must be compatible with dimensions of
|
||||||
|
// <s_segment_ids>.
|
||||||
|
if (c->RankKnown(s_segment_ids)) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
|
||||||
|
|
||||||
|
// Get the value of the num_segments input tensor.
|
||||||
|
DimensionHandle num_segments_dim;
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
|
||||||
|
|
||||||
|
// Output is {segment_id_rank} + s_data[segment_id_rank:].
|
||||||
|
ShapeHandle s_data_suffix;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
|
||||||
|
} else {
|
||||||
|
out = c->UnknownShape();
|
||||||
|
}
|
||||||
|
c->set_output(0, out);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
REGISTER_OP("SegmentSum")
|
REGISTER_OP("SegmentSum")
|
||||||
@ -1495,36 +1525,7 @@ REGISTER_OP("UnsortedSegmentSum")
|
|||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: numbertype")
|
.Attr("T: numbertype")
|
||||||
.Attr("Tindices: {int32,int64}")
|
.Attr("Tindices: {int32,int64}")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn(UnsortedSegmentReductionShapeFn)
|
||||||
ShapeHandle s_data = c->input(0);
|
|
||||||
ShapeHandle s_segment_ids = c->input(1);
|
|
||||||
ShapeHandle s_num_segments = c->input(2);
|
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
|
|
||||||
|
|
||||||
ShapeHandle out;
|
|
||||||
|
|
||||||
// Leading dimensions of data must be compatible with dimensions of
|
|
||||||
// <s_segment_ids>.
|
|
||||||
if (c->RankKnown(s_segment_ids)) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
|
|
||||||
|
|
||||||
// Get the value of the num_segments input tensor.
|
|
||||||
DimensionHandle num_segments_dim;
|
|
||||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
|
|
||||||
|
|
||||||
// Output is {segment_id_rank} + s_data[segment_id_rank:].
|
|
||||||
ShapeHandle s_data_suffix;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
|
|
||||||
} else {
|
|
||||||
out = c->UnknownShape();
|
|
||||||
}
|
|
||||||
c->set_output(0, out);
|
|
||||||
return Status::OK();
|
|
||||||
})
|
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Computes the sum along segments of a tensor.
|
Computes the sum along segments of a tensor.
|
||||||
|
|
||||||
@ -1554,6 +1555,43 @@ output: Has same shape as data, except for the first `segment_ids.rank`
|
|||||||
|
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
|
||||||
|
REGISTER_OP("UnsortedSegmentMax")
|
||||||
|
.Input("data: T")
|
||||||
|
.Input("segment_ids: Tindices")
|
||||||
|
.Input("num_segments: int32")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("T: realnumbertype")
|
||||||
|
.Attr("Tindices: {int32,int64}")
|
||||||
|
.SetShapeFn(UnsortedSegmentReductionShapeFn)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Computes the Max along segments of a tensor.
|
||||||
|
|
||||||
|
Read [the section on
|
||||||
|
Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation
|
||||||
|
of segments.
|
||||||
|
|
||||||
|
This operator is similar to the [unsorted segment sum operator](../../api_docs/python/math_ops.md#UnsortedSegmentSum).
|
||||||
|
Instead of computing the sum over segments, it computes the maximum
|
||||||
|
such that:
|
||||||
|
|
||||||
|
\\(output_i = \max_j data_j\\) where max is over `j` such
|
||||||
|
that `segment_ids[j] == i`.
|
||||||
|
|
||||||
|
If the maximum is empty for a given segment ID `i`, it outputs the smallest possible value for specific numeric type,
|
||||||
|
`output[i] = numeric_limits<T>::min()`.
|
||||||
|
|
||||||
|
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||||
|
<img style="width:100%" src="../../images/UnsortedSegmentSum.png" alt>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||||
|
first dimension.
|
||||||
|
|
||||||
|
output: Has same shape as data, except for dimension 0 which
|
||||||
|
has size `num_segments`.
|
||||||
|
|
||||||
|
)doc");
|
||||||
REGISTER_OP("SparseSegmentSum")
|
REGISTER_OP("SparseSegmentSum")
|
||||||
.Input("data: T")
|
.Input("data: T")
|
||||||
.Input("indices: Tidx")
|
.Input("indices: Tidx")
|
||||||
|
@ -25261,6 +25261,59 @@ op {
|
|||||||
summary: "Computes the sum along segments of a tensor."
|
summary: "Computes the sum along segments of a tensor."
|
||||||
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such\nthat `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\nrange of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
|
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such\nthat `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\nrange of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "UnsortedSegmentSum"
|
||||||
|
input_arg {
|
||||||
|
name: "data"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "segment_ids"
|
||||||
|
description: "A tensor whose shape is a prefix of `data.shape`."
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "num_segments"
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "Has same shape as data, except for the first `segment_ids.rank`\ndimensions, which are replaced with a single dimension which has size\n`num_segments`."
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_QINT8
|
||||||
|
type: DT_QUINT8
|
||||||
|
type: DT_QINT32
|
||||||
|
type: DT_HALF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
summary: "Computes the max along segments of a tensor."
|
||||||
|
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "Unstage"
|
name: "Unstage"
|
||||||
output_arg {
|
output_arg {
|
||||||
|
@ -77,14 +77,17 @@ void LogMessage::GenerateLogMessage() {
|
|||||||
|
|
||||||
void LogMessage::GenerateLogMessage() {
|
void LogMessage::GenerateLogMessage() {
|
||||||
static EnvTime* env_time = tensorflow::EnvTime::Default();
|
static EnvTime* env_time = tensorflow::EnvTime::Default();
|
||||||
time_t now = static_cast<time_t>(env_time->NowSeconds());
|
uint64 now_micros = env_time->NowMicros();
|
||||||
|
time_t now_seconds = static_cast<time_t>(now_micros / 1000000);
|
||||||
|
int32 micros_remainder = static_cast<int32>(now_micros % 1000000);
|
||||||
const size_t time_buffer_size = 30;
|
const size_t time_buffer_size = 30;
|
||||||
char time_buffer[time_buffer_size];
|
char time_buffer[time_buffer_size];
|
||||||
strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", localtime(&now));
|
strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S",
|
||||||
|
localtime(&now_seconds));
|
||||||
|
|
||||||
// TODO(jeff,sanjay): Replace this with something that logs through the env.
|
// TODO(jeff,sanjay): Replace this with something that logs through the env.
|
||||||
fprintf(stderr, "%s: %c %s:%d] %s\n", time_buffer, "IWEF"[severity_], fname_,
|
fprintf(stderr, "%s.%06d: %c %s:%d] %s\n", time_buffer, micros_remainder,
|
||||||
line_, str().c_str());
|
"IWEF"[severity_], fname_, line_, str().c_str());
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -18,9 +18,9 @@ limitations under the License.
|
|||||||
|
|
||||||
// TensorFlow uses semantic versioning, see http://semver.org/.
|
// TensorFlow uses semantic versioning, see http://semver.org/.
|
||||||
|
|
||||||
#define TF_MAJOR_VERSION 0
|
#define TF_MAJOR_VERSION 1
|
||||||
#define TF_MINOR_VERSION 12
|
#define TF_MINOR_VERSION 0
|
||||||
#define TF_PATCH_VERSION head
|
#define TF_PATCH_VERSION 0-rc1
|
||||||
|
|
||||||
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
|
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
|
||||||
// "-beta", "-rc", "-rc.1")
|
// "-beta", "-rc", "-rc.1")
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
Run using bazel:
|
Run using bazel:
|
||||||
|
|
||||||
bazel run -c opt \
|
bazel run --config opt \
|
||||||
<...>/tensorflow/examples/how_tos/reading_data:fully_connected_preloaded
|
<...>/tensorflow/examples/how_tos/reading_data:fully_connected_preloaded
|
||||||
|
|
||||||
or, if installed via pip:
|
or, if installed via pip:
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
Run using bazel:
|
Run using bazel:
|
||||||
|
|
||||||
bazel run -c opt \
|
bazel run --config opt \
|
||||||
<...>/tensorflow/examples/how_tos/reading_data:fully_connected_preloaded_var
|
<...>/tensorflow/examples/how_tos/reading_data:fully_connected_preloaded_var
|
||||||
|
|
||||||
or, if installed via pip:
|
or, if installed via pip:
|
||||||
|
@ -346,6 +346,17 @@ def read_list_of_floats_from_file(file_path):
|
|||||||
|
|
||||||
bottleneck_path_2_bottleneck_values = {}
|
bottleneck_path_2_bottleneck_values = {}
|
||||||
|
|
||||||
|
def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
|
||||||
|
image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor):
|
||||||
|
print('Creating bottleneck at ' + bottleneck_path)
|
||||||
|
image_path = get_image_path(image_lists, label_name, index, image_dir, category)
|
||||||
|
if not gfile.Exists(image_path):
|
||||||
|
tf.logging.fatal('File does not exist %s', image_path)
|
||||||
|
image_data = gfile.FastGFile(image_path, 'rb').read()
|
||||||
|
bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor)
|
||||||
|
bottleneck_string = ','.join(str(x) for x in bottleneck_values)
|
||||||
|
with open(bottleneck_path, 'w') as bottleneck_file:
|
||||||
|
bottleneck_file.write(bottleneck_string)
|
||||||
|
|
||||||
def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
|
def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
|
||||||
category, bottleneck_dir, jpeg_data_tensor,
|
category, bottleneck_dir, jpeg_data_tensor,
|
||||||
@ -376,28 +387,25 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
|
|||||||
sub_dir = label_lists['dir']
|
sub_dir = label_lists['dir']
|
||||||
sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
|
sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
|
||||||
ensure_dir_exists(sub_dir_path)
|
ensure_dir_exists(sub_dir_path)
|
||||||
bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
|
bottleneck_path = get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, category)
|
||||||
bottleneck_dir, category)
|
|
||||||
if not os.path.exists(bottleneck_path):
|
if not os.path.exists(bottleneck_path):
|
||||||
print('Creating bottleneck at ' + bottleneck_path)
|
create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
|
||||||
image_path = get_image_path(image_lists, label_name, index, image_dir,
|
|
||||||
category)
|
|
||||||
if not gfile.Exists(image_path):
|
|
||||||
tf.logging.fatal('File does not exist %s', image_path)
|
|
||||||
image_data = gfile.FastGFile(image_path, 'rb').read()
|
|
||||||
bottleneck_values = run_bottleneck_on_image(sess, image_data,
|
|
||||||
jpeg_data_tensor,
|
|
||||||
bottleneck_tensor)
|
|
||||||
bottleneck_string = ','.join(str(x) for x in bottleneck_values)
|
|
||||||
with open(bottleneck_path, 'w') as bottleneck_file:
|
|
||||||
bottleneck_file.write(bottleneck_string)
|
|
||||||
|
|
||||||
with open(bottleneck_path, 'r') as bottleneck_file:
|
with open(bottleneck_path, 'r') as bottleneck_file:
|
||||||
bottleneck_string = bottleneck_file.read()
|
bottleneck_string = bottleneck_file.read()
|
||||||
bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
|
did_hit_error = False
|
||||||
|
try:
|
||||||
|
bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
|
||||||
|
except:
|
||||||
|
print("Invalid float found, recreating bottleneck")
|
||||||
|
did_hit_error = True
|
||||||
|
if did_hit_error:
|
||||||
|
create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
|
||||||
|
with open(bottleneck_path, 'r') as bottleneck_file:
|
||||||
|
bottleneck_string = bottleneck_file.read()
|
||||||
|
# Allow exceptions to propagate here, since they shouldn't happen after a fresh creation
|
||||||
|
bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
|
||||||
return bottleneck_values
|
return bottleneck_values
|
||||||
|
|
||||||
|
|
||||||
def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
|
def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
|
||||||
jpeg_data_tensor, bottleneck_tensor):
|
jpeg_data_tensor, bottleneck_tensor):
|
||||||
"""Ensures all the training, testing, and validation bottlenecks are cached.
|
"""Ensures all the training, testing, and validation bottlenecks are cached.
|
||||||
@ -430,6 +438,7 @@ def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
|
|||||||
get_or_create_bottleneck(sess, image_lists, label_name, index,
|
get_or_create_bottleneck(sess, image_lists, label_name, index,
|
||||||
image_dir, category, bottleneck_dir,
|
image_dir, category, bottleneck_dir,
|
||||||
jpeg_data_tensor, bottleneck_tensor)
|
jpeg_data_tensor, bottleneck_tensor)
|
||||||
|
|
||||||
how_many_bottlenecks += 1
|
how_many_bottlenecks += 1
|
||||||
if how_many_bottlenecks % 100 == 0:
|
if how_many_bottlenecks % 100 == 0:
|
||||||
print(str(how_many_bottlenecks) + ' bottleneck files created.')
|
print(str(how_many_bottlenecks) + ' bottleneck files created.')
|
||||||
|
@ -32,7 +32,7 @@ normalized from 0-1 in left top right bottom order.
|
|||||||
To build it, run this command:
|
To build it, run this command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ bazel build -c opt tensorflow/examples/multibox_detector/...
|
$ bazel build --config opt tensorflow/examples/multibox_detector/...
|
||||||
```
|
```
|
||||||
|
|
||||||
That should build a binary executable that you can then run like this:
|
That should build a binary executable that you can then run like this:
|
||||||
|
@ -111,6 +111,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"url = 'http://commondatastorage.googleapis.com/books1000/'\n",
|
"url = 'http://commondatastorage.googleapis.com/books1000/'\n",
|
||||||
"last_percent_reported = None\n",
|
"last_percent_reported = None\n",
|
||||||
|
"data_root = '.' # Change me to store data elsewhere\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def download_progress_hook(count, blockSize, totalSize):\n",
|
"def download_progress_hook(count, blockSize, totalSize):\n",
|
||||||
" \"\"\"A hook to report the progress of a download. This is mostly intended for users with\n",
|
" \"\"\"A hook to report the progress of a download. This is mostly intended for users with\n",
|
||||||
@ -131,17 +132,18 @@
|
|||||||
" \n",
|
" \n",
|
||||||
"def maybe_download(filename, expected_bytes, force=False):\n",
|
"def maybe_download(filename, expected_bytes, force=False):\n",
|
||||||
" \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n",
|
" \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n",
|
||||||
" if force or not os.path.exists(filename):\n",
|
" dest_filename = os.path.join(data_root, filename)\n",
|
||||||
|
" if force or not os.path.exists(dest_filename):\n",
|
||||||
" print('Attempting to download:', filename) \n",
|
" print('Attempting to download:', filename) \n",
|
||||||
" filename, _ = urlretrieve(url + filename, filename, reporthook=download_progress_hook)\n",
|
" filename, _ = urlretrieve(url + filename, dest_filename, reporthook=download_progress_hook)\n",
|
||||||
" print('\\nDownload Complete!')\n",
|
" print('\\nDownload Complete!')\n",
|
||||||
" statinfo = os.stat(filename)\n",
|
" statinfo = os.stat(dest_filename)\n",
|
||||||
" if statinfo.st_size == expected_bytes:\n",
|
" if statinfo.st_size == expected_bytes:\n",
|
||||||
" print('Found and verified', filename)\n",
|
" print('Found and verified', dest_filename)\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" raise Exception(\n",
|
" raise Exception(\n",
|
||||||
" 'Failed to verify ' + filename + '. Can you get to it with a browser?')\n",
|
" 'Failed to verify ' + dest_filename + '. Can you get to it with a browser?')\n",
|
||||||
" return filename\n",
|
" return dest_filename\n",
|
||||||
"\n",
|
"\n",
|
||||||
"train_filename = maybe_download('notMNIST_large.tar.gz', 247336696)\n",
|
"train_filename = maybe_download('notMNIST_large.tar.gz', 247336696)\n",
|
||||||
"test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)"
|
"test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)"
|
||||||
@ -683,7 +685,7 @@
|
|||||||
"cellView": "both"
|
"cellView": "both"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"pickle_file = 'notMNIST.pickle'\n",
|
"pickle_file = os.path.join(data_root, 'notMNIST.pickle')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"try:\n",
|
"try:\n",
|
||||||
" f = open(pickle_file, 'wb')\n",
|
" f = open(pickle_file, 'wb')\n",
|
||||||
|
@ -1,4 +1,185 @@
|
|||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.ByteSize()` {#TaggedRunMetadata.ByteSize}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.Clear()` {#TaggedRunMetadata.Clear}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.ClearExtension(extension_handle)` {#TaggedRunMetadata.ClearExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.ClearField(field_name)` {#TaggedRunMetadata.ClearField}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.CopyFrom(other_msg)` {#TaggedRunMetadata.CopyFrom}
|
||||||
|
|
||||||
|
Copies the content of the specified message into the current message.
|
||||||
|
|
||||||
|
The method clears the current message and then merges the specified
|
||||||
|
message using MergeFrom.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.DiscardUnknownFields()` {#TaggedRunMetadata.DiscardUnknownFields}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.FindInitializationErrors()` {#TaggedRunMetadata.FindInitializationErrors}
|
||||||
|
|
||||||
|
Finds required fields which are not initialized.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A list of strings. Each string is a path to an uninitialized field from
|
||||||
|
the top-level message, e.g. "foo.bar[5].baz".
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.HasExtension(extension_handle)` {#TaggedRunMetadata.HasExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.HasField(field_name)` {#TaggedRunMetadata.HasField}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.IsInitialized(errors=None)` {#TaggedRunMetadata.IsInitialized}
|
||||||
|
|
||||||
|
Checks if all required fields of a message are set.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||||
|
paths of all missing required fields.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
True iff the specified message has all required fields set.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.ListFields()` {#TaggedRunMetadata.ListFields}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.MergeFrom(msg)` {#TaggedRunMetadata.MergeFrom}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.MergeFromString(serialized)` {#TaggedRunMetadata.MergeFromString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.ParseFromString(serialized)` {#TaggedRunMetadata.ParseFromString}
|
||||||
|
|
||||||
|
Parse serialized protocol buffer data into this message.
|
||||||
|
|
||||||
|
Like MergeFromString(), except we clear the object first and
|
||||||
|
do not return the value that MergeFromString returns.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.SerializePartialToString()` {#TaggedRunMetadata.SerializePartialToString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.SerializeToString()` {#TaggedRunMetadata.SerializeToString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.SetInParent()` {#TaggedRunMetadata.SetInParent}
|
||||||
|
|
||||||
|
Sets the _cached_byte_size_dirty bit to true,
|
||||||
|
and propagates this to our listener iff this was a state change.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.WhichOneof(oneof_name)` {#TaggedRunMetadata.WhichOneof}
|
||||||
|
|
||||||
|
Returns the name of the currently set field inside a oneof, or None.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__deepcopy__(memo=None)` {#TaggedRunMetadata.__deepcopy__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__eq__(other)` {#TaggedRunMetadata.__eq__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.summary.TaggedRunMetadata.__getstate__()` {#TaggedRunMetadata.__getstate__}
|
#### `tf.summary.TaggedRunMetadata.__getstate__()` {#TaggedRunMetadata.__getstate__}
|
||||||
@ -6,3 +187,66 @@
|
|||||||
Support the pickle protocol.
|
Support the pickle protocol.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__hash__()` {#TaggedRunMetadata.__hash__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__init__(**kwargs)` {#TaggedRunMetadata.__init__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__ne__(other_msg)` {#TaggedRunMetadata.__ne__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__repr__()` {#TaggedRunMetadata.__repr__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__setstate__(state)` {#TaggedRunMetadata.__setstate__}
|
||||||
|
|
||||||
|
Support the pickle protocol.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__str__()` {#TaggedRunMetadata.__str__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__unicode__()` {#TaggedRunMetadata.__unicode__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.run_metadata` {#TaggedRunMetadata.run_metadata}
|
||||||
|
|
||||||
|
Magic attribute generated for "run_metadata" proto field.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.tag` {#TaggedRunMetadata.tag}
|
||||||
|
|
||||||
|
Magic attribute generated for "tag" proto field.
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,185 @@
|
|||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.ByteSize()` {#SummaryDescription.ByteSize}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.Clear()` {#SummaryDescription.Clear}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.ClearExtension(extension_handle)` {#SummaryDescription.ClearExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.ClearField(field_name)` {#SummaryDescription.ClearField}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.CopyFrom(other_msg)` {#SummaryDescription.CopyFrom}
|
||||||
|
|
||||||
|
Copies the content of the specified message into the current message.
|
||||||
|
|
||||||
|
The method clears the current message and then merges the specified
|
||||||
|
message using MergeFrom.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.DiscardUnknownFields()` {#SummaryDescription.DiscardUnknownFields}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.FindInitializationErrors()` {#SummaryDescription.FindInitializationErrors}
|
||||||
|
|
||||||
|
Finds required fields which are not initialized.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A list of strings. Each string is a path to an uninitialized field from
|
||||||
|
the top-level message, e.g. "foo.bar[5].baz".
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.HasExtension(extension_handle)` {#SummaryDescription.HasExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.HasField(field_name)` {#SummaryDescription.HasField}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.IsInitialized(errors=None)` {#SummaryDescription.IsInitialized}
|
||||||
|
|
||||||
|
Checks if all required fields of a message are set.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||||
|
paths of all missing required fields.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
True iff the specified message has all required fields set.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.ListFields()` {#SummaryDescription.ListFields}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.MergeFrom(msg)` {#SummaryDescription.MergeFrom}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.MergeFromString(serialized)` {#SummaryDescription.MergeFromString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.ParseFromString(serialized)` {#SummaryDescription.ParseFromString}
|
||||||
|
|
||||||
|
Parse serialized protocol buffer data into this message.
|
||||||
|
|
||||||
|
Like MergeFromString(), except we clear the object first and
|
||||||
|
do not return the value that MergeFromString returns.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.SerializePartialToString()` {#SummaryDescription.SerializePartialToString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.SerializeToString()` {#SummaryDescription.SerializeToString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.SetInParent()` {#SummaryDescription.SetInParent}
|
||||||
|
|
||||||
|
Sets the _cached_byte_size_dirty bit to true,
|
||||||
|
and propagates this to our listener iff this was a state change.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.WhichOneof(oneof_name)` {#SummaryDescription.WhichOneof}
|
||||||
|
|
||||||
|
Returns the name of the currently set field inside a oneof, or None.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__deepcopy__(memo=None)` {#SummaryDescription.__deepcopy__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__eq__(other)` {#SummaryDescription.__eq__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.summary.SummaryDescription.__getstate__()` {#SummaryDescription.__getstate__}
|
#### `tf.summary.SummaryDescription.__getstate__()` {#SummaryDescription.__getstate__}
|
||||||
@ -6,3 +187,59 @@
|
|||||||
Support the pickle protocol.
|
Support the pickle protocol.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__hash__()` {#SummaryDescription.__hash__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__init__(**kwargs)` {#SummaryDescription.__init__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__ne__(other_msg)` {#SummaryDescription.__ne__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__repr__()` {#SummaryDescription.__repr__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__setstate__(state)` {#SummaryDescription.__setstate__}
|
||||||
|
|
||||||
|
Support the pickle protocol.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__str__()` {#SummaryDescription.__str__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__unicode__()` {#SummaryDescription.__unicode__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.type_hint` {#SummaryDescription.type_hint}
|
||||||
|
|
||||||
|
Magic attribute generated for "type_hint" proto field.
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,125 +173,6 @@ Checks that for all elements of farray1 and farray2
|
|||||||
* <b>`err`</b>: a float value.
|
* <b>`err`</b>: a float value.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertBetween(value, minv, maxv, msg=None)` {#TestCase.assertBetween}
|
|
||||||
|
|
||||||
Asserts that value is between minv and maxv (inclusive).
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertCommandFails(command, regexes, env=None, close_fds=True, msg=None)` {#TestCase.assertCommandFails}
|
|
||||||
|
|
||||||
Asserts a shell command fails and the error matches a regex in a list.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`command`</b>: List or string representing the command to run.
|
|
||||||
* <b>`regexes`</b>: the list of regular expression strings.
|
|
||||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
|
||||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
|
||||||
forking.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertCommandSucceeds(command, regexes=('',), env=None, close_fds=True, msg=None)` {#TestCase.assertCommandSucceeds}
|
|
||||||
|
|
||||||
Asserts that a shell command succeeds (i.e. exits with code 0).
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`command`</b>: List or string representing the command to run.
|
|
||||||
* <b>`regexes`</b>: List of regular expression byte strings that match success.
|
|
||||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
|
||||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
|
||||||
forking.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertContainsExactSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsExactSubsequence}
|
|
||||||
|
|
||||||
Assert that "container" contains "subsequence" as an exact subsequence.
|
|
||||||
|
|
||||||
Asserts that "container" contains all the elements of "subsequence", in
|
|
||||||
order, and without other elements interspersed. For example, [1, 2, 3] is an
|
|
||||||
exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
|
||||||
* <b>`subsequence`</b>: the list we hope will be an exact subsequence of container.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertContainsInOrder(strings, target, msg=None)` {#TestCase.assertContainsInOrder}
|
|
||||||
|
|
||||||
Asserts that the strings provided are found in the target in order.
|
|
||||||
|
|
||||||
This may be useful for checking HTML output.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`strings`</b>: A list of strings, such as [ 'fox', 'dog' ]
|
|
||||||
* <b>`target`</b>: A target string in which to look for the strings, such as
|
|
||||||
'The quick brown fox jumped over the lazy dog'.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertContainsSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsSubsequence}
|
|
||||||
|
|
||||||
Assert that "container" contains "subsequence" as a subsequence.
|
|
||||||
|
|
||||||
Asserts that "container" contains all the elements of "subsequence", in
|
|
||||||
order, but possibly with other elements interspersed. For example, [1, 2, 3]
|
|
||||||
is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
|
||||||
* <b>`subsequence`</b>: the list we hope will be a subsequence of container.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertContainsSubset(expected_subset, actual_set, msg=None)` {#TestCase.assertContainsSubset}
|
|
||||||
|
|
||||||
Checks whether actual iterable is a superset of expected iterable.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertCountEqual(*args, **kwargs)` {#TestCase.assertCountEqual}
|
|
||||||
|
|
||||||
An unordered sequence specific comparison.
|
|
||||||
|
|
||||||
Equivalent to assertItemsEqual(). This method is a compatibility layer
|
|
||||||
for Python 3k, since 2to3 does not convert assertItemsEqual() calls into
|
|
||||||
assertCountEqual() calls.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
|
||||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
|
||||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertDeviceEqual(device1, device2)` {#TestCase.assertDeviceEqual}
|
#### `tf.test.TestCase.assertDeviceEqual(device1, device2)` {#TestCase.assertDeviceEqual}
|
||||||
@ -314,48 +195,9 @@ Checks whether actual is a superset of expected.
|
|||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertDictEqual(a, b, msg=None)` {#TestCase.assertDictEqual}
|
#### `tf.test.TestCase.assertDictEqual(d1, d2, msg=None)` {#TestCase.assertDictEqual}
|
||||||
|
|
||||||
Raises AssertionError if a and b are not equal dictionaries.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`a`</b>: A dict, the expected value.
|
|
||||||
* <b>`b`</b>: A dict, the actual value.
|
|
||||||
* <b>`msg`</b>: An optional str, the associated message.
|
|
||||||
|
|
||||||
##### Raises:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`AssertionError`</b>: if the dictionaries are not equal.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertEmpty(container, msg=None)` {#TestCase.assertEmpty}
|
|
||||||
|
|
||||||
Assert that an object has zero length.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertEndsWith(actual, expected_end, msg=None)` {#TestCase.assertEndsWith}
|
|
||||||
|
|
||||||
Assert that actual.endswith(expected_end) is True.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`actual`</b>: str
|
|
||||||
* <b>`expected_end`</b>: str
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
@ -440,11 +282,10 @@ Included for symmetry with assertIsNone.
|
|||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertItemsEqual(*args, **kwargs)` {#TestCase.assertItemsEqual}
|
#### `tf.test.TestCase.assertItemsEqual(expected_seq, actual_seq, msg=None)` {#TestCase.assertItemsEqual}
|
||||||
|
|
||||||
An unordered sequence specific comparison.
|
An unordered sequence specific comparison. It asserts that
|
||||||
|
actual_seq and expected_seq have the same element counts.
|
||||||
It asserts that actual_seq and expected_seq have the same element counts.
|
|
||||||
Equivalent to::
|
Equivalent to::
|
||||||
|
|
||||||
self.assertEqual(Counter(iter(actual_seq)),
|
self.assertEqual(Counter(iter(actual_seq)),
|
||||||
@ -457,30 +298,6 @@ Asserts that each element has the same count in both sequences.
|
|||||||
- [0, 1, 1] and [1, 0, 1] compare equal.
|
- [0, 1, 1] and [1, 0, 1] compare equal.
|
||||||
- [0, 0, 1] and [0, 1] compare unequal.
|
- [0, 0, 1] and [0, 1] compare unequal.
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
|
||||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
|
||||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertJsonEqual(first, second, msg=None)` {#TestCase.assertJsonEqual}
|
|
||||||
|
|
||||||
Asserts that the JSON objects defined in two strings are equal.
|
|
||||||
|
|
||||||
A summary of the differences will be included in the failure message
|
|
||||||
using assertSameStructure.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`first`</b>: A string contining JSON to decode and compare to second.
|
|
||||||
* <b>`second`</b>: A string contining JSON to decode and compare to first.
|
|
||||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
@ -550,13 +367,6 @@ if not.
|
|||||||
* <b>`msg`</b>: An optional string message to append to the failure message.
|
* <b>`msg`</b>: An optional string message to append to the failure message.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNoCommonElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertNoCommonElements}
|
|
||||||
|
|
||||||
Checks whether actual iterable and expected iterable are disjoint.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)` {#TestCase.assertNotAlmostEqual}
|
#### `tf.test.TestCase.assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)` {#TestCase.assertNotAlmostEqual}
|
||||||
@ -587,33 +397,6 @@ as significant digits (measured from the most signficant digit).
|
|||||||
Objects that are equal automatically fail.
|
Objects that are equal automatically fail.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNotEmpty(container, msg=None)` {#TestCase.assertNotEmpty}
|
|
||||||
|
|
||||||
Assert that an object has non-zero length.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNotEndsWith(actual, unexpected_end, msg=None)` {#TestCase.assertNotEndsWith}
|
|
||||||
|
|
||||||
Assert that actual.endswith(unexpected_end) is False.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`actual`</b>: str
|
|
||||||
* <b>`unexpected_end`</b>: str
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNotEqual(first, second, msg=None)` {#TestCase.assertNotEqual}
|
#### `tf.test.TestCase.assertNotEqual(first, second, msg=None)` {#TestCase.assertNotEqual}
|
||||||
@ -651,20 +434,6 @@ Included for symmetry with assertIsInstance.
|
|||||||
Fail the test if the text matches the regular expression.
|
Fail the test if the text matches the regular expression.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNotStartsWith(actual, unexpected_start, msg=None)` {#TestCase.assertNotStartsWith}
|
|
||||||
|
|
||||||
Assert that actual.startswith(unexpected_start) is False.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`actual`</b>: str
|
|
||||||
* <b>`unexpected_start`</b>: str
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertProtoEquals(expected_message_maybe_ascii, message)` {#TestCase.assertProtoEquals}
|
#### `tf.test.TestCase.assertProtoEquals(expected_message_maybe_ascii, message)` {#TestCase.assertProtoEquals}
|
||||||
@ -739,38 +508,6 @@ Asserts that the message in a raised exception matches a regexp.
|
|||||||
* <b>`kwargs`</b>: Extra kwargs.
|
* <b>`kwargs`</b>: Extra kwargs.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertRaisesWithLiteralMatch(expected_exception, expected_exception_message, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithLiteralMatch}
|
|
||||||
|
|
||||||
Asserts that the message in a raised exception equals the given string.
|
|
||||||
|
|
||||||
Unlike assertRaisesRegexp, this method takes a literal string, not
|
|
||||||
a regular expression.
|
|
||||||
|
|
||||||
with self.assertRaisesWithLiteralMatch(ExType, 'message'):
|
|
||||||
DoSomething()
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
|
||||||
* <b>`expected_exception_message`</b>: String message expected in the raised
|
|
||||||
exception. For a raise exception e, expected_exception_message must
|
|
||||||
equal str(e).
|
|
||||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
|
||||||
* <b>`args`</b>: Extra args.
|
|
||||||
* <b>`kwargs`</b>: Extra kwargs.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
A context manager if callable_obj is None. Otherwise, None.
|
|
||||||
|
|
||||||
##### Raises:
|
|
||||||
|
|
||||||
self.failureException if callable_obj does not raise a macthing exception.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertRaisesWithPredicateMatch(exception_type, expected_err_re_or_predicate)` {#TestCase.assertRaisesWithPredicateMatch}
|
#### `tf.test.TestCase.assertRaisesWithPredicateMatch(exception_type, expected_err_re_or_predicate)` {#TestCase.assertRaisesWithPredicateMatch}
|
||||||
@ -795,71 +532,6 @@ predicate search.
|
|||||||
exception.
|
exception.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertRaisesWithRegexpMatch(expected_exception, expected_regexp, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithRegexpMatch}
|
|
||||||
|
|
||||||
Asserts that the message in a raised exception matches the given regexp.
|
|
||||||
|
|
||||||
This is just a wrapper around assertRaisesRegexp. Please use
|
|
||||||
assertRaisesRegexp instead of assertRaisesWithRegexpMatch.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
|
||||||
* <b>`expected_regexp`</b>: Regexp (re pattern object or string) expected to be
|
|
||||||
found in error message.
|
|
||||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
|
||||||
* <b>`args`</b>: Extra args.
|
|
||||||
* <b>`kwargs`</b>: Extra keyword args.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
A context manager if callable_obj is None. Otherwise, None.
|
|
||||||
|
|
||||||
##### Raises:
|
|
||||||
|
|
||||||
self.failureException if callable_obj does not raise a macthing exception.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertRegexMatch(actual_str, regexes, message=None)` {#TestCase.assertRegexMatch}
|
|
||||||
|
|
||||||
Asserts that at least one regex in regexes matches str.
|
|
||||||
|
|
||||||
If possible you should use assertRegexpMatches, which is a simpler
|
|
||||||
version of this method. assertRegexpMatches takes a single regular
|
|
||||||
expression (a string or re compiled object) instead of a list.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
1. This function uses substring matching, i.e. the matching
|
|
||||||
succeeds if *any* substring of the error message matches *any*
|
|
||||||
regex in the list. This is more convenient for the user than
|
|
||||||
full-string matching.
|
|
||||||
|
|
||||||
2. If regexes is the empty list, the matching will always fail.
|
|
||||||
|
|
||||||
3. Use regexes=[''] for a regex that will always pass.
|
|
||||||
|
|
||||||
4. '.' matches any single character *except* the newline. To
|
|
||||||
match any character, use '(.|
|
|
||||||
)'.
|
|
||||||
|
|
||||||
5. '^' matches the beginning of each line, not just the beginning
|
|
||||||
of the string. Similarly, '$' matches the end of each line.
|
|
||||||
|
|
||||||
6. An exception will be thrown if regexes contains an invalid
|
|
||||||
regex.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
actual_str: The string we try to match with the items in regexes.
|
|
||||||
regexes: The regular expressions we want to match against str.
|
|
||||||
See "Notes" above for detailed notes on how this is interpreted.
|
|
||||||
message: The message to be printed if the test fails.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertRegexpMatches(text, expected_regexp, msg=None)` {#TestCase.assertRegexpMatches}
|
#### `tf.test.TestCase.assertRegexpMatches(text, expected_regexp, msg=None)` {#TestCase.assertRegexpMatches}
|
||||||
@ -867,79 +539,6 @@ Asserts that at least one regex in regexes matches str.
|
|||||||
Fail the test unless the text matches the regular expression.
|
Fail the test unless the text matches the regular expression.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSameElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertSameElements}
|
|
||||||
|
|
||||||
Assert that two sequences have the same elements (in any order).
|
|
||||||
|
|
||||||
This method, unlike assertItemsEqual, doesn't care about any
|
|
||||||
duplicates in the expected and actual sequences.
|
|
||||||
|
|
||||||
>> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
|
|
||||||
# Doesn't raise an AssertionError
|
|
||||||
|
|
||||||
If possible, you should use assertItemsEqual instead of
|
|
||||||
assertSameElements.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
|
||||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
|
||||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSameStructure(a, b, aname='a', bname='b', msg=None)` {#TestCase.assertSameStructure}
|
|
||||||
|
|
||||||
Asserts that two values contain the same structural content.
|
|
||||||
|
|
||||||
The two arguments should be data trees consisting of trees of dicts and
|
|
||||||
lists. They will be deeply compared by walking into the contents of dicts
|
|
||||||
and lists; other items will be compared using the == operator.
|
|
||||||
If the two structures differ in content, the failure message will indicate
|
|
||||||
the location within the structures where the first difference is found.
|
|
||||||
This may be helpful when comparing large structures.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`a`</b>: The first structure to compare.
|
|
||||||
* <b>`b`</b>: The second structure to compare.
|
|
||||||
* <b>`aname`</b>: Variable name to use for the first structure in assertion messages.
|
|
||||||
* <b>`bname`</b>: Variable name to use for the second structure.
|
|
||||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSequenceAlmostEqual(expected_seq, actual_seq, places=None, msg=None, delta=None)` {#TestCase.assertSequenceAlmostEqual}
|
|
||||||
|
|
||||||
An approximate equality assertion for ordered sequences.
|
|
||||||
|
|
||||||
Fail if the two sequences are unequal as determined by their value
|
|
||||||
differences rounded to the given number of decimal places (default 7) and
|
|
||||||
comparing to zero, or by comparing that the difference between each value
|
|
||||||
in the two sequences is more than the given delta.
|
|
||||||
|
|
||||||
Note that decimal places (from zero) are usually not the same as significant
|
|
||||||
digits (measured from the most signficant digit).
|
|
||||||
|
|
||||||
If the two sequences compare equal then they will automatically compare
|
|
||||||
almost equal.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
|
||||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
|
||||||
* <b>`places`</b>: The number of decimal places to compare.
|
|
||||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
|
||||||
* <b>`delta`</b>: The OK difference between compared values.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)` {#TestCase.assertSequenceEqual}
|
#### `tf.test.TestCase.assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)` {#TestCase.assertSequenceEqual}
|
||||||
@ -960,26 +559,6 @@ which can be indexed, has a length, and has an equality operator.
|
|||||||
differences.
|
differences.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSequenceStartsWith(prefix, whole, msg=None)` {#TestCase.assertSequenceStartsWith}
|
|
||||||
|
|
||||||
An equality assertion for the beginning of ordered sequences.
|
|
||||||
|
|
||||||
If prefix is an empty sequence, it will raise an error unless whole is also
|
|
||||||
an empty sequence.
|
|
||||||
|
|
||||||
If prefix is not a sequence, it will raise an error if the first element of
|
|
||||||
whole does not match.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`prefix`</b>: A sequence expected at the beginning of the whole parameter.
|
|
||||||
* <b>`whole`</b>: The sequence in which to look for prefix.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSetEqual(set1, set2, msg=None)` {#TestCase.assertSetEqual}
|
#### `tf.test.TestCase.assertSetEqual(set1, set2, msg=None)` {#TestCase.assertSetEqual}
|
||||||
@ -1031,51 +610,6 @@ Assert that actual.startswith(expected_start) is True.
|
|||||||
* <b>`msg`</b>: Optional message to report on failure.
|
* <b>`msg`</b>: Optional message to report on failure.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertTotallyOrdered(*groups, **kwargs)` {#TestCase.assertTotallyOrdered}
|
|
||||||
|
|
||||||
Asserts that total ordering has been implemented correctly.
|
|
||||||
|
|
||||||
For example, say you have a class A that compares only on its attribute x.
|
|
||||||
Comparators other than __lt__ are omitted for brevity.
|
|
||||||
|
|
||||||
class A(object):
|
|
||||||
def __init__(self, x, y):
|
|
||||||
self.x = x
|
|
||||||
self.y = y
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(self.x)
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
try:
|
|
||||||
return self.x < other.x
|
|
||||||
except AttributeError:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
assertTotallyOrdered will check that instances can be ordered correctly.
|
|
||||||
For example,
|
|
||||||
|
|
||||||
self.assertTotallyOrdered(
|
|
||||||
[None], # None should come before everything else.
|
|
||||||
[1], # Integers sort earlier.
|
|
||||||
[A(1, 'a')],
|
|
||||||
[A(2, 'b')], # 2 is after 1.
|
|
||||||
[A(3, 'c'), A(3, 'd')], # The second argument is irrelevant.
|
|
||||||
[A(4, 'z')],
|
|
||||||
['foo']) # Strings sort last.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`*groups`</b>: A list of groups of elements. Each group of elements is a list
|
|
||||||
of objects that are equal. The elements in each group must be less than
|
|
||||||
the elements in the group after it. For example, these groups are
|
|
||||||
totally ordered: [None], [1], [2, 2], [3].
|
|
||||||
* <b>`**kwargs`</b>: optional msg keyword argument can be passed.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertTrue(expr, msg=None)` {#TestCase.assertTrue}
|
#### `tf.test.TestCase.assertTrue(expr, msg=None)` {#TestCase.assertTrue}
|
||||||
@ -1098,13 +632,6 @@ A tuple-specific equality assertion.
|
|||||||
differences.
|
differences.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertUrlEqual(a, b, msg=None)` {#TestCase.assertUrlEqual}
|
|
||||||
|
|
||||||
Asserts that urls are equal, ignoring ordering of query params.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assert_(expr, msg=None)` {#TestCase.assert_}
|
#### `tf.test.TestCase.assert_(expr, msg=None)` {#TestCase.assert_}
|
||||||
@ -1166,9 +693,9 @@ tearDown.
|
|||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.fail(msg=None, prefix=None)` {#TestCase.fail}
|
#### `tf.test.TestCase.fail(msg=None)` {#TestCase.fail}
|
||||||
|
|
||||||
Fail immediately with the given message, optionally prefixed.
|
Fail immediately, with the given message.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
@ -1220,13 +747,6 @@ Fail immediately with the given message, optionally prefixed.
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.getRecordedProperties()` {#TestCase.getRecordedProperties}
|
|
||||||
|
|
||||||
Return any properties that the user has recorded.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.get_temp_dir()` {#TestCase.get_temp_dir}
|
#### `tf.test.TestCase.get_temp_dir()` {#TestCase.get_temp_dir}
|
||||||
@ -1249,20 +769,6 @@ pollute each others environment.
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.recordProperty(property_name, property_value)` {#TestCase.recordProperty}
|
|
||||||
|
|
||||||
Record an arbitrary property for later use.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`property_name`</b>: str, name of property to record; must be a valid XML
|
|
||||||
attribute name
|
|
||||||
* <b>`property_value`</b>: value of property; must be valid XML attribute value
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.run(result=None)` {#TestCase.run}
|
#### `tf.test.TestCase.run(result=None)` {#TestCase.run}
|
||||||
@ -1288,18 +794,11 @@ Hook method for setting up class fixture before running tests in the class.
|
|||||||
|
|
||||||
#### `tf.test.TestCase.shortDescription()` {#TestCase.shortDescription}
|
#### `tf.test.TestCase.shortDescription()` {#TestCase.shortDescription}
|
||||||
|
|
||||||
Format both the test method name and the first line of its docstring.
|
Returns a one-line description of the test, or None if no
|
||||||
|
description has been provided.
|
||||||
|
|
||||||
If no docstring is given, only returns the method name.
|
The default implementation of this method returns the first line of
|
||||||
|
the specified test method's docstring.
|
||||||
This method overrides unittest.TestCase.shortDescription(), which
|
|
||||||
only returns the first line of the docstring, obscuring the name
|
|
||||||
of the test upon failure.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`desc`</b>: A short description of a test method.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
|||||||
|
#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
|||||||
|
#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
|||||||
|
#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -485,6 +485,187 @@ metadata is stored in its NodeDef. This method retrieves the description.
|
|||||||
### `class tf.summary.SummaryDescription` {#SummaryDescription}
|
### `class tf.summary.SummaryDescription` {#SummaryDescription}
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.ByteSize()` {#SummaryDescription.ByteSize}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.Clear()` {#SummaryDescription.Clear}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.ClearExtension(extension_handle)` {#SummaryDescription.ClearExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.ClearField(field_name)` {#SummaryDescription.ClearField}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.CopyFrom(other_msg)` {#SummaryDescription.CopyFrom}
|
||||||
|
|
||||||
|
Copies the content of the specified message into the current message.
|
||||||
|
|
||||||
|
The method clears the current message and then merges the specified
|
||||||
|
message using MergeFrom.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.DiscardUnknownFields()` {#SummaryDescription.DiscardUnknownFields}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.FindInitializationErrors()` {#SummaryDescription.FindInitializationErrors}
|
||||||
|
|
||||||
|
Finds required fields which are not initialized.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A list of strings. Each string is a path to an uninitialized field from
|
||||||
|
the top-level message, e.g. "foo.bar[5].baz".
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.HasExtension(extension_handle)` {#SummaryDescription.HasExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.HasField(field_name)` {#SummaryDescription.HasField}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.IsInitialized(errors=None)` {#SummaryDescription.IsInitialized}
|
||||||
|
|
||||||
|
Checks if all required fields of a message are set.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||||
|
paths of all missing required fields.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
True iff the specified message has all required fields set.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.ListFields()` {#SummaryDescription.ListFields}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.MergeFrom(msg)` {#SummaryDescription.MergeFrom}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.MergeFromString(serialized)` {#SummaryDescription.MergeFromString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.ParseFromString(serialized)` {#SummaryDescription.ParseFromString}
|
||||||
|
|
||||||
|
Parse serialized protocol buffer data into this message.
|
||||||
|
|
||||||
|
Like MergeFromString(), except we clear the object first and
|
||||||
|
do not return the value that MergeFromString returns.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.SerializePartialToString()` {#SummaryDescription.SerializePartialToString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.SerializeToString()` {#SummaryDescription.SerializeToString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.SetInParent()` {#SummaryDescription.SetInParent}
|
||||||
|
|
||||||
|
Sets the _cached_byte_size_dirty bit to true,
|
||||||
|
and propagates this to our listener iff this was a state change.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.WhichOneof(oneof_name)` {#SummaryDescription.WhichOneof}
|
||||||
|
|
||||||
|
Returns the name of the currently set field inside a oneof, or None.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__deepcopy__(memo=None)` {#SummaryDescription.__deepcopy__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__eq__(other)` {#SummaryDescription.__eq__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.summary.SummaryDescription.__getstate__()` {#SummaryDescription.__getstate__}
|
#### `tf.summary.SummaryDescription.__getstate__()` {#SummaryDescription.__getstate__}
|
||||||
@ -492,12 +673,249 @@ metadata is stored in its NodeDef. This method retrieves the description.
|
|||||||
Support the pickle protocol.
|
Support the pickle protocol.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__hash__()` {#SummaryDescription.__hash__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__init__(**kwargs)` {#SummaryDescription.__init__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__ne__(other_msg)` {#SummaryDescription.__ne__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__repr__()` {#SummaryDescription.__repr__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__setstate__(state)` {#SummaryDescription.__setstate__}
|
||||||
|
|
||||||
|
Support the pickle protocol.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__str__()` {#SummaryDescription.__str__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.__unicode__()` {#SummaryDescription.__unicode__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.SummaryDescription.type_hint` {#SummaryDescription.type_hint}
|
||||||
|
|
||||||
|
Magic attribute generated for "type_hint" proto field.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
### `class tf.summary.TaggedRunMetadata` {#TaggedRunMetadata}
|
### `class tf.summary.TaggedRunMetadata` {#TaggedRunMetadata}
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.ByteSize()` {#TaggedRunMetadata.ByteSize}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.Clear()` {#TaggedRunMetadata.Clear}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.ClearExtension(extension_handle)` {#TaggedRunMetadata.ClearExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.ClearField(field_name)` {#TaggedRunMetadata.ClearField}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.CopyFrom(other_msg)` {#TaggedRunMetadata.CopyFrom}
|
||||||
|
|
||||||
|
Copies the content of the specified message into the current message.
|
||||||
|
|
||||||
|
The method clears the current message and then merges the specified
|
||||||
|
message using MergeFrom.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.DiscardUnknownFields()` {#TaggedRunMetadata.DiscardUnknownFields}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.FindInitializationErrors()` {#TaggedRunMetadata.FindInitializationErrors}
|
||||||
|
|
||||||
|
Finds required fields which are not initialized.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A list of strings. Each string is a path to an uninitialized field from
|
||||||
|
the top-level message, e.g. "foo.bar[5].baz".
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.HasExtension(extension_handle)` {#TaggedRunMetadata.HasExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.HasField(field_name)` {#TaggedRunMetadata.HasField}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.IsInitialized(errors=None)` {#TaggedRunMetadata.IsInitialized}
|
||||||
|
|
||||||
|
Checks if all required fields of a message are set.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||||
|
paths of all missing required fields.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
True iff the specified message has all required fields set.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.ListFields()` {#TaggedRunMetadata.ListFields}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.MergeFrom(msg)` {#TaggedRunMetadata.MergeFrom}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.MergeFromString(serialized)` {#TaggedRunMetadata.MergeFromString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.ParseFromString(serialized)` {#TaggedRunMetadata.ParseFromString}
|
||||||
|
|
||||||
|
Parse serialized protocol buffer data into this message.
|
||||||
|
|
||||||
|
Like MergeFromString(), except we clear the object first and
|
||||||
|
do not return the value that MergeFromString returns.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.SerializePartialToString()` {#TaggedRunMetadata.SerializePartialToString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.SerializeToString()` {#TaggedRunMetadata.SerializeToString}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.SetInParent()` {#TaggedRunMetadata.SetInParent}
|
||||||
|
|
||||||
|
Sets the _cached_byte_size_dirty bit to true,
|
||||||
|
and propagates this to our listener iff this was a state change.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.WhichOneof(oneof_name)` {#TaggedRunMetadata.WhichOneof}
|
||||||
|
|
||||||
|
Returns the name of the currently set field inside a oneof, or None.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__deepcopy__(memo=None)` {#TaggedRunMetadata.__deepcopy__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__eq__(other)` {#TaggedRunMetadata.__eq__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.summary.TaggedRunMetadata.__getstate__()` {#TaggedRunMetadata.__getstate__}
|
#### `tf.summary.TaggedRunMetadata.__getstate__()` {#TaggedRunMetadata.__getstate__}
|
||||||
@ -505,4 +923,67 @@ Support the pickle protocol.
|
|||||||
Support the pickle protocol.
|
Support the pickle protocol.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__hash__()` {#TaggedRunMetadata.__hash__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__init__(**kwargs)` {#TaggedRunMetadata.__init__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__ne__(other_msg)` {#TaggedRunMetadata.__ne__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__repr__()` {#TaggedRunMetadata.__repr__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__setstate__(state)` {#TaggedRunMetadata.__setstate__}
|
||||||
|
|
||||||
|
Support the pickle protocol.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__str__()` {#TaggedRunMetadata.__str__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.__unicode__()` {#TaggedRunMetadata.__unicode__}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.run_metadata` {#TaggedRunMetadata.run_metadata}
|
||||||
|
|
||||||
|
Magic attribute generated for "run_metadata" proto field.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.summary.TaggedRunMetadata.tag` {#TaggedRunMetadata.tag}
|
||||||
|
|
||||||
|
Magic attribute generated for "tag" proto field.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -213,125 +213,6 @@ Checks that for all elements of farray1 and farray2
|
|||||||
* <b>`err`</b>: a float value.
|
* <b>`err`</b>: a float value.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertBetween(value, minv, maxv, msg=None)` {#TestCase.assertBetween}
|
|
||||||
|
|
||||||
Asserts that value is between minv and maxv (inclusive).
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertCommandFails(command, regexes, env=None, close_fds=True, msg=None)` {#TestCase.assertCommandFails}
|
|
||||||
|
|
||||||
Asserts a shell command fails and the error matches a regex in a list.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`command`</b>: List or string representing the command to run.
|
|
||||||
* <b>`regexes`</b>: the list of regular expression strings.
|
|
||||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
|
||||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
|
||||||
forking.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertCommandSucceeds(command, regexes=('',), env=None, close_fds=True, msg=None)` {#TestCase.assertCommandSucceeds}
|
|
||||||
|
|
||||||
Asserts that a shell command succeeds (i.e. exits with code 0).
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`command`</b>: List or string representing the command to run.
|
|
||||||
* <b>`regexes`</b>: List of regular expression byte strings that match success.
|
|
||||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
|
||||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
|
||||||
forking.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertContainsExactSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsExactSubsequence}
|
|
||||||
|
|
||||||
Assert that "container" contains "subsequence" as an exact subsequence.
|
|
||||||
|
|
||||||
Asserts that "container" contains all the elements of "subsequence", in
|
|
||||||
order, and without other elements interspersed. For example, [1, 2, 3] is an
|
|
||||||
exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
|
||||||
* <b>`subsequence`</b>: the list we hope will be an exact subsequence of container.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertContainsInOrder(strings, target, msg=None)` {#TestCase.assertContainsInOrder}
|
|
||||||
|
|
||||||
Asserts that the strings provided are found in the target in order.
|
|
||||||
|
|
||||||
This may be useful for checking HTML output.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`strings`</b>: A list of strings, such as [ 'fox', 'dog' ]
|
|
||||||
* <b>`target`</b>: A target string in which to look for the strings, such as
|
|
||||||
'The quick brown fox jumped over the lazy dog'.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertContainsSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsSubsequence}
|
|
||||||
|
|
||||||
Assert that "container" contains "subsequence" as a subsequence.
|
|
||||||
|
|
||||||
Asserts that "container" contains all the elements of "subsequence", in
|
|
||||||
order, but possibly with other elements interspersed. For example, [1, 2, 3]
|
|
||||||
is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
|
||||||
* <b>`subsequence`</b>: the list we hope will be a subsequence of container.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertContainsSubset(expected_subset, actual_set, msg=None)` {#TestCase.assertContainsSubset}
|
|
||||||
|
|
||||||
Checks whether actual iterable is a superset of expected iterable.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertCountEqual(*args, **kwargs)` {#TestCase.assertCountEqual}
|
|
||||||
|
|
||||||
An unordered sequence specific comparison.
|
|
||||||
|
|
||||||
Equivalent to assertItemsEqual(). This method is a compatibility layer
|
|
||||||
for Python 3k, since 2to3 does not convert assertItemsEqual() calls into
|
|
||||||
assertCountEqual() calls.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
|
||||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
|
||||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertDeviceEqual(device1, device2)` {#TestCase.assertDeviceEqual}
|
#### `tf.test.TestCase.assertDeviceEqual(device1, device2)` {#TestCase.assertDeviceEqual}
|
||||||
@ -354,48 +235,9 @@ Checks whether actual is a superset of expected.
|
|||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertDictEqual(a, b, msg=None)` {#TestCase.assertDictEqual}
|
#### `tf.test.TestCase.assertDictEqual(d1, d2, msg=None)` {#TestCase.assertDictEqual}
|
||||||
|
|
||||||
Raises AssertionError if a and b are not equal dictionaries.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`a`</b>: A dict, the expected value.
|
|
||||||
* <b>`b`</b>: A dict, the actual value.
|
|
||||||
* <b>`msg`</b>: An optional str, the associated message.
|
|
||||||
|
|
||||||
##### Raises:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`AssertionError`</b>: if the dictionaries are not equal.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertEmpty(container, msg=None)` {#TestCase.assertEmpty}
|
|
||||||
|
|
||||||
Assert that an object has zero length.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertEndsWith(actual, expected_end, msg=None)` {#TestCase.assertEndsWith}
|
|
||||||
|
|
||||||
Assert that actual.endswith(expected_end) is True.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`actual`</b>: str
|
|
||||||
* <b>`expected_end`</b>: str
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
@ -480,11 +322,10 @@ Included for symmetry with assertIsNone.
|
|||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertItemsEqual(*args, **kwargs)` {#TestCase.assertItemsEqual}
|
#### `tf.test.TestCase.assertItemsEqual(expected_seq, actual_seq, msg=None)` {#TestCase.assertItemsEqual}
|
||||||
|
|
||||||
An unordered sequence specific comparison.
|
An unordered sequence specific comparison. It asserts that
|
||||||
|
actual_seq and expected_seq have the same element counts.
|
||||||
It asserts that actual_seq and expected_seq have the same element counts.
|
|
||||||
Equivalent to::
|
Equivalent to::
|
||||||
|
|
||||||
self.assertEqual(Counter(iter(actual_seq)),
|
self.assertEqual(Counter(iter(actual_seq)),
|
||||||
@ -497,30 +338,6 @@ Asserts that each element has the same count in both sequences.
|
|||||||
- [0, 1, 1] and [1, 0, 1] compare equal.
|
- [0, 1, 1] and [1, 0, 1] compare equal.
|
||||||
- [0, 0, 1] and [0, 1] compare unequal.
|
- [0, 0, 1] and [0, 1] compare unequal.
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
|
||||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
|
||||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertJsonEqual(first, second, msg=None)` {#TestCase.assertJsonEqual}
|
|
||||||
|
|
||||||
Asserts that the JSON objects defined in two strings are equal.
|
|
||||||
|
|
||||||
A summary of the differences will be included in the failure message
|
|
||||||
using assertSameStructure.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`first`</b>: A string contining JSON to decode and compare to second.
|
|
||||||
* <b>`second`</b>: A string contining JSON to decode and compare to first.
|
|
||||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
@ -590,13 +407,6 @@ if not.
|
|||||||
* <b>`msg`</b>: An optional string message to append to the failure message.
|
* <b>`msg`</b>: An optional string message to append to the failure message.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNoCommonElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertNoCommonElements}
|
|
||||||
|
|
||||||
Checks whether actual iterable and expected iterable are disjoint.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)` {#TestCase.assertNotAlmostEqual}
|
#### `tf.test.TestCase.assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)` {#TestCase.assertNotAlmostEqual}
|
||||||
@ -627,33 +437,6 @@ as significant digits (measured from the most signficant digit).
|
|||||||
Objects that are equal automatically fail.
|
Objects that are equal automatically fail.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNotEmpty(container, msg=None)` {#TestCase.assertNotEmpty}
|
|
||||||
|
|
||||||
Assert that an object has non-zero length.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNotEndsWith(actual, unexpected_end, msg=None)` {#TestCase.assertNotEndsWith}
|
|
||||||
|
|
||||||
Assert that actual.endswith(unexpected_end) is False.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`actual`</b>: str
|
|
||||||
* <b>`unexpected_end`</b>: str
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNotEqual(first, second, msg=None)` {#TestCase.assertNotEqual}
|
#### `tf.test.TestCase.assertNotEqual(first, second, msg=None)` {#TestCase.assertNotEqual}
|
||||||
@ -691,20 +474,6 @@ Included for symmetry with assertIsInstance.
|
|||||||
Fail the test if the text matches the regular expression.
|
Fail the test if the text matches the regular expression.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertNotStartsWith(actual, unexpected_start, msg=None)` {#TestCase.assertNotStartsWith}
|
|
||||||
|
|
||||||
Assert that actual.startswith(unexpected_start) is False.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`actual`</b>: str
|
|
||||||
* <b>`unexpected_start`</b>: str
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertProtoEquals(expected_message_maybe_ascii, message)` {#TestCase.assertProtoEquals}
|
#### `tf.test.TestCase.assertProtoEquals(expected_message_maybe_ascii, message)` {#TestCase.assertProtoEquals}
|
||||||
@ -779,38 +548,6 @@ Asserts that the message in a raised exception matches a regexp.
|
|||||||
* <b>`kwargs`</b>: Extra kwargs.
|
* <b>`kwargs`</b>: Extra kwargs.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertRaisesWithLiteralMatch(expected_exception, expected_exception_message, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithLiteralMatch}
|
|
||||||
|
|
||||||
Asserts that the message in a raised exception equals the given string.
|
|
||||||
|
|
||||||
Unlike assertRaisesRegexp, this method takes a literal string, not
|
|
||||||
a regular expression.
|
|
||||||
|
|
||||||
with self.assertRaisesWithLiteralMatch(ExType, 'message'):
|
|
||||||
DoSomething()
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
|
||||||
* <b>`expected_exception_message`</b>: String message expected in the raised
|
|
||||||
exception. For a raise exception e, expected_exception_message must
|
|
||||||
equal str(e).
|
|
||||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
|
||||||
* <b>`args`</b>: Extra args.
|
|
||||||
* <b>`kwargs`</b>: Extra kwargs.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
A context manager if callable_obj is None. Otherwise, None.
|
|
||||||
|
|
||||||
##### Raises:
|
|
||||||
|
|
||||||
self.failureException if callable_obj does not raise a macthing exception.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertRaisesWithPredicateMatch(exception_type, expected_err_re_or_predicate)` {#TestCase.assertRaisesWithPredicateMatch}
|
#### `tf.test.TestCase.assertRaisesWithPredicateMatch(exception_type, expected_err_re_or_predicate)` {#TestCase.assertRaisesWithPredicateMatch}
|
||||||
@ -835,71 +572,6 @@ predicate search.
|
|||||||
exception.
|
exception.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertRaisesWithRegexpMatch(expected_exception, expected_regexp, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithRegexpMatch}
|
|
||||||
|
|
||||||
Asserts that the message in a raised exception matches the given regexp.
|
|
||||||
|
|
||||||
This is just a wrapper around assertRaisesRegexp. Please use
|
|
||||||
assertRaisesRegexp instead of assertRaisesWithRegexpMatch.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
|
||||||
* <b>`expected_regexp`</b>: Regexp (re pattern object or string) expected to be
|
|
||||||
found in error message.
|
|
||||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
|
||||||
* <b>`args`</b>: Extra args.
|
|
||||||
* <b>`kwargs`</b>: Extra keyword args.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
A context manager if callable_obj is None. Otherwise, None.
|
|
||||||
|
|
||||||
##### Raises:
|
|
||||||
|
|
||||||
self.failureException if callable_obj does not raise a macthing exception.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertRegexMatch(actual_str, regexes, message=None)` {#TestCase.assertRegexMatch}
|
|
||||||
|
|
||||||
Asserts that at least one regex in regexes matches str.
|
|
||||||
|
|
||||||
If possible you should use assertRegexpMatches, which is a simpler
|
|
||||||
version of this method. assertRegexpMatches takes a single regular
|
|
||||||
expression (a string or re compiled object) instead of a list.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
1. This function uses substring matching, i.e. the matching
|
|
||||||
succeeds if *any* substring of the error message matches *any*
|
|
||||||
regex in the list. This is more convenient for the user than
|
|
||||||
full-string matching.
|
|
||||||
|
|
||||||
2. If regexes is the empty list, the matching will always fail.
|
|
||||||
|
|
||||||
3. Use regexes=[''] for a regex that will always pass.
|
|
||||||
|
|
||||||
4. '.' matches any single character *except* the newline. To
|
|
||||||
match any character, use '(.|
|
|
||||||
)'.
|
|
||||||
|
|
||||||
5. '^' matches the beginning of each line, not just the beginning
|
|
||||||
of the string. Similarly, '$' matches the end of each line.
|
|
||||||
|
|
||||||
6. An exception will be thrown if regexes contains an invalid
|
|
||||||
regex.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
actual_str: The string we try to match with the items in regexes.
|
|
||||||
regexes: The regular expressions we want to match against str.
|
|
||||||
See "Notes" above for detailed notes on how this is interpreted.
|
|
||||||
message: The message to be printed if the test fails.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertRegexpMatches(text, expected_regexp, msg=None)` {#TestCase.assertRegexpMatches}
|
#### `tf.test.TestCase.assertRegexpMatches(text, expected_regexp, msg=None)` {#TestCase.assertRegexpMatches}
|
||||||
@ -907,79 +579,6 @@ Asserts that at least one regex in regexes matches str.
|
|||||||
Fail the test unless the text matches the regular expression.
|
Fail the test unless the text matches the regular expression.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSameElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertSameElements}
|
|
||||||
|
|
||||||
Assert that two sequences have the same elements (in any order).
|
|
||||||
|
|
||||||
This method, unlike assertItemsEqual, doesn't care about any
|
|
||||||
duplicates in the expected and actual sequences.
|
|
||||||
|
|
||||||
>> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
|
|
||||||
# Doesn't raise an AssertionError
|
|
||||||
|
|
||||||
If possible, you should use assertItemsEqual instead of
|
|
||||||
assertSameElements.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
|
||||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
|
||||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSameStructure(a, b, aname='a', bname='b', msg=None)` {#TestCase.assertSameStructure}
|
|
||||||
|
|
||||||
Asserts that two values contain the same structural content.
|
|
||||||
|
|
||||||
The two arguments should be data trees consisting of trees of dicts and
|
|
||||||
lists. They will be deeply compared by walking into the contents of dicts
|
|
||||||
and lists; other items will be compared using the == operator.
|
|
||||||
If the two structures differ in content, the failure message will indicate
|
|
||||||
the location within the structures where the first difference is found.
|
|
||||||
This may be helpful when comparing large structures.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`a`</b>: The first structure to compare.
|
|
||||||
* <b>`b`</b>: The second structure to compare.
|
|
||||||
* <b>`aname`</b>: Variable name to use for the first structure in assertion messages.
|
|
||||||
* <b>`bname`</b>: Variable name to use for the second structure.
|
|
||||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSequenceAlmostEqual(expected_seq, actual_seq, places=None, msg=None, delta=None)` {#TestCase.assertSequenceAlmostEqual}
|
|
||||||
|
|
||||||
An approximate equality assertion for ordered sequences.
|
|
||||||
|
|
||||||
Fail if the two sequences are unequal as determined by their value
|
|
||||||
differences rounded to the given number of decimal places (default 7) and
|
|
||||||
comparing to zero, or by comparing that the difference between each value
|
|
||||||
in the two sequences is more than the given delta.
|
|
||||||
|
|
||||||
Note that decimal places (from zero) are usually not the same as significant
|
|
||||||
digits (measured from the most signficant digit).
|
|
||||||
|
|
||||||
If the two sequences compare equal then they will automatically compare
|
|
||||||
almost equal.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
|
||||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
|
||||||
* <b>`places`</b>: The number of decimal places to compare.
|
|
||||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
|
||||||
* <b>`delta`</b>: The OK difference between compared values.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)` {#TestCase.assertSequenceEqual}
|
#### `tf.test.TestCase.assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)` {#TestCase.assertSequenceEqual}
|
||||||
@ -1000,26 +599,6 @@ which can be indexed, has a length, and has an equality operator.
|
|||||||
differences.
|
differences.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSequenceStartsWith(prefix, whole, msg=None)` {#TestCase.assertSequenceStartsWith}
|
|
||||||
|
|
||||||
An equality assertion for the beginning of ordered sequences.
|
|
||||||
|
|
||||||
If prefix is an empty sequence, it will raise an error unless whole is also
|
|
||||||
an empty sequence.
|
|
||||||
|
|
||||||
If prefix is not a sequence, it will raise an error if the first element of
|
|
||||||
whole does not match.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`prefix`</b>: A sequence expected at the beginning of the whole parameter.
|
|
||||||
* <b>`whole`</b>: The sequence in which to look for prefix.
|
|
||||||
* <b>`msg`</b>: Optional message to report on failure.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertSetEqual(set1, set2, msg=None)` {#TestCase.assertSetEqual}
|
#### `tf.test.TestCase.assertSetEqual(set1, set2, msg=None)` {#TestCase.assertSetEqual}
|
||||||
@ -1071,51 +650,6 @@ Assert that actual.startswith(expected_start) is True.
|
|||||||
* <b>`msg`</b>: Optional message to report on failure.
|
* <b>`msg`</b>: Optional message to report on failure.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertTotallyOrdered(*groups, **kwargs)` {#TestCase.assertTotallyOrdered}
|
|
||||||
|
|
||||||
Asserts that total ordering has been implemented correctly.
|
|
||||||
|
|
||||||
For example, say you have a class A that compares only on its attribute x.
|
|
||||||
Comparators other than __lt__ are omitted for brevity.
|
|
||||||
|
|
||||||
class A(object):
|
|
||||||
def __init__(self, x, y):
|
|
||||||
self.x = x
|
|
||||||
self.y = y
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(self.x)
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
try:
|
|
||||||
return self.x < other.x
|
|
||||||
except AttributeError:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
assertTotallyOrdered will check that instances can be ordered correctly.
|
|
||||||
For example,
|
|
||||||
|
|
||||||
self.assertTotallyOrdered(
|
|
||||||
[None], # None should come before everything else.
|
|
||||||
[1], # Integers sort earlier.
|
|
||||||
[A(1, 'a')],
|
|
||||||
[A(2, 'b')], # 2 is after 1.
|
|
||||||
[A(3, 'c'), A(3, 'd')], # The second argument is irrelevant.
|
|
||||||
[A(4, 'z')],
|
|
||||||
['foo']) # Strings sort last.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`*groups`</b>: A list of groups of elements. Each group of elements is a list
|
|
||||||
of objects that are equal. The elements in each group must be less than
|
|
||||||
the elements in the group after it. For example, these groups are
|
|
||||||
totally ordered: [None], [1], [2, 2], [3].
|
|
||||||
* <b>`**kwargs`</b>: optional msg keyword argument can be passed.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertTrue(expr, msg=None)` {#TestCase.assertTrue}
|
#### `tf.test.TestCase.assertTrue(expr, msg=None)` {#TestCase.assertTrue}
|
||||||
@ -1138,13 +672,6 @@ A tuple-specific equality assertion.
|
|||||||
differences.
|
differences.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.assertUrlEqual(a, b, msg=None)` {#TestCase.assertUrlEqual}
|
|
||||||
|
|
||||||
Asserts that urls are equal, ignoring ordering of query params.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.assert_(expr, msg=None)` {#TestCase.assert_}
|
#### `tf.test.TestCase.assert_(expr, msg=None)` {#TestCase.assert_}
|
||||||
@ -1206,9 +733,9 @@ tearDown.
|
|||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.fail(msg=None, prefix=None)` {#TestCase.fail}
|
#### `tf.test.TestCase.fail(msg=None)` {#TestCase.fail}
|
||||||
|
|
||||||
Fail immediately with the given message, optionally prefixed.
|
Fail immediately, with the given message.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
@ -1260,13 +787,6 @@ Fail immediately with the given message, optionally prefixed.
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.getRecordedProperties()` {#TestCase.getRecordedProperties}
|
|
||||||
|
|
||||||
Return any properties that the user has recorded.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.get_temp_dir()` {#TestCase.get_temp_dir}
|
#### `tf.test.TestCase.get_temp_dir()` {#TestCase.get_temp_dir}
|
||||||
@ -1289,20 +809,6 @@ pollute each others environment.
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
#### `tf.test.TestCase.recordProperty(property_name, property_value)` {#TestCase.recordProperty}
|
|
||||||
|
|
||||||
Record an arbitrary property for later use.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`property_name`</b>: str, name of property to record; must be a valid XML
|
|
||||||
attribute name
|
|
||||||
* <b>`property_value`</b>: value of property; must be valid XML attribute value
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
#### `tf.test.TestCase.run(result=None)` {#TestCase.run}
|
#### `tf.test.TestCase.run(result=None)` {#TestCase.run}
|
||||||
@ -1328,18 +834,11 @@ Hook method for setting up class fixture before running tests in the class.
|
|||||||
|
|
||||||
#### `tf.test.TestCase.shortDescription()` {#TestCase.shortDescription}
|
#### `tf.test.TestCase.shortDescription()` {#TestCase.shortDescription}
|
||||||
|
|
||||||
Format both the test method name and the first line of its docstring.
|
Returns a one-line description of the test, or None if no
|
||||||
|
description has been provided.
|
||||||
|
|
||||||
If no docstring is given, only returns the method name.
|
The default implementation of this method returns the first line of
|
||||||
|
the specified test method's docstring.
|
||||||
This method overrides unittest.TestCase.shortDescription(), which
|
|
||||||
only returns the first line of the docstring, obscuring the name
|
|
||||||
of the test upon failure.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`desc`</b>: A short description of a test method.
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
@ -78,37 +78,51 @@ If the above commands do not work on your system, you can follow these instructi
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp27-none-linux_x86_64.whl
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||||
|
|
||||||
# Mac OS X, CPU only, Python 2.7:
|
# Mac OS X, CPU only, Python 2.7:
|
||||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py2-none-any.whl
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py2-none-any.whl
|
||||||
|
|
||||||
# Mac OS X, GPU enabled, Python 2.7:
|
# Mac OS X, GPU enabled, Python 2.7:
|
||||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py2-none-any.whl
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py2-none-any.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.3
|
||||||
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.3
|
||||||
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.6
|
||||||
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.6
|
||||||
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||||
|
|
||||||
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
||||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py3-none-any.whl
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py3-none-any.whl
|
||||||
|
|
||||||
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
||||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py3-none-any.whl
|
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py3-none-any.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
Install TensorFlow:
|
Install TensorFlow:
|
||||||
@ -150,14 +164,14 @@ Both distributions include pip. To install the CPU-only version of
|
|||||||
TensorFlow, enter the following command at a command prompt:
|
TensorFlow, enter the following command at a command prompt:
|
||||||
|
|
||||||
```bat
|
```bat
|
||||||
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-0.12.1-cp35-cp35m-win_amd64.whl
|
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.0.0rc1-cp35-cp35m-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
To install the GPU version of TensorFlow, enter the following command
|
To install the GPU version of TensorFlow, enter the following command
|
||||||
at a command prompt:
|
at a command prompt:
|
||||||
|
|
||||||
```bat
|
```bat
|
||||||
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-win_amd64.whl
|
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.0.0rc1-cp35-cp35m-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
You can now [test your installation](#test-the-tensorflow-installation).
|
You can now [test your installation](#test-the-tensorflow-installation).
|
||||||
@ -212,37 +226,51 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp27-none-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||||
|
|
||||||
# Mac OS X, CPU only, Python 2.7:
|
# Mac OS X, CPU only, Python 2.7:
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py2-none-any.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py2-none-any.whl
|
||||||
|
|
||||||
# Mac OS X, GPU enabled, Python 2.7:
|
# Mac OS X, GPU enabled, Python 2.7:
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py2-none-any.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py2-none-any.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.3
|
||||||
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.3
|
||||||
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.6
|
||||||
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.6
|
||||||
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||||
|
|
||||||
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py3-none-any.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py3-none-any.whl
|
||||||
|
|
||||||
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py3-none-any.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py3-none-any.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
Finally install TensorFlow:
|
Finally install TensorFlow:
|
||||||
@ -364,37 +392,51 @@ select the correct binary to install:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp27-none-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||||
|
|
||||||
# Mac OS X, CPU only, Python 2.7:
|
# Mac OS X, CPU only, Python 2.7:
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py2-none-any.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py2-none-any.whl
|
||||||
|
|
||||||
# Mac OS X, GPU enabled, Python 2.7:
|
# Mac OS X, GPU enabled, Python 2.7:
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py2-none-any.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py2-none-any.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.3
|
||||||
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.3
|
||||||
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||||
|
|
||||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, CPU only, Python 3.6
|
||||||
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||||
|
|
||||||
|
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.6
|
||||||
|
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||||
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||||
|
|
||||||
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py3-none-any.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py3-none-any.whl
|
||||||
|
|
||||||
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
||||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py3-none-any.whl
|
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py3-none-any.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
Finally install TensorFlow:
|
Finally install TensorFlow:
|
||||||
@ -462,7 +504,7 @@ code.
|
|||||||
code.
|
code.
|
||||||
|
|
||||||
We also have tags with `latest` replaced by a released version (e.g.,
|
We also have tags with `latest` replaced by a released version (e.g.,
|
||||||
`0.12.1-gpu`).
|
`1.0.0-rc1-gpu`).
|
||||||
|
|
||||||
With Docker the installation is as follows:
|
With Docker the installation is as follows:
|
||||||
|
|
||||||
@ -557,7 +599,7 @@ To build TensorFlow from source on Windows, you can use experimental
|
|||||||
support for [Bazel on
|
support for [Bazel on
|
||||||
Windows](https://bazel.build/versions/master/docs/windows.html) or the
|
Windows](https://bazel.build/versions/master/docs/windows.html) or the
|
||||||
[TensorFlow CMake
|
[TensorFlow CMake
|
||||||
build](https://github.com/tensorflow/tensorflow/tree/r0.12/tensorflow/contrib/cmake).
|
build](https://github.com/tensorflow/tensorflow/tree/r1.0/tensorflow/contrib/cmake).
|
||||||
|
|
||||||
### Clone the TensorFlow repository
|
### Clone the TensorFlow repository
|
||||||
|
|
||||||
@ -856,37 +898,38 @@ default and if you want to limit RAM usage you can add `--local_resources
|
|||||||
2048,.5,1.0` while invoking bazel.
|
2048,.5,1.0` while invoking bazel.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
|
$ bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
|
||||||
|
|
||||||
# To build with support for CUDA:
|
# To build with support for CUDA:
|
||||||
$ bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
|
$ bazel build --config opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
|
||||||
|
|
||||||
# Alternatively, to build with support for OpenCL:
|
# Alternatively, to build with support for OpenCL (Experimental):
|
||||||
$ bazel build -c opt --config=sycl //tensorflow/tools/pip_package:build_pip_package
|
$ bazel build --config opt --config=sycl //tensorflow/tools/pip_package:build_pip_package
|
||||||
|
|
||||||
$ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
|
$ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
|
||||||
|
|
||||||
# The name of the .whl file will depend on your platform.
|
# The name of the .whl file will depend on your platform.
|
||||||
$ sudo pip install /tmp/tensorflow_pkg/tensorflow-0.12.1-py2-none-any.whl
|
$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.0.0rc1-py2-none-any.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
## Optimizing CPU performance
|
## Optimizing CPU performance
|
||||||
|
|
||||||
To be compatible with as wide a range of machines as possible, TensorFlow
|
To be compatible with as wide a range of machines as possible, TensorFlow
|
||||||
defaults to only using SSE4.1 SIMD instructions on x86 machines. Most modern PCs
|
defaults to only using SSE4 SIMD instructions. Most modern computers support
|
||||||
and Macs support more advanced instructions, so if you're building a binary
|
more advanced instructions. So if you're building a binary that you'll only
|
||||||
that you'll only be running on your own machine, you can enable these by using
|
be running on your own machine, you can enable these by using `-march=native`
|
||||||
`--copt=-march=native` in your bazel build command. For example:
|
for optimization options when running `configure`. Then you can build your
|
||||||
|
optimized binaries with the following command:
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
$ bazel build --copt=-march=native -c opt //tensorflow/tools/pip_package:build_pip_package
|
$ bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
|
||||||
```
|
```
|
||||||
|
|
||||||
If you are distributing a binary but know the capabilities of the machines
|
If you are distributing a binary but know the capabilities of the machines
|
||||||
you'll be running on, you can manually choose the right instructions with
|
you'll be running on, you can manually choose the right instructions with
|
||||||
something like `--copt=-march=avx`. You may also want to enable multiple
|
something like `-march=avx`. You may also want to enable multiple
|
||||||
features using several arguments, for example
|
features using several arguments, for example
|
||||||
`--copt=-mavx2 --copt=-mfma`.
|
`-mavx2,-mfma`.
|
||||||
|
|
||||||
If you run a binary built using SIMD instructions on a machine that doesn't
|
If you run a binary built using SIMD instructions on a machine that doesn't
|
||||||
support them, you'll see an illegal instruction error when that code is
|
support them, you'll see an illegal instruction error when that code is
|
||||||
@ -902,10 +945,10 @@ system directories, run the following commands inside the TensorFlow root
|
|||||||
directory:
|
directory:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
|
bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
|
||||||
|
|
||||||
# To build with GPU support:
|
# To build with GPU support:
|
||||||
bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
|
bazel build --config opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
|
||||||
|
|
||||||
mkdir _python_build
|
mkdir _python_build
|
||||||
cd _python_build
|
cd _python_build
|
||||||
|
@ -177,7 +177,7 @@ tf_custom_op_library(
|
|||||||
Run the following command to build `zero_out.so`.
|
Run the following command to build `zero_out.so`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ bazel build -c opt //tensorflow/core/user_ops:zero_out.so
|
$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so
|
||||||
```
|
```
|
||||||
|
|
||||||
> Note:
|
> Note:
|
||||||
|
@ -42,10 +42,10 @@ bazel build tensorflow/examples/image_retraining:retrain
|
|||||||
|
|
||||||
If you have a machine which supports [the AVX instruction set](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions)
|
If you have a machine which supports [the AVX instruction set](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions)
|
||||||
(common in x86 CPUs produced in the last few years) you can improve the running
|
(common in x86 CPUs produced in the last few years) you can improve the running
|
||||||
speed of the retraining by building for that architecture, like this:
|
speed of the retraining by building for that architecture, like this (after choosing appropriate options in `configure`):
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
bazel build -c opt --copt=-mavx tensorflow/examples/image_retraining:retrain
|
bazel build --config opt tensorflow/examples/image_retraining:retrain
|
||||||
```
|
```
|
||||||
|
|
||||||
The retrainer can then be run like this:
|
The retrainer can then be run like this:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Roadmap
|
# Roadmap
|
||||||
**Last updated: June 3, 2016**
|
**Last updated: January 23, 2017**
|
||||||
|
|
||||||
TensorFlow is a fast moving project. In order for the community to better
|
TensorFlow is a fast moving project. In order for the community to better
|
||||||
understand what the near future will bring, this document shares what we are
|
understand what the near future will bring, this document shares what we are
|
||||||
@ -11,29 +11,28 @@ The features on this list are targeted for the next few months. At this point,
|
|||||||
we do not have timelines for these features.
|
we do not have timelines for these features.
|
||||||
|
|
||||||
### Improve non-Python language support
|
### Improve non-Python language support
|
||||||
C and C++ APIs for:
|
|
||||||
|
|
||||||
* Graph construction
|
* Improve C++ API for graph construction and gradients
|
||||||
* Gradients
|
* Java language support
|
||||||
* Shape Inference
|
* Go language support
|
||||||
|
|
||||||
### Making TensorFlow easier to use
|
### Making TensorFlow easier to use
|
||||||
* Easier setup for distributed training jobs
|
* High-level APIs
|
||||||
|
* Well-maintained models showing best practices
|
||||||
|
|
||||||
### Performance
|
### Performance
|
||||||
* Speed and memory benchmarks
|
* Speed and memory benchmarks
|
||||||
|
* Distributed full model benchmarks
|
||||||
* Performance and memory usage improvements
|
* Performance and memory usage improvements
|
||||||
|
|
||||||
### Core Features
|
### Core Features
|
||||||
* Repeated partial graph evaluation ([#672](https://github.com/tensorflow/tensorflow/issues/672))
|
|
||||||
* Automatic op placement ([#2126](https://github.com/tensorflow/tensorflow/issues/2126))
|
* Automatic op placement ([#2126](https://github.com/tensorflow/tensorflow/issues/2126))
|
||||||
|
* Support for graph-level functions
|
||||||
|
|
||||||
### Platforms
|
### Platforms
|
||||||
* OpenCL support ([#22](https://github.com/tensorflow/tensorflow/issues/22))
|
* OpenCL support ([#22](https://github.com/tensorflow/tensorflow/issues/22))
|
||||||
|
|
||||||
### Community
|
### Community
|
||||||
* More educational resources
|
* More educational resources
|
||||||
* Better integration of TensorFlow into the opensource big data ecosystem ([#1996](https://github.com/tensorflow/tensorflow/issues/1996),
|
* Better integration of TensorFlow into the opensource big data ecosystem (e.g.
|
||||||
[#2218](https://github.com/tensorflow/tensorflow/issues/2218),
|
|
||||||
[#2655](https://github.com/tensorflow/tensorflow/issues/2655))
|
[#2655](https://github.com/tensorflow/tensorflow/issues/2655))
|
||||||
* Models benchmarking and comparison tooling
|
|
||||||
|
@ -30,7 +30,7 @@ then
|
|||||||
then
|
then
|
||||||
echo "Protocol buffer compiler protoc not found in PATH or in ${PROTOC}"
|
echo "Protocol buffer compiler protoc not found in PATH or in ${PROTOC}"
|
||||||
echo "Perhaps build it using:"
|
echo "Perhaps build it using:"
|
||||||
echo "bazel build -c opt @protobuf//:protoc"
|
echo "bazel build --config opt @protobuf//:protoc"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
PROTOC=$PATH_PROTOC
|
PROTOC=$PATH_PROTOC
|
||||||
|
@ -40,7 +40,7 @@ Configure and build the Java Archive (JAR) and native library:
|
|||||||
./configure
|
./configure
|
||||||
|
|
||||||
# Build the JAR and native library
|
# Build the JAR and native library
|
||||||
bazel build -c opt \
|
bazel build --config opt \
|
||||||
//tensorflow/java:tensorflow \
|
//tensorflow/java:tensorflow \
|
||||||
//tensorflow/java:libtensorflow_jni
|
//tensorflow/java:libtensorflow_jni
|
||||||
```
|
```
|
||||||
|
@ -151,14 +151,15 @@ def _FindAttrInOpDef(attr_name, op_def):
|
|||||||
|
|
||||||
def import_graph_def(graph_def, input_map=None, return_elements=None,
|
def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||||
name=None, op_dict=None, producer_op_list=None):
|
name=None, op_dict=None, producer_op_list=None):
|
||||||
"""Imports the TensorFlow graph in `graph_def` into the Python `Graph`.
|
"""Imports the graph from `graph_def` into the current default `Graph`.
|
||||||
|
|
||||||
This function provides a way to import a serialized TensorFlow
|
This function provides a way to import a serialized TensorFlow
|
||||||
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
|
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
|
||||||
protocol buffer, and extract individual objects in the `GraphDef` as
|
protocol buffer, and extract individual objects in the `GraphDef` as
|
||||||
[`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
|
[`Tensor`](#Tensor) and [`Operation`](#Operation) objects. Once extracted,
|
||||||
[`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
|
these objects are placed into the current default `Graph`. See
|
||||||
`GraphDef` proto.
|
[`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a `GraphDef`
|
||||||
|
proto.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_def: A `GraphDef` proto containing operations to be imported into
|
graph_def: A `GraphDef` proto containing operations to be imported into
|
||||||
|
@ -491,7 +491,9 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
print("dtype = %s, shape = %s" % (a.dtype, a.shape))
|
print("dtype = %s, shape = %s" % (a.dtype, a.shape))
|
||||||
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
def assertAllCloseAccordingToType(self, a, b, rtol=1e-6, atol=1e-6):
|
def assertAllCloseAccordingToType(self, a, b, rtol=1e-6, atol=1e-6,
|
||||||
|
float_rtol=1e-6, float_atol=1e-6,
|
||||||
|
half_rtol=1e-3, half_atol=1e-3):
|
||||||
"""Like assertAllClose, but also suitable for comparing fp16 arrays.
|
"""Like assertAllClose, but also suitable for comparing fp16 arrays.
|
||||||
|
|
||||||
In particular, the tolerance is reduced to 1e-3 if at least
|
In particular, the tolerance is reduced to 1e-3 if at least
|
||||||
@ -502,12 +504,19 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
b: a numpy ndarray or anything can be converted to one.
|
b: a numpy ndarray or anything can be converted to one.
|
||||||
rtol: relative tolerance
|
rtol: relative tolerance
|
||||||
atol: absolute tolerance
|
atol: absolute tolerance
|
||||||
|
float_rtol: relative tolerance for float32
|
||||||
|
float_atol: absolute tolerance for float32
|
||||||
|
half_rtol: relative tolerance for float16
|
||||||
|
half_atol: absolute tolerance for float16
|
||||||
"""
|
"""
|
||||||
a = self._GetNdArray(a)
|
a = self._GetNdArray(a)
|
||||||
b = self._GetNdArray(b)
|
b = self._GetNdArray(b)
|
||||||
|
if a.dtype == np.float32 or b.dtype == np.float32:
|
||||||
|
rtol = max(rtol, float_rtol)
|
||||||
|
atol = max(atol, float_atol)
|
||||||
if a.dtype == np.float16 or b.dtype == np.float16:
|
if a.dtype == np.float16 or b.dtype == np.float16:
|
||||||
rtol = max(rtol, 1e-3)
|
rtol = max(rtol, half_rtol)
|
||||||
atol = max(atol, 1e-3)
|
atol = max(atol, half_atol)
|
||||||
|
|
||||||
self.assertAllClose(a, b, rtol=rtol, atol=atol)
|
self.assertAllClose(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
@ -193,6 +193,55 @@ class TestUtilTest(test_util.TensorFlowTestCase):
|
|||||||
y = [15]
|
y = [15]
|
||||||
control_flow_ops.Assert(x, y).run()
|
control_flow_ops.Assert(x, y).run()
|
||||||
|
|
||||||
|
def testAssertAllCloseAccordingToType(self):
|
||||||
|
# test float64
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.asarray([1e-8], dtype=np.float64),
|
||||||
|
np.asarray([2e-8], dtype=np.float64),
|
||||||
|
rtol=1e-8, atol=1e-8
|
||||||
|
)
|
||||||
|
|
||||||
|
with (self.assertRaises(AssertionError)):
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.asarray([1e-7], dtype=np.float64),
|
||||||
|
np.asarray([2e-7], dtype=np.float64),
|
||||||
|
rtol=1e-8, atol=1e-8
|
||||||
|
)
|
||||||
|
|
||||||
|
# test float32
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.asarray([1e-7], dtype=np.float32),
|
||||||
|
np.asarray([2e-7], dtype=np.float32),
|
||||||
|
rtol=1e-8, atol=1e-8,
|
||||||
|
float_rtol=1e-7, float_atol=1e-7
|
||||||
|
)
|
||||||
|
|
||||||
|
with (self.assertRaises(AssertionError)):
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.asarray([1e-6], dtype=np.float32),
|
||||||
|
np.asarray([2e-6], dtype=np.float32),
|
||||||
|
rtol=1e-8, atol=1e-8,
|
||||||
|
float_rtol=1e-7, float_atol=1e-7
|
||||||
|
)
|
||||||
|
|
||||||
|
# test float16
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.asarray([1e-4], dtype=np.float16),
|
||||||
|
np.asarray([2e-4], dtype=np.float16),
|
||||||
|
rtol=1e-8, atol=1e-8,
|
||||||
|
float_rtol=1e-7, float_atol=1e-7,
|
||||||
|
half_rtol=1e-4, half_atol=1e-4
|
||||||
|
)
|
||||||
|
|
||||||
|
with (self.assertRaises(AssertionError)):
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.asarray([1e-3], dtype=np.float16),
|
||||||
|
np.asarray([2e-3], dtype=np.float16),
|
||||||
|
rtol=1e-8, atol=1e-8,
|
||||||
|
float_rtol=1e-7, float_atol=1e-7,
|
||||||
|
half_rtol=1e-4, half_atol=1e-4
|
||||||
|
)
|
||||||
|
|
||||||
def testRandomSeed(self):
|
def testRandomSeed(self):
|
||||||
a = random.randint(1, 1000)
|
a = random.randint(1, 1000)
|
||||||
a_np_rand = np.random.rand(1)
|
a_np_rand = np.random.rand(1)
|
||||||
|
@ -70,13 +70,13 @@ class ScalarStrictTest(test.TestCase):
|
|||||||
self.assertAllEqual(r, correct)
|
self.assertAllEqual(r, correct)
|
||||||
|
|
||||||
def testConcat(self):
|
def testConcat(self):
|
||||||
self.check(array_ops.concat_v2, (([2], [3], [7]), [0]),
|
self.check(array_ops.concat, (([2], [3], [7]), [0]),
|
||||||
'axis tensor should be a scalar integer', [2, 3, 7])
|
'axis tensor should be a scalar integer', [2, 3, 7])
|
||||||
for data in (2, 3, 7), (2, [3], 7), (2, 3, [7]):
|
for data in (2, 3, 7), (2, [3], 7), (2, 3, [7]):
|
||||||
self.check(array_ops.concat_v2, (data, 0),
|
self.check(array_ops.concat, (data, 0),
|
||||||
r'Expected \w+ dimensions in the range \[0, 0\)', [2, 3, 7])
|
r'Expected \w+ dimensions in the range \[0, 0\)', [2, 3, 7])
|
||||||
for data in ([2], 3, 7), ([2], [3], 7):
|
for data in ([2], 3, 7), ([2], [3], 7):
|
||||||
self.check(array_ops.concat_v2, (data, 0),
|
self.check(array_ops.concat, (data, 0),
|
||||||
r'Ranks of all input tensors should match', [2, 3, 7])
|
r'Ranks of all input tensors should match', [2, 3, 7])
|
||||||
|
|
||||||
def testFill(self):
|
def testFill(self):
|
||||||
|
@ -49,12 +49,21 @@ class SegmentReductionHelper(test.TestCase):
|
|||||||
slice_shape = x.shape[indices.ndim:]
|
slice_shape = x.shape[indices.ndim:]
|
||||||
x_flat = x.reshape((indices.size,) + slice_shape)
|
x_flat = x.reshape((indices.size,) + slice_shape)
|
||||||
for i, index in enumerate(indices.ravel()):
|
for i, index in enumerate(indices.ravel()):
|
||||||
if output[index] is not None:
|
if (output[index] is not None) and op1 == np.max:
|
||||||
|
for j in range(0, output[index].shape[0]):
|
||||||
|
output[index][j] = op1([output[index][j], x_flat[i][j]])
|
||||||
|
elif output[index] is not None:
|
||||||
output[index] = op1(output[index], x_flat[i])
|
output[index] = op1(output[index], x_flat[i])
|
||||||
else:
|
else:
|
||||||
output[index] = x_flat[i]
|
output[index] = x_flat[i]
|
||||||
# zero initialize values that are still uncalcuated.
|
# zero initialize values that are still uncalcuated.
|
||||||
output = [o if o is not None else np.zeros(slice_shape) for o in output]
|
# output = [o if o is not None else np.zeros(slice_shape) for o in output]
|
||||||
|
if not op1 == np.max:
|
||||||
|
output = [o if o is not None else np.zeros(slice_shape) for o in output]
|
||||||
|
else:
|
||||||
|
zeroslice = np.zeros(slice_shape)
|
||||||
|
zeroslice.fill(dtype.min)
|
||||||
|
output = [o if o is not None else zeroslice for o in output]
|
||||||
if op2 is not None:
|
if op2 is not None:
|
||||||
output = [op2(o) for o in output]
|
output = [op2(o) for o in output]
|
||||||
output = [o.reshape(slice_shape) for o in output]
|
output = [o.reshape(slice_shape) for o in output]
|
||||||
@ -245,7 +254,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
|
|||||||
self._assertAllClose(indices, np_ans, tf_ans)
|
self._assertAllClose(indices, np_ans, tf_ans)
|
||||||
self.assertShapeEqual(np_ans, s)
|
self.assertShapeEqual(np_ans, s)
|
||||||
|
|
||||||
def testGradient(self):
|
def testGradientSegmentSum(self):
|
||||||
num_cols = 2
|
num_cols = 2
|
||||||
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
|
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
|
||||||
num_segments = max(indices_flat) + 3
|
num_segments = max(indices_flat) + 3
|
||||||
@ -318,6 +327,23 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
|
|||||||
unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
|
unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
|
||||||
self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype))
|
self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype))
|
||||||
|
|
||||||
|
def testGradientSegmentMax(self):
|
||||||
|
num_cols = 2
|
||||||
|
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
|
||||||
|
num_segments = max(indices_flat) + 3
|
||||||
|
for indices in indices_flat, indices_flat.reshape(5, 2):
|
||||||
|
shape = indices.shape + (num_cols,)
|
||||||
|
with self.test_session():
|
||||||
|
tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
|
||||||
|
s = math_ops.unsorted_segment_max(data=tf_x, segment_ids=indices,
|
||||||
|
num_segments=num_segments)
|
||||||
|
jacob_t, jacob_n = gradient_checker.compute_gradient(
|
||||||
|
tf_x,
|
||||||
|
shape,
|
||||||
|
s,
|
||||||
|
[num_segments, num_cols],
|
||||||
|
x_init_value=np_x.astype(np.double), delta=1)
|
||||||
|
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
class UnsortedSegmentSumGpuTest(UnsortedSegmentSumTest):
|
class UnsortedSegmentSumGpuTest(UnsortedSegmentSumTest):
|
||||||
use_gpu = True
|
use_gpu = True
|
||||||
|
@ -979,12 +979,6 @@ def unstack(value, num=None, axis=0, name="unstack"):
|
|||||||
return gen_array_ops._unpack(value, num=num, axis=axis, name=name)
|
return gen_array_ops._unpack(value, num=num, axis=axis, name=name)
|
||||||
|
|
||||||
|
|
||||||
# concat_v2 is an alias for concat. concat_v2 will be deprecated and removed
|
|
||||||
# soon, please use concat.
|
|
||||||
def concat_v2(values, axis, name="concat_v2"):
|
|
||||||
return concat(values, axis, name)
|
|
||||||
|
|
||||||
|
|
||||||
def concat(values, axis, name="concat"):
|
def concat(values, axis, name="concat"):
|
||||||
"""Concatenates tensors along one dimension.
|
"""Concatenates tensors along one dimension.
|
||||||
|
|
||||||
|
@ -70,6 +70,11 @@ def _Collect(val, collections, default_collections):
|
|||||||
ops.add_to_collection(key, val)
|
ops.add_to_collection(key, val)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
"2016-11-30", "Please switch to tf.summary.histogram. Note that "
|
||||||
|
"tf.summary.histogram uses the node name instead of the tag. "
|
||||||
|
"This means that TensorFlow will automatically de-duplicate summary "
|
||||||
|
"names based on the scope they are created in.")
|
||||||
def histogram_summary(tag, values, collections=None, name=None):
|
def histogram_summary(tag, values, collections=None, name=None):
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
"""Outputs a `Summary` protocol buffer with a histogram.
|
"""Outputs a `Summary` protocol buffer with a histogram.
|
||||||
@ -304,6 +309,13 @@ def get_summary_op():
|
|||||||
return summary_op
|
return summary_op
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
"2016-11-30", "Please switch to tf.summary.scalar. Note that "
|
||||||
|
"tf.summary.scalar uses the node name instead of the tag. "
|
||||||
|
"This means that TensorFlow will automatically de-duplicate summary "
|
||||||
|
"names based on the scope they are created in. Also, passing a "
|
||||||
|
"tensor or list of tags to a scalar summary op is no longer "
|
||||||
|
"supported.")
|
||||||
def scalar_summary(tags, values, collections=None, name=None):
|
def scalar_summary(tags, values, collections=None, name=None):
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
"""Outputs a `Summary` protocol buffer with scalar values.
|
"""Outputs a `Summary` protocol buffer with scalar values.
|
||||||
|
@ -188,35 +188,42 @@ def _SparseSegmentSqrtNGrad(op, grad):
|
|||||||
dim0), None, None)
|
dim0), None, None)
|
||||||
|
|
||||||
|
|
||||||
def _SegmentMinOrMaxGrad(op, grad):
|
def _SegmentMinOrMaxGrad(op, grad, is_sorted):
|
||||||
"""Gradient for SegmentMin and SegmentMax. Both share the same code."""
|
"""Gradient for SegmentMin and (unsorted) SegmentMax. They share similar code."""
|
||||||
zeros = array_ops.zeros(
|
zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
|
||||||
array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype)
|
dtype=op.inputs[0].dtype)
|
||||||
|
|
||||||
# Get the number of selected (minimum or maximum) elements in each segment.
|
# Get the number of selected (minimum or maximum) elements in each segment.
|
||||||
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
|
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
|
||||||
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
|
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
|
||||||
num_selected = math_ops.segment_sum(
|
if is_sorted:
|
||||||
math_ops.cast(is_selected, grad.dtype), op.inputs[1])
|
num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
|
||||||
|
op.inputs[1])
|
||||||
|
else:
|
||||||
|
num_selected = math_ops.unsorted_segment_sum(math_ops.cast(is_selected, grad.dtype),
|
||||||
|
op.inputs[1], op.inputs[2])
|
||||||
|
|
||||||
# Compute the gradient for each segment. The gradient for the ith segment is
|
# Compute the gradient for each segment. The gradient for the ith segment is
|
||||||
# divided evenly among the selected elements in that segment.
|
# divided evenly among the selected elements in that segment.
|
||||||
weighted_grads = math_ops.div(grad, num_selected)
|
weighted_grads = math_ops.div(grad, num_selected)
|
||||||
gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
|
gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
|
||||||
|
|
||||||
return array_ops.where(is_selected, gathered_grads, zeros), None
|
if is_sorted:
|
||||||
|
return array_ops.where(is_selected, gathered_grads, zeros), None
|
||||||
|
else:
|
||||||
|
return array_ops.where(is_selected, gathered_grads, zeros), None, None
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("SegmentMin")
|
@ops.RegisterGradient("SegmentMin")
|
||||||
def _SegmentMinGrad(op, grad):
|
def _SegmentMinGrad(op, grad):
|
||||||
"""Gradient for SegmentMin."""
|
"""Gradient for SegmentMin."""
|
||||||
return _SegmentMinOrMaxGrad(op, grad)
|
return _SegmentMinOrMaxGrad(op, grad, True)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("SegmentMax")
|
@ops.RegisterGradient("SegmentMax")
|
||||||
def _SegmentMaxGrad(op, grad):
|
def _SegmentMaxGrad(op, grad):
|
||||||
"""Gradient for SegmentMax."""
|
"""Gradient for SegmentMax."""
|
||||||
return _SegmentMinOrMaxGrad(op, grad)
|
return _SegmentMinOrMaxGrad(op, grad, True)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("UnsortedSegmentSum")
|
@ops.RegisterGradient("UnsortedSegmentSum")
|
||||||
@ -225,6 +232,11 @@ def _UnsortedSegmentSumGrad(op, grad):
|
|||||||
return array_ops.gather(grad, op.inputs[1]), None, None
|
return array_ops.gather(grad, op.inputs[1]), None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("UnsortedSegmentMax")
|
||||||
|
def _UnsortedSegmentMaxGrad(op, grad):
|
||||||
|
return _SegmentMinOrMaxGrad(op, grad, False)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Abs")
|
@ops.RegisterGradient("Abs")
|
||||||
def _AbsGrad(op, grad):
|
def _AbsGrad(op, grad):
|
||||||
x = op.inputs[0]
|
x = op.inputs[0]
|
||||||
|
@ -196,6 +196,7 @@ tf.segment_sum(c, tf.constant([0, 0, 1]))
|
|||||||
@@segment_mean
|
@@segment_mean
|
||||||
|
|
||||||
@@unsorted_segment_sum
|
@@unsorted_segment_sum
|
||||||
|
@@unsorted_segment_max
|
||||||
|
|
||||||
@@sparse_segment_sum
|
@@sparse_segment_sum
|
||||||
@@sparse_segment_mean
|
@@sparse_segment_mean
|
||||||
|
@ -34,7 +34,7 @@ def _has_valid_dims(weights_shape, values_shape):
|
|||||||
with ops.name_scope(
|
with ops.name_scope(
|
||||||
None, "has_invalid_dims", (weights_shape, values_shape)) as scope:
|
None, "has_invalid_dims", (weights_shape, values_shape)) as scope:
|
||||||
values_shape_2d = array_ops.expand_dims(values_shape, -1)
|
values_shape_2d = array_ops.expand_dims(values_shape, -1)
|
||||||
valid_dims = array_ops.concat_v2(
|
valid_dims = array_ops.concat(
|
||||||
(values_shape_2d, array_ops.ones_like(values_shape_2d)), axis=1)
|
(values_shape_2d, array_ops.ones_like(values_shape_2d)), axis=1)
|
||||||
weights_shape_2d = array_ops.expand_dims(weights_shape, -1)
|
weights_shape_2d = array_ops.expand_dims(weights_shape, -1)
|
||||||
invalid_dims = sets.set_difference(weights_shape_2d, valid_dims)
|
invalid_dims = sets.set_difference(weights_shape_2d, valid_dims)
|
||||||
|
@ -64,7 +64,8 @@ def match_filenames_once(pattern, name=None):
|
|||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "matching_filenames", [pattern]) as name:
|
with ops.name_scope(name, "matching_filenames", [pattern]) as name:
|
||||||
return variables.Variable(io_ops.matching_files(pattern), trainable=False,
|
return variables.Variable(io_ops.matching_files(pattern), trainable=False,
|
||||||
name=name, validate_shape=False)
|
name=name, validate_shape=False,
|
||||||
|
collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
||||||
|
|
||||||
|
|
||||||
def limit_epochs(tensor, num_epochs=None, name=None):
|
def limit_epochs(tensor, num_epochs=None, name=None):
|
||||||
|
@ -164,7 +164,7 @@ class QueueRunnerTest(test.TestCase):
|
|||||||
coord.request_stop()
|
coord.request_stop()
|
||||||
# We should be able to join because the RequestStop() will cause
|
# We should be able to join because the RequestStop() will cause
|
||||||
# the queue to be closed and the enqueue to terminate.
|
# the queue to be closed and the enqueue to terminate.
|
||||||
coord.join(stop_grace_period_secs=0.05)
|
coord.join(stop_grace_period_secs=1.0)
|
||||||
|
|
||||||
def testMultipleSessions(self):
|
def testMultipleSessions(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
|
@ -68,18 +68,6 @@ from tensorflow.python.training import saver as saver_module
|
|||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def _TestDir(test_name):
|
|
||||||
test_dir = os.path.join(test.get_temp_dir(), test_name)
|
|
||||||
if os.path.exists(test_dir):
|
|
||||||
shutil.rmtree(test_dir)
|
|
||||||
gfile.MakeDirs(test_dir)
|
|
||||||
return test_dir
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: enable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointedOp(object):
|
class CheckpointedOp(object):
|
||||||
"""Op with a custom checkpointing implementation.
|
"""Op with a custom checkpointing implementation.
|
||||||
|
|
||||||
@ -591,6 +579,11 @@ class SaverTest(test.TestCase):
|
|||||||
|
|
||||||
class SaveRestoreShardedTest(test.TestCase):
|
class SaveRestoreShardedTest(test.TestCase):
|
||||||
|
|
||||||
|
def _get_test_dir(self, dirname):
|
||||||
|
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||||
|
gfile.MakeDirs(test_dir)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
def testBasics(self):
|
def testBasics(self):
|
||||||
save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
|
save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
|
||||||
|
|
||||||
@ -719,7 +712,9 @@ class SaveRestoreShardedTest(test.TestCase):
|
|||||||
var_full_shape = [10, 3]
|
var_full_shape = [10, 3]
|
||||||
# Allows save/restore mechanism to work w/ different slicings.
|
# Allows save/restore mechanism to work w/ different slicings.
|
||||||
var_name = "my_var"
|
var_name = "my_var"
|
||||||
saved_path = os.path.join(_TestDir("partitioned_variables"), "ckpt")
|
saved_dir = self._get_test_dir("partitioned_variables")
|
||||||
|
saved_path = os.path.join(saved_dir, "ckpt")
|
||||||
|
|
||||||
call_saver_with_dict = False # updated by test loop below
|
call_saver_with_dict = False # updated by test loop below
|
||||||
|
|
||||||
def _save(slices=None, partitioner=None):
|
def _save(slices=None, partitioner=None):
|
||||||
@ -842,8 +837,13 @@ class SaveRestoreShardedTest(test.TestCase):
|
|||||||
|
|
||||||
class MaxToKeepTest(test.TestCase):
|
class MaxToKeepTest(test.TestCase):
|
||||||
|
|
||||||
|
def _get_test_dir(self, dirname):
|
||||||
|
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||||
|
gfile.MakeDirs(test_dir)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
def testNonSharded(self):
|
def testNonSharded(self):
|
||||||
save_dir = _TestDir("max_to_keep_non_sharded")
|
save_dir = self._get_test_dir("max_to_keep_non_sharded")
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
v = variables.Variable(10.0, name="v")
|
v = variables.Variable(10.0, name="v")
|
||||||
@ -963,7 +963,7 @@ class MaxToKeepTest(test.TestCase):
|
|||||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
|
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
|
||||||
|
|
||||||
def testSharded(self):
|
def testSharded(self):
|
||||||
save_dir = _TestDir("max_to_keep_sharded")
|
save_dir = self._get_test_dir("max_to_keep_sharded")
|
||||||
|
|
||||||
with session.Session(
|
with session.Session(
|
||||||
target="",
|
target="",
|
||||||
@ -1018,8 +1018,8 @@ class MaxToKeepTest(test.TestCase):
|
|||||||
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
|
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
|
||||||
|
|
||||||
def testNoMaxToKeep(self):
|
def testNoMaxToKeep(self):
|
||||||
save_dir = _TestDir("no_max_to_keep")
|
save_dir = self._get_test_dir("no_max_to_keep")
|
||||||
save_dir2 = _TestDir("max_to_keep_0")
|
save_dir2 = self._get_test_dir("max_to_keep_0")
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
v = variables.Variable(10.0, name="v")
|
v = variables.Variable(10.0, name="v")
|
||||||
@ -1046,7 +1046,7 @@ class MaxToKeepTest(test.TestCase):
|
|||||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||||
|
|
||||||
def testNoMetaGraph(self):
|
def testNoMetaGraph(self):
|
||||||
save_dir = _TestDir("no_meta_graph")
|
save_dir = self._get_test_dir("no_meta_graph")
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
v = variables.Variable(10.0, name="v")
|
v = variables.Variable(10.0, name="v")
|
||||||
@ -1060,8 +1060,13 @@ class MaxToKeepTest(test.TestCase):
|
|||||||
|
|
||||||
class KeepCheckpointEveryNHoursTest(test.TestCase):
|
class KeepCheckpointEveryNHoursTest(test.TestCase):
|
||||||
|
|
||||||
|
def _get_test_dir(self, dirname):
|
||||||
|
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||||
|
gfile.MakeDirs(test_dir)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
def testNonSharded(self):
|
def testNonSharded(self):
|
||||||
save_dir = _TestDir("keep_checkpoint_every_n_hours")
|
save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
v = variables.Variable([10.0], name="v")
|
v = variables.Variable([10.0], name="v")
|
||||||
@ -1277,8 +1282,13 @@ class LatestCheckpointWithRelativePaths(test.TestCase):
|
|||||||
|
|
||||||
class CheckpointStateTest(test.TestCase):
|
class CheckpointStateTest(test.TestCase):
|
||||||
|
|
||||||
|
def _get_test_dir(self, dirname):
|
||||||
|
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||||
|
gfile.MakeDirs(test_dir)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
def testAbsPath(self):
|
def testAbsPath(self):
|
||||||
save_dir = _TestDir("abs_paths")
|
save_dir = self._get_test_dir("abs_paths")
|
||||||
abs_path = os.path.join(save_dir, "model-0")
|
abs_path = os.path.join(save_dir, "model-0")
|
||||||
ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)
|
ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)
|
||||||
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
|
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
|
||||||
@ -1297,7 +1307,7 @@ class CheckpointStateTest(test.TestCase):
|
|||||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
|
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
|
||||||
|
|
||||||
def testAllModelCheckpointPaths(self):
|
def testAllModelCheckpointPaths(self):
|
||||||
save_dir = _TestDir("all_models_test")
|
save_dir = self._get_test_dir("all_models_test")
|
||||||
abs_path = os.path.join(save_dir, "model-0")
|
abs_path = os.path.join(save_dir, "model-0")
|
||||||
for paths in [None, [], ["model-2"]]:
|
for paths in [None, [], ["model-2"]]:
|
||||||
ckpt = saver_module.generate_checkpoint_state_proto(
|
ckpt = saver_module.generate_checkpoint_state_proto(
|
||||||
@ -1309,7 +1319,7 @@ class CheckpointStateTest(test.TestCase):
|
|||||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
|
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
|
||||||
|
|
||||||
def testUpdateCheckpointState(self):
|
def testUpdateCheckpointState(self):
|
||||||
save_dir = _TestDir("update_checkpoint_state")
|
save_dir = self._get_test_dir("update_checkpoint_state")
|
||||||
os.chdir(save_dir)
|
os.chdir(save_dir)
|
||||||
# Make a temporary train directory.
|
# Make a temporary train directory.
|
||||||
train_dir = "train"
|
train_dir = "train"
|
||||||
@ -1325,7 +1335,7 @@ class CheckpointStateTest(test.TestCase):
|
|||||||
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
|
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
|
||||||
|
|
||||||
def testCheckPointStateFailsWhenIncomplete(self):
|
def testCheckPointStateFailsWhenIncomplete(self):
|
||||||
save_dir = _TestDir("checkpoint_state_fails_when_incomplete")
|
save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
|
||||||
os.chdir(save_dir)
|
os.chdir(save_dir)
|
||||||
ckpt_path = os.path.join(save_dir, "checkpoint")
|
ckpt_path = os.path.join(save_dir, "checkpoint")
|
||||||
ckpt_file = open(ckpt_path, "w")
|
ckpt_file = open(ckpt_path, "w")
|
||||||
@ -1335,7 +1345,7 @@ class CheckpointStateTest(test.TestCase):
|
|||||||
saver_module.get_checkpoint_state(save_dir)
|
saver_module.get_checkpoint_state(save_dir)
|
||||||
|
|
||||||
def testCheckPointCompletesRelativePaths(self):
|
def testCheckPointCompletesRelativePaths(self):
|
||||||
save_dir = _TestDir("checkpoint_completes_relative_paths")
|
save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
|
||||||
os.chdir(save_dir)
|
os.chdir(save_dir)
|
||||||
ckpt_path = os.path.join(save_dir, "checkpoint")
|
ckpt_path = os.path.join(save_dir, "checkpoint")
|
||||||
ckpt_file = open(ckpt_path, "w")
|
ckpt_file = open(ckpt_path, "w")
|
||||||
@ -1356,8 +1366,13 @@ class CheckpointStateTest(test.TestCase):
|
|||||||
|
|
||||||
class MetaGraphTest(test.TestCase):
|
class MetaGraphTest(test.TestCase):
|
||||||
|
|
||||||
|
def _get_test_dir(self, dirname):
|
||||||
|
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||||
|
gfile.MakeDirs(test_dir)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
def testAddCollectionDef(self):
|
def testAddCollectionDef(self):
|
||||||
test_dir = _TestDir("good_collection")
|
test_dir = self._get_test_dir("good_collection")
|
||||||
filename = os.path.join(test_dir, "metafile")
|
filename = os.path.join(test_dir, "metafile")
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
# Creates a graph.
|
# Creates a graph.
|
||||||
@ -1504,12 +1519,12 @@ class MetaGraphTest(test.TestCase):
|
|||||||
self.assertEqual(11.0, v1.eval())
|
self.assertEqual(11.0, v1.eval())
|
||||||
|
|
||||||
def testMultiSaverCollection(self):
|
def testMultiSaverCollection(self):
|
||||||
test_dir = _TestDir("saver_collection")
|
test_dir = self._get_test_dir("saver_collection")
|
||||||
self._testMultiSaverCollectionSave(test_dir)
|
self._testMultiSaverCollectionSave(test_dir)
|
||||||
self._testMultiSaverCollectionRestore(test_dir)
|
self._testMultiSaverCollectionRestore(test_dir)
|
||||||
|
|
||||||
def testBinaryAndTextFormat(self):
|
def testBinaryAndTextFormat(self):
|
||||||
test_dir = _TestDir("binary_and_text")
|
test_dir = self._get_test_dir("binary_and_text")
|
||||||
filename = os.path.join(test_dir, "metafile")
|
filename = os.path.join(test_dir, "metafile")
|
||||||
with self.test_session(graph=ops_lib.Graph()):
|
with self.test_session(graph=ops_lib.Graph()):
|
||||||
# Creates a graph.
|
# Creates a graph.
|
||||||
@ -1541,7 +1556,7 @@ class MetaGraphTest(test.TestCase):
|
|||||||
saver_module.import_meta_graph(filename)
|
saver_module.import_meta_graph(filename)
|
||||||
|
|
||||||
def testSliceVariable(self):
|
def testSliceVariable(self):
|
||||||
test_dir = _TestDir("slice_saver")
|
test_dir = self._get_test_dir("slice_saver")
|
||||||
filename = os.path.join(test_dir, "metafile")
|
filename = os.path.join(test_dir, "metafile")
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
v1 = variables.Variable([20.0], name="v1")
|
v1 = variables.Variable([20.0], name="v1")
|
||||||
@ -1679,7 +1694,7 @@ class MetaGraphTest(test.TestCase):
|
|||||||
sess.run(train_op)
|
sess.run(train_op)
|
||||||
|
|
||||||
def testGraphExtension(self):
|
def testGraphExtension(self):
|
||||||
test_dir = _TestDir("graph_extension")
|
test_dir = self._get_test_dir("graph_extension")
|
||||||
self._testGraphExtensionSave(test_dir)
|
self._testGraphExtensionSave(test_dir)
|
||||||
self._testGraphExtensionRestore(test_dir)
|
self._testGraphExtensionRestore(test_dir)
|
||||||
self._testRestoreFromTrainGraphWithControlContext(test_dir)
|
self._testRestoreFromTrainGraphWithControlContext(test_dir)
|
||||||
@ -1722,7 +1737,7 @@ class MetaGraphTest(test.TestCase):
|
|||||||
|
|
||||||
def testImportIntoNamescope(self):
|
def testImportIntoNamescope(self):
|
||||||
# Test that we can import a meta graph into a namescope.
|
# Test that we can import a meta graph into a namescope.
|
||||||
test_dir = _TestDir("import_into_namescope")
|
test_dir = self._get_test_dir("import_into_namescope")
|
||||||
filename = os.path.join(test_dir, "ckpt")
|
filename = os.path.join(test_dir, "ckpt")
|
||||||
image = array_ops.placeholder(dtypes.float32, [None, 784])
|
image = array_ops.placeholder(dtypes.float32, [None, 784])
|
||||||
label = array_ops.placeholder(dtypes.float32, [None, 10])
|
label = array_ops.placeholder(dtypes.float32, [None, 10])
|
||||||
@ -1870,8 +1885,13 @@ class CheckpointReaderForV2Test(CheckpointReaderTest):
|
|||||||
|
|
||||||
class WriteGraphTest(test.TestCase):
|
class WriteGraphTest(test.TestCase):
|
||||||
|
|
||||||
|
def _get_test_dir(self, dirname):
|
||||||
|
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||||
|
gfile.MakeDirs(test_dir)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
def testWriteGraph(self):
|
def testWriteGraph(self):
|
||||||
test_dir = _TestDir("write_graph_dir")
|
test_dir = self._get_test_dir("write_graph_dir")
|
||||||
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
|
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
|
||||||
path = graph_io.write_graph(ops_lib.get_default_graph(),
|
path = graph_io.write_graph(ops_lib.get_default_graph(),
|
||||||
os.path.join(test_dir, "l1"), "graph.pbtxt")
|
os.path.join(test_dir, "l1"), "graph.pbtxt")
|
||||||
@ -1881,7 +1901,7 @@ class WriteGraphTest(test.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
def testRecursiveCreate(self):
|
def testRecursiveCreate(self):
|
||||||
test_dir = _TestDir("deep_dir")
|
test_dir = self._get_test_dir("deep_dir")
|
||||||
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
|
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
|
||||||
path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
|
path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
|
||||||
os.path.join(test_dir, "l1", "l2", "l3"),
|
os.path.join(test_dir, "l1", "l2", "l3"),
|
||||||
@ -1935,6 +1955,11 @@ class SaverUtilsTest(test.TestCase):
|
|||||||
|
|
||||||
class ScopedGraphTest(test.TestCase):
|
class ScopedGraphTest(test.TestCase):
|
||||||
|
|
||||||
|
def _get_test_dir(self, dirname):
|
||||||
|
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||||
|
gfile.MakeDirs(test_dir)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):
|
def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):
|
||||||
graph = ops_lib.Graph()
|
graph = ops_lib.Graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
@ -2067,7 +2092,7 @@ class ScopedGraphTest(test.TestCase):
|
|||||||
# Verifies that we can save the subgraph under "hidden1" and restore it
|
# Verifies that we can save the subgraph under "hidden1" and restore it
|
||||||
# into "new_hidden1" in the new graph.
|
# into "new_hidden1" in the new graph.
|
||||||
def testScopedSaveAndRestore(self):
|
def testScopedSaveAndRestore(self):
|
||||||
test_dir = _TestDir("scoped_export_import")
|
test_dir = self._get_test_dir("scoped_export_import")
|
||||||
ckpt_filename = "ckpt"
|
ckpt_filename = "ckpt"
|
||||||
self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)
|
self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)
|
||||||
self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",
|
self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",
|
||||||
@ -2076,7 +2101,7 @@ class ScopedGraphTest(test.TestCase):
|
|||||||
# Verifies that we can copy the subgraph under "hidden1" and copy it
|
# Verifies that we can copy the subgraph under "hidden1" and copy it
|
||||||
# to different name scope in the same graph or different graph.
|
# to different name scope in the same graph or different graph.
|
||||||
def testCopyScopedGraph(self):
|
def testCopyScopedGraph(self):
|
||||||
test_dir = _TestDir("scoped_copy")
|
test_dir = self._get_test_dir("scoped_copy")
|
||||||
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
|
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
|
||||||
graph1 = ops_lib.Graph()
|
graph1 = ops_lib.Graph()
|
||||||
with graph1.as_default():
|
with graph1.as_default():
|
||||||
@ -2132,7 +2157,7 @@ class ScopedGraphTest(test.TestCase):
|
|||||||
self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
|
self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
|
||||||
|
|
||||||
def testExportGraphDefWithScope(self):
|
def testExportGraphDefWithScope(self):
|
||||||
test_dir = _TestDir("export_graph_def")
|
test_dir = self._get_test_dir("export_graph_def")
|
||||||
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
|
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
|
||||||
graph1 = ops_lib.Graph()
|
graph1 = ops_lib.Graph()
|
||||||
with graph1.as_default():
|
with graph1.as_default():
|
||||||
|
@ -64,15 +64,14 @@ def _summary_iterator(test_dir):
|
|||||||
return summary_iterator.summary_iterator(event_paths[-1])
|
return summary_iterator.summary_iterator(event_paths[-1])
|
||||||
|
|
||||||
|
|
||||||
def _test_dir(test_name):
|
|
||||||
test_dir = os.path.join(test.get_temp_dir(), test_name)
|
|
||||||
if os.path.exists(test_dir):
|
|
||||||
shutil.rmtree(test_dir)
|
|
||||||
return test_dir
|
|
||||||
|
|
||||||
|
|
||||||
class SupervisorTest(test.TestCase):
|
class SupervisorTest(test.TestCase):
|
||||||
|
|
||||||
|
def _test_dir(self, test_name):
|
||||||
|
test_dir = os.path.join(self.get_temp_dir(), test_name)
|
||||||
|
if os.path.exists(test_dir):
|
||||||
|
shutil.rmtree(test_dir)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
|
def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
|
||||||
"""Wait for a checkpoint file to appear.
|
"""Wait for a checkpoint file to appear.
|
||||||
|
|
||||||
@ -94,7 +93,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
# This test does not test much.
|
# This test does not test much.
|
||||||
def testBasics(self):
|
def testBasics(self):
|
||||||
logdir = _test_dir("basics")
|
logdir = self._test_dir("basics")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
my_op = constant_op.constant(1.0)
|
my_op = constant_op.constant(1.0)
|
||||||
sv = supervisor.Supervisor(logdir=logdir)
|
sv = supervisor.Supervisor(logdir=logdir)
|
||||||
@ -105,7 +104,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
def testManagedSession(self):
|
def testManagedSession(self):
|
||||||
logdir = _test_dir("managed_session")
|
logdir = self._test_dir("managed_session")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
my_op = constant_op.constant(1.0)
|
my_op = constant_op.constant(1.0)
|
||||||
sv = supervisor.Supervisor(logdir=logdir)
|
sv = supervisor.Supervisor(logdir=logdir)
|
||||||
@ -116,7 +115,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
self.assertTrue(sv.should_stop())
|
self.assertTrue(sv.should_stop())
|
||||||
|
|
||||||
def testManagedSessionUserError(self):
|
def testManagedSessionUserError(self):
|
||||||
logdir = _test_dir("managed_user_error")
|
logdir = self._test_dir("managed_user_error")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
my_op = constant_op.constant(1.0)
|
my_op = constant_op.constant(1.0)
|
||||||
sv = supervisor.Supervisor(logdir=logdir)
|
sv = supervisor.Supervisor(logdir=logdir)
|
||||||
@ -134,7 +133,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
self.assertEqual(1, last_step)
|
self.assertEqual(1, last_step)
|
||||||
|
|
||||||
def testManagedSessionIgnoreOutOfRangeError(self):
|
def testManagedSessionIgnoreOutOfRangeError(self):
|
||||||
logdir = _test_dir("managed_out_of_range")
|
logdir = self._test_dir("managed_out_of_range")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
my_op = constant_op.constant(1.0)
|
my_op = constant_op.constant(1.0)
|
||||||
sv = supervisor.Supervisor(logdir=logdir)
|
sv = supervisor.Supervisor(logdir=logdir)
|
||||||
@ -152,7 +151,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
self.assertEqual(3, last_step)
|
self.assertEqual(3, last_step)
|
||||||
|
|
||||||
def testManagedSessionDoNotKeepSummaryWriter(self):
|
def testManagedSessionDoNotKeepSummaryWriter(self):
|
||||||
logdir = _test_dir("managed_not_keep_summary_writer")
|
logdir = self._test_dir("managed_not_keep_summary_writer")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
summary.scalar("c1", constant_op.constant(1))
|
summary.scalar("c1", constant_op.constant(1))
|
||||||
summary.scalar("c2", constant_op.constant(2))
|
summary.scalar("c2", constant_op.constant(2))
|
||||||
@ -204,7 +203,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
next(rr)
|
next(rr)
|
||||||
|
|
||||||
def testManagedSessionKeepSummaryWriter(self):
|
def testManagedSessionKeepSummaryWriter(self):
|
||||||
logdir = _test_dir("managed_keep_summary_writer")
|
logdir = self._test_dir("managed_keep_summary_writer")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
summary.scalar("c1", constant_op.constant(1))
|
summary.scalar("c1", constant_op.constant(1))
|
||||||
summary.scalar("c2", constant_op.constant(2))
|
summary.scalar("c2", constant_op.constant(2))
|
||||||
@ -266,7 +265,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
def testManagedEndOfInputOneQueue(self):
|
def testManagedEndOfInputOneQueue(self):
|
||||||
# Tests that the supervisor finishes without an error when using
|
# Tests that the supervisor finishes without an error when using
|
||||||
# a fixed number of epochs, reading from a single queue.
|
# a fixed number of epochs, reading from a single queue.
|
||||||
logdir = _test_dir("managed_end_of_input_one_queue")
|
logdir = self._test_dir("managed_end_of_input_one_queue")
|
||||||
os.makedirs(logdir)
|
os.makedirs(logdir)
|
||||||
data_path = self._csv_data(logdir)
|
data_path = self._csv_data(logdir)
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -285,7 +284,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
# Tests that the supervisor finishes without an error when using
|
# Tests that the supervisor finishes without an error when using
|
||||||
# a fixed number of epochs, reading from two queues, the second
|
# a fixed number of epochs, reading from two queues, the second
|
||||||
# one producing a batch from the first one.
|
# one producing a batch from the first one.
|
||||||
logdir = _test_dir("managed_end_of_input_two_queues")
|
logdir = self._test_dir("managed_end_of_input_two_queues")
|
||||||
os.makedirs(logdir)
|
os.makedirs(logdir)
|
||||||
data_path = self._csv_data(logdir)
|
data_path = self._csv_data(logdir)
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -304,7 +303,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
def testManagedMainErrorTwoQueues(self):
|
def testManagedMainErrorTwoQueues(self):
|
||||||
# Tests that the supervisor correctly raises a main loop
|
# Tests that the supervisor correctly raises a main loop
|
||||||
# error even when using multiple queues for input.
|
# error even when using multiple queues for input.
|
||||||
logdir = _test_dir("managed_main_error_two_queues")
|
logdir = self._test_dir("managed_main_error_two_queues")
|
||||||
os.makedirs(logdir)
|
os.makedirs(logdir)
|
||||||
data_path = self._csv_data(logdir)
|
data_path = self._csv_data(logdir)
|
||||||
with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
|
with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
|
||||||
@ -327,7 +326,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sess.run(shuff_rec)
|
sess.run(shuff_rec)
|
||||||
|
|
||||||
def testSessionConfig(self):
|
def testSessionConfig(self):
|
||||||
logdir = _test_dir("session_config")
|
logdir = self._test_dir("session_config")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
with ops.device("/cpu:1"):
|
with ops.device("/cpu:1"):
|
||||||
my_op = constant_op.constant([1.0])
|
my_op = constant_op.constant([1.0])
|
||||||
@ -340,7 +339,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
def testChiefCanWriteEvents(self):
|
def testChiefCanWriteEvents(self):
|
||||||
logdir = _test_dir("can_write")
|
logdir = self._test_dir("can_write")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
summary.scalar("c1", constant_op.constant(1))
|
summary.scalar("c1", constant_op.constant(1))
|
||||||
summary.scalar("c2", constant_op.constant(2))
|
summary.scalar("c2", constant_op.constant(2))
|
||||||
@ -421,7 +420,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv.summary_computed(sess, sess.run(summ))
|
sv.summary_computed(sess, sess.run(summ))
|
||||||
|
|
||||||
def testLogdirButExplicitlyNoSummaryWriter(self):
|
def testLogdirButExplicitlyNoSummaryWriter(self):
|
||||||
logdir = _test_dir("explicit_no_summary_writer")
|
logdir = self._test_dir("explicit_no_summary_writer")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
variables.Variable([1.0], name="foo")
|
variables.Variable([1.0], name="foo")
|
||||||
summary.scalar("c1", constant_op.constant(1))
|
summary.scalar("c1", constant_op.constant(1))
|
||||||
@ -437,7 +436,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv.summary_computed(sess, sess.run(summ))
|
sv.summary_computed(sess, sess.run(summ))
|
||||||
|
|
||||||
def testNoLogdirButExplicitSummaryWriter(self):
|
def testNoLogdirButExplicitSummaryWriter(self):
|
||||||
logdir = _test_dir("explicit_summary_writer")
|
logdir = self._test_dir("explicit_summary_writer")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
summary.scalar("c1", constant_op.constant(1))
|
summary.scalar("c1", constant_op.constant(1))
|
||||||
summary.scalar("c2", constant_op.constant(2))
|
summary.scalar("c2", constant_op.constant(2))
|
||||||
@ -506,7 +505,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv.prepare_or_wait_for_session("")
|
sv.prepare_or_wait_for_session("")
|
||||||
|
|
||||||
def testInitOp(self):
|
def testInitOp(self):
|
||||||
logdir = _test_dir("default_init_op")
|
logdir = self._test_dir("default_init_op")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
v = variables.Variable([1.0, 2.0, 3.0])
|
v = variables.Variable([1.0, 2.0, 3.0])
|
||||||
sv = supervisor.Supervisor(logdir=logdir)
|
sv = supervisor.Supervisor(logdir=logdir)
|
||||||
@ -515,7 +514,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
def testInitFn(self):
|
def testInitFn(self):
|
||||||
logdir = _test_dir("default_init_op")
|
logdir = self._test_dir("default_init_op")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
v = variables.Variable([1.0, 2.0, 3.0])
|
v = variables.Variable([1.0, 2.0, 3.0])
|
||||||
|
|
||||||
@ -528,7 +527,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
def testInitOpWithFeedDict(self):
|
def testInitOpWithFeedDict(self):
|
||||||
logdir = _test_dir("feed_dict_init_op")
|
logdir = self._test_dir("feed_dict_init_op")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
p = array_ops.placeholder(dtypes.float32, shape=(3,))
|
p = array_ops.placeholder(dtypes.float32, shape=(3,))
|
||||||
v = variables.Variable(p, name="v")
|
v = variables.Variable(p, name="v")
|
||||||
@ -542,7 +541,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
def testReadyForLocalInitOp(self):
|
def testReadyForLocalInitOp(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
logdir = _test_dir("default_ready_for_local_init_op")
|
logdir = self._test_dir("default_ready_for_local_init_op")
|
||||||
|
|
||||||
uid = uuid.uuid4().hex
|
uid = uuid.uuid4().hex
|
||||||
|
|
||||||
@ -584,7 +583,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
def testReadyForLocalInitOpRestoreFromCheckpoint(self):
|
def testReadyForLocalInitOpRestoreFromCheckpoint(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
logdir = _test_dir("ready_for_local_init_op_restore")
|
logdir = self._test_dir("ready_for_local_init_op_restore")
|
||||||
|
|
||||||
uid = uuid.uuid4().hex
|
uid = uuid.uuid4().hex
|
||||||
|
|
||||||
@ -639,7 +638,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv1.stop()
|
sv1.stop()
|
||||||
|
|
||||||
def testLocalInitOp(self):
|
def testLocalInitOp(self):
|
||||||
logdir = _test_dir("default_local_init_op")
|
logdir = self._test_dir("default_local_init_op")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
# A local variable.
|
# A local variable.
|
||||||
v = variables.Variable(
|
v = variables.Variable(
|
||||||
@ -664,7 +663,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
def testLocalInitOpForNonChief(self):
|
def testLocalInitOpForNonChief(self):
|
||||||
logdir = _test_dir("default_local_init_op_non_chief")
|
logdir = self._test_dir("default_local_init_op_non_chief")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
with ops.device("/job:localhost"):
|
with ops.device("/job:localhost"):
|
||||||
# A local variable.
|
# A local variable.
|
||||||
@ -685,7 +684,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
def testInitOpFails(self):
|
def testInitOpFails(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
logdir = _test_dir("default_init_op_fails")
|
logdir = self._test_dir("default_init_op_fails")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
v = variables.Variable([1.0, 2.0, 3.0], name="v")
|
v = variables.Variable([1.0, 2.0, 3.0], name="v")
|
||||||
variables.Variable([4.0, 5.0, 6.0], name="w")
|
variables.Variable([4.0, 5.0, 6.0], name="w")
|
||||||
@ -697,7 +696,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
def testInitOpFailsForTransientVariable(self):
|
def testInitOpFailsForTransientVariable(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
logdir = _test_dir("default_init_op_fails_for_local_variable")
|
logdir = self._test_dir("default_init_op_fails_for_local_variable")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
v = variables.Variable(
|
v = variables.Variable(
|
||||||
[1.0, 2.0, 3.0],
|
[1.0, 2.0, 3.0],
|
||||||
@ -714,7 +713,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv.prepare_or_wait_for_session(server.target)
|
sv.prepare_or_wait_for_session(server.target)
|
||||||
|
|
||||||
def testSetupFail(self):
|
def testSetupFail(self):
|
||||||
logdir = _test_dir("setup_fail")
|
logdir = self._test_dir("setup_fail")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
variables.Variable([1.0, 2.0, 3.0], name="v")
|
variables.Variable([1.0, 2.0, 3.0], name="v")
|
||||||
with self.assertRaisesRegexp(ValueError, "must have their device set"):
|
with self.assertRaisesRegexp(ValueError, "must have their device set"):
|
||||||
@ -724,7 +723,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
supervisor.Supervisor(logdir=logdir, is_chief=False)
|
supervisor.Supervisor(logdir=logdir, is_chief=False)
|
||||||
|
|
||||||
def testDefaultGlobalStep(self):
|
def testDefaultGlobalStep(self):
|
||||||
logdir = _test_dir("default_global_step")
|
logdir = self._test_dir("default_global_step")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
variables.Variable(287, name="global_step")
|
variables.Variable(287, name="global_step")
|
||||||
sv = supervisor.Supervisor(logdir=logdir)
|
sv = supervisor.Supervisor(logdir=logdir)
|
||||||
@ -733,7 +732,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
def testRestoreFromMetaGraph(self):
|
def testRestoreFromMetaGraph(self):
|
||||||
logdir = _test_dir("restore_from_meta_graph")
|
logdir = self._test_dir("restore_from_meta_graph")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
variables.Variable(1, name="v0")
|
variables.Variable(1, name="v0")
|
||||||
sv = supervisor.Supervisor(logdir=logdir)
|
sv = supervisor.Supervisor(logdir=logdir)
|
||||||
@ -754,7 +753,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
# right away and get to run once before sv.stop() returns.
|
# right away and get to run once before sv.stop() returns.
|
||||||
# We still sleep a bit to make the test robust.
|
# We still sleep a bit to make the test robust.
|
||||||
def testStandardServicesWithoutGlobalStep(self):
|
def testStandardServicesWithoutGlobalStep(self):
|
||||||
logdir = _test_dir("standard_services_without_global_step")
|
logdir = self._test_dir("standard_services_without_global_step")
|
||||||
# Create a checkpoint.
|
# Create a checkpoint.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
v = variables.Variable([1.0], name="foo")
|
v = variables.Variable([1.0], name="foo")
|
||||||
@ -804,7 +803,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
# Same as testStandardServicesNoGlobalStep but with a global step.
|
# Same as testStandardServicesNoGlobalStep but with a global step.
|
||||||
# We should get a summary about the step time.
|
# We should get a summary about the step time.
|
||||||
def testStandardServicesWithGlobalStep(self):
|
def testStandardServicesWithGlobalStep(self):
|
||||||
logdir = _test_dir("standard_services_with_global_step")
|
logdir = self._test_dir("standard_services_with_global_step")
|
||||||
# Create a checkpoint.
|
# Create a checkpoint.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
v = variables.Variable([123], name="global_step")
|
v = variables.Variable([123], name="global_step")
|
||||||
@ -867,12 +866,12 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
def testNoQueueRunners(self):
|
def testNoQueueRunners(self):
|
||||||
with ops.Graph().as_default(), self.test_session() as sess:
|
with ops.Graph().as_default(), self.test_session() as sess:
|
||||||
sv = supervisor.Supervisor(logdir=_test_dir("no_queue_runners"))
|
sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners"))
|
||||||
self.assertEqual(0, len(sv.start_queue_runners(sess)))
|
self.assertEqual(0, len(sv.start_queue_runners(sess)))
|
||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
def testPrepareSessionAfterStopForChief(self):
|
def testPrepareSessionAfterStopForChief(self):
|
||||||
logdir = _test_dir("prepare_after_stop_chief")
|
logdir = self._test_dir("prepare_after_stop_chief")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
sv = supervisor.Supervisor(logdir=logdir, is_chief=True)
|
sv = supervisor.Supervisor(logdir=logdir, is_chief=True)
|
||||||
|
|
||||||
@ -891,7 +890,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
self.assertTrue(sv.should_stop())
|
self.assertTrue(sv.should_stop())
|
||||||
|
|
||||||
def testPrepareSessionAfterStopForNonChief(self):
|
def testPrepareSessionAfterStopForNonChief(self):
|
||||||
logdir = _test_dir("prepare_after_stop_nonchief")
|
logdir = self._test_dir("prepare_after_stop_nonchief")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
sv = supervisor.Supervisor(logdir=logdir, is_chief=False)
|
sv = supervisor.Supervisor(logdir=logdir, is_chief=False)
|
||||||
|
|
||||||
|
@ -51,18 +51,18 @@ work, but there may be bugs or performance issues.
|
|||||||
|
|
||||||
The first step in using TensorBoard is acquiring data from your TensorFlow run.
|
The first step in using TensorBoard is acquiring data from your TensorFlow run.
|
||||||
For this, you need [summary
|
For this, you need [summary
|
||||||
ops](https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#summary-operations).
|
ops](https://www.tensorflow.org/versions/r1.0/api_docs/python/train.html#summary-operations).
|
||||||
Summary ops are ops, like
|
Summary ops are ops, like
|
||||||
[`tf.matmul`](https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops.html#matmul)
|
[`tf.matmul`](https://www.tensorflow.org/versions/r1.0/api_docs/python/math_ops.html#matmul)
|
||||||
or
|
or
|
||||||
[`tf.nn.relu`](https://www.tensorflow.org/versions/r0.12/api_docs/python/nn.html#relu),
|
[`tf.nn.relu`](https://www.tensorflow.org/versions/r1.0/api_docs/python/nn.html#relu),
|
||||||
which means they take in tensors, produce tensors, and are evaluated from within
|
which means they take in tensors, produce tensors, and are evaluated from within
|
||||||
a TensorFlow graph. However, summary ops have a twist: the Tensors they produce
|
a TensorFlow graph. However, summary ops have a twist: the Tensors they produce
|
||||||
contain serialized protobufs, which are written to disk and sent to TensorBoard.
|
contain serialized protobufs, which are written to disk and sent to TensorBoard.
|
||||||
To visualize the summary data in TensorBoard, you should evaluate the summary
|
To visualize the summary data in TensorBoard, you should evaluate the summary
|
||||||
op, retrieve the result, and then write that result to disk using a
|
op, retrieve the result, and then write that result to disk using a
|
||||||
summary.FileWriter. A full explanation, with examples, is in [the
|
summary.FileWriter. A full explanation, with examples, is in [the
|
||||||
tutorial](https://www.tensorflow.org/versions/r0.12/how_tos/summaries_and_tensorboard/index.html).
|
tutorial](https://www.tensorflow.org/versions/r1.0/how_tos/summaries_and_tensorboard/index.html).
|
||||||
|
|
||||||
### Tags: Giving names to data
|
### Tags: Giving names to data
|
||||||
|
|
||||||
@ -184,7 +184,7 @@ TensorFlow model. To get best use of the graph visualizer, you should use name
|
|||||||
scopes to hierarchically group the ops in your graph - otherwise, the graph may
|
scopes to hierarchically group the ops in your graph - otherwise, the graph may
|
||||||
be difficult to decipher. For more information, including examples, see [the
|
be difficult to decipher. For more information, including examples, see [the
|
||||||
graph visualizer
|
graph visualizer
|
||||||
tutorial](https://www.tensorflow.org/versions/r0.12/how_tos/graph_viz/index.html#tensorboard-graph-visualization).
|
tutorial](https://www.tensorflow.org/versions/r1.0/how_tos/graph_viz/index.html#tensorboard-graph-visualization).
|
||||||
|
|
||||||
# Frequently Asked Questions
|
# Frequently Asked Questions
|
||||||
|
|
||||||
|
@ -20,6 +20,11 @@ load(
|
|||||||
"cuda_default_copts"
|
"cuda_default_copts"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#load(
|
||||||
|
# "//third_party/mkl:build_defs.bzl",
|
||||||
|
# "if_mkl",
|
||||||
|
#)
|
||||||
|
|
||||||
# List of proto files for android builds
|
# List of proto files for android builds
|
||||||
def tf_android_core_proto_sources(core_proto_sources_relative):
|
def tf_android_core_proto_sources(core_proto_sources_relative):
|
||||||
return ["//tensorflow/core:" + p
|
return ["//tensorflow/core:" + p
|
||||||
@ -377,6 +382,10 @@ def tf_cc_tests(srcs, deps, name='', linkstatic=0, tags=[], size="medium",
|
|||||||
args=args,
|
args=args,
|
||||||
linkopts=linkopts)
|
linkopts=linkopts)
|
||||||
|
|
||||||
|
#def tf_cc_test_mkl(srcs, deps, name='', linkstatic=0, tags=[], size="medium",
|
||||||
|
# args=None):
|
||||||
|
# tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args)
|
||||||
|
|
||||||
def tf_cc_tests_gpu(srcs, deps, name='', linkstatic=0, tags=[], size="medium",
|
def tf_cc_tests_gpu(srcs, deps, name='', linkstatic=0, tags=[], size="medium",
|
||||||
args=None):
|
args=None):
|
||||||
tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args)
|
tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args)
|
||||||
|
@ -39,7 +39,7 @@ $adb shell "/data/local/tmp/benchmark_model \
|
|||||||
### On desktop:
|
### On desktop:
|
||||||
(1) build the binary
|
(1) build the binary
|
||||||
```bash
|
```bash
|
||||||
$bazel build -c opt tensorflow/tools/benchmark:benchmark_model
|
$bazel build --config opt tensorflow/tools/benchmark:benchmark_model
|
||||||
```
|
```
|
||||||
|
|
||||||
(2) Run on your compute graph, similar to the Android case but without the need of adb shell.
|
(2) Run on your compute graph, similar to the Android case but without the need of adb shell.
|
||||||
@ -54,4 +54,4 @@ $bazel-bin/tensorflow/tools/benchmark/benchmark_model \
|
|||||||
```
|
```
|
||||||
|
|
||||||
The Inception graph used as an example here may be downloaded from
|
The Inception graph used as an example here may be downloaded from
|
||||||
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
|
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
# Install protobuf3.
|
# Install protobuf3.
|
||||||
|
|
||||||
# Select protobuf version.
|
# Select protobuf version.
|
||||||
PROTOBUF_VERSION="3.1.0"
|
PROTOBUF_VERSION="3.2.0"
|
||||||
protobuf_ver_flat=$(echo $PROTOBUF_VERSION | sed 's/\.//g' | sed 's/^0*//g')
|
protobuf_ver_flat=$(echo $PROTOBUF_VERSION | sed 's/\.//g' | sed 's/^0*//g')
|
||||||
local_protobuf_ver=$(protoc --version | awk '{print $2}')
|
local_protobuf_ver=$(protoc --version | awk '{print $2}')
|
||||||
local_protobuf_ver_flat=$(echo $local_protobuf_ver | sed 's/\.//g' | sed 's/^0*//g')
|
local_protobuf_ver_flat=$(echo $local_protobuf_ver | sed 's/\.//g' | sed 's/^0*//g')
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
PROTOBUF_VERSION="3.1.0"
|
PROTOBUF_VERSION="3.2.0"
|
||||||
PYTHON_BIN=${PYTHON_BIN:-python}
|
PYTHON_BIN=${PYTHON_BIN:-python}
|
||||||
DIR=${PWD}/protobuf
|
DIR=${PWD}/protobuf
|
||||||
|
|
||||||
|
@ -111,6 +111,7 @@ function get_failing_cpu_py_tests() {
|
|||||||
//$1/tensorflow/python:framework_ops_test + \
|
//$1/tensorflow/python:framework_ops_test + \
|
||||||
//$1/tensorflow/python:framework_tensor_util_test + \
|
//$1/tensorflow/python:framework_tensor_util_test + \
|
||||||
//$1/tensorflow/python:framework_test_util_test + \
|
//$1/tensorflow/python:framework_test_util_test + \
|
||||||
|
//$1/tensorflow/python:gradients_test + \
|
||||||
//$1/tensorflow/python:image_ops_test + \
|
//$1/tensorflow/python:image_ops_test + \
|
||||||
//$1/tensorflow/python:localhost_cluster_performance_test + \
|
//$1/tensorflow/python:localhost_cluster_performance_test + \
|
||||||
//$1/tensorflow/python:monitored_session_test + \
|
//$1/tensorflow/python:monitored_session_test + \
|
||||||
|
@ -36,6 +36,9 @@ particular, functions that have had reordered arguments like `tf.concat`,
|
|||||||
`tf.split` will cause the script to incorrectly add keyword arguments that
|
`tf.split` will cause the script to incorrectly add keyword arguments that
|
||||||
mismap arguments.
|
mismap arguments.
|
||||||
|
|
||||||
|
- This script wouldn't actually reorder arguments. Instead, the script will add
|
||||||
|
keyword arguments to functions that had their arguments reordered.
|
||||||
|
|
||||||
- This script is not able to upgrade all functions. One notable example is
|
- This script is not able to upgrade all functions. One notable example is
|
||||||
`tf.reverse()` which has been changed to take a list of indices rather than
|
`tf.reverse()` which has been changed to take a list of indices rather than
|
||||||
a tensor of bools. If the script detects this, it will report this to stdout
|
a tensor of bools. If the script detects this, it will report this to stdout
|
||||||
@ -43,6 +46,12 @@ a tensor of bools. If the script detects this, it will report this to stdout
|
|||||||
`tf.reverse(a, [False, True, True])` you will need to manually change it to
|
`tf.reverse(a, [False, True, True])` you will need to manually change it to
|
||||||
`tf.reverse(a, [1, 2])`.
|
`tf.reverse(a, [1, 2])`.
|
||||||
|
|
||||||
|
- There are some syntaxes that are not handleable with this script as this
|
||||||
|
script was designed to use only standard python packages. If the script fails
|
||||||
|
with "A necessary keyword argument failed to be inserted." or
|
||||||
|
"Failed to find keyword lexicographically. Fix manually.", you can try
|
||||||
|
[@machrisaa's fork of this script](https://github.com/machrisaa/tf0to1).
|
||||||
|
[@machrisaa](https://github.com/machrisaa) has used the
|
||||||
|
[RedBaron Python refactoring engine](https://redbaron.readthedocs.io/en/latest/)
|
||||||
|
which is able to localize syntactic elements more reliably than the built-in
|
||||||
|
`ast` module this script is based upon.
|
||||||
|
@ -95,11 +95,15 @@ class APIChangeSpec(object):
|
|||||||
"tf.split": {
|
"tf.split": {
|
||||||
"split_dim": "axis",
|
"split_dim": "axis",
|
||||||
"num_split": "num_or_size_splits"
|
"num_split": "num_or_size_splits"
|
||||||
}
|
},
|
||||||
|
"tf.concat": {
|
||||||
|
"concat_dim": "axis"
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mapping from function to the new name of the function
|
# Mapping from function to the new name of the function
|
||||||
self.function_renames = {
|
self.function_renames = {
|
||||||
|
"tf.inv": "tf.reciprocal",
|
||||||
"tf.contrib.deprecated.scalar_summary": "tf.summary.scalar",
|
"tf.contrib.deprecated.scalar_summary": "tf.summary.scalar",
|
||||||
"tf.contrib.deprecated.histogram_summary": "tf.summary.histogram",
|
"tf.contrib.deprecated.histogram_summary": "tf.summary.histogram",
|
||||||
"tf.listdiff": "tf.setdiff1d",
|
"tf.listdiff": "tf.setdiff1d",
|
||||||
@ -142,6 +146,13 @@ class APIChangeSpec(object):
|
|||||||
"tf.select": "tf.where",
|
"tf.select": "tf.where",
|
||||||
"tf.complex_abs": "tf.abs",
|
"tf.complex_abs": "tf.abs",
|
||||||
"tf.batch_matmul": "tf.matmul",
|
"tf.batch_matmul": "tf.matmul",
|
||||||
|
"tf.pack": "tf.stack",
|
||||||
|
"tf.unpack": "tf.unstack",
|
||||||
|
}
|
||||||
|
|
||||||
|
self.change_to_function = {
|
||||||
|
"tf.ones_initializer",
|
||||||
|
"tf.zeros_initializer",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Functions that were reordered should be changed to the new keyword args
|
# Functions that were reordered should be changed to the new keyword args
|
||||||
@ -149,6 +160,7 @@ class APIChangeSpec(object):
|
|||||||
# positional arguments yourself, this could do the wrong thing.
|
# positional arguments yourself, this could do the wrong thing.
|
||||||
self.function_reorders = {
|
self.function_reorders = {
|
||||||
"tf.split": ["axis", "num_or_size_splits", "value", "name"],
|
"tf.split": ["axis", "num_or_size_splits", "value", "name"],
|
||||||
|
"tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"],
|
||||||
"tf.concat": ["concat_dim", "values", "name"],
|
"tf.concat": ["concat_dim", "values", "name"],
|
||||||
"tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
|
"tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
|
||||||
"tf.nn.softmax_cross_entropy_with_logits": [
|
"tf.nn.softmax_cross_entropy_with_logits": [
|
||||||
@ -335,6 +347,62 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
|||||||
items.append(curr.id)
|
items.append(curr.id)
|
||||||
return ".".join(reversed(items))
|
return ".".join(reversed(items))
|
||||||
|
|
||||||
|
def _find_true_position(self, node):
|
||||||
|
"""Return correct line number and column offset for a given node.
|
||||||
|
|
||||||
|
This is necessary mainly because ListComp's location reporting reports
|
||||||
|
the next token after the list comprehension list opening.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Node for which we wish to know the lineno and col_offset
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
find_open = re.compile("^\s*(\\[).*$")
|
||||||
|
find_string_chars = re.compile("['\"]")
|
||||||
|
|
||||||
|
if isinstance(node, ast.ListComp):
|
||||||
|
# Strangely, ast.ListComp returns the col_offset of the first token
|
||||||
|
# after the '[' token which appears to be a bug. Workaround by
|
||||||
|
# explicitly finding the real start of the list comprehension.
|
||||||
|
line = node.lineno
|
||||||
|
col = node.col_offset
|
||||||
|
# loop over lines
|
||||||
|
while 1:
|
||||||
|
# Reverse the text to and regular expression search for whitespace
|
||||||
|
text = self._lines[line-1]
|
||||||
|
reversed_preceding_text = text[:col][::-1]
|
||||||
|
# First find if a [ can be found with only whitespace between it and
|
||||||
|
# col.
|
||||||
|
m = find_open.match(reversed_preceding_text)
|
||||||
|
if m:
|
||||||
|
new_col_offset = col - m.start(1) - 1
|
||||||
|
return line, new_col_offset
|
||||||
|
else:
|
||||||
|
if (reversed_preceding_text=="" or
|
||||||
|
reversed_preceding_text.isspace()):
|
||||||
|
line = line - 1
|
||||||
|
prev_line = self._lines[line - 1]
|
||||||
|
# TODO(aselle):
|
||||||
|
# this is poor comment detection, but it is good enough for
|
||||||
|
# cases where the comment does not contain string literal starting/
|
||||||
|
# ending characters. If ast gave us start and end locations of the
|
||||||
|
# ast nodes rather than just start, we could use string literal
|
||||||
|
# node ranges to filter out spurious #'s that appear in string
|
||||||
|
# literals.
|
||||||
|
comment_start = prev_line.find("#")
|
||||||
|
if comment_start == -1:
|
||||||
|
col = len(prev_line) -1
|
||||||
|
elif find_string_chars.search(prev_line[comment_start:]) is None:
|
||||||
|
col = comment_start
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
# Most other nodes return proper locations (with notably does not), but
|
||||||
|
# it is not possible to use that in an argument.
|
||||||
|
return node.lineno, node.col_offset
|
||||||
|
|
||||||
|
|
||||||
def visit_Call(self, node): # pylint: disable=invalid-name
|
def visit_Call(self, node): # pylint: disable=invalid-name
|
||||||
"""Handle visiting a call node in the AST.
|
"""Handle visiting a call node in the AST.
|
||||||
|
|
||||||
@ -342,11 +410,13 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
|||||||
node: Current Node
|
node: Current Node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ast.NodeVisitor.generic_visit(self, node)
|
|
||||||
|
|
||||||
# Find a simple attribute name path e.g. "tf.foo.bar"
|
# Find a simple attribute name path e.g. "tf.foo.bar"
|
||||||
full_name = self._get_attribute_full_path(node.func)
|
full_name = self._get_attribute_full_path(node.func)
|
||||||
|
|
||||||
|
# Make sure the func is marked as being part of a call
|
||||||
|
node.func.is_function_for_call = True
|
||||||
|
|
||||||
if full_name and full_name.startswith("tf."):
|
if full_name and full_name.startswith("tf."):
|
||||||
# Call special handlers
|
# Call special handlers
|
||||||
function_handles = self._api_change_spec.function_handle
|
function_handles = self._api_change_spec.function_handle
|
||||||
@ -356,27 +426,60 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
|||||||
# Examine any non-keyword argument and make it into a keyword argument
|
# Examine any non-keyword argument and make it into a keyword argument
|
||||||
# if reordering required.
|
# if reordering required.
|
||||||
function_reorders = self._api_change_spec.function_reorders
|
function_reorders = self._api_change_spec.function_reorders
|
||||||
|
function_keyword_renames = (
|
||||||
|
self._api_change_spec.function_keyword_renames)
|
||||||
|
|
||||||
if full_name in function_reorders:
|
if full_name in function_reorders:
|
||||||
reordered = function_reorders[full_name]
|
reordered = function_reorders[full_name]
|
||||||
for idx, arg in enumerate(node.args):
|
for idx, arg in enumerate(node.args):
|
||||||
self._file_edit.add("Added keyword %r to reordered function %r"
|
lineno, col_offset = self._find_true_position(arg)
|
||||||
% (reordered[idx], full_name), arg.lineno,
|
if lineno is None or col_offset is None:
|
||||||
arg.col_offset, "", reordered[idx] + "=")
|
self._file_edit.add(
|
||||||
|
"Failed to add keyword %r to reordered function %r"
|
||||||
|
% (reordered[idx], full_name), arg.lineno, arg.col_offset,
|
||||||
|
"", "",
|
||||||
|
error="A necessary keyword argument failed to be inserted.")
|
||||||
|
else:
|
||||||
|
keyword_arg = reordered[idx]
|
||||||
|
if (full_name in function_keyword_renames and
|
||||||
|
keyword_arg in function_keyword_renames[full_name]):
|
||||||
|
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
||||||
|
self._file_edit.add("Added keyword %r to reordered function %r"
|
||||||
|
% (reordered[idx], full_name), lineno,
|
||||||
|
col_offset, "", keyword_arg + "=")
|
||||||
|
|
||||||
# Examine each keyword argument and convert it to the final renamed form
|
# Examine each keyword argument and convert it to the final renamed form
|
||||||
function_keyword_renames = (
|
|
||||||
self._api_change_spec.function_keyword_renames)
|
|
||||||
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
||||||
function_keyword_renames[full_name])
|
function_keyword_renames[full_name])
|
||||||
for keyword in node.keywords:
|
for keyword in node.keywords:
|
||||||
argkey = keyword.arg
|
argkey = keyword.arg
|
||||||
argval = keyword.value
|
argval = keyword.value
|
||||||
|
|
||||||
if argkey in renamed_keywords:
|
if argkey in renamed_keywords:
|
||||||
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
argval_lineno, argval_col_offset = self._find_true_position(argval)
|
||||||
|
if (argval_lineno is not None and argval_col_offset is not None):
|
||||||
|
# TODO(aselle): We should scan backward to find the start of the
|
||||||
|
# keyword key. Unfortunately ast does not give you the location of
|
||||||
|
# keyword keys, so we are forced to infer it from the keyword arg
|
||||||
|
# value.
|
||||||
|
key_start = argval_col_offset - len(argkey) - 1
|
||||||
|
key_end = key_start + len(argkey) + 1
|
||||||
|
if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=":
|
||||||
|
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
||||||
(argkey, renamed_keywords[argkey]),
|
(argkey, renamed_keywords[argkey]),
|
||||||
argval.lineno,
|
argval_lineno,
|
||||||
argval.col_offset - len(argkey) - 1,
|
argval_col_offset - len(argkey) - 1,
|
||||||
argkey + "=", renamed_keywords[argkey] + "=")
|
argkey + "=", renamed_keywords[argkey] + "=")
|
||||||
|
continue
|
||||||
|
self._file_edit.add(
|
||||||
|
"Failed to rename keyword argument from %r to %r" %
|
||||||
|
(argkey, renamed_keywords[argkey]),
|
||||||
|
argval.lineno,
|
||||||
|
argval.col_offset - len(argkey) - 1,
|
||||||
|
"", "",
|
||||||
|
error="Failed to find keyword lexographically. Fix manually.")
|
||||||
|
|
||||||
|
ast.NodeVisitor.generic_visit(self, node)
|
||||||
|
|
||||||
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
||||||
"""Handle bare Attributes i.e. [tf.foo, tf.bar].
|
"""Handle bare Attributes i.e. [tf.foo, tf.bar].
|
||||||
@ -387,6 +490,11 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
|||||||
full_name = self._get_attribute_full_path(node)
|
full_name = self._get_attribute_full_path(node)
|
||||||
if full_name and full_name.startswith("tf."):
|
if full_name and full_name.startswith("tf."):
|
||||||
self._rename_functions(node, full_name)
|
self._rename_functions(node, full_name)
|
||||||
|
if full_name in self._api_change_spec.change_to_function:
|
||||||
|
if not hasattr(node, "is_function_for_call"):
|
||||||
|
new_text = full_name + "()"
|
||||||
|
self._file_edit.add("Changed %r to %r"%(full_name, new_text),
|
||||||
|
node.lineno, node.col_offset, full_name, new_text)
|
||||||
|
|
||||||
ast.NodeVisitor.generic_visit(self, node)
|
ast.NodeVisitor.generic_visit(self, node)
|
||||||
|
|
||||||
|
@ -59,12 +59,45 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
|||||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n")
|
self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n")
|
||||||
|
|
||||||
|
def testRenamePack(self):
|
||||||
|
text = "tf.pack(a)\n"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, "tf.stack(a)\n")
|
||||||
|
text = "tf.unpack(a)\n"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, "tf.unstack(a)\n")
|
||||||
|
|
||||||
def testReorder(self):
|
def testReorder(self):
|
||||||
text = "tf.concat(a, b)\ntf.split(a, b, c)\n"
|
text = "tf.concat(a, b)\ntf.split(a, b, c)\n"
|
||||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(new_text, "tf.concat(concat_dim=a, values=b)\n"
|
self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n"
|
||||||
"tf.split(axis=a, num_or_size_splits=b, value=c)\n")
|
"tf.split(axis=a, num_or_size_splits=b, value=c)\n")
|
||||||
|
|
||||||
|
def testConcatReorderWithKeywordArgs(self):
|
||||||
|
text = "tf.concat(concat_dim=a, values=b)\n"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n")
|
||||||
|
text = "tf.concat(values=b, concat_dim=a)\n"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n")
|
||||||
|
text = "tf.concat(a, values=b)\n"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n")
|
||||||
|
|
||||||
|
def testConcatReorderNested(self):
|
||||||
|
text = "tf.concat(a, tf.concat(c, d))\n"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(
|
||||||
|
new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n")
|
||||||
|
|
||||||
|
def testInitializers(self):
|
||||||
|
text = ("tf.zeros_initializer;tf.zeros_initializer ()\n"
|
||||||
|
"tf.ones_initializer;tf.ones_initializer ()\n")
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(
|
||||||
|
new_text, "tf.zeros_initializer();tf.zeros_initializer ()\n"
|
||||||
|
"tf.ones_initializer();tf.ones_initializer ()\n")
|
||||||
|
|
||||||
def testKeyword(self):
|
def testKeyword(self):
|
||||||
text = "tf.reduce_any(a, reduction_indices=[1, 2])\n"
|
text = "tf.reduce_any(a, reduction_indices=[1, 2])\n"
|
||||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
@ -80,6 +113,19 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(new_text, new_text)
|
self.assertEqual(new_text, new_text)
|
||||||
self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."])
|
self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."])
|
||||||
|
|
||||||
|
def testListComprehension(self):
|
||||||
|
def _test(input, output):
|
||||||
|
_, unused_report, errors, new_text = self._upgrade(input)
|
||||||
|
self.assertEqual(new_text, output)
|
||||||
|
_test("tf.concat(0, \t[x for x in y])\n",
|
||||||
|
"tf.concat(axis=0, \tvalues=[x for x in y])\n")
|
||||||
|
_test("tf.concat(0,[x for x in y])\n",
|
||||||
|
"tf.concat(axis=0,values=[x for x in y])\n")
|
||||||
|
_test("tf.concat(0,[\nx for x in y])\n",
|
||||||
|
"tf.concat(axis=0,values=[\nx for x in y])\n")
|
||||||
|
_test("tf.concat(0,[\n \tx for x in y])\n",
|
||||||
|
"tf.concat(axis=0,values=[\n \tx for x in y])\n")
|
||||||
|
|
||||||
# TODO(aselle): Explicitly not testing command line interface and process_tree
|
# TODO(aselle): Explicitly not testing command line interface and process_tree
|
||||||
# for now, since this is a one off utility.
|
# for now, since this is a one off utility.
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ RUN pip --no-cache-dir install \
|
|||||||
numpy \
|
numpy \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
|
pandas \
|
||||||
Pillow \
|
Pillow \
|
||||||
&& \
|
&& \
|
||||||
python -m ipykernel.kernelspec
|
python -m ipykernel.kernelspec
|
||||||
|
@ -32,6 +32,7 @@ RUN pip --no-cache-dir install \
|
|||||||
numpy \
|
numpy \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
|
pandas \
|
||||||
&& \
|
&& \
|
||||||
python -m ipykernel.kernelspec
|
python -m ipykernel.kernelspec
|
||||||
|
|
||||||
@ -82,7 +83,7 @@ RUN mkdir /bazel && \
|
|||||||
|
|
||||||
RUN git clone https://github.com/tensorflow/tensorflow.git && \
|
RUN git clone https://github.com/tensorflow/tensorflow.git && \
|
||||||
cd tensorflow && \
|
cd tensorflow && \
|
||||||
git checkout r0.12
|
git checkout r1.0
|
||||||
WORKDIR /tensorflow
|
WORKDIR /tensorflow
|
||||||
|
|
||||||
# TODO(craigcitro): Don't install the pip package, since it makes it
|
# TODO(craigcitro): Don't install the pip package, since it makes it
|
||||||
|
@ -32,6 +32,7 @@ RUN pip --no-cache-dir install \
|
|||||||
numpy \
|
numpy \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
|
pandas \
|
||||||
&& \
|
&& \
|
||||||
python -m ipykernel.kernelspec
|
python -m ipykernel.kernelspec
|
||||||
|
|
||||||
@ -82,7 +83,7 @@ RUN mkdir /bazel && \
|
|||||||
|
|
||||||
RUN git clone https://github.com/tensorflow/tensorflow.git && \
|
RUN git clone https://github.com/tensorflow/tensorflow.git && \
|
||||||
cd tensorflow && \
|
cd tensorflow && \
|
||||||
git checkout r0.12
|
git checkout r1.0
|
||||||
WORKDIR /tensorflow
|
WORKDIR /tensorflow
|
||||||
|
|
||||||
# Configure the build for our CUDA configuration.
|
# Configure the build for our CUDA configuration.
|
||||||
|
@ -30,6 +30,7 @@ RUN pip --no-cache-dir install \
|
|||||||
numpy \
|
numpy \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
|
pandas \
|
||||||
Pillow \
|
Pillow \
|
||||||
&& \
|
&& \
|
||||||
python -m ipykernel.kernelspec
|
python -m ipykernel.kernelspec
|
||||||
@ -58,6 +59,9 @@ COPY notebooks /notebooks
|
|||||||
# We just add a little wrapper script.
|
# We just add a little wrapper script.
|
||||||
COPY run_jupyter.sh /
|
COPY run_jupyter.sh /
|
||||||
|
|
||||||
|
# For CUDA profiling, TensorFlow requires CUPTI.
|
||||||
|
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
# TensorBoard
|
# TensorBoard
|
||||||
EXPOSE 6006
|
EXPOSE 6006
|
||||||
# IPython
|
# IPython
|
||||||
|
@ -328,7 +328,9 @@ def _generate_signature(func, reverse_index):
|
|||||||
len(argspec.args or []) - len(argspec.defaults or []))
|
len(argspec.args or []) - len(argspec.defaults or []))
|
||||||
|
|
||||||
# Python documentation skips `self` when printing method signatures.
|
# Python documentation skips `self` when printing method signatures.
|
||||||
first_arg = 1 if inspect.ismethod(func) and 'self' in argspec.args[:1] else 0
|
# Note we cannot test for ismethod here since unbound methods do not register
|
||||||
|
# as methods (in Python 3).
|
||||||
|
first_arg = 1 if 'self' in argspec.args[:1] else 0
|
||||||
|
|
||||||
# Add all args without defaults.
|
# Add all args without defaults.
|
||||||
for arg in argspec.args[first_arg:first_arg_with_default]:
|
for arg in argspec.args[first_arg:first_arg_with_default]:
|
||||||
@ -679,6 +681,15 @@ def generate_global_index(library_name, index, duplicate_of):
|
|||||||
for full_name, py_object in six.iteritems(index):
|
for full_name, py_object in six.iteritems(index):
|
||||||
if (inspect.ismodule(py_object) or inspect.isfunction(py_object) or
|
if (inspect.ismodule(py_object) or inspect.isfunction(py_object) or
|
||||||
inspect.isclass(py_object)):
|
inspect.isclass(py_object)):
|
||||||
|
# In Python 3, unbound methods are functions, so eliminate those.
|
||||||
|
if inspect.isfunction(py_object):
|
||||||
|
if full_name.count('.') == 0:
|
||||||
|
parent_name = ''
|
||||||
|
else:
|
||||||
|
parent_name = full_name[:full_name.rfind('.')]
|
||||||
|
if parent_name in index and inspect.isclass(index[parent_name]):
|
||||||
|
# Skip methods (=functions with class parents).
|
||||||
|
continue
|
||||||
symbol_links.append((full_name,
|
symbol_links.append((full_name,
|
||||||
_markdown_link(full_name, full_name,
|
_markdown_link(full_name, full_name,
|
||||||
'.', duplicate_of)))
|
'.', duplicate_of)))
|
||||||
|
@ -13,7 +13,7 @@ and [Rust](https://github.com/tensorflow/rust).
|
|||||||
The command:
|
The command:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
bazel build -c opt //tensorflow/tools/lib_package:libtensorflow
|
bazel build --config opt //tensorflow/tools/lib_package:libtensorflow
|
||||||
```
|
```
|
||||||
|
|
||||||
produces `bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz`, which
|
produces `bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz`, which
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
package(default_visibility = ["//visibility:private"])
|
package(default_visibility = ["//visibility:private"])
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
|
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
|
||||||
|
load("//third_party/mkl:build_defs.bzl", "if_mkl")
|
||||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
|
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
|
||||||
|
|
||||||
# This returns a list of headers of all public header libraries (e.g.,
|
# This returns a list of headers of all public header libraries (e.g.,
|
||||||
@ -131,5 +132,5 @@ sh_binary(
|
|||||||
"//tensorflow/python/tools:all_files",
|
"//tensorflow/python/tools:all_files",
|
||||||
"//tensorflow/tensorboard",
|
"//tensorflow/tensorboard",
|
||||||
],
|
],
|
||||||
}),
|
}) + if_mkl(["//third_party/mkl:intel_binary_blob"]),
|
||||||
)
|
)
|
||||||
|
@ -89,6 +89,13 @@ function main() {
|
|||||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/external \
|
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/external \
|
||||||
"${TMPDIR}/external"
|
"${TMPDIR}/external"
|
||||||
RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles
|
RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles
|
||||||
|
# Copy MKL libs over so they can be loaded at runtime
|
||||||
|
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl ]; then
|
||||||
|
mkdir "${TMPDIR}/_solib_k8"
|
||||||
|
cp -R \
|
||||||
|
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl \
|
||||||
|
"${TMPDIR}/_solib_k8"
|
||||||
|
fi
|
||||||
else
|
else
|
||||||
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external ]; then
|
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external ]; then
|
||||||
# Old-style runfiles structure (--legacy_external_runfiles).
|
# Old-style runfiles structure (--legacy_external_runfiles).
|
||||||
@ -99,6 +106,13 @@ function main() {
|
|||||||
cp_external \
|
cp_external \
|
||||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external \
|
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external \
|
||||||
"${TMPDIR}/external"
|
"${TMPDIR}/external"
|
||||||
|
# Copy MKL libs over so they can be loaded at runtime
|
||||||
|
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl ]; then
|
||||||
|
mkdir "${TMPDIR}/_solib_k8"
|
||||||
|
cp -R \
|
||||||
|
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl \
|
||||||
|
"${TMPDIR}/_solib_k8"
|
||||||
|
fi
|
||||||
else
|
else
|
||||||
# New-style runfiles structure (--nolegacy_external_runfiles).
|
# New-style runfiles structure (--nolegacy_external_runfiles).
|
||||||
cp -R \
|
cp -R \
|
||||||
@ -109,6 +123,13 @@ function main() {
|
|||||||
cp_external \
|
cp_external \
|
||||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \
|
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \
|
||||||
"${TMPDIR}/external"
|
"${TMPDIR}/external"
|
||||||
|
# Copy MKL libs over so they can be loaded at runtime
|
||||||
|
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl ]; then
|
||||||
|
mkdir "${TMPDIR}/_solib_k8"
|
||||||
|
cp -R \
|
||||||
|
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl \
|
||||||
|
"${TMPDIR}/_solib_k8"
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow
|
RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow
|
||||||
fi
|
fi
|
||||||
|
@ -29,7 +29,7 @@ from setuptools.dist import Distribution
|
|||||||
# This version string is semver compatible, but incompatible with pip.
|
# This version string is semver compatible, but incompatible with pip.
|
||||||
# For pip, we will remove all '-' characters from this string, and use the
|
# For pip, we will remove all '-' characters from this string, and use the
|
||||||
# result for pip.
|
# result for pip.
|
||||||
_VERSION = '0.12.1'
|
_VERSION = '1.0.0-rc1'
|
||||||
|
|
||||||
REQUIRED_PACKAGES = [
|
REQUIRED_PACKAGES = [
|
||||||
'numpy >= 1.11.0',
|
'numpy >= 1.11.0',
|
||||||
@ -151,6 +151,7 @@ def find_files(pattern, root):
|
|||||||
|
|
||||||
|
|
||||||
matches = ['../' + x for x in find_files('*', 'external') if '.py' not in x]
|
matches = ['../' + x for x in find_files('*', 'external') if '.py' not in x]
|
||||||
|
matches += ['../' + x for x in find_files('*', '_solib_k8') if '.py' not in x]
|
||||||
|
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
EXTENSION_NAME = 'python/_pywrap_tensorflow.pyd'
|
EXTENSION_NAME = 'python/_pywrap_tensorflow.pyd'
|
||||||
|
@ -98,7 +98,7 @@ TODO(xpan): Provide graph.pbtxt, model.ckpt, tfprof_log and run_meta download.
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
# Build the tool.
|
# Build the tool.
|
||||||
bazel build -c opt tensorflow/tools/tfprof/...
|
bazel build --config opt tensorflow/tools/tfprof/...
|
||||||
|
|
||||||
# Help information, including detail 'option' instructions.
|
# Help information, including detail 'option' instructions.
|
||||||
bazel-bin/tensorflow/tools/tfprof/tfprof help
|
bazel-bin/tensorflow/tools/tfprof/tfprof help
|
||||||
|
@ -78,11 +78,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
|||||||
native.new_http_archive(
|
native.new_http_archive(
|
||||||
name = "libxsmm_archive",
|
name = "libxsmm_archive",
|
||||||
urls = [
|
urls = [
|
||||||
"http://bazel-mirror.storage.googleapis.com/github.com/hfp/libxsmm/archive/1.6.6.tar.gz",
|
"http://bazel-mirror.storage.googleapis.com/github.com/hfp/libxsmm/archive/1.7.tar.gz",
|
||||||
"https://github.com/hfp/libxsmm/archive/1.6.6.tar.gz",
|
"https://github.com/hfp/libxsmm/archive/1.7.tar.gz",
|
||||||
],
|
],
|
||||||
sha256 = "7c048a48e17f7f14a475be7b83e6e941289e03debb42ce9e02a06353412f9f2a",
|
sha256 = "2eea65624a697e74b939511cd2a686b4c957e90c99be168fe134d96771e811ad",
|
||||||
strip_prefix = "libxsmm-1.6.6",
|
strip_prefix = "libxsmm-1.7",
|
||||||
build_file = str(Label("//third_party:libxsmm.BUILD")),
|
build_file = str(Label("//third_party:libxsmm.BUILD")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user