Merge changes from github.
Change: 146918929
This commit is contained in:
parent
15ff7b7027
commit
639b4e71f5
@ -33,10 +33,11 @@ and discussion.**
|
||||
|
||||
People who are a little more adventurous can also try our nightly binaries:
|
||||
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||
* [Android](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
|
||||
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
|
||||
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
|
||||
|
||||
|
112
RELEASE.md
112
RELEASE.md
@ -1,7 +1,20 @@
|
||||
# Changes since the last release
|
||||
# Release 1.0.0
|
||||
|
||||
## Major Features and Improvements
|
||||
* XLA (experimental): initial release of [XLA](https://www.tensorflow.org/versions/master/experimental/xla/), a domain-specific compiler for TensorFlow graphs, that targets CPUs and GPUs.
|
||||
* TensorFlow Debugger (tfdbg): command-line interface and API.
|
||||
* New python 3 docker images added.
|
||||
* Made pip packages pypi compliant. TensorFlow can now be installed by `pip
|
||||
install tensorflow` command.
|
||||
* Several python API calls have been changed to resemble NumPy more closely.
|
||||
* Android: person detection + tracking demo implementing Scalable Object
|
||||
Detection using Deep Neural Networks.
|
||||
* New (experimental) [Java API](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java).
|
||||
* Add new Android image stylization demo based on "A Learned Representation For Artistic Style", and add YOLO object detector support.
|
||||
|
||||
## Breaking Changes to the API
|
||||
|
||||
To help you upgrade your existing TensorFlow Python code to match the API changes below, we have prepared a [conversion script](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/compatibility).
|
||||
* TensorFlow/models have been moved to a separate github repository.
|
||||
* Division and modulus operators (/, //, %) now match Python (flooring)
|
||||
semantics. This applies to `tf.div` and `tf.mod` as well. To obtain forced
|
||||
integer truncation based behaviors you can use `tf.truncatediv`
|
||||
@ -51,16 +64,93 @@
|
||||
keywords. In particular we now match NumPy order as
|
||||
`tf.sparse_split(sp_input, num_split, axis)`. NOTE: we have temporarily
|
||||
made `tf.sparse_split` require keyword arguments.
|
||||
* Deprecated `tf.concat` operator. Please switch to use `tf.concat_v2` for now.
|
||||
In the Beta release, we will update `tf.concat` to match argument order of
|
||||
`tf.concat_v2.
|
||||
* tf.image.decode_jpeg by default uses the faster DCT method, sacrificing
|
||||
* `tf.concat` now takes arguments in reversed order and with different keywords. In particular we now match NumPy order as `tf.concat(values, axis, name)`.
|
||||
* `tf.image.decode_jpeg` by default uses the faster DCT method, sacrificing
|
||||
a little fidelity for improved speed. One can revert to the old
|
||||
behavior by specifying the attribute dct_method='INTEGER_ACCURATE'.
|
||||
behavior by specifying the attribute `dct_method='INTEGER_ACCURATE'`.
|
||||
* `tf.complex_abs` has been removed from the Python interface. `tf.abs`
|
||||
supports complex tensors and should be used instead.
|
||||
* In the C++ API (in tensorflow/cc), Input, Output, etc. have moved
|
||||
from the tensorflow::ops namespace to tensorflow.
|
||||
* Template.`var_scope` property renamed to `.variable_scope`
|
||||
* SyncReplicasOptimizer is removed and SyncReplicasOptimizerV2 renamed to SyncReplicasOptimizer.
|
||||
* `tf.zeros_initializer()` and `tf.ones_initializer()` now return a callable
|
||||
that must be called with initializer arguments, in your code replace
|
||||
`tf.zeros_initializer` with `tf.zeros_initializer()`.
|
||||
* `SparseTensor.shape` has been renamed to `SparseTensor.dense_shape`. Same for
|
||||
`SparseTensorValue.shape`.
|
||||
* Replace tf.scalar_summary, tf.histogram_summary, tf.audio_summary, tf.image_summary with tf.summary.scalar, tf.summary.histogram, tf.summary.audio, tf.summary.image, respectively. The new summary ops take name rather than tag as their first argument, meaning summary ops now respect TensorFlow name scopes.
|
||||
* Replace tf.train.SummaryWriter and tf.train.SummaryWriterCache with tf.summary.FileWriter and tf.summary.FileWriterCache.
|
||||
* Removes RegisterShape from public API. Use C++ shape function registration
|
||||
instead.
|
||||
* Deprecated `_ref` dtypes from the python API.
|
||||
* In the C++ API (in tensorflow/cc), Input, Output, etc. have moved
|
||||
from the tensorflow::ops namespace to tensorflow.
|
||||
* Change arg order for `{softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits` to be (labels, predictions), and force use of named args.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* New op: `parallel_stack`.
|
||||
* Introducing common tf io compression options constants for
|
||||
RecordReader/RecordWriter.
|
||||
* Add `sparse_column_with_vocabulary_file`, to specify a feature column that
|
||||
transform string features to IDs, where the mapping is defined by a vocabulary
|
||||
file.
|
||||
* Added `index_to_string_table` which returns a lookup table that maps indices to
|
||||
strings.
|
||||
* Add `string_to_index_table`, which returns a lookup table that matches strings
|
||||
to indices.
|
||||
* Add a `ParallelForWithWorkerId` function.
|
||||
* Add `string_to_index_table`, which returns a lookup table that matches strings
|
||||
to indices.
|
||||
* Support restore session from checkpoint files in v2 in `contrib/session_bundle`.
|
||||
* Added a tf.contrib.image.rotate function for arbitrary angles.
|
||||
* Added `tf.contrib.framework.filter_variables` as a convenience function to
|
||||
filter lists of variables based on regular expressions.
|
||||
* `make_template()` takes an optional `custom_getter_ param`.
|
||||
* Added comment about how existing directories are handled by
|
||||
`recursive_create_dir`.
|
||||
* Added an op for QR factorizations.
|
||||
* Divides and mods in Python API now use flooring (Python) semantics.
|
||||
* Android: pre-built libs are now built nightly.
|
||||
* Android: cmake/gradle build for TensorFlow Inference library under
|
||||
`contrib/android/cmake`
|
||||
* Android: Much more robust Session initialization code.
|
||||
* Android: TF stats now exposed directly in demo and log when debug mode is
|
||||
active
|
||||
* Android: new/better README.md documentation
|
||||
* saved_model is available as `tf.saved_model`.
|
||||
* Empty op is now stateful.
|
||||
* Improve speed of scatter_update on the cpu for ASSIGN operations.
|
||||
* Change `reduce_join` to treat `reduction_indices` in the same way as other `reduce_` ops.
|
||||
* Move `TensorForestEstimator` to `contrib/tensor_forest`.
|
||||
* Enable compiler optimizations by default and allow configuration in configure.
|
||||
* `tf.divide` now honors the name field.
|
||||
* Make metrics weight broadcasting more strict.
|
||||
* Add new queue-like `StagingArea` and new ops: `stage` and `unstage`.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
||||
Aaron Hu, Abhishek Aggarwal, Adam Michael, Adriano Carmezim, @AfirSraftGarrier,
|
||||
Alexander Novikov, Alexander Rosenberg Johansen, Andrew Gibiansky, Andrew Hundt,
|
||||
Anish Shah, Anton Loss, @b0noI, @BoyuanJiang, Carl Thomé, Chad Kennedy, Comic
|
||||
Chang, Connor Braa, Daniel N. Lang, Daniel Trebbien,
|
||||
@danielgordon10, Darcy Liu, Darren Garvey, Dmitri Lapin, Eron Wright, Evan
|
||||
Cofer, Fabrizio Milo, Finbarr Timbers, Franck Dernoncourt, Garrett Smith,
|
||||
@guschmue, Hao Wei, Henrik Holst, Huazuo Gao, @Ian, @Issac, Jacob Israel,
|
||||
Jangsoo Park, Jin Kim, Jingtian Peng, John Pope, Kye Bostelmann, Liangliang He,
|
||||
Ling Zhang, Luheng He, Luke Iwanski, @lvli, Michael Basilyan, Mihir Patel,
|
||||
Mikalai Drabovich, Morten Just, @newge, Nick Butlin, Nishant Shukla,
|
||||
Pengfei Ni, Przemyslaw Tredak, @rasbt, @Ronny, Rudolf Rosa, @RustingSword,
|
||||
Sam Abrahams, Sam Putnam, @SeongAhJo, Shi Jiaxin, @skavulya, Steffen MüLler,
|
||||
@TheUSER123, @tiriplicamihai, @vhasanov, Victor Costan, Vit Stepanovs,
|
||||
Wangda Tan, Wenjian Huang, Xingdong Zuo, Yaroslav Bulatov, Yota Toyama,
|
||||
Yuan (Terry) Tang, Yuxin Wu
|
||||
|
||||
We are also grateful to all who filed issues or helped resolve them, asked and
|
||||
answered questions, and were part of inspiring discussions.
|
||||
|
||||
|
||||
# Release 0.12.0
|
||||
|
||||
@ -100,15 +190,15 @@
|
||||
## Breaking Changes to the API
|
||||
|
||||
* `BusAdjacency` enum replaced with a protocol buffer `DeviceLocality`. PCI bus
|
||||
indexing now starts from 1 instead of 0, and bus_id==0 is used where
|
||||
previously BUS_ANY was used.
|
||||
indexing now starts from 1 instead of 0, and `bus_id==0` is used where
|
||||
previously `BUS_ANY` was used.
|
||||
* `Env::FileExists` and `FileSystem::FileExists` now return a tensorflow::Status
|
||||
intead of a bool. Any callers to this function can be converted to a bool
|
||||
by adding .ok() to the call.
|
||||
* The C API type `TF_SessionWithGraph` has been renamed to `TF_Session`,
|
||||
indicating its preferred use in language bindings for TensorFlow.
|
||||
What was previously `TF_Session` has been renamed to `TF_DeprecatedSession`.
|
||||
* Renamed TF_Port to TF_Output in the C API.
|
||||
* Renamed `TF_Port` to `TF_Output` in the C API.
|
||||
* Removes RegisterShape from public API. Use C++ shape function registration instead.
|
||||
indexing now starts from 1 instead of 0, and `bus_id==0` is used where
|
||||
previously `BUS_ANY` was used.
|
||||
@ -143,7 +233,7 @@
|
||||
`tf.global_variables_initializer` respectively.
|
||||
* `tf.zeros_initializer()` and `tf.ones_initializer()` now return a callable
|
||||
that must be called with initializer arguments, in your code replace
|
||||
tf.zeros_initializer with tf.zeros_initializer()
|
||||
`tf.zeros_initializer` with `tf.zeros_initializer()`
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
|
116
configure
vendored
116
configure
vendored
@ -41,7 +41,8 @@ function bazel_clean_and_fetch() {
|
||||
if ! is_windows; then
|
||||
bazel clean --expunge
|
||||
fi
|
||||
bazel fetch "//tensorflow/... -//tensorflow/examples/android/..."
|
||||
bazel fetch "//tensorflow/... -//tensorflow/contrib/nccl/... \
|
||||
-//tensorflow/examples/android/..."
|
||||
}
|
||||
|
||||
# Delete any leftover BUILD files from the Makefile build, which would interfere
|
||||
@ -73,10 +74,77 @@ while true; do
|
||||
# Retry
|
||||
done
|
||||
|
||||
## Set up MKL related environment settings
|
||||
if false; then # Disable building with MKL for now
|
||||
while [ "$TF_NEED_MKL" == "" ]; do
|
||||
fromuser=""
|
||||
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
|
||||
fromuser="1"
|
||||
case $INPUT in
|
||||
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
|
||||
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||
* ) echo "Invalid selection: " $INPUT;;
|
||||
esac
|
||||
done
|
||||
|
||||
OSNAME=`uname -s`
|
||||
|
||||
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
|
||||
DST=`dirname $0`
|
||||
ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170110.tgz
|
||||
GITHUB_RELEASE_TAG=v0.3
|
||||
MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME"
|
||||
if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then
|
||||
wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL
|
||||
fi
|
||||
tar -xzf $DST/third_party/mkl/$ARCHIVE_BASENAME -C $DST/third_party/mkl/
|
||||
extracted_dir_name="${ARCHIVE_BASENAME%.*}"
|
||||
MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name
|
||||
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
|
||||
|
||||
if [ "$OSNAME" == "Linux" ]; then
|
||||
# Full MKL configuration
|
||||
MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version?
|
||||
MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION?
|
||||
|
||||
# MKL-ML configuration
|
||||
MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version?
|
||||
MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION?
|
||||
elif [ "$OSNAME" == "Darwin" ]; then
|
||||
echo "Darwin is unsupported yet";
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then
|
||||
ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/
|
||||
ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/
|
||||
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
|
||||
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
|
||||
else
|
||||
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist";
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$fromuser" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cat > third_party/mkl/mkl.config <<EOF
|
||||
# MKL_INSTALL_PATH refers to the location of MKL root folder. The MKL header and library
|
||||
# files can be either in this directory, or under include/ and lib64/
|
||||
MKL_INSTALL_PATH=$MKL_INSTALL_PATH
|
||||
EOF
|
||||
|
||||
fi # TF_NEED_MKL
|
||||
################## MKL
|
||||
fi # Disable building with MKL for now
|
||||
|
||||
## Set up architecture-dependent optimization flags.
|
||||
if [ -z "$CC_OPT_FLAGS" ]; then
|
||||
default_cc_opt_flags="-march=native"
|
||||
read -p "Please specify optimization flags to use during compilation [Default is $default_cc_opt_flags]: " CC_OPT_FLAGS
|
||||
read -p "Please specify optimization flags to use during compilation when bazel option "\
|
||||
"\"--config=opt\" is specified [Default is $default_cc_opt_flags]: " CC_OPT_FLAGS
|
||||
if [ -z "$CC_OPT_FLAGS" ]; then
|
||||
CC_OPT_FLAGS=$default_cc_opt_flags
|
||||
fi
|
||||
@ -328,46 +396,8 @@ while true; do
|
||||
|
||||
if [[ -z "$TF_CUDNN_VERSION" ]]; then
|
||||
TF_CUDNN_EXT=""
|
||||
cudnn_lib_path=""
|
||||
cudnn_alt_lib_path=""
|
||||
if is_windows; then
|
||||
cudnn_lib_path="${CUDNN_INSTALL_PATH}/lib/x64/cudnn.lib"
|
||||
cudnn_alt_lib_path="${CUDNN_INSTALL_PATH}/lib/x64/cudnn.lib"
|
||||
elif is_linux; then
|
||||
cudnn_lib_path="${CUDNN_INSTALL_PATH}/lib64/libcudnn.so"
|
||||
cudnn_alt_lib_path="${CUDNN_INSTALL_PATH}/libcudnn.so"
|
||||
elif is_macos; then
|
||||
cudnn_lib_path="${CUDNN_INSTALL_PATH}/lib/libcudnn.dylib"
|
||||
cudnn_alt_lib_path="${CUDNN_INSTALL_PATH}/libcudnn.dylib"
|
||||
fi
|
||||
# Resolve to the SONAME of the symlink. Use readlink without -f
|
||||
# to resolve exactly once to the SONAME. E.g, libcudnn.so ->
|
||||
# libcudnn.so.4.
|
||||
# If the path is not a symlink, readlink will exit with an error code, so
|
||||
# in that case, we return the path itself.
|
||||
if [ -f "$cudnn_lib_path" ]; then
|
||||
REALVAL=`readlink "${cudnn_lib_path}" || echo "${cudnn_lib_path}"`
|
||||
else
|
||||
REALVAL=`readlink "${cudnn_alt_lib_path}" || echo "${cudnn_alt_lib_path}"`
|
||||
fi
|
||||
|
||||
# Extract the version of the SONAME, if it was indeed symlinked to
|
||||
# the SONAME version of the file.
|
||||
if [[ "$REALVAL" =~ .so[.]+([0-9]*) ]]; then
|
||||
TF_CUDNN_EXT="."${BASH_REMATCH[1]}
|
||||
TF_CUDNN_VERSION=${BASH_REMATCH[1]}
|
||||
echo "libcudnn.so resolves to libcudnn${TF_CUDNN_EXT}"
|
||||
elif [[ "$REALVAL" =~ ([0-9]*).dylib ]]; then
|
||||
TF_CUDNN_EXT=${BASH_REMATCH[1]}".dylib"
|
||||
TF_CUDNN_VERSION=${BASH_REMATCH[1]}
|
||||
echo "libcudnn.dylib resolves to libcudnn${TF_CUDNN_EXT}"
|
||||
fi
|
||||
else
|
||||
if is_macos; then
|
||||
TF_CUDNN_EXT=".${TF_CUDNN_VERSION}.dylib"
|
||||
else
|
||||
TF_CUDNN_EXT=".$TF_CUDNN_VERSION"
|
||||
fi
|
||||
TF_CUDNN_EXT=".$TF_CUDNN_VERSION"
|
||||
fi
|
||||
|
||||
if is_windows; then
|
||||
@ -377,8 +407,8 @@ while true; do
|
||||
CUDA_DNN_LIB_PATH="lib64/libcudnn.so${TF_CUDNN_EXT}"
|
||||
CUDA_DNN_LIB_ALT_PATH="libcudnn.so${TF_CUDNN_EXT}"
|
||||
elif is_macos; then
|
||||
CUDA_DNN_LIB_PATH="lib/libcudnn${TF_CUDNN_EXT}"
|
||||
CUDA_DNN_LIB_ALT_PATH="libcudnn${TF_CUDNN_EXT}"
|
||||
CUDA_DNN_LIB_PATH="lib/libcudnn${TF_CUDNN_EXT}.dylib"
|
||||
CUDA_DNN_LIB_ALT_PATH="libcudnn${TF_CUDNN_EXT}.dylib"
|
||||
fi
|
||||
|
||||
if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" -o -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then
|
||||
|
@ -171,6 +171,7 @@ filegroup(
|
||||
"//tensorflow/contrib/slim/python/slim/data:all_files",
|
||||
"//tensorflow/contrib/slim/python/slim/nets:all_files",
|
||||
"//tensorflow/contrib/solvers:all_files",
|
||||
"//tensorflow/contrib/sparsemax:all_files",
|
||||
"//tensorflow/contrib/specs:all_files",
|
||||
"//tensorflow/contrib/stat_summarizer:all_files",
|
||||
"//tensorflow/contrib/tensor_forest:all_files",
|
||||
@ -246,6 +247,20 @@ filegroup(
|
||||
visibility = [":__subpackages__"],
|
||||
)
|
||||
|
||||
#load(
|
||||
# "//third_party/mkl:build_defs.bzl",
|
||||
# "if_mkl",
|
||||
#)
|
||||
|
||||
#filegroup(
|
||||
# name = "intel_binary_blob",
|
||||
# data = if_mkl(
|
||||
# [
|
||||
# "//third_party/mkl:intel_binary_blob",
|
||||
# ],
|
||||
# ),
|
||||
#)
|
||||
|
||||
# -------------------------------------------
|
||||
# New rules should be added above this target.
|
||||
# -------------------------------------------
|
||||
|
@ -47,6 +47,7 @@ py_library(
|
||||
"//tensorflow/contrib/slim",
|
||||
"//tensorflow/contrib/slim:nets",
|
||||
"//tensorflow/contrib/solvers:solvers_py",
|
||||
"//tensorflow/contrib/sparsemax:sparsemax_py",
|
||||
"//tensorflow/contrib/specs",
|
||||
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
|
||||
"//tensorflow/contrib/tensor_forest:init_py",
|
||||
|
@ -49,6 +49,7 @@ from tensorflow.contrib import rnn
|
||||
from tensorflow.contrib import seq2seq
|
||||
from tensorflow.contrib import slim
|
||||
from tensorflow.contrib import solvers
|
||||
from tensorflow.contrib import sparsemax
|
||||
from tensorflow.contrib import stat_summarizer
|
||||
from tensorflow.contrib import tensor_forest
|
||||
from tensorflow.contrib import tensorboard
|
||||
|
@ -170,7 +170,8 @@ if (tensorflow_ENABLE_GPU)
|
||||
|
||||
# add cudnn
|
||||
include_directories(${CUDNN_HOME})
|
||||
set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDNN_HOME}/lib/x64/cudnn.lib)
|
||||
set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES}
|
||||
${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib)
|
||||
|
||||
# create cuda_config.h
|
||||
FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h
|
||||
@ -179,6 +180,7 @@ if (tensorflow_ENABLE_GPU)
|
||||
"#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n"
|
||||
"#define TF_CUDA_VERSION \"64_80\"\n"
|
||||
"#define TF_CUDNN_VERSION \"64_5\"\n"
|
||||
"#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n"
|
||||
"#endif // CUDA_CUDA_CONFIG_H_\n"
|
||||
)
|
||||
|
||||
|
@ -71,7 +71,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names})
|
||||
COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal}
|
||||
DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir
|
||||
)
|
||||
|
||||
|
||||
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h)
|
||||
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc)
|
||||
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h)
|
||||
@ -79,6 +79,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names})
|
||||
endforeach()
|
||||
|
||||
|
||||
|
||||
########################################################
|
||||
# tf_cc_ops library
|
||||
########################################################
|
||||
|
@ -372,6 +372,9 @@ add_python_module("tensorflow/contrib/slim/python/slim/nets")
|
||||
add_python_module("tensorflow/contrib/solvers")
|
||||
add_python_module("tensorflow/contrib/solvers/python")
|
||||
add_python_module("tensorflow/contrib/solvers/python/ops")
|
||||
add_python_module("tensorflow/contrib/sparsemax")
|
||||
add_python_module("tensorflow/contrib/sparsemax/python")
|
||||
add_python_module("tensorflow/contrib/sparsemax/python/ops")
|
||||
add_python_module("tensorflow/contrib/specs")
|
||||
add_python_module("tensorflow/contrib/specs/python")
|
||||
add_python_module("tensorflow/contrib/stat_summarizer")
|
||||
|
@ -102,7 +102,12 @@ class GMM(estimator.Estimator):
|
||||
results = self.evaluate(input_fn=input_fn, batch_size=batch_size,
|
||||
steps=steps)
|
||||
return np.sum(results[GMM.SCORES])
|
||||
|
||||
|
||||
def weights(self):
|
||||
"""Returns the cluster weights."""
|
||||
return checkpoint_utils.load_variable(
|
||||
self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
|
||||
|
||||
def clusters(self):
|
||||
"""Returns cluster centers."""
|
||||
clusters = checkpoint_utils.load_variable(
|
||||
|
@ -92,6 +92,7 @@ def _init_clusters_random(data, num_clusters, random_seed):
|
||||
|
||||
class GmmAlgorithm(object):
|
||||
"""Tensorflow Gaussian mixture model clustering class."""
|
||||
CLUSTERS_WEIGHT = 'alphas'
|
||||
CLUSTERS_VARIABLE = 'clusters'
|
||||
CLUSTERS_COVS_VARIABLE = 'clusters_covs'
|
||||
|
||||
@ -187,11 +188,13 @@ class GmmAlgorithm(object):
|
||||
array_ops.expand_dims(array_ops.diag_part(cov), 0),
|
||||
[self._num_classes, 1])
|
||||
self._covs = variables.Variable(
|
||||
covs, name='clusters_covs', validate_shape=False)
|
||||
covs, name=self.CLUSTERS_COVS_VARIABLE, validate_shape=False)
|
||||
# Mixture weights, representing the probability that a randomly
|
||||
# selected unobservable data (in EM terms) was generated by component k.
|
||||
self._alpha = variables.Variable(
|
||||
array_ops.tile([1.0 / self._num_classes], [self._num_classes]))
|
||||
array_ops.tile([1.0 / self._num_classes], [self._num_classes]),
|
||||
name=self.CLUSTERS_WEIGHT,
|
||||
validate_shape=False)
|
||||
|
||||
def training_ops(self):
|
||||
"""Returns the training operation."""
|
||||
|
@ -109,6 +109,16 @@ class GMMTest(test.TestCase):
|
||||
np.linalg.inv(covs[assignments[r]])), points[r, :] -
|
||||
means[assignments[r]])))
|
||||
return (points, assignments, scores)
|
||||
|
||||
def test_weights(self):
|
||||
"""Tests the shape of the weights."""
|
||||
gmm = gmm_lib.GMM(self.num_centers,
|
||||
initial_clusters=self.initial_means,
|
||||
random_seed=4,
|
||||
config=run_config.RunConfig(tf_random_seed=2))
|
||||
gmm.fit(input_fn=self.input_fn(), steps=0)
|
||||
weights = gmm.weights()
|
||||
self.assertAllEqual(list(weights.shape), [self.num_centers])
|
||||
|
||||
def test_clusters(self):
|
||||
"""Tests the shape of the clusters."""
|
||||
|
@ -480,6 +480,7 @@ py_test(
|
||||
size = "medium",
|
||||
srcs = ["python/learn/estimators/estimator_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
|
@ -191,6 +191,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
|
||||
if not dnn_feature_columns:
|
||||
dnn_logits = None
|
||||
else:
|
||||
if not dnn_hidden_units:
|
||||
raise ValueError(
|
||||
"dnn_hidden_units must be defined when dnn_feature_columns is specified.")
|
||||
dnn_partitioner = (
|
||||
partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=num_ps_replicas))
|
||||
|
@ -241,6 +241,26 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
|
||||
dnn_feature_columns=None,
|
||||
dnn_hidden_units=[3, 3])
|
||||
|
||||
def testNoDnnHiddenUnits(self):
|
||||
def _input_fn():
|
||||
return {
|
||||
'age':
|
||||
constant_op.constant([1]),
|
||||
'language':
|
||||
sparse_tensor.SparseTensor(
|
||||
values=['english'], indices=[[0, 0]], dense_shape=[1, 1])
|
||||
}, constant_op.constant([[1]])
|
||||
|
||||
language = feature_column.sparse_column_with_hash_bucket('language', 100)
|
||||
age = feature_column.real_valued_column('age')
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'dnn_hidden_units must be defined when dnn_feature_columns is specified'):
|
||||
classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
|
||||
dnn_feature_columns=[age, language])
|
||||
classifier.fit(input_fn=_input_fn, steps=2)
|
||||
|
||||
def testEmbeddingMultiplier(self):
|
||||
embedding_language = feature_column.embedding_column(
|
||||
feature_column.sparse_column_with_hash_bucket('language', 10),
|
||||
|
@ -274,10 +274,10 @@ def bidirectional_rnn(cell_fw,
|
||||
output_bw = _reverse_seq(tmp, sequence_length)
|
||||
# Concat each of the forward/backward outputs
|
||||
outputs = [
|
||||
array_ops_.concat_v2([fw, bw], 1) for fw, bw in zip(output_fw, output_bw)
|
||||
array_ops_.concat([fw, bw], 1) for fw, bw in zip(output_fw, output_bw)
|
||||
]
|
||||
|
||||
return outputs, array_ops_.concat_v2([state_fw, state_bw], 1)
|
||||
return outputs, array_ops_.concat([state_fw, state_bw], 1)
|
||||
|
||||
|
||||
# End of TensorFlow 0.7
|
||||
|
@ -59,7 +59,7 @@ def embedding_lookup(params, ids, name='embedding_lookup'):
|
||||
ids_flat = array_ops_.reshape(
|
||||
ids, math_ops.reduce_prod(shape, keep_dims=True))
|
||||
embeds_flat = nn.embedding_lookup(params, ids_flat, name)
|
||||
embed_shape = array_ops_.concat_v2([shape, [-1]], 0)
|
||||
embed_shape = array_ops_.concat([shape, [-1]], 0)
|
||||
embeds = array_ops_.reshape(embeds_flat, embed_shape)
|
||||
embeds.set_shape(ids.get_shape().concatenate(params.get_shape()[1:]))
|
||||
return embeds
|
||||
|
@ -427,7 +427,6 @@ def sparse_softmax_cross_entropy(logits, labels, weights=1.0, scope=None):
|
||||
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
|
||||
[logits, labels, weights]) as scope:
|
||||
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
|
||||
weights = array_ops.squeeze(weights)
|
||||
|
||||
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits,
|
||||
|
@ -243,6 +243,34 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
expected_value = 400.0 * label_smoothing / 3.0
|
||||
self.assertAlmostEqual(loss.eval(), expected_value, 3)
|
||||
|
||||
def testLossWithDynamicallyShapedWeights1D(self):
|
||||
logits = constant_op.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = constant_op.constant([[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
weights = [2.3, 2.4, 2.5]
|
||||
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
|
||||
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, {weights_placeholder: weights})
|
||||
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
|
||||
|
||||
def testLossWithDynamicallyShapedWeights2D(self):
|
||||
logits = constant_op.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = constant_op.constant([[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
weights = [[2.3], [2.4], [2.5]]
|
||||
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None])
|
||||
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, {weights_placeholder: weights})
|
||||
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
|
||||
|
||||
|
||||
class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
|
||||
@ -445,6 +473,30 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
loss_ops.sparse_softmax_cross_entropy(
|
||||
logits, labels, weights=weights).eval()
|
||||
|
||||
def testLossWithDynamicallyShapedWeights1D(self):
|
||||
logits = constant_op.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = constant_op.constant([2, 0, 1])
|
||||
weights = [2.3, 2.4, 2.5]
|
||||
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
|
||||
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder)
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, {weights_placeholder: weights})
|
||||
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
|
||||
|
||||
def testLossWithDynamicallyShapedWeights2D(self):
|
||||
logits = constant_op.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = constant_op.constant([2, 0, 1])
|
||||
weights = [[2.3], [2.4], [2.5]]
|
||||
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None])
|
||||
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder)
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, {weights_placeholder: weights})
|
||||
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
|
||||
|
||||
|
||||
class SigmoidCrossEntropyLossTest(test.TestCase):
|
||||
|
||||
|
@ -84,7 +84,7 @@ cuda_py_test(
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "nccl_manager_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = if_cuda(
|
||||
[
|
||||
"kernels/nccl_manager.cc",
|
||||
|
@ -95,7 +95,7 @@ class RNNCellTest(test.TestCase):
|
||||
input_size = 4
|
||||
feature_size = 2
|
||||
frequency_skip = 1
|
||||
num_shifts = (input_size - feature_size) / frequency_skip + 1
|
||||
num_shifts = (input_size - feature_size) // frequency_skip + 1
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
|
@ -880,7 +880,7 @@ names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
|
||||
|
||||
# Create the summary ops such that they also print out to std output:
|
||||
summary_ops = []
|
||||
for metric_name, metric_value in metrics_to_values.iteritems():
|
||||
for metric_name, metric_value in names_to_values.iteritems():
|
||||
op = tf.summary.scalar(metric_name, metric_value)
|
||||
op = tf.Print(op, [metric_value], metric_name)
|
||||
summary_ops.append(op)
|
||||
|
76
tensorflow/contrib/sparsemax/BUILD
Normal file
76
tensorflow/contrib/sparsemax/BUILD
Normal file
@ -0,0 +1,76 @@
|
||||
# Description:
|
||||
# Contains ops to train linear models on top of TensorFlow.
|
||||
# APIs here are meant to evolve over time.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
"tf_py_test",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "sparsemax_py",
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "sparsemax_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/sparsemax_test.py"],
|
||||
additional_deps = [
|
||||
":sparsemax_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "sparsemax_loss_test",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/sparsemax_loss_test.py"],
|
||||
additional_deps = [
|
||||
":sparsemax_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
30
tensorflow/contrib/sparsemax/__init__.py
Normal file
30
tensorflow/contrib/sparsemax/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Module that implements sparsemax and sparsemax loss, see [1].
|
||||
|
||||
[1] https://arxiv.org/abs/1602.02068
|
||||
|
||||
## Sparsemax
|
||||
|
||||
@@sparsemax
|
||||
@@sparsemax_loss
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.sparsemax.python.ops.sparsemax import sparsemax
|
||||
from tensorflow.contrib.sparsemax.python.ops.sparsemax_loss \
|
||||
import sparsemax_loss
|
@ -0,0 +1,224 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for SparsemaxLossOp."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.sparsemax import sparsemax, sparsemax_loss
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
test_obs = 10
|
||||
|
||||
|
||||
class SparsemaxLossTest(test.TestCase):
|
||||
|
||||
def _np_sparsemax(self, z):
|
||||
z = z - np.mean(z, axis=1)[:, np.newaxis]
|
||||
|
||||
# sort z
|
||||
z_sorted = np.sort(z, axis=1)[:, ::-1]
|
||||
|
||||
# calculate k(z)
|
||||
z_cumsum = np.cumsum(z_sorted, axis=1)
|
||||
k = np.arange(1, z.shape[1] + 1)
|
||||
z_check = 1 + k * z_sorted > z_cumsum
|
||||
# use argmax to get the index by row as .nonzero() doesn't
|
||||
# take an axis argument. np.argmax return the first index, but the last
|
||||
# index is required here, use np.flip to get the last index and
|
||||
# `z.shape[axis]` to compensate for np.flip afterwards.
|
||||
k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1)
|
||||
|
||||
# calculate tau(z)
|
||||
tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1]
|
||||
tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1)
|
||||
|
||||
# calculate p
|
||||
return np.maximum(0, z - tau_z)
|
||||
|
||||
def _np_sparsemax_loss(self, z, q):
|
||||
z = z - np.mean(z, axis=1)[:, np.newaxis]
|
||||
|
||||
# Calculate q^T * z
|
||||
z_k = np.sum(q * z, axis=1)
|
||||
|
||||
# calculate sum over S(z)
|
||||
p = self._np_sparsemax(z)
|
||||
s = p > 0
|
||||
# z_i^2 - tau(z)^2 = p_i (2 * z_i - p_i) for i \in S(z)
|
||||
S_sum = np.sum(s * p * (2 * z - p), axis=1)
|
||||
|
||||
# because q is binary, sum([q_1^2, q_2^2, ...]) is just sum(q)
|
||||
q_norm = np.sum(q, axis=1)
|
||||
|
||||
return -z_k + 0.5 * S_sum + 0.5 * q_norm
|
||||
|
||||
def _np_sparsemax_loss_grad(self, z, q):
|
||||
# chain rule
|
||||
grad = 1
|
||||
|
||||
return grad * (-q + self._np_sparsemax(z))
|
||||
|
||||
def _tf_sparsemax(self, z, dtype, use_gpu):
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_sparsemax_op = sparsemax(z.astype(dtype))
|
||||
tf_sparsemax_out = tf_sparsemax_op.eval()
|
||||
|
||||
return tf_sparsemax_op, tf_sparsemax_out
|
||||
|
||||
def _tf_sparsemax_loss(self, z, q, dtype, use_gpu):
|
||||
z = z.astype(dtype)
|
||||
q = q.astype(dtype)
|
||||
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_sparsemax_op = sparsemax(z)
|
||||
tf_loss_op = sparsemax_loss(z, tf_sparsemax_op, q)
|
||||
tf_loss_out = tf_loss_op.eval()
|
||||
|
||||
return tf_loss_op, tf_loss_out
|
||||
|
||||
def _test_sparsemax_loss_against_numpy(self, dtype, random, use_gpu):
|
||||
"""check sparsemax-loss kernel against numpy"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||
q = np.zeros((test_obs, 10))
|
||||
q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1
|
||||
|
||||
tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
|
||||
np_loss = self._np_sparsemax_loss(z, q).astype(dtype)
|
||||
|
||||
self.assertAllCloseAccordingToType(np_loss, tf_loss_out,
|
||||
half_atol=1e-2, half_rtol=5e-3)
|
||||
self.assertShapeEqual(np_loss, tf_loss_op)
|
||||
|
||||
def _test_constant_add(self, dtype, random, use_gpu):
|
||||
"""check sparsemax-loss proposition 3"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||
c = random.uniform(low=-3, high=3, size=(test_obs, 1))
|
||||
q = np.zeros((test_obs, 10))
|
||||
q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
|
||||
|
||||
_, tf_loss_zpc = self._tf_sparsemax_loss(
|
||||
z + c, q, dtype, use_gpu
|
||||
)
|
||||
|
||||
_, tf_loss_z = self._tf_sparsemax_loss(
|
||||
z, q, dtype, use_gpu
|
||||
)
|
||||
|
||||
self.assertAllCloseAccordingToType(tf_loss_zpc, tf_loss_z,
|
||||
float_atol=5e-6, float_rtol=5e-6,
|
||||
half_atol=1e-2, half_rtol=1e-2)
|
||||
|
||||
def _test_sparsemax_loss_positive(self, dtype, random, use_gpu):
|
||||
"""check sparsemax-loss proposition 4"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||
q = np.zeros((test_obs, 10))
|
||||
q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1
|
||||
|
||||
tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
|
||||
|
||||
self.assertAllCloseAccordingToType(np.abs(tf_loss_out), tf_loss_out)
|
||||
self.assertShapeEqual(np.zeros(test_obs), tf_loss_op)
|
||||
|
||||
def _test_sparsemax_loss_zero(self, dtype, random, use_gpu):
|
||||
"""check sparsemax-loss proposition 5"""
|
||||
# construct z and q, such that z_k >= 1 + max_{j!=k} z_k holds for
|
||||
# delta_0 = 1.
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||
z[:, 0] = np.max(z, axis=1) + 1.05
|
||||
|
||||
q = np.zeros((test_obs, 10))
|
||||
q[:, 0] = 1
|
||||
|
||||
tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
|
||||
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
|
||||
|
||||
self.assertAllCloseAccordingToType(np.zeros(test_obs), tf_loss_out)
|
||||
self.assertShapeEqual(np.zeros(test_obs), tf_loss_op)
|
||||
|
||||
self.assertAllCloseAccordingToType(q, tf_sparsemax_out)
|
||||
self.assertShapeEqual(q, tf_sparsemax_op)
|
||||
|
||||
def _test_gradient_against_estimate(self, dtype, random, use_gpu):
|
||||
"""check sparsemax-loss Rop, aginst estimated-loss Rop"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
|
||||
q = np.zeros((test_obs, 10)).astype(dtype)
|
||||
q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
|
||||
|
||||
logits = array_ops.placeholder(dtype, name='z')
|
||||
sparsemax_op = sparsemax(logits)
|
||||
loss_op = sparsemax_loss(logits, sparsemax_op, q)
|
||||
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
logits, z.shape,
|
||||
loss_op, (test_obs, ),
|
||||
x_init_value=z, delta=1e-9
|
||||
)
|
||||
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def _test_gradient_against_numpy(self, dtype, random, use_gpu):
|
||||
"""check sparsemax-loss Rop, aginst numpy Rop"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||
q = np.zeros((test_obs, 10))
|
||||
q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
|
||||
|
||||
logits = constant_op.constant(z.astype(dtype), name='z')
|
||||
sparsemax_op = sparsemax(logits)
|
||||
loss_op = sparsemax_loss(logits, sparsemax_op, q.astype(dtype))
|
||||
loss_grad_op = gradients_impl.gradients(loss_op, [logits])[0]
|
||||
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_grad = loss_grad_op.eval()
|
||||
np_grad = self._np_sparsemax_loss_grad(z, q).astype(dtype)
|
||||
|
||||
self.assertAllCloseAccordingToType(np_grad, tf_grad,
|
||||
half_atol=1e-2, half_rtol=5e-3)
|
||||
self.assertShapeEqual(np_grad, loss_grad_op)
|
||||
|
||||
def _test_dtype(self, dtype):
|
||||
random = np.random.RandomState(1)
|
||||
|
||||
self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_constant_add(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_sparsemax_loss_positive(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_sparsemax_loss_zero(dtype, random, use_gpu=False)
|
||||
|
||||
# sparsemax is not a smooth function so gradient estimation is only
|
||||
# possibol for float64.
|
||||
if dtype == 'float64':
|
||||
self._test_gradient_against_estimate(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_gradient_against_numpy(dtype, random, use_gpu=False)
|
||||
|
||||
def testFloat(self):
|
||||
self._test_dtype('float32')
|
||||
|
||||
def testDouble(self):
|
||||
self._test_dtype('float64')
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,252 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for SparsemaxOp."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.sparsemax import sparsemax
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
test_obs = 10
|
||||
|
||||
|
||||
class SparsemaxTest(test.TestCase):
|
||||
|
||||
def _np_sparsemax(self, z):
|
||||
z = z - np.mean(z, axis=1)[:, np.newaxis]
|
||||
|
||||
# sort z
|
||||
z_sorted = np.sort(z, axis=1)[:, ::-1]
|
||||
|
||||
# calculate k(z)
|
||||
z_cumsum = np.cumsum(z_sorted, axis=1)
|
||||
k = np.arange(1, z.shape[1] + 1)
|
||||
z_check = 1 + k * z_sorted > z_cumsum
|
||||
# use argmax to get the index by row as .nonzero() doesn't
|
||||
# take an axis argument. np.argmax return the first index, but the last
|
||||
# index is required here, use np.flip to get the last index and
|
||||
# `z.shape[axis]` to compensate for np.flip afterwards.
|
||||
k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1)
|
||||
|
||||
# calculate tau(z)
|
||||
tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1]
|
||||
tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1)
|
||||
|
||||
# calculate p
|
||||
return np.maximum(0, z - tau_z)
|
||||
|
||||
def _np_sparsemax_grad(self, z):
|
||||
# chain rule
|
||||
grad = np.ones_like(z)
|
||||
|
||||
# Construct S(z)
|
||||
probability = self._np_sparsemax(z)
|
||||
support = probability > 0
|
||||
|
||||
# Calculate \hat{v}, which will be a vector (scalar for each z)
|
||||
v_hat = np.sum(grad * support, axis=1) / np.sum(support, axis=1)
|
||||
|
||||
# Calculates J(z) * v
|
||||
return support * (grad - v_hat[:, np.newaxis])
|
||||
|
||||
def _tf_sparsemax(self, z, dtype, use_gpu):
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_sparsemax_op = sparsemax(z.astype(dtype))
|
||||
tf_sparsemax_out = tf_sparsemax_op.eval()
|
||||
|
||||
return tf_sparsemax_op, tf_sparsemax_out
|
||||
|
||||
def _test_sparsemax_against_numpy(self, dtype, random, use_gpu):
|
||||
"""check sparsemax kernel against numpy"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||
|
||||
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
|
||||
p_sparemax = self._np_sparsemax(z).astype(dtype)
|
||||
|
||||
self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out,
|
||||
half_atol=5e-3)
|
||||
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
|
||||
|
||||
def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
|
||||
"""check sparsemax proposition 1, part 1"""
|
||||
z = np.zeros((1, 10))
|
||||
|
||||
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
|
||||
p_sparemax = np.ones_like(z, dtype=dtype) / z.size
|
||||
|
||||
self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out)
|
||||
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
|
||||
|
||||
def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
|
||||
"""check sparsemax proposition 1, part 2"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||
|
||||
# assume |A(z)| = 1, as z is continues random
|
||||
z_sort_arg = np.argsort(z, axis=1)[:, ::-1]
|
||||
z_sort = np.sort(z, axis=-1)[:, ::-1]
|
||||
gamma_z = z_sort[:, 0] - z_sort[:, 1]
|
||||
epsilon = (0.99 * gamma_z * 1).reshape(-1, 1)
|
||||
|
||||
# construct the expected 1_A(z) array
|
||||
p_expected = np.zeros((test_obs, 10), dtype=dtype)
|
||||
p_expected[np.arange(0, test_obs), z_sort_arg[:, 0]] = 1
|
||||
|
||||
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
|
||||
(1 / epsilon) * z, dtype, use_gpu
|
||||
)
|
||||
|
||||
self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out)
|
||||
self.assertShapeEqual(p_expected, tf_sparsemax_op)
|
||||
|
||||
def _test_constant_add(self, dtype, random, use_gpu):
|
||||
"""check sparsemax proposition 2"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
|
||||
c = random.uniform(low=-3, high=3, size=(test_obs, 1)).astype(dtype)
|
||||
|
||||
_, tf_sparsemax_zpc = self._tf_sparsemax(
|
||||
z + c, dtype, use_gpu
|
||||
)
|
||||
|
||||
_, tf_sparsemax_z = self._tf_sparsemax(
|
||||
z, dtype, use_gpu
|
||||
)
|
||||
|
||||
self.assertAllCloseAccordingToType(tf_sparsemax_zpc, tf_sparsemax_z,
|
||||
half_atol=5e-3)
|
||||
|
||||
def _test_permutation(self, dtype, random, use_gpu):
|
||||
"""check sparsemax proposition 3"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||
_, p = self._tf_sparsemax(z, dtype, use_gpu)
|
||||
|
||||
for i in range(test_obs):
|
||||
per = random.permutation(10)
|
||||
|
||||
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
|
||||
z[i, per].reshape(1, -1), dtype, use_gpu
|
||||
)
|
||||
p_expected = p[i, per].reshape(1, -1)
|
||||
|
||||
self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out,
|
||||
half_atol=5e-3)
|
||||
self.assertShapeEqual(p_expected, tf_sparsemax_op)
|
||||
|
||||
def _test_diffrence(self, dtype, random, use_gpu):
|
||||
"""check sparsemax proposition 4"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
|
||||
_, p = self._tf_sparsemax(z, dtype, use_gpu)
|
||||
|
||||
etol = {'float16': 1e-2, 'float32': 1e-6, 'float64': 1e-9}[dtype]
|
||||
|
||||
for val in range(0, test_obs):
|
||||
for i in range(0, 10):
|
||||
for j in range(0, 10):
|
||||
# check condition, the obesite pair will be checked anyway
|
||||
if z[val, i] > z[val, j]:
|
||||
continue
|
||||
|
||||
self.assertTrue(
|
||||
0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol,
|
||||
"0 <= %.10f <= %.10f" % (
|
||||
p[val, j] - p[val, i], z[val, j] - z[val, i] + etol
|
||||
)
|
||||
)
|
||||
|
||||
def _test_two_dimentional(self, dtype, random, use_gpu):
|
||||
"""check two dimentation sparsemax case"""
|
||||
t = np.linspace(-2, 2, test_obs, dtype=dtype)
|
||||
z = np.vstack([
|
||||
t, np.zeros(test_obs, dtype=dtype)
|
||||
]).T
|
||||
|
||||
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
|
||||
|
||||
p0_expected = np.select([t < -1, t <= 1, t > 1], [0, (t + 1) / 2, 1])
|
||||
|
||||
self.assertAllCloseAccordingToType(p0_expected, tf_sparsemax_out[:, 0])
|
||||
self.assertAllCloseAccordingToType(1 - p0_expected, tf_sparsemax_out[:, 1])
|
||||
self.assertShapeEqual(z, tf_sparsemax_op)
|
||||
|
||||
def _test_gradient_against_estimate(self, dtype, random, use_gpu):
|
||||
"""check sparsemax Rop, aginst estimated Rop"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
|
||||
|
||||
logits = array_ops.placeholder(dtype, name='z')
|
||||
sparsemax_op = sparsemax(logits)
|
||||
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
logits, z.shape,
|
||||
sparsemax_op, z.shape,
|
||||
x_init_value=z, delta=1e-9
|
||||
)
|
||||
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def _test_gradient_against_numpy(self, dtype, random, use_gpu):
|
||||
"""check sparsemax Rop, aginst numpy Rop"""
|
||||
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
|
||||
|
||||
logits = constant_op.constant(z, name='z')
|
||||
sparsemax_op = sparsemax(logits)
|
||||
sparsemax_grad_op = gradients_impl.gradients(sparsemax_op, [logits])[0]
|
||||
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_grad = sparsemax_grad_op.eval()
|
||||
np_grad = self._np_sparsemax_grad(z)
|
||||
|
||||
self.assertAllCloseAccordingToType(np_grad, tf_grad)
|
||||
self.assertShapeEqual(np_grad, sparsemax_grad_op)
|
||||
|
||||
def _test_dtype(self, dtype):
|
||||
random = np.random.RandomState(1)
|
||||
|
||||
self._test_sparsemax_against_numpy(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_sparsemax_of_inf(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_constant_add(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_permutation(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_diffrence(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_two_dimentional(dtype, random, use_gpu=False)
|
||||
|
||||
# sparsemax is not a smooth function so gradient estimation is only
|
||||
# possibol for float64.
|
||||
if dtype == 'float64':
|
||||
self._test_gradient_against_estimate(dtype, random, use_gpu=False)
|
||||
|
||||
self._test_gradient_against_numpy(dtype, random, use_gpu=False)
|
||||
|
||||
def testFloat(self):
|
||||
self._test_dtype('float32')
|
||||
|
||||
def testDouble(self):
|
||||
self._test_dtype('float64')
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
74
tensorflow/contrib/sparsemax/python/ops/sparsemax.py
Normal file
74
tensorflow/contrib/sparsemax/python/ops/sparsemax.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Sparsemax op."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.framework import ops, dtypes
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
|
||||
def sparsemax(logits, name=None):
|
||||
"""Computes sparsemax activations [1].
|
||||
|
||||
For each batch `i` and class `j` we have
|
||||
sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)
|
||||
|
||||
[1]: https://arxiv.org/abs/1602.02068
|
||||
|
||||
Args:
|
||||
logits: A `Tensor`. Must be one of the following types: `half`, `float32`,
|
||||
`float64`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor`. Has the same type as `logits`.
|
||||
"""
|
||||
|
||||
with ops.name_scope(name, "sparsemax", [logits]) as name:
|
||||
logits = ops.convert_to_tensor(logits, name="logits")
|
||||
obs = array_ops.shape(logits)[0]
|
||||
dims = array_ops.shape(logits)[1]
|
||||
|
||||
z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
|
||||
|
||||
# sort z
|
||||
z_sorted, _ = nn.top_k(z, k=dims)
|
||||
|
||||
# calculate k(z)
|
||||
z_cumsum = math_ops.cumsum(z_sorted, axis=1)
|
||||
k = math_ops.range(
|
||||
1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype
|
||||
)
|
||||
z_check = 1 + k * z_sorted > z_cumsum
|
||||
# because the z_check vector is always [1,1,...1,0,0,...0] finding the
|
||||
# (index + 1) of the last `1` is the same as just summing the number of 1.
|
||||
k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1)
|
||||
|
||||
# calculate tau(z)
|
||||
indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1)
|
||||
tau_sum = array_ops.gather_nd(z_cumsum, indices)
|
||||
tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype)
|
||||
|
||||
# calculate p
|
||||
return math_ops.maximum(
|
||||
math_ops.cast(0, logits.dtype),
|
||||
z - tau_z[:, array_ops.newaxis]
|
||||
)
|
59
tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
Normal file
59
tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Sparsemax Loss op."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def sparsemax_loss(logits, sparsemax, labels, name=None):
|
||||
"""Computes sparsemax loss function [1].
|
||||
|
||||
[1]: https://arxiv.org/abs/1602.02068
|
||||
|
||||
Args:
|
||||
logits: A `Tensor`. Must be one of the following types: `half`, `float32`,
|
||||
`float64`.
|
||||
sparsemax: A `Tensor`. Must have the same type as `logits`.
|
||||
labels: A `Tensor`. Must have the same type as `logits`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor`. Has the same type as `logits`.
|
||||
"""
|
||||
|
||||
with ops.name_scope(name, "sparsemax_loss",
|
||||
[logits, sparsemax, labels]) as name:
|
||||
logits = ops.convert_to_tensor(logits, name="logits")
|
||||
sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax")
|
||||
labels = ops.convert_to_tensor(labels, name="labels")
|
||||
|
||||
shifted_logits = logits - \
|
||||
math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
|
||||
|
||||
# sum over support
|
||||
support = math_ops.cast(sparsemax > 0, sparsemax.dtype)
|
||||
sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax)
|
||||
|
||||
# - z_k + ||q||^2
|
||||
q_part = labels * (0.5 * labels - shifted_logits)
|
||||
|
||||
return math_ops.reduce_sum(sum_s + q_part, axis=1)
|
@ -77,6 +77,8 @@ load(
|
||||
"tf_opts_nortti_if_android",
|
||||
"cc_header_only_library",
|
||||
)
|
||||
|
||||
#load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
|
||||
@ -111,7 +113,10 @@ load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
|
||||
#load(
|
||||
# "//third_party/mkl:build_defs.bzl",
|
||||
# "if_mkl",
|
||||
#)
|
||||
# -----------------------------------------------------------------------------
|
||||
# Public targets
|
||||
|
||||
@ -1863,6 +1868,35 @@ tf_cc_tests(
|
||||
],
|
||||
)
|
||||
|
||||
#if_mkl(
|
||||
# tf_cc_test_mkl(
|
||||
# name = "mkl_related_tests",
|
||||
# size = "small",
|
||||
# srcs = ["graph/mkl_optimizer_merge_test.cc"],
|
||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||
# deps = [
|
||||
# ":core",
|
||||
# ":core_cpu",
|
||||
# ":core_cpu_internal",
|
||||
# ":direct_session_internal",
|
||||
# ":framework",
|
||||
# ":framework_internal",
|
||||
# ":lib",
|
||||
# ":lib_internal",
|
||||
# ":ops",
|
||||
# ":protos_all_cc",
|
||||
# ":test",
|
||||
# ":test_main",
|
||||
# ":testlib",
|
||||
# "//third_party/eigen3",
|
||||
# "//tensorflow/cc:cc_ops",
|
||||
# "//tensorflow/cc:scope",
|
||||
# "//tensorflow/cc:sendrecv_ops",
|
||||
# "//tensorflow/core/kernels:ops_util",
|
||||
# ],
|
||||
# ),
|
||||
#)
|
||||
|
||||
tf_cc_tests_gpu(
|
||||
name = "gpu_related_tests",
|
||||
size = "small",
|
||||
|
596
tensorflow/core/graph/mkl_optimizer_merge.cc
Normal file
596
tensorflow/core/graph/mkl_optimizer_merge.cc
Normal file
@ -0,0 +1,596 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
// This module implements node merging optimization on the graph.
|
||||
// We process the nodes in the graph in reverse postorder
|
||||
// (i.e. inputs before their downstream dependencies).
|
||||
//
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/graph/mkl_optimizer_merge.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// How many hops do we search for matching node in the backward dataflow graph?
|
||||
// We use maxhop of 10 based on empirical observations. Also, these are
|
||||
// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D)
|
||||
// directly goes to backward nodes, we do not expect the hop-distance
|
||||
// would be more than few nodes.
|
||||
static size_t kNodeMergeContextMaxDepth = 10;
|
||||
|
||||
// This optimization pass performs two tasks: merge
|
||||
// nodes in the forward pass, and rewrite the gradient ops
|
||||
// corresponding to merged forward ops.
|
||||
//
|
||||
// Merging nodes in the graph: Currently, it merges Conv2D+AddBias together.
|
||||
//
|
||||
// Rewriting nodes in the graph: This is neded in order to optimize
|
||||
// gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and
|
||||
// MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into
|
||||
// Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad.
|
||||
// This is context-specific optimization, where the context is the
|
||||
// forward operator that the BiasAddGrad corresponds to.
|
||||
class NodeMergeRewritePass : public GraphOptimizationPass {
|
||||
public:
|
||||
NodeMergeRewritePass() {
|
||||
csinfo_.conv2d = "Conv2D";
|
||||
csinfo_.conv2dwithbias = "Conv2DWithBias";
|
||||
csinfo_.conv2dwithbiasbackpropbias = "Conv2DWithBiasBackpropBias";
|
||||
csinfo_.biasadd = "BiasAdd";
|
||||
csinfo_.matmul = "MatMul";
|
||||
csinfo_.biasaddgrad = "BiasAddGrad";
|
||||
|
||||
minfo_.push_back({csinfo_.conv2d, csinfo_.biasadd, 0,
|
||||
csinfo_.conv2dwithbias});
|
||||
|
||||
// We use maxhop of 10 based on emperical observations. Also, these are
|
||||
// maxhops in backward data-flow graph. Since input of forward nodes
|
||||
// (Conv2D) directly goes to backward nodes, we do not expect the
|
||||
// hop-distance would be more than few nodes.
|
||||
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
|
||||
{csinfo_.conv2dwithbias, kNodeMergeContextMaxDepth}});
|
||||
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
|
||||
{csinfo_.conv2d, kNodeMergeContextMaxDepth}});
|
||||
// For now, we are rewriting BiasAddGrad to BiasAddGrad for MatMul. This is
|
||||
// because we do not have a separate Op for MatMulwithBias.
|
||||
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.biasaddgrad,
|
||||
{csinfo_.matmul, kNodeMergeContextMaxDepth}});
|
||||
}
|
||||
|
||||
// Standard interface to run optimization pass
|
||||
Status Run(const GraphOptimizationPassOptions& options);
|
||||
|
||||
// Helper function which does most of heavy lifting for node merge
|
||||
//
|
||||
// Extracts common functionality between Run public interface and
|
||||
// test interface.
|
||||
//
|
||||
// @return true, if and only if graph is mutated; false otherwise.
|
||||
bool RunPass(std::unique_ptr<Graph>* g);
|
||||
|
||||
private:
|
||||
/// Structure to specify information used in node merge
|
||||
typedef struct {
|
||||
string pred; // Predecessor node string
|
||||
string succ; // Successor node string
|
||||
int op; // What operand no the predecessor node corresponds
|
||||
// to successor node?
|
||||
string newnode; // Name of the node after merge
|
||||
} MergeInfo;
|
||||
|
||||
/// Structure to specify information used in node rewrite
|
||||
typedef struct {
|
||||
string node; // Name of the node to be rewritten
|
||||
string rewrite; // New name of the node after rewrite
|
||||
typedef struct {
|
||||
string fwd; // Node name in forward pass that this node
|
||||
// corresponds to
|
||||
size_t maxhop; // Maximum number of hops the mfwd_ is located
|
||||
// from this node. If mfwd_ is farther than mmaxhop_
|
||||
// then we do not rewrite the node.
|
||||
} ContextInfo;
|
||||
ContextInfo cinfo; // Context for rewrite
|
||||
} RewriteInfo;
|
||||
|
||||
/// Structure to store all constant strings
|
||||
typedef struct {
|
||||
string conv2d;
|
||||
string conv2dwithbias;
|
||||
string conv2dwithbiasbackpropbias;
|
||||
string biasadd;
|
||||
string matmul;
|
||||
string biasaddgrad;
|
||||
} ConstStringInfo;
|
||||
|
||||
ConstStringInfo csinfo_;
|
||||
std::vector<MergeInfo> minfo_;
|
||||
std::vector<RewriteInfo> rinfo_;
|
||||
|
||||
private:
|
||||
// Return a node that can be merged with input node
|
||||
//
|
||||
// @return pointer to the node if we can find such a
|
||||
// node. Otherwise, it returns nullptr.
|
||||
Node* FindNodeForMerge(const Node* a) const;
|
||||
|
||||
// Merge predecessor node with its successor.
|
||||
// Currently, we merge Conv2D with AddBias only.
|
||||
//
|
||||
// Input nodes succ and pred may be deleted if the call to
|
||||
// this function is successful. Attempt to use the pointers
|
||||
// after the call to function may result is undefined behaviors.
|
||||
//
|
||||
// @input g - input graph, succ - successor node, pred - predecessor node
|
||||
// @return Status::OK(), if merging is successful and supported.
|
||||
// Returns appropriate Status error code otherwise.
|
||||
// Graph is updated in case nodes are merged. Otherwise, it is
|
||||
// not updated.
|
||||
Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred);
|
||||
|
||||
// Is input node (n) a candidate for rewrite?
|
||||
//
|
||||
// @return true, if it can be rewritten; false, otherwise.
|
||||
bool IsApplicableRewriteNode(const Node* n) const;
|
||||
|
||||
// Rewrites input node to a new node specified by its matching rewrite info.
|
||||
//
|
||||
// Method first searches matching rewrite info for input node and then
|
||||
// uses that info to rewrite.
|
||||
//
|
||||
// Input node may be deleted in case of rewrite. Attempt to use the node
|
||||
// after the call can result in undefined behaviors.
|
||||
//
|
||||
// @input g - input graph, n - Node to be rewritten
|
||||
// @return Status::OK(), if the input node is rewritten;
|
||||
// Returns appropriate Status error code otherwise.
|
||||
// Graph is updated in case the input node is rewritten.
|
||||
// Otherwise, it is not updated.
|
||||
Status RewriteNode(std::unique_ptr<Graph>* g, Node* n);
|
||||
|
||||
// Helper function that searches the matching rewriteinfo for the node.
|
||||
// Implements depth-first search in the data dependence graph for the
|
||||
// gradient op in backward direction.
|
||||
//
|
||||
// @input n - Node (gradient op) whose rewriteinfo is to be searched,
|
||||
// fwdn - pointer to node from the forward pass that this node
|
||||
// belongs to
|
||||
// @return Matching rewriteinfo in case a match is found; null otherwise.
|
||||
const RewriteInfo* FindMatchingRewriteInfo(const Node* n,
|
||||
const Node** fwdn) const;
|
||||
};
|
||||
|
||||
/// We register merge optimizer for phase 1 and MKLToTF insertion for phase 2.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1,
|
||||
NodeMergeRewritePass);
|
||||
|
||||
static void FillInputs(const Node* n,
|
||||
gtl::InlinedVector<Node*, 4>* control_edges,
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
|
||||
DCHECK_EQ(in->size(), n->num_inputs());
|
||||
control_edges->clear();
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
control_edges->push_back(e->src());
|
||||
} else {
|
||||
(*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
|
||||
}
|
||||
}
|
||||
std::sort(control_edges->begin(), control_edges->end());
|
||||
if (n->op_def().is_commutative()) {
|
||||
// For commutative inputs, we sort the input by the input Node*
|
||||
// to get a canonical ordering (so that add(a,b) and add(b, a) will
|
||||
// hash to the same value if is_commutative is true for 'add').
|
||||
std::sort(in->begin(), in->end());
|
||||
}
|
||||
}
|
||||
|
||||
Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const {
|
||||
// Search for all matching mergeinfo.
|
||||
// We allow more than one match for extensibility.
|
||||
std::vector<const MergeInfo*> matching_mi;
|
||||
for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
|
||||
if (a->type_string() == mi->succ) {
|
||||
matching_mi.push_back(&*mi);
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "FindNodeForMerge: " << a->type_string();
|
||||
|
||||
for (const MergeInfo* mi : matching_mi) {
|
||||
const int N_in = a->num_inputs();
|
||||
if (mi->op >= N_in) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get the control edges and input of node
|
||||
gtl::InlinedVector<Node*, 4> a_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
|
||||
FillInputs(a, &a_control_edges, &a_in);
|
||||
|
||||
// Get operand op of the operator
|
||||
Node *b = nullptr;
|
||||
b = a_in[mi->op].first;
|
||||
if (b == nullptr || (b->type_string() != mi->pred)) {
|
||||
// NOTE: Should the first check be assert?
|
||||
continue;
|
||||
}
|
||||
|
||||
VLOG(1) << " FindNode: " << b->type_string();
|
||||
|
||||
gtl::InlinedVector<Node*, 4> b_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
|
||||
FillInputs(b, &b_control_edges, &b_in);
|
||||
|
||||
// Shouldn't merge if a and b have different control edges.
|
||||
if (a_control_edges != b_control_edges) {
|
||||
continue;
|
||||
} else {
|
||||
// We found a match.
|
||||
return b;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g,
|
||||
Node* succ, Node* pred) {
|
||||
CHECK_NOTNULL(succ);
|
||||
CHECK_NOTNULL(pred);
|
||||
|
||||
if (succ->type_string() == csinfo_.biasadd &&
|
||||
pred->type_string() == csinfo_.conv2d) {
|
||||
// 1. Get all attributes from input nodes.
|
||||
DataType T_pred, T_succ;
|
||||
string padding;
|
||||
std::vector<int32> strides;
|
||||
string data_format_pred, data_format_succ;
|
||||
bool use_cudnn_on_gnu;
|
||||
int groups = 1;
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
|
||||
TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
|
||||
TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu",
|
||||
&use_cudnn_on_gnu));
|
||||
// Groups attribute may not be there on the input node. So we do not
|
||||
// check for error in GetNodeAttr call.
|
||||
GetNodeAttr(pred->def(), "groups", &groups);
|
||||
// We check to ensure that data formats of both succ and pred are same.
|
||||
// We expect them to be same, so we can enforce this as assert.
|
||||
// But assert can be too strict, so we enforce this as a check.
|
||||
// If the check fails, then we do not merge two nodes.
|
||||
if (data_format_pred != data_format_succ ||
|
||||
T_pred != T_succ) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"data_format or T attribute of Conv2D and BiasAdd"
|
||||
"do not match. Will skip node merge optimization");
|
||||
}
|
||||
|
||||
// 2. Get inputs from both the nodes.
|
||||
// Find the 2 inputs from the conv and the bias from the add Bias.
|
||||
Node* oper1 = nullptr;
|
||||
Node* oper2 = nullptr;
|
||||
Node* oper3 = nullptr;
|
||||
|
||||
const int succ_num = succ->num_inputs();
|
||||
gtl::InlinedVector<Node*, 4> succ_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
|
||||
FillInputs(succ, &succ_control_edges, &succ_in);
|
||||
|
||||
const int pred_num = pred->num_inputs();
|
||||
gtl::InlinedVector<Node*, 4> pred_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
|
||||
FillInputs(pred, &pred_control_edges, &pred_in);
|
||||
|
||||
// We need to ensure that there is only 1 edge between Conv2D and AddBias.
|
||||
// Otherwise, merging is semantically incorrect.
|
||||
if (pred->out_edges().size() != 1) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"Conv2D has multiple outputs."
|
||||
"Will skip node merge optimization");
|
||||
}
|
||||
|
||||
for (const Edge *e : pred->out_edges()) {
|
||||
if (e->dst() != succ) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"Conv2D does not feed to BiasAdd."
|
||||
"Will skip node merge optimization");
|
||||
}
|
||||
}
|
||||
|
||||
// Get operand 0, 1 of conv2D
|
||||
oper1 = pred_in[0].first;
|
||||
oper2 = pred_in[1].first;
|
||||
// Get operand 1 of add_bias
|
||||
oper3 = succ_in[1].first;
|
||||
|
||||
Node* ret;
|
||||
// We will use the node name of BiasAdd as the name of new node
|
||||
TF_CHECK_OK(NodeBuilder(succ->name(), csinfo_.conv2dwithbias)
|
||||
.Input(oper1)
|
||||
.Input(oper2)
|
||||
.Input(oper3)
|
||||
.Attr("T", T_pred)
|
||||
.Attr("strides", strides)
|
||||
.Attr("padding", padding)
|
||||
.Attr("data_format", data_format_pred)
|
||||
.Attr("use_cudnn_on_gpu", use_cudnn_on_gnu)
|
||||
.Attr("groups", groups)
|
||||
.Finalize(&**g, &ret));
|
||||
CHECK_NOTNULL(ret);
|
||||
|
||||
// Incoming edges are fixed, we will fix the outgoing edges now.
|
||||
for (const Edge* e : succ->out_edges()) {
|
||||
(*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
|
||||
}
|
||||
|
||||
(*g)->RemoveNode(succ);
|
||||
(*g)->RemoveNode(pred);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status(error::Code::UNIMPLEMENTED,
|
||||
"Unimplemented case for node merge optimization.");
|
||||
}
|
||||
|
||||
Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node *n) {
|
||||
CHECK_NOTNULL(n);
|
||||
|
||||
// Get the matching rewriteinfo for the node
|
||||
const Node* fwdn = nullptr;
|
||||
const RewriteInfo* ri = FindMatchingRewriteInfo(n, &fwdn);
|
||||
if (ri == nullptr || fwdn == nullptr) {
|
||||
VLOG(1) << "Rewriteinfo not found for: " << n->type_string();
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"Rewrite info not found for the node."
|
||||
"Will skip node rewrite optimization");
|
||||
}
|
||||
|
||||
VLOG(1) << "Rewrite called for: " << n->type_string();
|
||||
|
||||
if (n->type_string() == csinfo_.biasaddgrad &&
|
||||
ri->node == csinfo_.biasaddgrad &&
|
||||
(ri->rewrite == csinfo_.conv2dwithbiasbackpropbias ||
|
||||
ri->rewrite == csinfo_.biasaddgrad)) {
|
||||
DataType T; string data_format;
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format));
|
||||
|
||||
int n_num = n->num_inputs(); // this must be 1.
|
||||
CHECK_EQ(n_num, 1);
|
||||
|
||||
gtl::InlinedVector<Node*, 4> n_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> n_in(n_num);
|
||||
FillInputs(n, &n_control_edges, &n_in);
|
||||
|
||||
Node *ret = nullptr, *op = n_in[0].first;
|
||||
|
||||
if (ri->rewrite == csinfo_.conv2dwithbiasbackpropbias) {
|
||||
// Get strides info from Conv2D (node in the forward pass that this
|
||||
// node corresponds to).
|
||||
std::vector<int32> strides;
|
||||
TF_CHECK_OK(GetNodeAttr(fwdn->def(), "strides", &strides));
|
||||
|
||||
// We use same name as original node name as there may be fetchoutputs
|
||||
// associated with it.
|
||||
TF_CHECK_OK(NodeBuilder(n->name(), ri->rewrite)
|
||||
.Input(op)
|
||||
.Attr("T", T)
|
||||
.Attr("data_format", data_format)
|
||||
.Attr("strides", strides)
|
||||
.Finalize(&**g, &ret));
|
||||
} else {
|
||||
CHECK_EQ(ri->rewrite, csinfo_.biasaddgrad);
|
||||
TF_CHECK_OK(NodeBuilder(n->name(), ri->rewrite)
|
||||
.Input(op)
|
||||
.Attr("T", T)
|
||||
.Attr("data_format", data_format)
|
||||
.Finalize(&**g, &ret));
|
||||
}
|
||||
|
||||
CHECK_NOTNULL(ret);
|
||||
|
||||
// Incoming edges are fixed, we will fix the outgoing edges now.
|
||||
for (const Edge* e : n->out_edges()) {
|
||||
(*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
|
||||
}
|
||||
|
||||
VLOG(1) << "Rewrite node: " << n->type_string() << " successful";
|
||||
(*g)->RemoveNode(n);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status(error::Code::UNIMPLEMENTED,
|
||||
"Unimplemented case for node rewrite optimization.");
|
||||
}
|
||||
|
||||
const NodeMergeRewritePass::RewriteInfo*
|
||||
NodeMergeRewritePass::FindMatchingRewriteInfo(const Node* n,
|
||||
const Node** fwdn) const {
|
||||
CHECK_NOTNULL(n);
|
||||
CHECK_NOTNULL(fwdn);
|
||||
*fwdn = nullptr;
|
||||
|
||||
// Search for matching rewriteinfo based on node name.
|
||||
// There could be more than one matching rewriteinfos.
|
||||
std::vector<const RewriteInfo*> matching_ri;
|
||||
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
|
||||
if (n->type_string() == ri->node) {
|
||||
matching_ri.push_back(&*ri);
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "Searching graph for: " << n->type_string() << " in backwards.";
|
||||
|
||||
// Now we will check for forward op name for rewrite info in data
|
||||
// flow graph. Get the max hops we should search for the fwd node
|
||||
// We are now going to search (breadth-first) backwards in data
|
||||
// dependence graph (for up to max hops) from n for the node
|
||||
// specified in fwd.
|
||||
// queue to maintain nodes to be visited and depth info for
|
||||
// breadth-first search
|
||||
std::queue<std::pair<const Node*, int>> nqueue;
|
||||
const Node* curr_node = n;
|
||||
size_t curr_depth = 0;
|
||||
nqueue.push(std::make_pair(curr_node, curr_depth));
|
||||
|
||||
while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) {
|
||||
std::pair<const Node*, int> curr_pair = nqueue.front();
|
||||
nqueue.pop();
|
||||
|
||||
std::set<const Node*> visited_nodes;
|
||||
curr_node = curr_pair.first;
|
||||
curr_depth = curr_pair.second;
|
||||
CHECK_NOTNULL(curr_node);
|
||||
|
||||
VLOG(1) << "Visiting node: " << curr_node->type_string()
|
||||
<< " at depth: " << curr_depth
|
||||
<< " for node: " << n->type_string();
|
||||
|
||||
// If we find a match, we return immediately with the matching rewrite
|
||||
// info.
|
||||
for (const RewriteInfo* ri : matching_ri) {
|
||||
if (curr_node->type_string() == ri->cinfo.fwd) {
|
||||
*fwdn = curr_node;
|
||||
return ri;
|
||||
}
|
||||
}
|
||||
|
||||
// Else we explore backward edges from current node.
|
||||
// Add the source nodes of all incoming edges of the node to the queue.
|
||||
for (const Edge* e : curr_node->in_edges()) {
|
||||
// We do not visit already visited node.
|
||||
if (visited_nodes.find(e->src()) == visited_nodes.end()) {
|
||||
// Depth of these nodes is 1 more than the depth of current node.
|
||||
nqueue.push(std::make_pair(e->src(), curr_depth+1));
|
||||
visited_nodes.insert(e->src());
|
||||
}
|
||||
}
|
||||
} /* while */
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool NodeMergeRewritePass::IsApplicableRewriteNode(const Node *n) const {
|
||||
CHECK_NOTNULL(n);
|
||||
|
||||
// Search for matching rewriteinfo
|
||||
// Even if we find one match, we return true.
|
||||
bool match_found = false;
|
||||
for (const RewriteInfo &ri : rinfo_) {
|
||||
if (n->type_string() == ri.node) {
|
||||
match_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return match_found;
|
||||
}
|
||||
|
||||
bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
bool result = false;
|
||||
CHECK_NOTNULL(g);
|
||||
|
||||
DumpGraph("Before OptimizeMerge", &**g);
|
||||
|
||||
std::vector<Node*> order;
|
||||
GetReversePostOrder(**g, &order);
|
||||
std::vector<std::pair<Node*, Node*>> nodes_to_be_merged;
|
||||
std::vector<Node*> nodes_to_be_rewritten;
|
||||
|
||||
VLOG(1) << "Running NodeMerge Optimization";
|
||||
|
||||
for (Node* n : order) {
|
||||
if (!n->IsOp()) continue;
|
||||
Node* n1 = nullptr;
|
||||
if ((n1 = FindNodeForMerge(n)) != nullptr) {
|
||||
VLOG(1) << "Scheduled nodes " << n->name() << " and "
|
||||
<< n1->name() << " for merging";
|
||||
nodes_to_be_merged.push_back(std::make_pair(n, n1));
|
||||
} else if (IsApplicableRewriteNode(n)) {
|
||||
VLOG(1) << "Scheduled node " << n->name() << " for rewrite";
|
||||
nodes_to_be_rewritten.push_back(n);
|
||||
}
|
||||
}
|
||||
|
||||
for (std::pair < Node*, Node* > i : nodes_to_be_merged) {
|
||||
// Even if MergeNode merges single pair of nodes, we
|
||||
// need to return true.
|
||||
string n1_name = i.first->name();
|
||||
string n2_name = i.second->name();
|
||||
if (MergeNode(g, i.first, i.second) == Status::OK()) {
|
||||
VLOG(1) << "Merged nodes " << n1_name << " and " << n2_name;
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
|
||||
DumpGraph("After OptimizeMerge(nodemerge)", &**g);
|
||||
|
||||
for (Node* i : nodes_to_be_rewritten) {
|
||||
string name = i->name();
|
||||
if (RewriteNode(g, i) == Status::OK()) {
|
||||
VLOG(1) << "Rewrite node: " << name << " successful.";
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
|
||||
DumpGraph("After OptimizeMerge(noderewrite)", &**g);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool OptimizeNodeMerge(std::unique_ptr<Graph>* g) {
|
||||
return NodeMergeRewritePass().RunPass(g);
|
||||
}
|
||||
|
||||
Status NodeMergeRewritePass::Run(const GraphOptimizationPassOptions& options) {
|
||||
// Currently checking only for two cases - Conv2D+Bias and Matmul+Bias.
|
||||
// It is possible to extend it to other operators in future.
|
||||
if (options.graph == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the ownership of graph
|
||||
std::unique_ptr<Graph>* g = std::move(options.graph);
|
||||
|
||||
RunPass(g);
|
||||
|
||||
// Return the ownership of graph back
|
||||
options.graph->reset(g->release());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif
|
42
tensorflow/core/graph/mkl_optimizer_merge.h
Normal file
42
tensorflow/core/graph/mkl_optimizer_merge.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// An optimization pass that performs node merging and rewrite on graph nodes
|
||||
|
||||
#ifndef TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
|
||||
#define TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Interface to invoke the pass for unit test
|
||||
//
|
||||
// Returns true if and only if 'g' is mutated.
|
||||
extern bool OptimizeNodeMerge(std::unique_ptr<Graph>* g);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
||||
|
||||
#endif // TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
|
397
tensorflow/core/graph/mkl_optimizer_merge_test.cc
Normal file
397
tensorflow/core/graph/mkl_optimizer_merge_test.cc
Normal file
@ -0,0 +1,397 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/core/graph/mkl_optimizer_merge.h"
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class OptimizerMergeTest : public ::testing::Test {
|
||||
public:
|
||||
OptimizerMergeTest() : graph_(OpRegistry::Global()) {}
|
||||
|
||||
static void InitGraph(const string& s, Graph* graph) {
|
||||
GraphDef graph_def;
|
||||
|
||||
auto parser = protobuf::TextFormat::Parser();
|
||||
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
||||
GraphConstructorOptions opts;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
||||
}
|
||||
|
||||
void InitGraph(const string& s) {
|
||||
InitGraph(s, &graph_);
|
||||
original_ = CanonicalGraphString(&graph_);
|
||||
}
|
||||
|
||||
static bool IncludeNode(const Node* n) { return n->IsOp(); }
|
||||
|
||||
static string EdgeId(const Node* n, int index) {
|
||||
if (index == 0) {
|
||||
return n->name();
|
||||
} else if (index == Graph::kControlSlot) {
|
||||
return strings::StrCat(n->name(), ":control");
|
||||
} else {
|
||||
return strings::StrCat(n->name(), ":", index);
|
||||
}
|
||||
}
|
||||
|
||||
string CanonicalGraphString(Graph* g) {
|
||||
std::vector<string> nodes;
|
||||
std::vector<string> edges;
|
||||
for (const Node* n : g->nodes()) {
|
||||
if (IncludeNode(n)) {
|
||||
nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
|
||||
}
|
||||
}
|
||||
for (const Edge* e : g->edges()) {
|
||||
if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
|
||||
edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
|
||||
EdgeId(e->dst(), e->dst_input())));
|
||||
}
|
||||
}
|
||||
// Canonicalize
|
||||
std::sort(nodes.begin(), nodes.end());
|
||||
std::sort(edges.begin(), edges.end());
|
||||
return strings::StrCat(str_util::Join(nodes, ";"), "|",
|
||||
str_util::Join(edges, ";"));
|
||||
}
|
||||
|
||||
string DoNodeMerge() {
|
||||
string before = CanonicalGraphString(&graph_);
|
||||
LOG(ERROR) << "Before node merge optimize: " << before;
|
||||
|
||||
std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
|
||||
OptimizeNodeMerge(ug);
|
||||
|
||||
string result = CanonicalGraphString(&graph_);
|
||||
LOG(ERROR) << "After node merge optimize: " << result;
|
||||
return result;
|
||||
}
|
||||
|
||||
const string& OriginalGraph() const { return original_; }
|
||||
|
||||
Graph graph_;
|
||||
string original_;
|
||||
};
|
||||
|
||||
REGISTER_OP("Input").Output("o: float").SetIsStateful();
|
||||
|
||||
TEST_F(OptimizerMergeTest, Basic) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }"
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Mul);D(Mul)|"
|
||||
"A->C;A->D;B->C:1;B->D:1");
|
||||
}
|
||||
|
||||
// Test set 1: Conv2D + AddBias
|
||||
|
||||
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y)
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C', 'D'] }"
|
||||
"node { name: 'Y' op: 'Input'}"
|
||||
"node { name: 'Z' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);D(Input);E(Conv2DWithBias);Y(Input);Z(Sub)|"
|
||||
"A->E;B->E:1;D->E:2;E->Z;Y->Z:1");
|
||||
}
|
||||
|
||||
// Graph contains only Conv2D, no AddBias.
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_NoAddBias) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Conv2D)|"
|
||||
"A->C;B->C:1");
|
||||
}
|
||||
|
||||
// Conv2D output does not go to BiasAdd.
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_Dataflow1) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D', 'E'] }"); // Output of Conv2D does not go to BiasAdd.
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Conv2D);D(Input);E(Input);F(BiasAdd)|"
|
||||
"A->C;B->C:1;D->F;E->F:1");
|
||||
}
|
||||
|
||||
// Conv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
|
||||
// Merge should not be done in such case.
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_Dataflow2) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D', 'E'] }" // Conv2D has two outputs.
|
||||
// No merge should happen.
|
||||
"node { name: 'G' op: 'Add'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'E'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Conv2D);D(Input);E(Input);F(BiasAdd);G(Add)|"
|
||||
"A->C;B->C:1;C->G;D->F;E->F:1;E->G:1");
|
||||
}
|
||||
|
||||
// data_format attribute value mismatch. Merge should not be done
|
||||
// in such case.
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_AttrMismatch) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NHCW' } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Conv2D);D(Input);E(BiasAdd)|"
|
||||
"A->C;B->C:1;C->E;D->E:1");
|
||||
}
|
||||
|
||||
// Test set 2: Conv2D..BiasAddGrad -> Conv2DWithBiasBackpropBias rewrite tests
|
||||
|
||||
// C=Conv2D(A,B); D=Sub(C,A); F=BiasAddGrad(D)
|
||||
TEST_F(OptimizerMergeTest, Conv2DBackprop_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Conv2D);D(Sub);E(Conv2DWithBiasBackpropBias)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// No Conv2D in the context for BiasAddGrad. No rewrite should happen.
|
||||
// C=Add(A,B); D=Sub(C,A); F=BiasAddGrad(D,E)
|
||||
TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoConv2D) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Add'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// No Conv2D in the context for BiasAddGrad, but MatMul in context.
|
||||
// Rewrite should happen, but name of BiasAddGrad does not change.
|
||||
// C=MatMul(A,B); D=Sub(C,A); F=BiasAddGrad(D,E)
|
||||
TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoConv2D_MatMul) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'MatMul'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'transpose_a' value { b: false } }"
|
||||
" attr { key: 'transpose_b' value { b: false } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests
|
||||
// C=MatMul(A,B); D=Sub(C,A); F=BiasAddGrad(D,E)
|
||||
TEST_F(OptimizerMergeTest, MatMulBiasAddGrad_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'MatMul'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'transpose_a' value { b: false } }"
|
||||
" attr { key: 'transpose_b' value { b: false } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// No MatMul in the context for BiasAddGrad. No rewrite should happen.
|
||||
// C=Add(A,B); D=Sub(C,A); F=BiasAddGrad(D,E)
|
||||
TEST_F(OptimizerMergeTest, MatMulBiasAddGrad_Negative_NoMatMul) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Add'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
|
||||
static void BM_NodeMerge(int iters, int op_nodes) {
|
||||
testing::StopTiming();
|
||||
string s;
|
||||
for (int in = 0; in < 10; in++) {
|
||||
s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
|
||||
}
|
||||
random::PhiloxRandom philox(301, 17);
|
||||
random::SimplePhilox rnd(&philox);
|
||||
for (int op = 0; op < op_nodes; op++) {
|
||||
s += strings::Printf(
|
||||
"node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
|
||||
"type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
|
||||
op, rnd.Uniform(10), rnd.Uniform(10));
|
||||
}
|
||||
|
||||
bool first = true;
|
||||
while (iters > 0) {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
OptimizerMergeTest::InitGraph(s, graph);
|
||||
int N = graph->num_node_ids();
|
||||
if (first) {
|
||||
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
|
||||
first = false;
|
||||
}
|
||||
{
|
||||
testing::StartTiming();
|
||||
std::unique_ptr<Graph> ug(graph);
|
||||
OptimizeNodeMerge(&ug);
|
||||
testing::StopTiming();
|
||||
}
|
||||
iters -= N; // Our benchmark units are individual graph nodes,
|
||||
// not whole graphs
|
||||
// delete graph;
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_NodeMerge)->Arg(1000)->Arg(10000);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif /* INTEL_MKL */
|
@ -7,9 +7,9 @@
|
||||
# append "_gpu" to the test name to invoke the GPU benchmarks. Example:
|
||||
#
|
||||
# # for CPU tests
|
||||
# $ bazel test -c opt --copt=-mavx //third_party/tensorflow/core/kernels:my_op_test
|
||||
# $ bazel test --config opt //third_party/tensorflow/core/kernels:my_op_test
|
||||
# # for GPU benchmarks
|
||||
# $ bazel run -c opt --copt=-mavx --config=cuda //third_party/tensorflow/core/kernels:my_op_test_gpu -- --benchmarks=..
|
||||
# $ bazel run --config opt --config=cuda //third_party/tensorflow/core/kernels:my_op_test_gpu -- --benchmarks=..
|
||||
#
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
|
@ -170,12 +170,11 @@ struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
|
||||
desc.pad_w_out = 0;
|
||||
desc.threads = num_threads;
|
||||
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
|
||||
desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
|
||||
desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_LIBXSMM;//LIBXSMM_DNN_CONV_FORMAT_RSCK;
|
||||
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
|
||||
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||
|
||||
auto input_ptr = input_backward.data();
|
||||
auto filter_ptr = kernel.data();
|
||||
|
@ -219,12 +219,11 @@ class LaunchXsmmConvOp<CPUDevice, float> {
|
||||
desc.pad_w_out = 0;
|
||||
desc.threads = num_threads;
|
||||
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
|
||||
desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
|
||||
desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_LIBXSMM;
|
||||
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
|
||||
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
|
||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||
|
||||
if (!CanUseXsmmConv2D(desc, data_format)) {
|
||||
return false;
|
||||
|
@ -220,13 +220,15 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
|
||||
namespace functor {
|
||||
|
||||
// UnsortedSegmentSumFunctor implementation for CPUDevice.
|
||||
// todo: Remove duplicate code in UnsortedSegmentSumFunctor and UnsortedSegmentMaxFunctor.
|
||||
template <typename T, typename Index>
|
||||
struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> {
|
||||
struct UnsortedSegmentSumFunctor<CPUDevice, T, Index>
|
||||
: UnsortedSegmentBaseFunctor<CPUDevice, T, Index> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||
typename TTypes<Index>::ConstFlat segment_ids,
|
||||
const Index data_size, const T* data,
|
||||
typename TTypes<T, 2>::Tensor output) {
|
||||
typename TTypes<T, 2>::Tensor output) override {
|
||||
output.setZero();
|
||||
if (data_size == 0) {
|
||||
return;
|
||||
@ -243,16 +245,44 @@ struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// UnsortedSegmentMaxFunctor implementation for CPUDevice.
|
||||
template <typename T, typename Index>
|
||||
struct UnsortedSegmentMaxFunctor<CPUDevice, T, Index>
|
||||
: UnsortedSegmentBaseFunctor<CPUDevice, T, Index> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||
typename TTypes<Index>::ConstFlat segment_ids,
|
||||
const Index data_size, const T* data,
|
||||
typename TTypes<T, 2>::Tensor output) override {
|
||||
output.setConstant(std::numeric_limits<T>::min());
|
||||
if (data_size == 0) {
|
||||
return;
|
||||
}
|
||||
const int64 N = segment_ids.dimension(0);
|
||||
auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N);
|
||||
for (int64 i = 0; i < N; ++i) {
|
||||
Index j = internal::SubtleMustCopy(segment_ids(i));
|
||||
OP_REQUIRES(ctx, FastBoundsCheck(j, output_rows),
|
||||
errors::InvalidArgument(
|
||||
"segment_ids", SliceDebugString(segment_ids_shape, i),
|
||||
" = ", j, " is out of range [0, ", output_rows, ")"));
|
||||
output.template chip<0>(j) =
|
||||
data_flat.template chip<0>(i).cwiseMax(output.template chip<0>(j));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
|
||||
// Similar to SegmentReductionOp but can handle unsorted segment definitions and
|
||||
// specifying size of output.
|
||||
// Base class for SegmentReductionOps that can handle unsorted segment
|
||||
// definitions
|
||||
// and specifying the size of the output in addition to a reduction function
|
||||
template <typename Device, class T, class Index>
|
||||
class UnsortedSegmentSumOp : public OpKernel {
|
||||
class UnsortedSegmentBaseOp : public OpKernel {
|
||||
public:
|
||||
explicit UnsortedSegmentSumOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
explicit UnsortedSegmentBaseOp(
|
||||
OpKernelConstruction* context,
|
||||
functor::UnsortedSegmentBaseFunctor<Device, T, Index>& functor)
|
||||
: OpKernel(context), reduction_functor_(functor) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& data = context->input(0);
|
||||
@ -288,27 +318,70 @@ class UnsortedSegmentSumOp : public OpKernel {
|
||||
auto output_flat = output->flat_outer_dims<T>();
|
||||
|
||||
auto data_ptr = data.template flat<T>().data();
|
||||
functor::UnsortedSegmentSumFunctor<Device, T, Index>()(
|
||||
context, context->template eigen_device<Device>(), output_rows,
|
||||
segment_ids.shape(), segment_flat, data.NumElements(), data_ptr,
|
||||
output_flat);
|
||||
reduction_functor_(context, context->template eigen_device<Device>(),
|
||||
output_rows, segment_ids.shape(), segment_flat,
|
||||
data.NumElements(), data_ptr, output_flat);
|
||||
}
|
||||
private:
|
||||
functor::UnsortedSegmentBaseFunctor<Device, T, Index>& reduction_functor_;
|
||||
};
|
||||
|
||||
#define REGISTER_CPU_UNSORTED_KERNELS(type, index_type) \
|
||||
template <typename Device, class T, class Index>
|
||||
class UnsortedSegmentSumOp : public UnsortedSegmentBaseOp<Device, T, Index> {
|
||||
public:
|
||||
explicit UnsortedSegmentSumOp(OpKernelConstruction* context)
|
||||
: UnsortedSegmentBaseOp<Device, T, Index>(
|
||||
context,
|
||||
sum_functor_) {}
|
||||
private:
|
||||
functor::UnsortedSegmentSumFunctor<Device, T, Index> sum_functor_;
|
||||
};
|
||||
|
||||
template <typename Device, class T, class Index>
|
||||
class UnsortedSegmentMaxOp : public UnsortedSegmentBaseOp<Device, T, Index> {
|
||||
public:
|
||||
explicit UnsortedSegmentMaxOp(OpKernelConstruction* context)
|
||||
: UnsortedSegmentBaseOp<Device, T, Index>(
|
||||
context,
|
||||
max_functor_) {}
|
||||
private:
|
||||
functor::UnsortedSegmentMaxFunctor<Device, T, Index> max_functor_;
|
||||
};
|
||||
|
||||
#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
UnsortedSegmentSumOp<CPUDevice, type, index_type>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentMax") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
UnsortedSegmentMaxOp<CPUDevice, type, index_type>);
|
||||
|
||||
#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
UnsortedSegmentSumOp<CPUDevice, type, index_type>);
|
||||
|
||||
#define REGISTER_CPU_UNSORTED_KERNELS_ALL(type) \
|
||||
REGISTER_CPU_UNSORTED_KERNELS(type, int32); \
|
||||
REGISTER_CPU_UNSORTED_KERNELS(type, int64);
|
||||
#define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \
|
||||
REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \
|
||||
REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64)
|
||||
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL);
|
||||
#undef REGISTER_CPU_UNSORTED_KERNELS
|
||||
#undef REGISTER_CPU_UNSORTED_KERNELS_ALL
|
||||
#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \
|
||||
REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32); \
|
||||
REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64)
|
||||
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL);
|
||||
REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64);
|
||||
REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
|
||||
#undef REGISTER_REAL_CPU_UNSORTED_KERNELS
|
||||
#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS
|
||||
#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
|
||||
#undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_GPU_UNSORTED_KERNELS(type, index_type) \
|
||||
|
@ -26,6 +26,17 @@ namespace tensorflow {
|
||||
class OpKernelContext;
|
||||
|
||||
namespace functor {
|
||||
// BaseFunctor for definition of UnsorteSegmentReductionOp
|
||||
// for usage without templates.
|
||||
template <typename Device, typename T, typename Index>
|
||||
struct UnsortedSegmentBaseFunctor{
|
||||
virtual ~UnsortedSegmentBaseFunctor(){}
|
||||
virtual void operator()(OpKernelContext* ctx, const Device& d,
|
||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||
typename TTypes<Index>::ConstFlat segment_ids,
|
||||
const Index data_size, const T* data,
|
||||
typename TTypes<T, 2>::Tensor output){};
|
||||
};
|
||||
|
||||
// Functor for UnsortedSegmentSumOp.
|
||||
// 'output_rows': the number of output segments (unique segment ids in
|
||||
@ -37,7 +48,7 @@ namespace functor {
|
||||
// 'data': input data tensor.
|
||||
// 'output': output reshaped to {output_rows, output.size/output_rows}
|
||||
template <typename Device, typename T, typename Index>
|
||||
struct UnsortedSegmentSumFunctor {
|
||||
struct UnsortedSegmentSumFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> {
|
||||
void operator()(OpKernelContext* ctx, const Device& d,
|
||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||
typename TTypes<Index>::ConstFlat segment_ids,
|
||||
@ -45,6 +56,23 @@ struct UnsortedSegmentSumFunctor {
|
||||
typename TTypes<T, 2>::Tensor output);
|
||||
};
|
||||
|
||||
// Functor for UnsortedSegmentMaxOp.
|
||||
// 'output_rows': the number of output segments (unique segment ids in
|
||||
// 'segment_ids').
|
||||
// 'segment_ids_shape': shape of 'segment_ids' tensor.
|
||||
// 'segment_ids': unsorted map from input to output segment ids at which to
|
||||
// perform segment sum operation.
|
||||
// 'data_size': size of input data tensor.
|
||||
// 'data': input data tensor.
|
||||
// 'output': output reshaped to {output_rows, output.size/output_rows}
|
||||
template <typename Device, typename T, typename Index>
|
||||
struct UnsortedSegmentMaxFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> {
|
||||
void operator()(OpKernelContext* ctx, const Device& d,
|
||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||
typename TTypes<Index>::ConstFlat segment_ids,
|
||||
const Index data_size, const T* data,
|
||||
typename TTypes<T, 2>::Tensor output);
|
||||
};
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -56,12 +56,12 @@ namespace functor {
|
||||
|
||||
// UnsortedSegmentSumFunctor implementation for GPUDevice.
|
||||
template <typename T, typename Index>
|
||||
struct UnsortedSegmentSumFunctor<GPUDevice, T, Index> {
|
||||
struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>: UnsortedSegmentBaseFunctor<GPUDevice, T, Index> {
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||
typename TTypes<Index>::ConstFlat segment_ids,
|
||||
const Index data_size, const T* data,
|
||||
typename TTypes<T, 2>::Tensor output) {
|
||||
typename TTypes<T, 2>::Tensor output) override {
|
||||
if (output.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#ifdef TENSORFLOW_USE_LIBXSMM
|
||||
#include "include/libxsmm_intrinsics_x86.h"
|
||||
#include "include/libxsmm_malloc.h"
|
||||
#include "include/libxsmm_spmdm.h"
|
||||
#endif
|
||||
|
||||
@ -896,6 +897,8 @@ class LibxsmmSparseMatMul {
|
||||
} else {
|
||||
std::unique_ptr<TensorInfoCacheEntry> e{
|
||||
new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
|
||||
// setup scoped allocator, which uses cpu_allocator() for this scope
|
||||
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
|
||||
libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
|
||||
return e;
|
||||
}
|
||||
|
@ -33,6 +33,7 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(void);
|
||||
|
||||
#include "include/libxsmm_cpuid.h"
|
||||
#include "libxsmm_dnn_handle.h"
|
||||
#include "libxsmm_malloc.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -143,26 +144,28 @@ struct HashFunction{
|
||||
S << w.d.S; u << w.d.u;
|
||||
v << w.d.v; padh << w.d.pad_h_in;
|
||||
padw << w.d.pad_w_in;
|
||||
|
||||
|
||||
|
||||
|
||||
std::string out_ = N.str() + C.str()\
|
||||
+ H.str() + W.str()\
|
||||
+ K.str() + R.str()\
|
||||
+ S.str() + u.str()\
|
||||
+ v.str() + padh.str()\
|
||||
+ padw.str();
|
||||
|
||||
|
||||
return ( std::hash<std::string>()(out_));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class handles{
|
||||
public:
|
||||
libxsmm_dnn_conv_handle* find( const libxsmm_dnn_conv_desc_wrap &w) {
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i = libxsmm_handles.find(w);
|
||||
libxsmm_dnn_layer* find( const libxsmm_dnn_conv_desc_wrap &w) {
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*,
|
||||
HashFunction>::iterator i = libxsmm_handles.find(w);
|
||||
if (i == libxsmm_handles.end()){
|
||||
libxsmm_dnn_err_t status;
|
||||
libxsmm_dnn_conv_handle* libxsmm_handle = libxsmm_dnn_create_conv_handle_check(w.d, &status);
|
||||
libxsmm_dnn_layer* libxsmm_handle =
|
||||
libxsmm_dnn_create_conv_layer(w.d, &status);
|
||||
chk_libxsmm_err(status, "Create handle");
|
||||
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
|
||||
return libxsmm_handle;
|
||||
@ -171,15 +174,14 @@ class handles{
|
||||
return i->second;
|
||||
}
|
||||
~handles(){
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i;
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*,
|
||||
HashFunction>::iterator i;
|
||||
for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(i->second),
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
|
||||
"Destroy handle");
|
||||
}
|
||||
private:
|
||||
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction> libxsmm_handles;
|
||||
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction> libxsmm_handles;
|
||||
};
|
||||
|
||||
static handles libxsmm_handles;
|
||||
@ -187,22 +189,25 @@ static handles libxsmm_handles;
|
||||
template <typename InputPtr, typename FilterPtr, typename OutputPtr>
|
||||
static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
const libxsmm_dnn_conv_desc& desc,
|
||||
libxsmm_dnn_conv_kind kind, InputPtr input,
|
||||
libxsmm_dnn_compute_kind kind, InputPtr input,
|
||||
FilterPtr filter, OutputPtr output) {
|
||||
// setup scoped allocator, which adopts the allocator from the context
|
||||
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx);
|
||||
libxsmm_dnn_err_t status;
|
||||
libxsmm_dnn_conv_handle* libxsmm_handle;
|
||||
libxsmm_dnn_layer* libxsmm_handle;
|
||||
libxsmm_dnn_conv_desc_wrap w(desc);
|
||||
void* scratch;
|
||||
|
||||
if(kind == LIBXSMM_DNN_CONV_KIND_FWD)
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
||||
libxsmm_handle = libxsmm_handles.find(w);
|
||||
else{
|
||||
libxsmm_handle = libxsmm_dnn_create_conv_handle_check(desc, &status);
|
||||
else {
|
||||
libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
|
||||
chk_libxsmm_err(status, "Create handle");
|
||||
}
|
||||
|
||||
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
|
||||
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
|
||||
"Destroy handle");
|
||||
return false; // Use non-libxsmm code
|
||||
}
|
||||
@ -211,23 +216,23 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
libxsmm_dnn_buffer* libxsmm_input;
|
||||
libxsmm_dnn_buffer* libxsmm_output;
|
||||
libxsmm_dnn_filter* libxsmm_filter;
|
||||
|
||||
/*
|
||||
|
||||
/*
|
||||
const DeviceBase::CpuWorkerThreads* worker_threads =
|
||||
ctx->device()->tensorflow_cpu_worker_threads();
|
||||
|
||||
|
||||
int num_threads = worker_threads->num_threads;
|
||||
*/
|
||||
|
||||
int ifmblock = (libxsmm_handle->ifmblock);
|
||||
int ofmblock = (libxsmm_handle->ofmblock);
|
||||
int ofmblock = (libxsmm_handle->ofmblock);
|
||||
|
||||
int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
|
||||
int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
|
||||
int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1;
|
||||
float *native_filter = (float*)libxsmm_aligned_malloc( blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), 2097152);
|
||||
|
||||
float *native_filter = (float*)libxsmm_aligned_scratch(
|
||||
blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float),
|
||||
2097152);
|
||||
|
||||
|
||||
const DeviceBase::CpuWorkerThreads* worker_threads =
|
||||
ctx->device()->tensorflow_cpu_worker_threads();
|
||||
|
||||
@ -264,50 +269,78 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
count.Wait();
|
||||
}
|
||||
|
||||
libxsmm_input = libxsmm_dnn_link_input_buffer_check(
|
||||
libxsmm_handle, input, LIBXSMM_DNN_CONV_FORMAT_NHWC_PTR, &status);
|
||||
libxsmm_input = libxsmm_dnn_link_buffer(
|
||||
libxsmm_handle, LIBXSMM_DNN_INPUT, input, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link input buffer");
|
||||
libxsmm_output = libxsmm_dnn_link_output_buffer_check(
|
||||
libxsmm_handle, output, LIBXSMM_DNN_CONV_FORMAT_NHWC_PTR, &status);
|
||||
libxsmm_output = libxsmm_dnn_link_buffer(
|
||||
libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link output buffer");
|
||||
libxsmm_filter = libxsmm_dnn_link_filter_check(
|
||||
libxsmm_handle, native_filter, LIBXSMM_DNN_CONV_FORMAT_LIBXSMM_PTR, &status);
|
||||
libxsmm_filter = libxsmm_dnn_link_filter(
|
||||
libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link filter");
|
||||
|
||||
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output");
|
||||
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_input_buffer(libxsmm_handle, libxsmm_input),
|
||||
"Bind input");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_bind_output_buffer(libxsmm_handle, libxsmm_output),
|
||||
"Bind output");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter),
|
||||
"Bind filter");
|
||||
|
||||
if (kind == LIBXSMM_DNN_CONV_KIND_BWD) {
|
||||
libxsmm_dnn_transpose_filter(libxsmm_handle);
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"Bind input forward");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||
"Bind output forward");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"Bind filter forward");
|
||||
} else {
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_GRADIENT_INPUT),
|
||||
"Bind input backward");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"Bind output backward");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"Bind filter backward");
|
||||
}
|
||||
|
||||
/* bind scratch */
|
||||
scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, kind, &status ), 2097152);
|
||||
chk_libxsmm_err( status, "scratch allocation" );
|
||||
chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, kind, scratch ), "binding scratch" );
|
||||
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||
libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER);
|
||||
}
|
||||
|
||||
BlockingCounter counter(num_threads);
|
||||
|
||||
|
||||
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
worker_threads->workers->Schedule([=, &counter]() {
|
||||
chk_libxsmm_err(libxsmm_dnn_convolve_st(libxsmm_handle, kind, 0, i),
|
||||
chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i),
|
||||
"Worker");
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
|
||||
/* clean up */
|
||||
chk_libxsmm_err( libxsmm_dnn_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ), "release scratch" );
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
||||
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ), "release input" );
|
||||
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT ), "release output" );
|
||||
chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" );
|
||||
} else {
|
||||
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT ), "release input" );
|
||||
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ), "release output" );
|
||||
chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" );
|
||||
}
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
|
||||
|
||||
if(kind != LIBXSMM_DNN_CONV_KIND_FWD)
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
|
||||
if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
|
||||
"Destroy handle");
|
||||
|
||||
libxsmm_free(native_filter);
|
||||
libxsmm_free(scratch);
|
||||
return true; // Succeeded
|
||||
}
|
||||
|
||||
@ -315,7 +348,7 @@ template <typename T>
|
||||
struct XsmmFwdConv2D<CPUDevice, T> {
|
||||
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
||||
const T* input, const T* filter, T* output) {
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_FWD, input,
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD, input,
|
||||
filter, output);
|
||||
}
|
||||
};
|
||||
@ -324,7 +357,7 @@ template <typename T>
|
||||
struct XsmmBkwInputConv2D<CPUDevice, T> {
|
||||
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
||||
T* input, const T* filter, const T* output) {
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_BWD, input,
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD, input,
|
||||
filter, output);
|
||||
}
|
||||
};
|
||||
@ -333,7 +366,7 @@ template <typename T>
|
||||
struct XsmmBkwFilterConv2D<CPUDevice, T> {
|
||||
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
||||
const T* input, T* filter, const T* output) {
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_UPD, input,
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD, input,
|
||||
filter, output);
|
||||
}
|
||||
};
|
||||
|
@ -188,6 +188,8 @@ class XsmmConv2DTest : public OpsTestBase {
|
||||
TEST_F(XsmmConv2DTest, Basic) {
|
||||
MakeOp(1);
|
||||
|
||||
// setup scoped allocator, which uses cpu_allocator() for this scope
|
||||
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
|
||||
|
||||
int ifw = 14; /* input width, "W" */
|
||||
int ifh = 14; /* input height, "H" */
|
||||
@ -223,9 +225,9 @@ TEST_F(XsmmConv2DTest, Basic) {
|
||||
//Initialization of Filter and Image
|
||||
|
||||
/* allocate data */
|
||||
float *naive_input = (float*)libxsmm_aligned_malloc( nImg*nIfm*ifhp*ifwp*sizeof(float), 2097152);
|
||||
float *naive_output = (float*)libxsmm_aligned_malloc( nImg*nOfm*ofhp*ofwp*sizeof(float), 2097152);
|
||||
float *naive_filter = (float*)libxsmm_aligned_malloc( nOfm*nIfm*kh*kw* sizeof(float), 2097152);
|
||||
float *naive_input = (float*)libxsmm_aligned_scratch( nImg*nIfm*ifhp*ifwp*sizeof(float), 2097152);
|
||||
float *naive_output = (float*)libxsmm_aligned_scratch( nImg*nOfm*ofhp*ofwp*sizeof(float), 2097152);
|
||||
float *naive_filter = (float*)libxsmm_aligned_scratch( nOfm*nIfm*kh*kw* sizeof(float), 2097152);
|
||||
/* initialize data */
|
||||
init_buf(naive_input, nImg*nIfm*ifhp*ifwp, 0, 0);
|
||||
zero_buf(naive_output, nImg*nOfm*ofhp*ofwp);
|
||||
@ -322,12 +324,11 @@ TEST(XsmmConv2DTest, Basic) {
|
||||
desc.pad_w_out = 0;
|
||||
desc.threads = num_threads;
|
||||
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
|
||||
desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
|
||||
desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_LIBXSMM;//LIBXSMM_DNN_CONV_FORMAT_RSCK;
|
||||
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
|
||||
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||
|
||||
if (!CanUseXsmmConv2D(desc, data_format)) {
|
||||
return false;
|
||||
|
@ -588,6 +588,7 @@ REGISTER_OP_GRADIENT("Mean", MeanGrad);
|
||||
// REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
|
||||
// REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
|
||||
// REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);
|
||||
// REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad);
|
||||
|
||||
Status MinMaxGradHelper(const string& op, const AttrSlice& attrs,
|
||||
FunctionDef* g) {
|
||||
|
@ -1342,6 +1342,36 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
|
||||
ShapeHandle s_data = c->input(0);
|
||||
ShapeHandle s_segment_ids = c->input(1);
|
||||
ShapeHandle s_num_segments = c->input(2);
|
||||
TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
|
||||
|
||||
ShapeHandle out;
|
||||
|
||||
// Leading dimensions of data must be compatible with dimensions of
|
||||
// <s_segment_ids>.
|
||||
if (c->RankKnown(s_segment_ids)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
|
||||
|
||||
// Get the value of the num_segments input tensor.
|
||||
DimensionHandle num_segments_dim;
|
||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
|
||||
|
||||
// Output is {segment_id_rank} + s_data[segment_id_rank:].
|
||||
ShapeHandle s_data_suffix;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
|
||||
} else {
|
||||
out = c->UnknownShape();
|
||||
}
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
REGISTER_OP("SegmentSum")
|
||||
@ -1495,36 +1525,7 @@ REGISTER_OP("UnsortedSegmentSum")
|
||||
.Output("output: T")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32,int64}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle s_data = c->input(0);
|
||||
ShapeHandle s_segment_ids = c->input(1);
|
||||
ShapeHandle s_num_segments = c->input(2);
|
||||
TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
|
||||
|
||||
ShapeHandle out;
|
||||
|
||||
// Leading dimensions of data must be compatible with dimensions of
|
||||
// <s_segment_ids>.
|
||||
if (c->RankKnown(s_segment_ids)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
|
||||
|
||||
// Get the value of the num_segments input tensor.
|
||||
DimensionHandle num_segments_dim;
|
||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
|
||||
|
||||
// Output is {segment_id_rank} + s_data[segment_id_rank:].
|
||||
ShapeHandle s_data_suffix;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
|
||||
} else {
|
||||
out = c->UnknownShape();
|
||||
}
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
})
|
||||
.SetShapeFn(UnsortedSegmentReductionShapeFn)
|
||||
.Doc(R"doc(
|
||||
Computes the sum along segments of a tensor.
|
||||
|
||||
@ -1554,6 +1555,43 @@ output: Has same shape as data, except for the first `segment_ids.rank`
|
||||
|
||||
)doc");
|
||||
|
||||
|
||||
REGISTER_OP("UnsortedSegmentMax")
|
||||
.Input("data: T")
|
||||
.Input("segment_ids: Tindices")
|
||||
.Input("num_segments: int32")
|
||||
.Output("output: T")
|
||||
.Attr("T: realnumbertype")
|
||||
.Attr("Tindices: {int32,int64}")
|
||||
.SetShapeFn(UnsortedSegmentReductionShapeFn)
|
||||
.Doc(R"doc(
|
||||
Computes the Max along segments of a tensor.
|
||||
|
||||
Read [the section on
|
||||
Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation
|
||||
of segments.
|
||||
|
||||
This operator is similar to the [unsorted segment sum operator](../../api_docs/python/math_ops.md#UnsortedSegmentSum).
|
||||
Instead of computing the sum over segments, it computes the maximum
|
||||
such that:
|
||||
|
||||
\\(output_i = \max_j data_j\\) where max is over `j` such
|
||||
that `segment_ids[j] == i`.
|
||||
|
||||
If the maximum is empty for a given segment ID `i`, it outputs the smallest possible value for specific numeric type,
|
||||
`output[i] = numeric_limits<T>::min()`.
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/UnsortedSegmentSum.png" alt>
|
||||
</div>
|
||||
|
||||
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||
first dimension.
|
||||
|
||||
output: Has same shape as data, except for dimension 0 which
|
||||
has size `num_segments`.
|
||||
|
||||
)doc");
|
||||
REGISTER_OP("SparseSegmentSum")
|
||||
.Input("data: T")
|
||||
.Input("indices: Tidx")
|
||||
|
@ -25261,6 +25261,59 @@ op {
|
||||
summary: "Computes the sum along segments of a tensor."
|
||||
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such\nthat `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\nrange of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
|
||||
}
|
||||
op {
|
||||
name: "UnsortedSegmentSum"
|
||||
input_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "segment_ids"
|
||||
description: "A tensor whose shape is a prefix of `data.shape`."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "num_segments"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Has same shape as data, except for the first `segment_ids.rank`\ndimensions, which are replaced with a single dimension which has size\n`num_segments`."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Computes the max along segments of a tensor."
|
||||
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
|
||||
}
|
||||
op {
|
||||
name: "Unstage"
|
||||
output_arg {
|
||||
|
@ -77,14 +77,17 @@ void LogMessage::GenerateLogMessage() {
|
||||
|
||||
void LogMessage::GenerateLogMessage() {
|
||||
static EnvTime* env_time = tensorflow::EnvTime::Default();
|
||||
time_t now = static_cast<time_t>(env_time->NowSeconds());
|
||||
uint64 now_micros = env_time->NowMicros();
|
||||
time_t now_seconds = static_cast<time_t>(now_micros / 1000000);
|
||||
int32 micros_remainder = static_cast<int32>(now_micros % 1000000);
|
||||
const size_t time_buffer_size = 30;
|
||||
char time_buffer[time_buffer_size];
|
||||
strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", localtime(&now));
|
||||
strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S",
|
||||
localtime(&now_seconds));
|
||||
|
||||
// TODO(jeff,sanjay): Replace this with something that logs through the env.
|
||||
fprintf(stderr, "%s: %c %s:%d] %s\n", time_buffer, "IWEF"[severity_], fname_,
|
||||
line_, str().c_str());
|
||||
fprintf(stderr, "%s.%06d: %c %s:%d] %s\n", time_buffer, micros_remainder,
|
||||
"IWEF"[severity_], fname_, line_, str().c_str());
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -18,9 +18,9 @@ limitations under the License.
|
||||
|
||||
// TensorFlow uses semantic versioning, see http://semver.org/.
|
||||
|
||||
#define TF_MAJOR_VERSION 0
|
||||
#define TF_MINOR_VERSION 12
|
||||
#define TF_PATCH_VERSION head
|
||||
#define TF_MAJOR_VERSION 1
|
||||
#define TF_MINOR_VERSION 0
|
||||
#define TF_PATCH_VERSION 0-rc1
|
||||
|
||||
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
|
||||
// "-beta", "-rc", "-rc.1")
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
Run using bazel:
|
||||
|
||||
bazel run -c opt \
|
||||
bazel run --config opt \
|
||||
<...>/tensorflow/examples/how_tos/reading_data:fully_connected_preloaded
|
||||
|
||||
or, if installed via pip:
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
Run using bazel:
|
||||
|
||||
bazel run -c opt \
|
||||
bazel run --config opt \
|
||||
<...>/tensorflow/examples/how_tos/reading_data:fully_connected_preloaded_var
|
||||
|
||||
or, if installed via pip:
|
||||
|
@ -346,6 +346,17 @@ def read_list_of_floats_from_file(file_path):
|
||||
|
||||
bottleneck_path_2_bottleneck_values = {}
|
||||
|
||||
def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
|
||||
image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor):
|
||||
print('Creating bottleneck at ' + bottleneck_path)
|
||||
image_path = get_image_path(image_lists, label_name, index, image_dir, category)
|
||||
if not gfile.Exists(image_path):
|
||||
tf.logging.fatal('File does not exist %s', image_path)
|
||||
image_data = gfile.FastGFile(image_path, 'rb').read()
|
||||
bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor)
|
||||
bottleneck_string = ','.join(str(x) for x in bottleneck_values)
|
||||
with open(bottleneck_path, 'w') as bottleneck_file:
|
||||
bottleneck_file.write(bottleneck_string)
|
||||
|
||||
def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
|
||||
category, bottleneck_dir, jpeg_data_tensor,
|
||||
@ -376,28 +387,25 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
|
||||
sub_dir = label_lists['dir']
|
||||
sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
|
||||
ensure_dir_exists(sub_dir_path)
|
||||
bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
|
||||
bottleneck_dir, category)
|
||||
bottleneck_path = get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, category)
|
||||
if not os.path.exists(bottleneck_path):
|
||||
print('Creating bottleneck at ' + bottleneck_path)
|
||||
image_path = get_image_path(image_lists, label_name, index, image_dir,
|
||||
category)
|
||||
if not gfile.Exists(image_path):
|
||||
tf.logging.fatal('File does not exist %s', image_path)
|
||||
image_data = gfile.FastGFile(image_path, 'rb').read()
|
||||
bottleneck_values = run_bottleneck_on_image(sess, image_data,
|
||||
jpeg_data_tensor,
|
||||
bottleneck_tensor)
|
||||
bottleneck_string = ','.join(str(x) for x in bottleneck_values)
|
||||
with open(bottleneck_path, 'w') as bottleneck_file:
|
||||
bottleneck_file.write(bottleneck_string)
|
||||
|
||||
create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
|
||||
with open(bottleneck_path, 'r') as bottleneck_file:
|
||||
bottleneck_string = bottleneck_file.read()
|
||||
bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
|
||||
did_hit_error = False
|
||||
try:
|
||||
bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
|
||||
except:
|
||||
print("Invalid float found, recreating bottleneck")
|
||||
did_hit_error = True
|
||||
if did_hit_error:
|
||||
create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
|
||||
with open(bottleneck_path, 'r') as bottleneck_file:
|
||||
bottleneck_string = bottleneck_file.read()
|
||||
# Allow exceptions to propagate here, since they shouldn't happen after a fresh creation
|
||||
bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
|
||||
return bottleneck_values
|
||||
|
||||
|
||||
def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
|
||||
jpeg_data_tensor, bottleneck_tensor):
|
||||
"""Ensures all the training, testing, and validation bottlenecks are cached.
|
||||
@ -430,6 +438,7 @@ def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
|
||||
get_or_create_bottleneck(sess, image_lists, label_name, index,
|
||||
image_dir, category, bottleneck_dir,
|
||||
jpeg_data_tensor, bottleneck_tensor)
|
||||
|
||||
how_many_bottlenecks += 1
|
||||
if how_many_bottlenecks % 100 == 0:
|
||||
print(str(how_many_bottlenecks) + ' bottleneck files created.')
|
||||
|
@ -32,7 +32,7 @@ normalized from 0-1 in left top right bottom order.
|
||||
To build it, run this command:
|
||||
|
||||
```bash
|
||||
$ bazel build -c opt tensorflow/examples/multibox_detector/...
|
||||
$ bazel build --config opt tensorflow/examples/multibox_detector/...
|
||||
```
|
||||
|
||||
That should build a binary executable that you can then run like this:
|
||||
|
@ -111,6 +111,7 @@
|
||||
"source": [
|
||||
"url = 'http://commondatastorage.googleapis.com/books1000/'\n",
|
||||
"last_percent_reported = None\n",
|
||||
"data_root = '.' # Change me to store data elsewhere\n",
|
||||
"\n",
|
||||
"def download_progress_hook(count, blockSize, totalSize):\n",
|
||||
" \"\"\"A hook to report the progress of a download. This is mostly intended for users with\n",
|
||||
@ -131,17 +132,18 @@
|
||||
" \n",
|
||||
"def maybe_download(filename, expected_bytes, force=False):\n",
|
||||
" \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n",
|
||||
" if force or not os.path.exists(filename):\n",
|
||||
" dest_filename = os.path.join(data_root, filename)\n",
|
||||
" if force or not os.path.exists(dest_filename):\n",
|
||||
" print('Attempting to download:', filename) \n",
|
||||
" filename, _ = urlretrieve(url + filename, filename, reporthook=download_progress_hook)\n",
|
||||
" filename, _ = urlretrieve(url + filename, dest_filename, reporthook=download_progress_hook)\n",
|
||||
" print('\\nDownload Complete!')\n",
|
||||
" statinfo = os.stat(filename)\n",
|
||||
" statinfo = os.stat(dest_filename)\n",
|
||||
" if statinfo.st_size == expected_bytes:\n",
|
||||
" print('Found and verified', filename)\n",
|
||||
" print('Found and verified', dest_filename)\n",
|
||||
" else:\n",
|
||||
" raise Exception(\n",
|
||||
" 'Failed to verify ' + filename + '. Can you get to it with a browser?')\n",
|
||||
" return filename\n",
|
||||
" 'Failed to verify ' + dest_filename + '. Can you get to it with a browser?')\n",
|
||||
" return dest_filename\n",
|
||||
"\n",
|
||||
"train_filename = maybe_download('notMNIST_large.tar.gz', 247336696)\n",
|
||||
"test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)"
|
||||
@ -683,7 +685,7 @@
|
||||
"cellView": "both"
|
||||
},
|
||||
"source": [
|
||||
"pickle_file = 'notMNIST.pickle'\n",
|
||||
"pickle_file = os.path.join(data_root, 'notMNIST.pickle')\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" f = open(pickle_file, 'wb')\n",
|
||||
|
@ -1,4 +1,185 @@
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ByteSize()` {#TaggedRunMetadata.ByteSize}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.Clear()` {#TaggedRunMetadata.Clear}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ClearExtension(extension_handle)` {#TaggedRunMetadata.ClearExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ClearField(field_name)` {#TaggedRunMetadata.ClearField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.CopyFrom(other_msg)` {#TaggedRunMetadata.CopyFrom}
|
||||
|
||||
Copies the content of the specified message into the current message.
|
||||
|
||||
The method clears the current message and then merges the specified
|
||||
message using MergeFrom.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.DiscardUnknownFields()` {#TaggedRunMetadata.DiscardUnknownFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.FindInitializationErrors()` {#TaggedRunMetadata.FindInitializationErrors}
|
||||
|
||||
Finds required fields which are not initialized.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A list of strings. Each string is a path to an uninitialized field from
|
||||
the top-level message, e.g. "foo.bar[5].baz".
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.HasExtension(extension_handle)` {#TaggedRunMetadata.HasExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.HasField(field_name)` {#TaggedRunMetadata.HasField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.IsInitialized(errors=None)` {#TaggedRunMetadata.IsInitialized}
|
||||
|
||||
Checks if all required fields of a message are set.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||
paths of all missing required fields.
|
||||
|
||||
##### Returns:
|
||||
|
||||
True iff the specified message has all required fields set.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ListFields()` {#TaggedRunMetadata.ListFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.MergeFrom(msg)` {#TaggedRunMetadata.MergeFrom}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.MergeFromString(serialized)` {#TaggedRunMetadata.MergeFromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ParseFromString(serialized)` {#TaggedRunMetadata.ParseFromString}
|
||||
|
||||
Parse serialized protocol buffer data into this message.
|
||||
|
||||
Like MergeFromString(), except we clear the object first and
|
||||
do not return the value that MergeFromString returns.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SerializePartialToString()` {#TaggedRunMetadata.SerializePartialToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SerializeToString()` {#TaggedRunMetadata.SerializeToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SetInParent()` {#TaggedRunMetadata.SetInParent}
|
||||
|
||||
Sets the _cached_byte_size_dirty bit to true,
|
||||
and propagates this to our listener iff this was a state change.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.WhichOneof(oneof_name)` {#TaggedRunMetadata.WhichOneof}
|
||||
|
||||
Returns the name of the currently set field inside a oneof, or None.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__deepcopy__(memo=None)` {#TaggedRunMetadata.__deepcopy__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__eq__(other)` {#TaggedRunMetadata.__eq__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__getstate__()` {#TaggedRunMetadata.__getstate__}
|
||||
@ -6,3 +187,66 @@
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__hash__()` {#TaggedRunMetadata.__hash__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__init__(**kwargs)` {#TaggedRunMetadata.__init__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__ne__(other_msg)` {#TaggedRunMetadata.__ne__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__repr__()` {#TaggedRunMetadata.__repr__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__setstate__(state)` {#TaggedRunMetadata.__setstate__}
|
||||
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__str__()` {#TaggedRunMetadata.__str__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__unicode__()` {#TaggedRunMetadata.__unicode__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.run_metadata` {#TaggedRunMetadata.run_metadata}
|
||||
|
||||
Magic attribute generated for "run_metadata" proto field.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.tag` {#TaggedRunMetadata.tag}
|
||||
|
||||
Magic attribute generated for "tag" proto field.
|
||||
|
||||
|
||||
|
@ -1,4 +1,185 @@
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ByteSize()` {#SummaryDescription.ByteSize}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.Clear()` {#SummaryDescription.Clear}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ClearExtension(extension_handle)` {#SummaryDescription.ClearExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ClearField(field_name)` {#SummaryDescription.ClearField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.CopyFrom(other_msg)` {#SummaryDescription.CopyFrom}
|
||||
|
||||
Copies the content of the specified message into the current message.
|
||||
|
||||
The method clears the current message and then merges the specified
|
||||
message using MergeFrom.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.DiscardUnknownFields()` {#SummaryDescription.DiscardUnknownFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.FindInitializationErrors()` {#SummaryDescription.FindInitializationErrors}
|
||||
|
||||
Finds required fields which are not initialized.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A list of strings. Each string is a path to an uninitialized field from
|
||||
the top-level message, e.g. "foo.bar[5].baz".
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.HasExtension(extension_handle)` {#SummaryDescription.HasExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.HasField(field_name)` {#SummaryDescription.HasField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.IsInitialized(errors=None)` {#SummaryDescription.IsInitialized}
|
||||
|
||||
Checks if all required fields of a message are set.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||
paths of all missing required fields.
|
||||
|
||||
##### Returns:
|
||||
|
||||
True iff the specified message has all required fields set.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ListFields()` {#SummaryDescription.ListFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.MergeFrom(msg)` {#SummaryDescription.MergeFrom}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.MergeFromString(serialized)` {#SummaryDescription.MergeFromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ParseFromString(serialized)` {#SummaryDescription.ParseFromString}
|
||||
|
||||
Parse serialized protocol buffer data into this message.
|
||||
|
||||
Like MergeFromString(), except we clear the object first and
|
||||
do not return the value that MergeFromString returns.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SerializePartialToString()` {#SummaryDescription.SerializePartialToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SerializeToString()` {#SummaryDescription.SerializeToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SetInParent()` {#SummaryDescription.SetInParent}
|
||||
|
||||
Sets the _cached_byte_size_dirty bit to true,
|
||||
and propagates this to our listener iff this was a state change.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.WhichOneof(oneof_name)` {#SummaryDescription.WhichOneof}
|
||||
|
||||
Returns the name of the currently set field inside a oneof, or None.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__deepcopy__(memo=None)` {#SummaryDescription.__deepcopy__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__eq__(other)` {#SummaryDescription.__eq__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__getstate__()` {#SummaryDescription.__getstate__}
|
||||
@ -6,3 +187,59 @@
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__hash__()` {#SummaryDescription.__hash__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__init__(**kwargs)` {#SummaryDescription.__init__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__ne__(other_msg)` {#SummaryDescription.__ne__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__repr__()` {#SummaryDescription.__repr__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__setstate__(state)` {#SummaryDescription.__setstate__}
|
||||
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__str__()` {#SummaryDescription.__str__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__unicode__()` {#SummaryDescription.__unicode__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.type_hint` {#SummaryDescription.type_hint}
|
||||
|
||||
Magic attribute generated for "type_hint" proto field.
|
||||
|
||||
|
||||
|
@ -173,125 +173,6 @@ Checks that for all elements of farray1 and farray2
|
||||
* <b>`err`</b>: a float value.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertBetween(value, minv, maxv, msg=None)` {#TestCase.assertBetween}
|
||||
|
||||
Asserts that value is between minv and maxv (inclusive).
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCommandFails(command, regexes, env=None, close_fds=True, msg=None)` {#TestCase.assertCommandFails}
|
||||
|
||||
Asserts a shell command fails and the error matches a regex in a list.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`command`</b>: List or string representing the command to run.
|
||||
* <b>`regexes`</b>: the list of regular expression strings.
|
||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
||||
forking.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCommandSucceeds(command, regexes=('',), env=None, close_fds=True, msg=None)` {#TestCase.assertCommandSucceeds}
|
||||
|
||||
Asserts that a shell command succeeds (i.e. exits with code 0).
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`command`</b>: List or string representing the command to run.
|
||||
* <b>`regexes`</b>: List of regular expression byte strings that match success.
|
||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
||||
forking.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsExactSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsExactSubsequence}
|
||||
|
||||
Assert that "container" contains "subsequence" as an exact subsequence.
|
||||
|
||||
Asserts that "container" contains all the elements of "subsequence", in
|
||||
order, and without other elements interspersed. For example, [1, 2, 3] is an
|
||||
exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
||||
* <b>`subsequence`</b>: the list we hope will be an exact subsequence of container.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsInOrder(strings, target, msg=None)` {#TestCase.assertContainsInOrder}
|
||||
|
||||
Asserts that the strings provided are found in the target in order.
|
||||
|
||||
This may be useful for checking HTML output.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`strings`</b>: A list of strings, such as [ 'fox', 'dog' ]
|
||||
* <b>`target`</b>: A target string in which to look for the strings, such as
|
||||
'The quick brown fox jumped over the lazy dog'.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsSubsequence}
|
||||
|
||||
Assert that "container" contains "subsequence" as a subsequence.
|
||||
|
||||
Asserts that "container" contains all the elements of "subsequence", in
|
||||
order, but possibly with other elements interspersed. For example, [1, 2, 3]
|
||||
is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
||||
* <b>`subsequence`</b>: the list we hope will be a subsequence of container.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsSubset(expected_subset, actual_set, msg=None)` {#TestCase.assertContainsSubset}
|
||||
|
||||
Checks whether actual iterable is a superset of expected iterable.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCountEqual(*args, **kwargs)` {#TestCase.assertCountEqual}
|
||||
|
||||
An unordered sequence specific comparison.
|
||||
|
||||
Equivalent to assertItemsEqual(). This method is a compatibility layer
|
||||
for Python 3k, since 2to3 does not convert assertItemsEqual() calls into
|
||||
assertCountEqual() calls.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertDeviceEqual(device1, device2)` {#TestCase.assertDeviceEqual}
|
||||
@ -314,48 +195,9 @@ Checks whether actual is a superset of expected.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertDictEqual(a, b, msg=None)` {#TestCase.assertDictEqual}
|
||||
|
||||
Raises AssertionError if a and b are not equal dictionaries.
|
||||
|
||||
##### Args:
|
||||
#### `tf.test.TestCase.assertDictEqual(d1, d2, msg=None)` {#TestCase.assertDictEqual}
|
||||
|
||||
|
||||
* <b>`a`</b>: A dict, the expected value.
|
||||
* <b>`b`</b>: A dict, the actual value.
|
||||
* <b>`msg`</b>: An optional str, the associated message.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`AssertionError`</b>: if the dictionaries are not equal.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertEmpty(container, msg=None)` {#TestCase.assertEmpty}
|
||||
|
||||
Assert that an object has zero length.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertEndsWith(actual, expected_end, msg=None)` {#TestCase.assertEndsWith}
|
||||
|
||||
Assert that actual.endswith(expected_end) is True.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`expected_end`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
@ -440,11 +282,10 @@ Included for symmetry with assertIsNone.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertItemsEqual(*args, **kwargs)` {#TestCase.assertItemsEqual}
|
||||
#### `tf.test.TestCase.assertItemsEqual(expected_seq, actual_seq, msg=None)` {#TestCase.assertItemsEqual}
|
||||
|
||||
An unordered sequence specific comparison.
|
||||
|
||||
It asserts that actual_seq and expected_seq have the same element counts.
|
||||
An unordered sequence specific comparison. It asserts that
|
||||
actual_seq and expected_seq have the same element counts.
|
||||
Equivalent to::
|
||||
|
||||
self.assertEqual(Counter(iter(actual_seq)),
|
||||
@ -457,30 +298,6 @@ Asserts that each element has the same count in both sequences.
|
||||
- [0, 1, 1] and [1, 0, 1] compare equal.
|
||||
- [0, 0, 1] and [0, 1] compare unequal.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertJsonEqual(first, second, msg=None)` {#TestCase.assertJsonEqual}
|
||||
|
||||
Asserts that the JSON objects defined in two strings are equal.
|
||||
|
||||
A summary of the differences will be included in the failure message
|
||||
using assertSameStructure.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`first`</b>: A string contining JSON to decode and compare to second.
|
||||
* <b>`second`</b>: A string contining JSON to decode and compare to first.
|
||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
@ -550,13 +367,6 @@ if not.
|
||||
* <b>`msg`</b>: An optional string message to append to the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNoCommonElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertNoCommonElements}
|
||||
|
||||
Checks whether actual iterable and expected iterable are disjoint.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)` {#TestCase.assertNotAlmostEqual}
|
||||
@ -587,33 +397,6 @@ as significant digits (measured from the most signficant digit).
|
||||
Objects that are equal automatically fail.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEmpty(container, msg=None)` {#TestCase.assertNotEmpty}
|
||||
|
||||
Assert that an object has non-zero length.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEndsWith(actual, unexpected_end, msg=None)` {#TestCase.assertNotEndsWith}
|
||||
|
||||
Assert that actual.endswith(unexpected_end) is False.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`unexpected_end`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEqual(first, second, msg=None)` {#TestCase.assertNotEqual}
|
||||
@ -651,20 +434,6 @@ Included for symmetry with assertIsInstance.
|
||||
Fail the test if the text matches the regular expression.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotStartsWith(actual, unexpected_start, msg=None)` {#TestCase.assertNotStartsWith}
|
||||
|
||||
Assert that actual.startswith(unexpected_start) is False.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`unexpected_start`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertProtoEquals(expected_message_maybe_ascii, message)` {#TestCase.assertProtoEquals}
|
||||
@ -739,38 +508,6 @@ Asserts that the message in a raised exception matches a regexp.
|
||||
* <b>`kwargs`</b>: Extra kwargs.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithLiteralMatch(expected_exception, expected_exception_message, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithLiteralMatch}
|
||||
|
||||
Asserts that the message in a raised exception equals the given string.
|
||||
|
||||
Unlike assertRaisesRegexp, this method takes a literal string, not
|
||||
a regular expression.
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ExType, 'message'):
|
||||
DoSomething()
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
||||
* <b>`expected_exception_message`</b>: String message expected in the raised
|
||||
exception. For a raise exception e, expected_exception_message must
|
||||
equal str(e).
|
||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
||||
* <b>`args`</b>: Extra args.
|
||||
* <b>`kwargs`</b>: Extra kwargs.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A context manager if callable_obj is None. Otherwise, None.
|
||||
|
||||
##### Raises:
|
||||
|
||||
self.failureException if callable_obj does not raise a macthing exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithPredicateMatch(exception_type, expected_err_re_or_predicate)` {#TestCase.assertRaisesWithPredicateMatch}
|
||||
@ -795,71 +532,6 @@ predicate search.
|
||||
exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithRegexpMatch(expected_exception, expected_regexp, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithRegexpMatch}
|
||||
|
||||
Asserts that the message in a raised exception matches the given regexp.
|
||||
|
||||
This is just a wrapper around assertRaisesRegexp. Please use
|
||||
assertRaisesRegexp instead of assertRaisesWithRegexpMatch.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
||||
* <b>`expected_regexp`</b>: Regexp (re pattern object or string) expected to be
|
||||
found in error message.
|
||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
||||
* <b>`args`</b>: Extra args.
|
||||
* <b>`kwargs`</b>: Extra keyword args.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A context manager if callable_obj is None. Otherwise, None.
|
||||
|
||||
##### Raises:
|
||||
|
||||
self.failureException if callable_obj does not raise a macthing exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRegexMatch(actual_str, regexes, message=None)` {#TestCase.assertRegexMatch}
|
||||
|
||||
Asserts that at least one regex in regexes matches str.
|
||||
|
||||
If possible you should use assertRegexpMatches, which is a simpler
|
||||
version of this method. assertRegexpMatches takes a single regular
|
||||
expression (a string or re compiled object) instead of a list.
|
||||
|
||||
Notes:
|
||||
1. This function uses substring matching, i.e. the matching
|
||||
succeeds if *any* substring of the error message matches *any*
|
||||
regex in the list. This is more convenient for the user than
|
||||
full-string matching.
|
||||
|
||||
2. If regexes is the empty list, the matching will always fail.
|
||||
|
||||
3. Use regexes=[''] for a regex that will always pass.
|
||||
|
||||
4. '.' matches any single character *except* the newline. To
|
||||
match any character, use '(.|
|
||||
)'.
|
||||
|
||||
5. '^' matches the beginning of each line, not just the beginning
|
||||
of the string. Similarly, '$' matches the end of each line.
|
||||
|
||||
6. An exception will be thrown if regexes contains an invalid
|
||||
regex.
|
||||
|
||||
Args:
|
||||
actual_str: The string we try to match with the items in regexes.
|
||||
regexes: The regular expressions we want to match against str.
|
||||
See "Notes" above for detailed notes on how this is interpreted.
|
||||
message: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRegexpMatches(text, expected_regexp, msg=None)` {#TestCase.assertRegexpMatches}
|
||||
@ -867,79 +539,6 @@ Asserts that at least one regex in regexes matches str.
|
||||
Fail the test unless the text matches the regular expression.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSameElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertSameElements}
|
||||
|
||||
Assert that two sequences have the same elements (in any order).
|
||||
|
||||
This method, unlike assertItemsEqual, doesn't care about any
|
||||
duplicates in the expected and actual sequences.
|
||||
|
||||
>> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
|
||||
# Doesn't raise an AssertionError
|
||||
|
||||
If possible, you should use assertItemsEqual instead of
|
||||
assertSameElements.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSameStructure(a, b, aname='a', bname='b', msg=None)` {#TestCase.assertSameStructure}
|
||||
|
||||
Asserts that two values contain the same structural content.
|
||||
|
||||
The two arguments should be data trees consisting of trees of dicts and
|
||||
lists. They will be deeply compared by walking into the contents of dicts
|
||||
and lists; other items will be compared using the == operator.
|
||||
If the two structures differ in content, the failure message will indicate
|
||||
the location within the structures where the first difference is found.
|
||||
This may be helpful when comparing large structures.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`a`</b>: The first structure to compare.
|
||||
* <b>`b`</b>: The second structure to compare.
|
||||
* <b>`aname`</b>: Variable name to use for the first structure in assertion messages.
|
||||
* <b>`bname`</b>: Variable name to use for the second structure.
|
||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceAlmostEqual(expected_seq, actual_seq, places=None, msg=None, delta=None)` {#TestCase.assertSequenceAlmostEqual}
|
||||
|
||||
An approximate equality assertion for ordered sequences.
|
||||
|
||||
Fail if the two sequences are unequal as determined by their value
|
||||
differences rounded to the given number of decimal places (default 7) and
|
||||
comparing to zero, or by comparing that the difference between each value
|
||||
in the two sequences is more than the given delta.
|
||||
|
||||
Note that decimal places (from zero) are usually not the same as significant
|
||||
digits (measured from the most signficant digit).
|
||||
|
||||
If the two sequences compare equal then they will automatically compare
|
||||
almost equal.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`places`</b>: The number of decimal places to compare.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
* <b>`delta`</b>: The OK difference between compared values.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)` {#TestCase.assertSequenceEqual}
|
||||
@ -960,26 +559,6 @@ which can be indexed, has a length, and has an equality operator.
|
||||
differences.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceStartsWith(prefix, whole, msg=None)` {#TestCase.assertSequenceStartsWith}
|
||||
|
||||
An equality assertion for the beginning of ordered sequences.
|
||||
|
||||
If prefix is an empty sequence, it will raise an error unless whole is also
|
||||
an empty sequence.
|
||||
|
||||
If prefix is not a sequence, it will raise an error if the first element of
|
||||
whole does not match.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`prefix`</b>: A sequence expected at the beginning of the whole parameter.
|
||||
* <b>`whole`</b>: The sequence in which to look for prefix.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSetEqual(set1, set2, msg=None)` {#TestCase.assertSetEqual}
|
||||
@ -1031,51 +610,6 @@ Assert that actual.startswith(expected_start) is True.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertTotallyOrdered(*groups, **kwargs)` {#TestCase.assertTotallyOrdered}
|
||||
|
||||
Asserts that total ordering has been implemented correctly.
|
||||
|
||||
For example, say you have a class A that compares only on its attribute x.
|
||||
Comparators other than __lt__ are omitted for brevity.
|
||||
|
||||
class A(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.x)
|
||||
|
||||
def __lt__(self, other):
|
||||
try:
|
||||
return self.x < other.x
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
|
||||
assertTotallyOrdered will check that instances can be ordered correctly.
|
||||
For example,
|
||||
|
||||
self.assertTotallyOrdered(
|
||||
[None], # None should come before everything else.
|
||||
[1], # Integers sort earlier.
|
||||
[A(1, 'a')],
|
||||
[A(2, 'b')], # 2 is after 1.
|
||||
[A(3, 'c'), A(3, 'd')], # The second argument is irrelevant.
|
||||
[A(4, 'z')],
|
||||
['foo']) # Strings sort last.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`*groups`</b>: A list of groups of elements. Each group of elements is a list
|
||||
of objects that are equal. The elements in each group must be less than
|
||||
the elements in the group after it. For example, these groups are
|
||||
totally ordered: [None], [1], [2, 2], [3].
|
||||
* <b>`**kwargs`</b>: optional msg keyword argument can be passed.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertTrue(expr, msg=None)` {#TestCase.assertTrue}
|
||||
@ -1098,13 +632,6 @@ A tuple-specific equality assertion.
|
||||
differences.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertUrlEqual(a, b, msg=None)` {#TestCase.assertUrlEqual}
|
||||
|
||||
Asserts that urls are equal, ignoring ordering of query params.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assert_(expr, msg=None)` {#TestCase.assert_}
|
||||
@ -1166,9 +693,9 @@ tearDown.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.fail(msg=None, prefix=None)` {#TestCase.fail}
|
||||
#### `tf.test.TestCase.fail(msg=None)` {#TestCase.fail}
|
||||
|
||||
Fail immediately with the given message, optionally prefixed.
|
||||
Fail immediately, with the given message.
|
||||
|
||||
|
||||
- - -
|
||||
@ -1220,13 +747,6 @@ Fail immediately with the given message, optionally prefixed.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.getRecordedProperties()` {#TestCase.getRecordedProperties}
|
||||
|
||||
Return any properties that the user has recorded.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.get_temp_dir()` {#TestCase.get_temp_dir}
|
||||
@ -1249,20 +769,6 @@ pollute each others environment.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.recordProperty(property_name, property_value)` {#TestCase.recordProperty}
|
||||
|
||||
Record an arbitrary property for later use.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`property_name`</b>: str, name of property to record; must be a valid XML
|
||||
attribute name
|
||||
* <b>`property_value`</b>: value of property; must be valid XML attribute value
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.run(result=None)` {#TestCase.run}
|
||||
@ -1288,18 +794,11 @@ Hook method for setting up class fixture before running tests in the class.
|
||||
|
||||
#### `tf.test.TestCase.shortDescription()` {#TestCase.shortDescription}
|
||||
|
||||
Format both the test method name and the first line of its docstring.
|
||||
Returns a one-line description of the test, or None if no
|
||||
description has been provided.
|
||||
|
||||
If no docstring is given, only returns the method name.
|
||||
|
||||
This method overrides unittest.TestCase.shortDescription(), which
|
||||
only returns the first line of the docstring, obscuring the name
|
||||
of the test upon failure.
|
||||
|
||||
##### Returns:
|
||||
|
||||
|
||||
* <b>`desc`</b>: A short description of a test method.
|
||||
The default implementation of this method returns the first line of
|
||||
the specified test method's docstring.
|
||||
|
||||
|
||||
- - -
|
||||
|
@ -0,0 +1,4 @@
|
||||
#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension}
|
||||
|
||||
|
||||
|
@ -0,0 +1,4 @@
|
||||
#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString}
|
||||
|
||||
|
||||
|
@ -0,0 +1,4 @@
|
||||
#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension}
|
||||
|
||||
|
||||
|
@ -0,0 +1,4 @@
|
||||
#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString}
|
||||
|
||||
|
||||
|
@ -485,6 +485,187 @@ metadata is stored in its NodeDef. This method retrieves the description.
|
||||
### `class tf.summary.SummaryDescription` {#SummaryDescription}
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ByteSize()` {#SummaryDescription.ByteSize}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.Clear()` {#SummaryDescription.Clear}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ClearExtension(extension_handle)` {#SummaryDescription.ClearExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ClearField(field_name)` {#SummaryDescription.ClearField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.CopyFrom(other_msg)` {#SummaryDescription.CopyFrom}
|
||||
|
||||
Copies the content of the specified message into the current message.
|
||||
|
||||
The method clears the current message and then merges the specified
|
||||
message using MergeFrom.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.DiscardUnknownFields()` {#SummaryDescription.DiscardUnknownFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.FindInitializationErrors()` {#SummaryDescription.FindInitializationErrors}
|
||||
|
||||
Finds required fields which are not initialized.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A list of strings. Each string is a path to an uninitialized field from
|
||||
the top-level message, e.g. "foo.bar[5].baz".
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.HasExtension(extension_handle)` {#SummaryDescription.HasExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.HasField(field_name)` {#SummaryDescription.HasField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.IsInitialized(errors=None)` {#SummaryDescription.IsInitialized}
|
||||
|
||||
Checks if all required fields of a message are set.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||
paths of all missing required fields.
|
||||
|
||||
##### Returns:
|
||||
|
||||
True iff the specified message has all required fields set.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ListFields()` {#SummaryDescription.ListFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.MergeFrom(msg)` {#SummaryDescription.MergeFrom}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.MergeFromString(serialized)` {#SummaryDescription.MergeFromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ParseFromString(serialized)` {#SummaryDescription.ParseFromString}
|
||||
|
||||
Parse serialized protocol buffer data into this message.
|
||||
|
||||
Like MergeFromString(), except we clear the object first and
|
||||
do not return the value that MergeFromString returns.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SerializePartialToString()` {#SummaryDescription.SerializePartialToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SerializeToString()` {#SummaryDescription.SerializeToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SetInParent()` {#SummaryDescription.SetInParent}
|
||||
|
||||
Sets the _cached_byte_size_dirty bit to true,
|
||||
and propagates this to our listener iff this was a state change.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.WhichOneof(oneof_name)` {#SummaryDescription.WhichOneof}
|
||||
|
||||
Returns the name of the currently set field inside a oneof, or None.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__deepcopy__(memo=None)` {#SummaryDescription.__deepcopy__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__eq__(other)` {#SummaryDescription.__eq__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__getstate__()` {#SummaryDescription.__getstate__}
|
||||
@ -492,12 +673,249 @@ metadata is stored in its NodeDef. This method retrieves the description.
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__hash__()` {#SummaryDescription.__hash__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__init__(**kwargs)` {#SummaryDescription.__init__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__ne__(other_msg)` {#SummaryDescription.__ne__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__repr__()` {#SummaryDescription.__repr__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__setstate__(state)` {#SummaryDescription.__setstate__}
|
||||
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__str__()` {#SummaryDescription.__str__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__unicode__()` {#SummaryDescription.__unicode__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.type_hint` {#SummaryDescription.type_hint}
|
||||
|
||||
Magic attribute generated for "type_hint" proto field.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.summary.TaggedRunMetadata` {#TaggedRunMetadata}
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ByteSize()` {#TaggedRunMetadata.ByteSize}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.Clear()` {#TaggedRunMetadata.Clear}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ClearExtension(extension_handle)` {#TaggedRunMetadata.ClearExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ClearField(field_name)` {#TaggedRunMetadata.ClearField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.CopyFrom(other_msg)` {#TaggedRunMetadata.CopyFrom}
|
||||
|
||||
Copies the content of the specified message into the current message.
|
||||
|
||||
The method clears the current message and then merges the specified
|
||||
message using MergeFrom.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.DiscardUnknownFields()` {#TaggedRunMetadata.DiscardUnknownFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.FindInitializationErrors()` {#TaggedRunMetadata.FindInitializationErrors}
|
||||
|
||||
Finds required fields which are not initialized.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A list of strings. Each string is a path to an uninitialized field from
|
||||
the top-level message, e.g. "foo.bar[5].baz".
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.HasExtension(extension_handle)` {#TaggedRunMetadata.HasExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.HasField(field_name)` {#TaggedRunMetadata.HasField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.IsInitialized(errors=None)` {#TaggedRunMetadata.IsInitialized}
|
||||
|
||||
Checks if all required fields of a message are set.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||
paths of all missing required fields.
|
||||
|
||||
##### Returns:
|
||||
|
||||
True iff the specified message has all required fields set.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ListFields()` {#TaggedRunMetadata.ListFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.MergeFrom(msg)` {#TaggedRunMetadata.MergeFrom}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.MergeFromString(serialized)` {#TaggedRunMetadata.MergeFromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ParseFromString(serialized)` {#TaggedRunMetadata.ParseFromString}
|
||||
|
||||
Parse serialized protocol buffer data into this message.
|
||||
|
||||
Like MergeFromString(), except we clear the object first and
|
||||
do not return the value that MergeFromString returns.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SerializePartialToString()` {#TaggedRunMetadata.SerializePartialToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SerializeToString()` {#TaggedRunMetadata.SerializeToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SetInParent()` {#TaggedRunMetadata.SetInParent}
|
||||
|
||||
Sets the _cached_byte_size_dirty bit to true,
|
||||
and propagates this to our listener iff this was a state change.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.WhichOneof(oneof_name)` {#TaggedRunMetadata.WhichOneof}
|
||||
|
||||
Returns the name of the currently set field inside a oneof, or None.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__deepcopy__(memo=None)` {#TaggedRunMetadata.__deepcopy__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__eq__(other)` {#TaggedRunMetadata.__eq__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__getstate__()` {#TaggedRunMetadata.__getstate__}
|
||||
@ -505,4 +923,67 @@ Support the pickle protocol.
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__hash__()` {#TaggedRunMetadata.__hash__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__init__(**kwargs)` {#TaggedRunMetadata.__init__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__ne__(other_msg)` {#TaggedRunMetadata.__ne__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__repr__()` {#TaggedRunMetadata.__repr__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__setstate__(state)` {#TaggedRunMetadata.__setstate__}
|
||||
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__str__()` {#TaggedRunMetadata.__str__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__unicode__()` {#TaggedRunMetadata.__unicode__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.run_metadata` {#TaggedRunMetadata.run_metadata}
|
||||
|
||||
Magic attribute generated for "run_metadata" proto field.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.tag` {#TaggedRunMetadata.tag}
|
||||
|
||||
Magic attribute generated for "tag" proto field.
|
||||
|
||||
|
||||
|
||||
|
@ -213,125 +213,6 @@ Checks that for all elements of farray1 and farray2
|
||||
* <b>`err`</b>: a float value.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertBetween(value, minv, maxv, msg=None)` {#TestCase.assertBetween}
|
||||
|
||||
Asserts that value is between minv and maxv (inclusive).
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCommandFails(command, regexes, env=None, close_fds=True, msg=None)` {#TestCase.assertCommandFails}
|
||||
|
||||
Asserts a shell command fails and the error matches a regex in a list.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`command`</b>: List or string representing the command to run.
|
||||
* <b>`regexes`</b>: the list of regular expression strings.
|
||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
||||
forking.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCommandSucceeds(command, regexes=('',), env=None, close_fds=True, msg=None)` {#TestCase.assertCommandSucceeds}
|
||||
|
||||
Asserts that a shell command succeeds (i.e. exits with code 0).
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`command`</b>: List or string representing the command to run.
|
||||
* <b>`regexes`</b>: List of regular expression byte strings that match success.
|
||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
||||
forking.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsExactSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsExactSubsequence}
|
||||
|
||||
Assert that "container" contains "subsequence" as an exact subsequence.
|
||||
|
||||
Asserts that "container" contains all the elements of "subsequence", in
|
||||
order, and without other elements interspersed. For example, [1, 2, 3] is an
|
||||
exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
||||
* <b>`subsequence`</b>: the list we hope will be an exact subsequence of container.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsInOrder(strings, target, msg=None)` {#TestCase.assertContainsInOrder}
|
||||
|
||||
Asserts that the strings provided are found in the target in order.
|
||||
|
||||
This may be useful for checking HTML output.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`strings`</b>: A list of strings, such as [ 'fox', 'dog' ]
|
||||
* <b>`target`</b>: A target string in which to look for the strings, such as
|
||||
'The quick brown fox jumped over the lazy dog'.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsSubsequence}
|
||||
|
||||
Assert that "container" contains "subsequence" as a subsequence.
|
||||
|
||||
Asserts that "container" contains all the elements of "subsequence", in
|
||||
order, but possibly with other elements interspersed. For example, [1, 2, 3]
|
||||
is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
||||
* <b>`subsequence`</b>: the list we hope will be a subsequence of container.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsSubset(expected_subset, actual_set, msg=None)` {#TestCase.assertContainsSubset}
|
||||
|
||||
Checks whether actual iterable is a superset of expected iterable.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCountEqual(*args, **kwargs)` {#TestCase.assertCountEqual}
|
||||
|
||||
An unordered sequence specific comparison.
|
||||
|
||||
Equivalent to assertItemsEqual(). This method is a compatibility layer
|
||||
for Python 3k, since 2to3 does not convert assertItemsEqual() calls into
|
||||
assertCountEqual() calls.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertDeviceEqual(device1, device2)` {#TestCase.assertDeviceEqual}
|
||||
@ -354,48 +235,9 @@ Checks whether actual is a superset of expected.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertDictEqual(a, b, msg=None)` {#TestCase.assertDictEqual}
|
||||
|
||||
Raises AssertionError if a and b are not equal dictionaries.
|
||||
|
||||
##### Args:
|
||||
#### `tf.test.TestCase.assertDictEqual(d1, d2, msg=None)` {#TestCase.assertDictEqual}
|
||||
|
||||
|
||||
* <b>`a`</b>: A dict, the expected value.
|
||||
* <b>`b`</b>: A dict, the actual value.
|
||||
* <b>`msg`</b>: An optional str, the associated message.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`AssertionError`</b>: if the dictionaries are not equal.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertEmpty(container, msg=None)` {#TestCase.assertEmpty}
|
||||
|
||||
Assert that an object has zero length.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertEndsWith(actual, expected_end, msg=None)` {#TestCase.assertEndsWith}
|
||||
|
||||
Assert that actual.endswith(expected_end) is True.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`expected_end`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
@ -480,11 +322,10 @@ Included for symmetry with assertIsNone.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertItemsEqual(*args, **kwargs)` {#TestCase.assertItemsEqual}
|
||||
#### `tf.test.TestCase.assertItemsEqual(expected_seq, actual_seq, msg=None)` {#TestCase.assertItemsEqual}
|
||||
|
||||
An unordered sequence specific comparison.
|
||||
|
||||
It asserts that actual_seq and expected_seq have the same element counts.
|
||||
An unordered sequence specific comparison. It asserts that
|
||||
actual_seq and expected_seq have the same element counts.
|
||||
Equivalent to::
|
||||
|
||||
self.assertEqual(Counter(iter(actual_seq)),
|
||||
@ -497,30 +338,6 @@ Asserts that each element has the same count in both sequences.
|
||||
- [0, 1, 1] and [1, 0, 1] compare equal.
|
||||
- [0, 0, 1] and [0, 1] compare unequal.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertJsonEqual(first, second, msg=None)` {#TestCase.assertJsonEqual}
|
||||
|
||||
Asserts that the JSON objects defined in two strings are equal.
|
||||
|
||||
A summary of the differences will be included in the failure message
|
||||
using assertSameStructure.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`first`</b>: A string contining JSON to decode and compare to second.
|
||||
* <b>`second`</b>: A string contining JSON to decode and compare to first.
|
||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
@ -590,13 +407,6 @@ if not.
|
||||
* <b>`msg`</b>: An optional string message to append to the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNoCommonElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertNoCommonElements}
|
||||
|
||||
Checks whether actual iterable and expected iterable are disjoint.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)` {#TestCase.assertNotAlmostEqual}
|
||||
@ -627,33 +437,6 @@ as significant digits (measured from the most signficant digit).
|
||||
Objects that are equal automatically fail.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEmpty(container, msg=None)` {#TestCase.assertNotEmpty}
|
||||
|
||||
Assert that an object has non-zero length.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEndsWith(actual, unexpected_end, msg=None)` {#TestCase.assertNotEndsWith}
|
||||
|
||||
Assert that actual.endswith(unexpected_end) is False.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`unexpected_end`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEqual(first, second, msg=None)` {#TestCase.assertNotEqual}
|
||||
@ -691,20 +474,6 @@ Included for symmetry with assertIsInstance.
|
||||
Fail the test if the text matches the regular expression.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotStartsWith(actual, unexpected_start, msg=None)` {#TestCase.assertNotStartsWith}
|
||||
|
||||
Assert that actual.startswith(unexpected_start) is False.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`unexpected_start`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertProtoEquals(expected_message_maybe_ascii, message)` {#TestCase.assertProtoEquals}
|
||||
@ -779,38 +548,6 @@ Asserts that the message in a raised exception matches a regexp.
|
||||
* <b>`kwargs`</b>: Extra kwargs.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithLiteralMatch(expected_exception, expected_exception_message, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithLiteralMatch}
|
||||
|
||||
Asserts that the message in a raised exception equals the given string.
|
||||
|
||||
Unlike assertRaisesRegexp, this method takes a literal string, not
|
||||
a regular expression.
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ExType, 'message'):
|
||||
DoSomething()
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
||||
* <b>`expected_exception_message`</b>: String message expected in the raised
|
||||
exception. For a raise exception e, expected_exception_message must
|
||||
equal str(e).
|
||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
||||
* <b>`args`</b>: Extra args.
|
||||
* <b>`kwargs`</b>: Extra kwargs.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A context manager if callable_obj is None. Otherwise, None.
|
||||
|
||||
##### Raises:
|
||||
|
||||
self.failureException if callable_obj does not raise a macthing exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithPredicateMatch(exception_type, expected_err_re_or_predicate)` {#TestCase.assertRaisesWithPredicateMatch}
|
||||
@ -835,71 +572,6 @@ predicate search.
|
||||
exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithRegexpMatch(expected_exception, expected_regexp, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithRegexpMatch}
|
||||
|
||||
Asserts that the message in a raised exception matches the given regexp.
|
||||
|
||||
This is just a wrapper around assertRaisesRegexp. Please use
|
||||
assertRaisesRegexp instead of assertRaisesWithRegexpMatch.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
||||
* <b>`expected_regexp`</b>: Regexp (re pattern object or string) expected to be
|
||||
found in error message.
|
||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
||||
* <b>`args`</b>: Extra args.
|
||||
* <b>`kwargs`</b>: Extra keyword args.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A context manager if callable_obj is None. Otherwise, None.
|
||||
|
||||
##### Raises:
|
||||
|
||||
self.failureException if callable_obj does not raise a macthing exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRegexMatch(actual_str, regexes, message=None)` {#TestCase.assertRegexMatch}
|
||||
|
||||
Asserts that at least one regex in regexes matches str.
|
||||
|
||||
If possible you should use assertRegexpMatches, which is a simpler
|
||||
version of this method. assertRegexpMatches takes a single regular
|
||||
expression (a string or re compiled object) instead of a list.
|
||||
|
||||
Notes:
|
||||
1. This function uses substring matching, i.e. the matching
|
||||
succeeds if *any* substring of the error message matches *any*
|
||||
regex in the list. This is more convenient for the user than
|
||||
full-string matching.
|
||||
|
||||
2. If regexes is the empty list, the matching will always fail.
|
||||
|
||||
3. Use regexes=[''] for a regex that will always pass.
|
||||
|
||||
4. '.' matches any single character *except* the newline. To
|
||||
match any character, use '(.|
|
||||
)'.
|
||||
|
||||
5. '^' matches the beginning of each line, not just the beginning
|
||||
of the string. Similarly, '$' matches the end of each line.
|
||||
|
||||
6. An exception will be thrown if regexes contains an invalid
|
||||
regex.
|
||||
|
||||
Args:
|
||||
actual_str: The string we try to match with the items in regexes.
|
||||
regexes: The regular expressions we want to match against str.
|
||||
See "Notes" above for detailed notes on how this is interpreted.
|
||||
message: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRegexpMatches(text, expected_regexp, msg=None)` {#TestCase.assertRegexpMatches}
|
||||
@ -907,79 +579,6 @@ Asserts that at least one regex in regexes matches str.
|
||||
Fail the test unless the text matches the regular expression.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSameElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertSameElements}
|
||||
|
||||
Assert that two sequences have the same elements (in any order).
|
||||
|
||||
This method, unlike assertItemsEqual, doesn't care about any
|
||||
duplicates in the expected and actual sequences.
|
||||
|
||||
>> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
|
||||
# Doesn't raise an AssertionError
|
||||
|
||||
If possible, you should use assertItemsEqual instead of
|
||||
assertSameElements.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSameStructure(a, b, aname='a', bname='b', msg=None)` {#TestCase.assertSameStructure}
|
||||
|
||||
Asserts that two values contain the same structural content.
|
||||
|
||||
The two arguments should be data trees consisting of trees of dicts and
|
||||
lists. They will be deeply compared by walking into the contents of dicts
|
||||
and lists; other items will be compared using the == operator.
|
||||
If the two structures differ in content, the failure message will indicate
|
||||
the location within the structures where the first difference is found.
|
||||
This may be helpful when comparing large structures.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`a`</b>: The first structure to compare.
|
||||
* <b>`b`</b>: The second structure to compare.
|
||||
* <b>`aname`</b>: Variable name to use for the first structure in assertion messages.
|
||||
* <b>`bname`</b>: Variable name to use for the second structure.
|
||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceAlmostEqual(expected_seq, actual_seq, places=None, msg=None, delta=None)` {#TestCase.assertSequenceAlmostEqual}
|
||||
|
||||
An approximate equality assertion for ordered sequences.
|
||||
|
||||
Fail if the two sequences are unequal as determined by their value
|
||||
differences rounded to the given number of decimal places (default 7) and
|
||||
comparing to zero, or by comparing that the difference between each value
|
||||
in the two sequences is more than the given delta.
|
||||
|
||||
Note that decimal places (from zero) are usually not the same as significant
|
||||
digits (measured from the most signficant digit).
|
||||
|
||||
If the two sequences compare equal then they will automatically compare
|
||||
almost equal.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`places`</b>: The number of decimal places to compare.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
* <b>`delta`</b>: The OK difference between compared values.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)` {#TestCase.assertSequenceEqual}
|
||||
@ -1000,26 +599,6 @@ which can be indexed, has a length, and has an equality operator.
|
||||
differences.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceStartsWith(prefix, whole, msg=None)` {#TestCase.assertSequenceStartsWith}
|
||||
|
||||
An equality assertion for the beginning of ordered sequences.
|
||||
|
||||
If prefix is an empty sequence, it will raise an error unless whole is also
|
||||
an empty sequence.
|
||||
|
||||
If prefix is not a sequence, it will raise an error if the first element of
|
||||
whole does not match.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`prefix`</b>: A sequence expected at the beginning of the whole parameter.
|
||||
* <b>`whole`</b>: The sequence in which to look for prefix.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSetEqual(set1, set2, msg=None)` {#TestCase.assertSetEqual}
|
||||
@ -1071,51 +650,6 @@ Assert that actual.startswith(expected_start) is True.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertTotallyOrdered(*groups, **kwargs)` {#TestCase.assertTotallyOrdered}
|
||||
|
||||
Asserts that total ordering has been implemented correctly.
|
||||
|
||||
For example, say you have a class A that compares only on its attribute x.
|
||||
Comparators other than __lt__ are omitted for brevity.
|
||||
|
||||
class A(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.x)
|
||||
|
||||
def __lt__(self, other):
|
||||
try:
|
||||
return self.x < other.x
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
|
||||
assertTotallyOrdered will check that instances can be ordered correctly.
|
||||
For example,
|
||||
|
||||
self.assertTotallyOrdered(
|
||||
[None], # None should come before everything else.
|
||||
[1], # Integers sort earlier.
|
||||
[A(1, 'a')],
|
||||
[A(2, 'b')], # 2 is after 1.
|
||||
[A(3, 'c'), A(3, 'd')], # The second argument is irrelevant.
|
||||
[A(4, 'z')],
|
||||
['foo']) # Strings sort last.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`*groups`</b>: A list of groups of elements. Each group of elements is a list
|
||||
of objects that are equal. The elements in each group must be less than
|
||||
the elements in the group after it. For example, these groups are
|
||||
totally ordered: [None], [1], [2, 2], [3].
|
||||
* <b>`**kwargs`</b>: optional msg keyword argument can be passed.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertTrue(expr, msg=None)` {#TestCase.assertTrue}
|
||||
@ -1138,13 +672,6 @@ A tuple-specific equality assertion.
|
||||
differences.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertUrlEqual(a, b, msg=None)` {#TestCase.assertUrlEqual}
|
||||
|
||||
Asserts that urls are equal, ignoring ordering of query params.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assert_(expr, msg=None)` {#TestCase.assert_}
|
||||
@ -1206,9 +733,9 @@ tearDown.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.fail(msg=None, prefix=None)` {#TestCase.fail}
|
||||
#### `tf.test.TestCase.fail(msg=None)` {#TestCase.fail}
|
||||
|
||||
Fail immediately with the given message, optionally prefixed.
|
||||
Fail immediately, with the given message.
|
||||
|
||||
|
||||
- - -
|
||||
@ -1260,13 +787,6 @@ Fail immediately with the given message, optionally prefixed.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.getRecordedProperties()` {#TestCase.getRecordedProperties}
|
||||
|
||||
Return any properties that the user has recorded.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.get_temp_dir()` {#TestCase.get_temp_dir}
|
||||
@ -1289,20 +809,6 @@ pollute each others environment.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.recordProperty(property_name, property_value)` {#TestCase.recordProperty}
|
||||
|
||||
Record an arbitrary property for later use.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`property_name`</b>: str, name of property to record; must be a valid XML
|
||||
attribute name
|
||||
* <b>`property_value`</b>: value of property; must be valid XML attribute value
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.run(result=None)` {#TestCase.run}
|
||||
@ -1328,18 +834,11 @@ Hook method for setting up class fixture before running tests in the class.
|
||||
|
||||
#### `tf.test.TestCase.shortDescription()` {#TestCase.shortDescription}
|
||||
|
||||
Format both the test method name and the first line of its docstring.
|
||||
Returns a one-line description of the test, or None if no
|
||||
description has been provided.
|
||||
|
||||
If no docstring is given, only returns the method name.
|
||||
|
||||
This method overrides unittest.TestCase.shortDescription(), which
|
||||
only returns the first line of the docstring, obscuring the name
|
||||
of the test upon failure.
|
||||
|
||||
##### Returns:
|
||||
|
||||
|
||||
* <b>`desc`</b>: A short description of a test method.
|
||||
The default implementation of this method returns the first line of
|
||||
the specified test method's docstring.
|
||||
|
||||
|
||||
- - -
|
||||
|
@ -78,37 +78,51 @@ If the above commands do not work on your system, you can follow these instructi
|
||||
|
||||
```bash
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp27-none-linux_x86_64.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 2.7:
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py2-none-any.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py2-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 2.7:
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py2-none-any.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py2-none-any.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.3
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.3
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.6
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.6
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py3-none-any.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py3-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py3-none-any.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py3-none-any.whl
|
||||
```
|
||||
|
||||
Install TensorFlow:
|
||||
@ -150,14 +164,14 @@ Both distributions include pip. To install the CPU-only version of
|
||||
TensorFlow, enter the following command at a command prompt:
|
||||
|
||||
```bat
|
||||
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-0.12.1-cp35-cp35m-win_amd64.whl
|
||||
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.0.0rc1-cp35-cp35m-win_amd64.whl
|
||||
```
|
||||
|
||||
To install the GPU version of TensorFlow, enter the following command
|
||||
at a command prompt:
|
||||
|
||||
```bat
|
||||
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-win_amd64.whl
|
||||
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.0.0rc1-cp35-cp35m-win_amd64.whl
|
||||
```
|
||||
|
||||
You can now [test your installation](#test-the-tensorflow-installation).
|
||||
@ -212,37 +226,51 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
|
||||
|
||||
```bash
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp27-none-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 2.7:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py2-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py2-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 2.7:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py2-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py2-none-any.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.3
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.3
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.6
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.6
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py3-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py3-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py3-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py3-none-any.whl
|
||||
```
|
||||
|
||||
Finally install TensorFlow:
|
||||
@ -364,37 +392,51 @@ select the correct binary to install:
|
||||
|
||||
```bash
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp27-none-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 2.7:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py2-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py2-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 2.7:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py2-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py2-none-any.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.3
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.3
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp33-cp33m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.6
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.6
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.0.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.1-py3-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.0rc1-py3-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.1-py3-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.0.0rc1-py3-none-any.whl
|
||||
```
|
||||
|
||||
Finally install TensorFlow:
|
||||
@ -462,7 +504,7 @@ code.
|
||||
code.
|
||||
|
||||
We also have tags with `latest` replaced by a released version (e.g.,
|
||||
`0.12.1-gpu`).
|
||||
`1.0.0-rc1-gpu`).
|
||||
|
||||
With Docker the installation is as follows:
|
||||
|
||||
@ -557,7 +599,7 @@ To build TensorFlow from source on Windows, you can use experimental
|
||||
support for [Bazel on
|
||||
Windows](https://bazel.build/versions/master/docs/windows.html) or the
|
||||
[TensorFlow CMake
|
||||
build](https://github.com/tensorflow/tensorflow/tree/r0.12/tensorflow/contrib/cmake).
|
||||
build](https://github.com/tensorflow/tensorflow/tree/r1.0/tensorflow/contrib/cmake).
|
||||
|
||||
### Clone the TensorFlow repository
|
||||
|
||||
@ -856,37 +898,38 @@ default and if you want to limit RAM usage you can add `--local_resources
|
||||
2048,.5,1.0` while invoking bazel.
|
||||
|
||||
```bash
|
||||
$ bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
|
||||
$ bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
|
||||
|
||||
# To build with support for CUDA:
|
||||
$ bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
|
||||
$ bazel build --config opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
|
||||
|
||||
# Alternatively, to build with support for OpenCL:
|
||||
$ bazel build -c opt --config=sycl //tensorflow/tools/pip_package:build_pip_package
|
||||
# Alternatively, to build with support for OpenCL (Experimental):
|
||||
$ bazel build --config opt --config=sycl //tensorflow/tools/pip_package:build_pip_package
|
||||
|
||||
$ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
|
||||
|
||||
# The name of the .whl file will depend on your platform.
|
||||
$ sudo pip install /tmp/tensorflow_pkg/tensorflow-0.12.1-py2-none-any.whl
|
||||
$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.0.0rc1-py2-none-any.whl
|
||||
```
|
||||
|
||||
## Optimizing CPU performance
|
||||
|
||||
To be compatible with as wide a range of machines as possible, TensorFlow
|
||||
defaults to only using SSE4.1 SIMD instructions on x86 machines. Most modern PCs
|
||||
and Macs support more advanced instructions, so if you're building a binary
|
||||
that you'll only be running on your own machine, you can enable these by using
|
||||
`--copt=-march=native` in your bazel build command. For example:
|
||||
defaults to only using SSE4 SIMD instructions. Most modern computers support
|
||||
more advanced instructions. So if you're building a binary that you'll only
|
||||
be running on your own machine, you can enable these by using `-march=native`
|
||||
for optimization options when running `configure`. Then you can build your
|
||||
optimized binaries with the following command:
|
||||
|
||||
``` bash
|
||||
$ bazel build --copt=-march=native -c opt //tensorflow/tools/pip_package:build_pip_package
|
||||
$ bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
|
||||
```
|
||||
|
||||
If you are distributing a binary but know the capabilities of the machines
|
||||
you'll be running on, you can manually choose the right instructions with
|
||||
something like `--copt=-march=avx`. You may also want to enable multiple
|
||||
something like `-march=avx`. You may also want to enable multiple
|
||||
features using several arguments, for example
|
||||
`--copt=-mavx2 --copt=-mfma`.
|
||||
`-mavx2,-mfma`.
|
||||
|
||||
If you run a binary built using SIMD instructions on a machine that doesn't
|
||||
support them, you'll see an illegal instruction error when that code is
|
||||
@ -902,10 +945,10 @@ system directories, run the following commands inside the TensorFlow root
|
||||
directory:
|
||||
|
||||
```bash
|
||||
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
|
||||
bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
|
||||
|
||||
# To build with GPU support:
|
||||
bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
|
||||
bazel build --config opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
|
||||
|
||||
mkdir _python_build
|
||||
cd _python_build
|
||||
|
@ -177,7 +177,7 @@ tf_custom_op_library(
|
||||
Run the following command to build `zero_out.so`.
|
||||
|
||||
```bash
|
||||
$ bazel build -c opt //tensorflow/core/user_ops:zero_out.so
|
||||
$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so
|
||||
```
|
||||
|
||||
> Note:
|
||||
|
@ -42,10 +42,10 @@ bazel build tensorflow/examples/image_retraining:retrain
|
||||
|
||||
If you have a machine which supports [the AVX instruction set](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions)
|
||||
(common in x86 CPUs produced in the last few years) you can improve the running
|
||||
speed of the retraining by building for that architecture, like this:
|
||||
speed of the retraining by building for that architecture, like this (after choosing appropriate options in `configure`):
|
||||
|
||||
```sh
|
||||
bazel build -c opt --copt=-mavx tensorflow/examples/image_retraining:retrain
|
||||
bazel build --config opt tensorflow/examples/image_retraining:retrain
|
||||
```
|
||||
|
||||
The retrainer can then be run like this:
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Roadmap
|
||||
**Last updated: June 3, 2016**
|
||||
**Last updated: January 23, 2017**
|
||||
|
||||
TensorFlow is a fast moving project. In order for the community to better
|
||||
understand what the near future will bring, this document shares what we are
|
||||
@ -11,29 +11,28 @@ The features on this list are targeted for the next few months. At this point,
|
||||
we do not have timelines for these features.
|
||||
|
||||
### Improve non-Python language support
|
||||
C and C++ APIs for:
|
||||
|
||||
* Graph construction
|
||||
* Gradients
|
||||
* Shape Inference
|
||||
* Improve C++ API for graph construction and gradients
|
||||
* Java language support
|
||||
* Go language support
|
||||
|
||||
### Making TensorFlow easier to use
|
||||
* Easier setup for distributed training jobs
|
||||
* High-level APIs
|
||||
* Well-maintained models showing best practices
|
||||
|
||||
### Performance
|
||||
* Speed and memory benchmarks
|
||||
* Distributed full model benchmarks
|
||||
* Performance and memory usage improvements
|
||||
|
||||
### Core Features
|
||||
* Repeated partial graph evaluation ([#672](https://github.com/tensorflow/tensorflow/issues/672))
|
||||
* Automatic op placement ([#2126](https://github.com/tensorflow/tensorflow/issues/2126))
|
||||
* Support for graph-level functions
|
||||
|
||||
### Platforms
|
||||
* OpenCL support ([#22](https://github.com/tensorflow/tensorflow/issues/22))
|
||||
|
||||
### Community
|
||||
* More educational resources
|
||||
* Better integration of TensorFlow into the opensource big data ecosystem ([#1996](https://github.com/tensorflow/tensorflow/issues/1996),
|
||||
[#2218](https://github.com/tensorflow/tensorflow/issues/2218),
|
||||
* Better integration of TensorFlow into the opensource big data ecosystem (e.g.
|
||||
[#2655](https://github.com/tensorflow/tensorflow/issues/2655))
|
||||
* Models benchmarking and comparison tooling
|
||||
|
@ -30,7 +30,7 @@ then
|
||||
then
|
||||
echo "Protocol buffer compiler protoc not found in PATH or in ${PROTOC}"
|
||||
echo "Perhaps build it using:"
|
||||
echo "bazel build -c opt @protobuf//:protoc"
|
||||
echo "bazel build --config opt @protobuf//:protoc"
|
||||
exit 1
|
||||
fi
|
||||
PROTOC=$PATH_PROTOC
|
||||
|
@ -40,7 +40,7 @@ Configure and build the Java Archive (JAR) and native library:
|
||||
./configure
|
||||
|
||||
# Build the JAR and native library
|
||||
bazel build -c opt \
|
||||
bazel build --config opt \
|
||||
//tensorflow/java:tensorflow \
|
||||
//tensorflow/java:libtensorflow_jni
|
||||
```
|
||||
|
@ -151,14 +151,15 @@ def _FindAttrInOpDef(attr_name, op_def):
|
||||
|
||||
def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
name=None, op_dict=None, producer_op_list=None):
|
||||
"""Imports the TensorFlow graph in `graph_def` into the Python `Graph`.
|
||||
"""Imports the graph from `graph_def` into the current default `Graph`.
|
||||
|
||||
This function provides a way to import a serialized TensorFlow
|
||||
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
|
||||
protocol buffer, and extract individual objects in the `GraphDef` as
|
||||
[`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
|
||||
[`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
|
||||
`GraphDef` proto.
|
||||
[`Tensor`](#Tensor) and [`Operation`](#Operation) objects. Once extracted,
|
||||
these objects are placed into the current default `Graph`. See
|
||||
[`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a `GraphDef`
|
||||
proto.
|
||||
|
||||
Args:
|
||||
graph_def: A `GraphDef` proto containing operations to be imported into
|
||||
|
@ -491,7 +491,9 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
print("dtype = %s, shape = %s" % (a.dtype, a.shape))
|
||||
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
def assertAllCloseAccordingToType(self, a, b, rtol=1e-6, atol=1e-6):
|
||||
def assertAllCloseAccordingToType(self, a, b, rtol=1e-6, atol=1e-6,
|
||||
float_rtol=1e-6, float_atol=1e-6,
|
||||
half_rtol=1e-3, half_atol=1e-3):
|
||||
"""Like assertAllClose, but also suitable for comparing fp16 arrays.
|
||||
|
||||
In particular, the tolerance is reduced to 1e-3 if at least
|
||||
@ -502,12 +504,19 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
b: a numpy ndarray or anything can be converted to one.
|
||||
rtol: relative tolerance
|
||||
atol: absolute tolerance
|
||||
float_rtol: relative tolerance for float32
|
||||
float_atol: absolute tolerance for float32
|
||||
half_rtol: relative tolerance for float16
|
||||
half_atol: absolute tolerance for float16
|
||||
"""
|
||||
a = self._GetNdArray(a)
|
||||
b = self._GetNdArray(b)
|
||||
if a.dtype == np.float32 or b.dtype == np.float32:
|
||||
rtol = max(rtol, float_rtol)
|
||||
atol = max(atol, float_atol)
|
||||
if a.dtype == np.float16 or b.dtype == np.float16:
|
||||
rtol = max(rtol, 1e-3)
|
||||
atol = max(atol, 1e-3)
|
||||
rtol = max(rtol, half_rtol)
|
||||
atol = max(atol, half_atol)
|
||||
|
||||
self.assertAllClose(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
@ -193,6 +193,55 @@ class TestUtilTest(test_util.TensorFlowTestCase):
|
||||
y = [15]
|
||||
control_flow_ops.Assert(x, y).run()
|
||||
|
||||
def testAssertAllCloseAccordingToType(self):
|
||||
# test float64
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.asarray([1e-8], dtype=np.float64),
|
||||
np.asarray([2e-8], dtype=np.float64),
|
||||
rtol=1e-8, atol=1e-8
|
||||
)
|
||||
|
||||
with (self.assertRaises(AssertionError)):
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.asarray([1e-7], dtype=np.float64),
|
||||
np.asarray([2e-7], dtype=np.float64),
|
||||
rtol=1e-8, atol=1e-8
|
||||
)
|
||||
|
||||
# test float32
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.asarray([1e-7], dtype=np.float32),
|
||||
np.asarray([2e-7], dtype=np.float32),
|
||||
rtol=1e-8, atol=1e-8,
|
||||
float_rtol=1e-7, float_atol=1e-7
|
||||
)
|
||||
|
||||
with (self.assertRaises(AssertionError)):
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.asarray([1e-6], dtype=np.float32),
|
||||
np.asarray([2e-6], dtype=np.float32),
|
||||
rtol=1e-8, atol=1e-8,
|
||||
float_rtol=1e-7, float_atol=1e-7
|
||||
)
|
||||
|
||||
# test float16
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.asarray([1e-4], dtype=np.float16),
|
||||
np.asarray([2e-4], dtype=np.float16),
|
||||
rtol=1e-8, atol=1e-8,
|
||||
float_rtol=1e-7, float_atol=1e-7,
|
||||
half_rtol=1e-4, half_atol=1e-4
|
||||
)
|
||||
|
||||
with (self.assertRaises(AssertionError)):
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.asarray([1e-3], dtype=np.float16),
|
||||
np.asarray([2e-3], dtype=np.float16),
|
||||
rtol=1e-8, atol=1e-8,
|
||||
float_rtol=1e-7, float_atol=1e-7,
|
||||
half_rtol=1e-4, half_atol=1e-4
|
||||
)
|
||||
|
||||
def testRandomSeed(self):
|
||||
a = random.randint(1, 1000)
|
||||
a_np_rand = np.random.rand(1)
|
||||
|
@ -70,13 +70,13 @@ class ScalarStrictTest(test.TestCase):
|
||||
self.assertAllEqual(r, correct)
|
||||
|
||||
def testConcat(self):
|
||||
self.check(array_ops.concat_v2, (([2], [3], [7]), [0]),
|
||||
self.check(array_ops.concat, (([2], [3], [7]), [0]),
|
||||
'axis tensor should be a scalar integer', [2, 3, 7])
|
||||
for data in (2, 3, 7), (2, [3], 7), (2, 3, [7]):
|
||||
self.check(array_ops.concat_v2, (data, 0),
|
||||
self.check(array_ops.concat, (data, 0),
|
||||
r'Expected \w+ dimensions in the range \[0, 0\)', [2, 3, 7])
|
||||
for data in ([2], 3, 7), ([2], [3], 7):
|
||||
self.check(array_ops.concat_v2, (data, 0),
|
||||
self.check(array_ops.concat, (data, 0),
|
||||
r'Ranks of all input tensors should match', [2, 3, 7])
|
||||
|
||||
def testFill(self):
|
||||
|
@ -49,12 +49,21 @@ class SegmentReductionHelper(test.TestCase):
|
||||
slice_shape = x.shape[indices.ndim:]
|
||||
x_flat = x.reshape((indices.size,) + slice_shape)
|
||||
for i, index in enumerate(indices.ravel()):
|
||||
if output[index] is not None:
|
||||
if (output[index] is not None) and op1 == np.max:
|
||||
for j in range(0, output[index].shape[0]):
|
||||
output[index][j] = op1([output[index][j], x_flat[i][j]])
|
||||
elif output[index] is not None:
|
||||
output[index] = op1(output[index], x_flat[i])
|
||||
else:
|
||||
output[index] = x_flat[i]
|
||||
# zero initialize values that are still uncalcuated.
|
||||
output = [o if o is not None else np.zeros(slice_shape) for o in output]
|
||||
# output = [o if o is not None else np.zeros(slice_shape) for o in output]
|
||||
if not op1 == np.max:
|
||||
output = [o if o is not None else np.zeros(slice_shape) for o in output]
|
||||
else:
|
||||
zeroslice = np.zeros(slice_shape)
|
||||
zeroslice.fill(dtype.min)
|
||||
output = [o if o is not None else zeroslice for o in output]
|
||||
if op2 is not None:
|
||||
output = [op2(o) for o in output]
|
||||
output = [o.reshape(slice_shape) for o in output]
|
||||
@ -245,7 +254,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
|
||||
self._assertAllClose(indices, np_ans, tf_ans)
|
||||
self.assertShapeEqual(np_ans, s)
|
||||
|
||||
def testGradient(self):
|
||||
def testGradientSegmentSum(self):
|
||||
num_cols = 2
|
||||
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
|
||||
num_segments = max(indices_flat) + 3
|
||||
@ -318,6 +327,23 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
|
||||
unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
|
||||
self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype))
|
||||
|
||||
def testGradientSegmentMax(self):
|
||||
num_cols = 2
|
||||
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
|
||||
num_segments = max(indices_flat) + 3
|
||||
for indices in indices_flat, indices_flat.reshape(5, 2):
|
||||
shape = indices.shape + (num_cols,)
|
||||
with self.test_session():
|
||||
tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
|
||||
s = math_ops.unsorted_segment_max(data=tf_x, segment_ids=indices,
|
||||
num_segments=num_segments)
|
||||
jacob_t, jacob_n = gradient_checker.compute_gradient(
|
||||
tf_x,
|
||||
shape,
|
||||
s,
|
||||
[num_segments, num_cols],
|
||||
x_init_value=np_x.astype(np.double), delta=1)
|
||||
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
||||
|
||||
class UnsortedSegmentSumGpuTest(UnsortedSegmentSumTest):
|
||||
use_gpu = True
|
||||
|
@ -979,12 +979,6 @@ def unstack(value, num=None, axis=0, name="unstack"):
|
||||
return gen_array_ops._unpack(value, num=num, axis=axis, name=name)
|
||||
|
||||
|
||||
# concat_v2 is an alias for concat. concat_v2 will be deprecated and removed
|
||||
# soon, please use concat.
|
||||
def concat_v2(values, axis, name="concat_v2"):
|
||||
return concat(values, axis, name)
|
||||
|
||||
|
||||
def concat(values, axis, name="concat"):
|
||||
"""Concatenates tensors along one dimension.
|
||||
|
||||
|
@ -70,6 +70,11 @@ def _Collect(val, collections, default_collections):
|
||||
ops.add_to_collection(key, val)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"2016-11-30", "Please switch to tf.summary.histogram. Note that "
|
||||
"tf.summary.histogram uses the node name instead of the tag. "
|
||||
"This means that TensorFlow will automatically de-duplicate summary "
|
||||
"names based on the scope they are created in.")
|
||||
def histogram_summary(tag, values, collections=None, name=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Outputs a `Summary` protocol buffer with a histogram.
|
||||
@ -304,6 +309,13 @@ def get_summary_op():
|
||||
return summary_op
|
||||
|
||||
|
||||
@deprecated(
|
||||
"2016-11-30", "Please switch to tf.summary.scalar. Note that "
|
||||
"tf.summary.scalar uses the node name instead of the tag. "
|
||||
"This means that TensorFlow will automatically de-duplicate summary "
|
||||
"names based on the scope they are created in. Also, passing a "
|
||||
"tensor or list of tags to a scalar summary op is no longer "
|
||||
"supported.")
|
||||
def scalar_summary(tags, values, collections=None, name=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Outputs a `Summary` protocol buffer with scalar values.
|
||||
|
@ -188,35 +188,42 @@ def _SparseSegmentSqrtNGrad(op, grad):
|
||||
dim0), None, None)
|
||||
|
||||
|
||||
def _SegmentMinOrMaxGrad(op, grad):
|
||||
"""Gradient for SegmentMin and SegmentMax. Both share the same code."""
|
||||
zeros = array_ops.zeros(
|
||||
array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype)
|
||||
def _SegmentMinOrMaxGrad(op, grad, is_sorted):
|
||||
"""Gradient for SegmentMin and (unsorted) SegmentMax. They share similar code."""
|
||||
zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
|
||||
dtype=op.inputs[0].dtype)
|
||||
|
||||
# Get the number of selected (minimum or maximum) elements in each segment.
|
||||
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
|
||||
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
|
||||
num_selected = math_ops.segment_sum(
|
||||
math_ops.cast(is_selected, grad.dtype), op.inputs[1])
|
||||
if is_sorted:
|
||||
num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
|
||||
op.inputs[1])
|
||||
else:
|
||||
num_selected = math_ops.unsorted_segment_sum(math_ops.cast(is_selected, grad.dtype),
|
||||
op.inputs[1], op.inputs[2])
|
||||
|
||||
# Compute the gradient for each segment. The gradient for the ith segment is
|
||||
# divided evenly among the selected elements in that segment.
|
||||
weighted_grads = math_ops.div(grad, num_selected)
|
||||
gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
|
||||
|
||||
return array_ops.where(is_selected, gathered_grads, zeros), None
|
||||
if is_sorted:
|
||||
return array_ops.where(is_selected, gathered_grads, zeros), None
|
||||
else:
|
||||
return array_ops.where(is_selected, gathered_grads, zeros), None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("SegmentMin")
|
||||
def _SegmentMinGrad(op, grad):
|
||||
"""Gradient for SegmentMin."""
|
||||
return _SegmentMinOrMaxGrad(op, grad)
|
||||
return _SegmentMinOrMaxGrad(op, grad, True)
|
||||
|
||||
|
||||
@ops.RegisterGradient("SegmentMax")
|
||||
def _SegmentMaxGrad(op, grad):
|
||||
"""Gradient for SegmentMax."""
|
||||
return _SegmentMinOrMaxGrad(op, grad)
|
||||
return _SegmentMinOrMaxGrad(op, grad, True)
|
||||
|
||||
|
||||
@ops.RegisterGradient("UnsortedSegmentSum")
|
||||
@ -225,6 +232,11 @@ def _UnsortedSegmentSumGrad(op, grad):
|
||||
return array_ops.gather(grad, op.inputs[1]), None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("UnsortedSegmentMax")
|
||||
def _UnsortedSegmentMaxGrad(op, grad):
|
||||
return _SegmentMinOrMaxGrad(op, grad, False)
|
||||
|
||||
|
||||
@ops.RegisterGradient("Abs")
|
||||
def _AbsGrad(op, grad):
|
||||
x = op.inputs[0]
|
||||
|
@ -196,6 +196,7 @@ tf.segment_sum(c, tf.constant([0, 0, 1]))
|
||||
@@segment_mean
|
||||
|
||||
@@unsorted_segment_sum
|
||||
@@unsorted_segment_max
|
||||
|
||||
@@sparse_segment_sum
|
||||
@@sparse_segment_mean
|
||||
|
@ -34,7 +34,7 @@ def _has_valid_dims(weights_shape, values_shape):
|
||||
with ops.name_scope(
|
||||
None, "has_invalid_dims", (weights_shape, values_shape)) as scope:
|
||||
values_shape_2d = array_ops.expand_dims(values_shape, -1)
|
||||
valid_dims = array_ops.concat_v2(
|
||||
valid_dims = array_ops.concat(
|
||||
(values_shape_2d, array_ops.ones_like(values_shape_2d)), axis=1)
|
||||
weights_shape_2d = array_ops.expand_dims(weights_shape, -1)
|
||||
invalid_dims = sets.set_difference(weights_shape_2d, valid_dims)
|
||||
|
@ -64,7 +64,8 @@ def match_filenames_once(pattern, name=None):
|
||||
"""
|
||||
with ops.name_scope(name, "matching_filenames", [pattern]) as name:
|
||||
return variables.Variable(io_ops.matching_files(pattern), trainable=False,
|
||||
name=name, validate_shape=False)
|
||||
name=name, validate_shape=False,
|
||||
collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
||||
|
||||
|
||||
def limit_epochs(tensor, num_epochs=None, name=None):
|
||||
|
@ -164,7 +164,7 @@ class QueueRunnerTest(test.TestCase):
|
||||
coord.request_stop()
|
||||
# We should be able to join because the RequestStop() will cause
|
||||
# the queue to be closed and the enqueue to terminate.
|
||||
coord.join(stop_grace_period_secs=0.05)
|
||||
coord.join(stop_grace_period_secs=1.0)
|
||||
|
||||
def testMultipleSessions(self):
|
||||
with self.test_session() as sess:
|
||||
|
@ -68,18 +68,6 @@ from tensorflow.python.training import saver as saver_module
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def _TestDir(test_name):
|
||||
test_dir = os.path.join(test.get_temp_dir(), test_name)
|
||||
if os.path.exists(test_dir):
|
||||
shutil.rmtree(test_dir)
|
||||
gfile.MakeDirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
class CheckpointedOp(object):
|
||||
"""Op with a custom checkpointing implementation.
|
||||
|
||||
@ -591,6 +579,11 @@ class SaverTest(test.TestCase):
|
||||
|
||||
class SaveRestoreShardedTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||
gfile.MakeDirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def testBasics(self):
|
||||
save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
|
||||
|
||||
@ -719,7 +712,9 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
var_full_shape = [10, 3]
|
||||
# Allows save/restore mechanism to work w/ different slicings.
|
||||
var_name = "my_var"
|
||||
saved_path = os.path.join(_TestDir("partitioned_variables"), "ckpt")
|
||||
saved_dir = self._get_test_dir("partitioned_variables")
|
||||
saved_path = os.path.join(saved_dir, "ckpt")
|
||||
|
||||
call_saver_with_dict = False # updated by test loop below
|
||||
|
||||
def _save(slices=None, partitioner=None):
|
||||
@ -842,8 +837,13 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
|
||||
class MaxToKeepTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||
gfile.MakeDirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def testNonSharded(self):
|
||||
save_dir = _TestDir("max_to_keep_non_sharded")
|
||||
save_dir = self._get_test_dir("max_to_keep_non_sharded")
|
||||
|
||||
with self.test_session() as sess:
|
||||
v = variables.Variable(10.0, name="v")
|
||||
@ -963,7 +963,7 @@ class MaxToKeepTest(test.TestCase):
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
|
||||
|
||||
def testSharded(self):
|
||||
save_dir = _TestDir("max_to_keep_sharded")
|
||||
save_dir = self._get_test_dir("max_to_keep_sharded")
|
||||
|
||||
with session.Session(
|
||||
target="",
|
||||
@ -1018,8 +1018,8 @@ class MaxToKeepTest(test.TestCase):
|
||||
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
|
||||
|
||||
def testNoMaxToKeep(self):
|
||||
save_dir = _TestDir("no_max_to_keep")
|
||||
save_dir2 = _TestDir("max_to_keep_0")
|
||||
save_dir = self._get_test_dir("no_max_to_keep")
|
||||
save_dir2 = self._get_test_dir("max_to_keep_0")
|
||||
|
||||
with self.test_session() as sess:
|
||||
v = variables.Variable(10.0, name="v")
|
||||
@ -1046,7 +1046,7 @@ class MaxToKeepTest(test.TestCase):
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
|
||||
def testNoMetaGraph(self):
|
||||
save_dir = _TestDir("no_meta_graph")
|
||||
save_dir = self._get_test_dir("no_meta_graph")
|
||||
|
||||
with self.test_session() as sess:
|
||||
v = variables.Variable(10.0, name="v")
|
||||
@ -1060,8 +1060,13 @@ class MaxToKeepTest(test.TestCase):
|
||||
|
||||
class KeepCheckpointEveryNHoursTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||
gfile.MakeDirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def testNonSharded(self):
|
||||
save_dir = _TestDir("keep_checkpoint_every_n_hours")
|
||||
save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
|
||||
|
||||
with self.test_session() as sess:
|
||||
v = variables.Variable([10.0], name="v")
|
||||
@ -1277,8 +1282,13 @@ class LatestCheckpointWithRelativePaths(test.TestCase):
|
||||
|
||||
class CheckpointStateTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||
gfile.MakeDirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def testAbsPath(self):
|
||||
save_dir = _TestDir("abs_paths")
|
||||
save_dir = self._get_test_dir("abs_paths")
|
||||
abs_path = os.path.join(save_dir, "model-0")
|
||||
ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
|
||||
@ -1297,7 +1307,7 @@ class CheckpointStateTest(test.TestCase):
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
|
||||
|
||||
def testAllModelCheckpointPaths(self):
|
||||
save_dir = _TestDir("all_models_test")
|
||||
save_dir = self._get_test_dir("all_models_test")
|
||||
abs_path = os.path.join(save_dir, "model-0")
|
||||
for paths in [None, [], ["model-2"]]:
|
||||
ckpt = saver_module.generate_checkpoint_state_proto(
|
||||
@ -1309,7 +1319,7 @@ class CheckpointStateTest(test.TestCase):
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
|
||||
|
||||
def testUpdateCheckpointState(self):
|
||||
save_dir = _TestDir("update_checkpoint_state")
|
||||
save_dir = self._get_test_dir("update_checkpoint_state")
|
||||
os.chdir(save_dir)
|
||||
# Make a temporary train directory.
|
||||
train_dir = "train"
|
||||
@ -1325,7 +1335,7 @@ class CheckpointStateTest(test.TestCase):
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
|
||||
|
||||
def testCheckPointStateFailsWhenIncomplete(self):
|
||||
save_dir = _TestDir("checkpoint_state_fails_when_incomplete")
|
||||
save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
|
||||
os.chdir(save_dir)
|
||||
ckpt_path = os.path.join(save_dir, "checkpoint")
|
||||
ckpt_file = open(ckpt_path, "w")
|
||||
@ -1335,7 +1345,7 @@ class CheckpointStateTest(test.TestCase):
|
||||
saver_module.get_checkpoint_state(save_dir)
|
||||
|
||||
def testCheckPointCompletesRelativePaths(self):
|
||||
save_dir = _TestDir("checkpoint_completes_relative_paths")
|
||||
save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
|
||||
os.chdir(save_dir)
|
||||
ckpt_path = os.path.join(save_dir, "checkpoint")
|
||||
ckpt_file = open(ckpt_path, "w")
|
||||
@ -1356,8 +1366,13 @@ class CheckpointStateTest(test.TestCase):
|
||||
|
||||
class MetaGraphTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||
gfile.MakeDirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def testAddCollectionDef(self):
|
||||
test_dir = _TestDir("good_collection")
|
||||
test_dir = self._get_test_dir("good_collection")
|
||||
filename = os.path.join(test_dir, "metafile")
|
||||
with self.test_session():
|
||||
# Creates a graph.
|
||||
@ -1504,12 +1519,12 @@ class MetaGraphTest(test.TestCase):
|
||||
self.assertEqual(11.0, v1.eval())
|
||||
|
||||
def testMultiSaverCollection(self):
|
||||
test_dir = _TestDir("saver_collection")
|
||||
test_dir = self._get_test_dir("saver_collection")
|
||||
self._testMultiSaverCollectionSave(test_dir)
|
||||
self._testMultiSaverCollectionRestore(test_dir)
|
||||
|
||||
def testBinaryAndTextFormat(self):
|
||||
test_dir = _TestDir("binary_and_text")
|
||||
test_dir = self._get_test_dir("binary_and_text")
|
||||
filename = os.path.join(test_dir, "metafile")
|
||||
with self.test_session(graph=ops_lib.Graph()):
|
||||
# Creates a graph.
|
||||
@ -1541,7 +1556,7 @@ class MetaGraphTest(test.TestCase):
|
||||
saver_module.import_meta_graph(filename)
|
||||
|
||||
def testSliceVariable(self):
|
||||
test_dir = _TestDir("slice_saver")
|
||||
test_dir = self._get_test_dir("slice_saver")
|
||||
filename = os.path.join(test_dir, "metafile")
|
||||
with self.test_session():
|
||||
v1 = variables.Variable([20.0], name="v1")
|
||||
@ -1679,7 +1694,7 @@ class MetaGraphTest(test.TestCase):
|
||||
sess.run(train_op)
|
||||
|
||||
def testGraphExtension(self):
|
||||
test_dir = _TestDir("graph_extension")
|
||||
test_dir = self._get_test_dir("graph_extension")
|
||||
self._testGraphExtensionSave(test_dir)
|
||||
self._testGraphExtensionRestore(test_dir)
|
||||
self._testRestoreFromTrainGraphWithControlContext(test_dir)
|
||||
@ -1722,7 +1737,7 @@ class MetaGraphTest(test.TestCase):
|
||||
|
||||
def testImportIntoNamescope(self):
|
||||
# Test that we can import a meta graph into a namescope.
|
||||
test_dir = _TestDir("import_into_namescope")
|
||||
test_dir = self._get_test_dir("import_into_namescope")
|
||||
filename = os.path.join(test_dir, "ckpt")
|
||||
image = array_ops.placeholder(dtypes.float32, [None, 784])
|
||||
label = array_ops.placeholder(dtypes.float32, [None, 10])
|
||||
@ -1870,8 +1885,13 @@ class CheckpointReaderForV2Test(CheckpointReaderTest):
|
||||
|
||||
class WriteGraphTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||
gfile.MakeDirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def testWriteGraph(self):
|
||||
test_dir = _TestDir("write_graph_dir")
|
||||
test_dir = self._get_test_dir("write_graph_dir")
|
||||
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
|
||||
path = graph_io.write_graph(ops_lib.get_default_graph(),
|
||||
os.path.join(test_dir, "l1"), "graph.pbtxt")
|
||||
@ -1881,7 +1901,7 @@ class WriteGraphTest(test.TestCase):
|
||||
|
||||
|
||||
def testRecursiveCreate(self):
|
||||
test_dir = _TestDir("deep_dir")
|
||||
test_dir = self._get_test_dir("deep_dir")
|
||||
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
|
||||
path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
|
||||
os.path.join(test_dir, "l1", "l2", "l3"),
|
||||
@ -1935,6 +1955,11 @@ class SaverUtilsTest(test.TestCase):
|
||||
|
||||
class ScopedGraphTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||
gfile.MakeDirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):
|
||||
graph = ops_lib.Graph()
|
||||
with graph.as_default():
|
||||
@ -2067,7 +2092,7 @@ class ScopedGraphTest(test.TestCase):
|
||||
# Verifies that we can save the subgraph under "hidden1" and restore it
|
||||
# into "new_hidden1" in the new graph.
|
||||
def testScopedSaveAndRestore(self):
|
||||
test_dir = _TestDir("scoped_export_import")
|
||||
test_dir = self._get_test_dir("scoped_export_import")
|
||||
ckpt_filename = "ckpt"
|
||||
self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)
|
||||
self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",
|
||||
@ -2076,7 +2101,7 @@ class ScopedGraphTest(test.TestCase):
|
||||
# Verifies that we can copy the subgraph under "hidden1" and copy it
|
||||
# to different name scope in the same graph or different graph.
|
||||
def testCopyScopedGraph(self):
|
||||
test_dir = _TestDir("scoped_copy")
|
||||
test_dir = self._get_test_dir("scoped_copy")
|
||||
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
|
||||
graph1 = ops_lib.Graph()
|
||||
with graph1.as_default():
|
||||
@ -2132,7 +2157,7 @@ class ScopedGraphTest(test.TestCase):
|
||||
self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
|
||||
|
||||
def testExportGraphDefWithScope(self):
|
||||
test_dir = _TestDir("export_graph_def")
|
||||
test_dir = self._get_test_dir("export_graph_def")
|
||||
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
|
||||
graph1 = ops_lib.Graph()
|
||||
with graph1.as_default():
|
||||
|
@ -64,15 +64,14 @@ def _summary_iterator(test_dir):
|
||||
return summary_iterator.summary_iterator(event_paths[-1])
|
||||
|
||||
|
||||
def _test_dir(test_name):
|
||||
test_dir = os.path.join(test.get_temp_dir(), test_name)
|
||||
if os.path.exists(test_dir):
|
||||
shutil.rmtree(test_dir)
|
||||
return test_dir
|
||||
|
||||
|
||||
class SupervisorTest(test.TestCase):
|
||||
|
||||
def _test_dir(self, test_name):
|
||||
test_dir = os.path.join(self.get_temp_dir(), test_name)
|
||||
if os.path.exists(test_dir):
|
||||
shutil.rmtree(test_dir)
|
||||
return test_dir
|
||||
|
||||
def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
|
||||
"""Wait for a checkpoint file to appear.
|
||||
|
||||
@ -94,7 +93,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
# This test does not test much.
|
||||
def testBasics(self):
|
||||
logdir = _test_dir("basics")
|
||||
logdir = self._test_dir("basics")
|
||||
with ops.Graph().as_default():
|
||||
my_op = constant_op.constant(1.0)
|
||||
sv = supervisor.Supervisor(logdir=logdir)
|
||||
@ -105,7 +104,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv.stop()
|
||||
|
||||
def testManagedSession(self):
|
||||
logdir = _test_dir("managed_session")
|
||||
logdir = self._test_dir("managed_session")
|
||||
with ops.Graph().as_default():
|
||||
my_op = constant_op.constant(1.0)
|
||||
sv = supervisor.Supervisor(logdir=logdir)
|
||||
@ -116,7 +115,7 @@ class SupervisorTest(test.TestCase):
|
||||
self.assertTrue(sv.should_stop())
|
||||
|
||||
def testManagedSessionUserError(self):
|
||||
logdir = _test_dir("managed_user_error")
|
||||
logdir = self._test_dir("managed_user_error")
|
||||
with ops.Graph().as_default():
|
||||
my_op = constant_op.constant(1.0)
|
||||
sv = supervisor.Supervisor(logdir=logdir)
|
||||
@ -134,7 +133,7 @@ class SupervisorTest(test.TestCase):
|
||||
self.assertEqual(1, last_step)
|
||||
|
||||
def testManagedSessionIgnoreOutOfRangeError(self):
|
||||
logdir = _test_dir("managed_out_of_range")
|
||||
logdir = self._test_dir("managed_out_of_range")
|
||||
with ops.Graph().as_default():
|
||||
my_op = constant_op.constant(1.0)
|
||||
sv = supervisor.Supervisor(logdir=logdir)
|
||||
@ -152,7 +151,7 @@ class SupervisorTest(test.TestCase):
|
||||
self.assertEqual(3, last_step)
|
||||
|
||||
def testManagedSessionDoNotKeepSummaryWriter(self):
|
||||
logdir = _test_dir("managed_not_keep_summary_writer")
|
||||
logdir = self._test_dir("managed_not_keep_summary_writer")
|
||||
with ops.Graph().as_default():
|
||||
summary.scalar("c1", constant_op.constant(1))
|
||||
summary.scalar("c2", constant_op.constant(2))
|
||||
@ -204,7 +203,7 @@ class SupervisorTest(test.TestCase):
|
||||
next(rr)
|
||||
|
||||
def testManagedSessionKeepSummaryWriter(self):
|
||||
logdir = _test_dir("managed_keep_summary_writer")
|
||||
logdir = self._test_dir("managed_keep_summary_writer")
|
||||
with ops.Graph().as_default():
|
||||
summary.scalar("c1", constant_op.constant(1))
|
||||
summary.scalar("c2", constant_op.constant(2))
|
||||
@ -266,7 +265,7 @@ class SupervisorTest(test.TestCase):
|
||||
def testManagedEndOfInputOneQueue(self):
|
||||
# Tests that the supervisor finishes without an error when using
|
||||
# a fixed number of epochs, reading from a single queue.
|
||||
logdir = _test_dir("managed_end_of_input_one_queue")
|
||||
logdir = self._test_dir("managed_end_of_input_one_queue")
|
||||
os.makedirs(logdir)
|
||||
data_path = self._csv_data(logdir)
|
||||
with ops.Graph().as_default():
|
||||
@ -285,7 +284,7 @@ class SupervisorTest(test.TestCase):
|
||||
# Tests that the supervisor finishes without an error when using
|
||||
# a fixed number of epochs, reading from two queues, the second
|
||||
# one producing a batch from the first one.
|
||||
logdir = _test_dir("managed_end_of_input_two_queues")
|
||||
logdir = self._test_dir("managed_end_of_input_two_queues")
|
||||
os.makedirs(logdir)
|
||||
data_path = self._csv_data(logdir)
|
||||
with ops.Graph().as_default():
|
||||
@ -304,7 +303,7 @@ class SupervisorTest(test.TestCase):
|
||||
def testManagedMainErrorTwoQueues(self):
|
||||
# Tests that the supervisor correctly raises a main loop
|
||||
# error even when using multiple queues for input.
|
||||
logdir = _test_dir("managed_main_error_two_queues")
|
||||
logdir = self._test_dir("managed_main_error_two_queues")
|
||||
os.makedirs(logdir)
|
||||
data_path = self._csv_data(logdir)
|
||||
with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
|
||||
@ -327,7 +326,7 @@ class SupervisorTest(test.TestCase):
|
||||
sess.run(shuff_rec)
|
||||
|
||||
def testSessionConfig(self):
|
||||
logdir = _test_dir("session_config")
|
||||
logdir = self._test_dir("session_config")
|
||||
with ops.Graph().as_default():
|
||||
with ops.device("/cpu:1"):
|
||||
my_op = constant_op.constant([1.0])
|
||||
@ -340,7 +339,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv.stop()
|
||||
|
||||
def testChiefCanWriteEvents(self):
|
||||
logdir = _test_dir("can_write")
|
||||
logdir = self._test_dir("can_write")
|
||||
with ops.Graph().as_default():
|
||||
summary.scalar("c1", constant_op.constant(1))
|
||||
summary.scalar("c2", constant_op.constant(2))
|
||||
@ -421,7 +420,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv.summary_computed(sess, sess.run(summ))
|
||||
|
||||
def testLogdirButExplicitlyNoSummaryWriter(self):
|
||||
logdir = _test_dir("explicit_no_summary_writer")
|
||||
logdir = self._test_dir("explicit_no_summary_writer")
|
||||
with ops.Graph().as_default():
|
||||
variables.Variable([1.0], name="foo")
|
||||
summary.scalar("c1", constant_op.constant(1))
|
||||
@ -437,7 +436,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv.summary_computed(sess, sess.run(summ))
|
||||
|
||||
def testNoLogdirButExplicitSummaryWriter(self):
|
||||
logdir = _test_dir("explicit_summary_writer")
|
||||
logdir = self._test_dir("explicit_summary_writer")
|
||||
with ops.Graph().as_default():
|
||||
summary.scalar("c1", constant_op.constant(1))
|
||||
summary.scalar("c2", constant_op.constant(2))
|
||||
@ -506,7 +505,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv.prepare_or_wait_for_session("")
|
||||
|
||||
def testInitOp(self):
|
||||
logdir = _test_dir("default_init_op")
|
||||
logdir = self._test_dir("default_init_op")
|
||||
with ops.Graph().as_default():
|
||||
v = variables.Variable([1.0, 2.0, 3.0])
|
||||
sv = supervisor.Supervisor(logdir=logdir)
|
||||
@ -515,7 +514,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv.stop()
|
||||
|
||||
def testInitFn(self):
|
||||
logdir = _test_dir("default_init_op")
|
||||
logdir = self._test_dir("default_init_op")
|
||||
with ops.Graph().as_default():
|
||||
v = variables.Variable([1.0, 2.0, 3.0])
|
||||
|
||||
@ -528,7 +527,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv.stop()
|
||||
|
||||
def testInitOpWithFeedDict(self):
|
||||
logdir = _test_dir("feed_dict_init_op")
|
||||
logdir = self._test_dir("feed_dict_init_op")
|
||||
with ops.Graph().as_default():
|
||||
p = array_ops.placeholder(dtypes.float32, shape=(3,))
|
||||
v = variables.Variable(p, name="v")
|
||||
@ -542,7 +541,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
def testReadyForLocalInitOp(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
logdir = _test_dir("default_ready_for_local_init_op")
|
||||
logdir = self._test_dir("default_ready_for_local_init_op")
|
||||
|
||||
uid = uuid.uuid4().hex
|
||||
|
||||
@ -584,7 +583,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
def testReadyForLocalInitOpRestoreFromCheckpoint(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
logdir = _test_dir("ready_for_local_init_op_restore")
|
||||
logdir = self._test_dir("ready_for_local_init_op_restore")
|
||||
|
||||
uid = uuid.uuid4().hex
|
||||
|
||||
@ -639,7 +638,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv1.stop()
|
||||
|
||||
def testLocalInitOp(self):
|
||||
logdir = _test_dir("default_local_init_op")
|
||||
logdir = self._test_dir("default_local_init_op")
|
||||
with ops.Graph().as_default():
|
||||
# A local variable.
|
||||
v = variables.Variable(
|
||||
@ -664,7 +663,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv.stop()
|
||||
|
||||
def testLocalInitOpForNonChief(self):
|
||||
logdir = _test_dir("default_local_init_op_non_chief")
|
||||
logdir = self._test_dir("default_local_init_op_non_chief")
|
||||
with ops.Graph().as_default():
|
||||
with ops.device("/job:localhost"):
|
||||
# A local variable.
|
||||
@ -685,7 +684,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
def testInitOpFails(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
logdir = _test_dir("default_init_op_fails")
|
||||
logdir = self._test_dir("default_init_op_fails")
|
||||
with ops.Graph().as_default():
|
||||
v = variables.Variable([1.0, 2.0, 3.0], name="v")
|
||||
variables.Variable([4.0, 5.0, 6.0], name="w")
|
||||
@ -697,7 +696,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
def testInitOpFailsForTransientVariable(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
logdir = _test_dir("default_init_op_fails_for_local_variable")
|
||||
logdir = self._test_dir("default_init_op_fails_for_local_variable")
|
||||
with ops.Graph().as_default():
|
||||
v = variables.Variable(
|
||||
[1.0, 2.0, 3.0],
|
||||
@ -714,7 +713,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv.prepare_or_wait_for_session(server.target)
|
||||
|
||||
def testSetupFail(self):
|
||||
logdir = _test_dir("setup_fail")
|
||||
logdir = self._test_dir("setup_fail")
|
||||
with ops.Graph().as_default():
|
||||
variables.Variable([1.0, 2.0, 3.0], name="v")
|
||||
with self.assertRaisesRegexp(ValueError, "must have their device set"):
|
||||
@ -724,7 +723,7 @@ class SupervisorTest(test.TestCase):
|
||||
supervisor.Supervisor(logdir=logdir, is_chief=False)
|
||||
|
||||
def testDefaultGlobalStep(self):
|
||||
logdir = _test_dir("default_global_step")
|
||||
logdir = self._test_dir("default_global_step")
|
||||
with ops.Graph().as_default():
|
||||
variables.Variable(287, name="global_step")
|
||||
sv = supervisor.Supervisor(logdir=logdir)
|
||||
@ -733,7 +732,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv.stop()
|
||||
|
||||
def testRestoreFromMetaGraph(self):
|
||||
logdir = _test_dir("restore_from_meta_graph")
|
||||
logdir = self._test_dir("restore_from_meta_graph")
|
||||
with ops.Graph().as_default():
|
||||
variables.Variable(1, name="v0")
|
||||
sv = supervisor.Supervisor(logdir=logdir)
|
||||
@ -754,7 +753,7 @@ class SupervisorTest(test.TestCase):
|
||||
# right away and get to run once before sv.stop() returns.
|
||||
# We still sleep a bit to make the test robust.
|
||||
def testStandardServicesWithoutGlobalStep(self):
|
||||
logdir = _test_dir("standard_services_without_global_step")
|
||||
logdir = self._test_dir("standard_services_without_global_step")
|
||||
# Create a checkpoint.
|
||||
with ops.Graph().as_default():
|
||||
v = variables.Variable([1.0], name="foo")
|
||||
@ -804,7 +803,7 @@ class SupervisorTest(test.TestCase):
|
||||
# Same as testStandardServicesNoGlobalStep but with a global step.
|
||||
# We should get a summary about the step time.
|
||||
def testStandardServicesWithGlobalStep(self):
|
||||
logdir = _test_dir("standard_services_with_global_step")
|
||||
logdir = self._test_dir("standard_services_with_global_step")
|
||||
# Create a checkpoint.
|
||||
with ops.Graph().as_default():
|
||||
v = variables.Variable([123], name="global_step")
|
||||
@ -867,12 +866,12 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
def testNoQueueRunners(self):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
sv = supervisor.Supervisor(logdir=_test_dir("no_queue_runners"))
|
||||
sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners"))
|
||||
self.assertEqual(0, len(sv.start_queue_runners(sess)))
|
||||
sv.stop()
|
||||
|
||||
def testPrepareSessionAfterStopForChief(self):
|
||||
logdir = _test_dir("prepare_after_stop_chief")
|
||||
logdir = self._test_dir("prepare_after_stop_chief")
|
||||
with ops.Graph().as_default():
|
||||
sv = supervisor.Supervisor(logdir=logdir, is_chief=True)
|
||||
|
||||
@ -891,7 +890,7 @@ class SupervisorTest(test.TestCase):
|
||||
self.assertTrue(sv.should_stop())
|
||||
|
||||
def testPrepareSessionAfterStopForNonChief(self):
|
||||
logdir = _test_dir("prepare_after_stop_nonchief")
|
||||
logdir = self._test_dir("prepare_after_stop_nonchief")
|
||||
with ops.Graph().as_default():
|
||||
sv = supervisor.Supervisor(logdir=logdir, is_chief=False)
|
||||
|
||||
|
@ -51,18 +51,18 @@ work, but there may be bugs or performance issues.
|
||||
|
||||
The first step in using TensorBoard is acquiring data from your TensorFlow run.
|
||||
For this, you need [summary
|
||||
ops](https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#summary-operations).
|
||||
ops](https://www.tensorflow.org/versions/r1.0/api_docs/python/train.html#summary-operations).
|
||||
Summary ops are ops, like
|
||||
[`tf.matmul`](https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops.html#matmul)
|
||||
[`tf.matmul`](https://www.tensorflow.org/versions/r1.0/api_docs/python/math_ops.html#matmul)
|
||||
or
|
||||
[`tf.nn.relu`](https://www.tensorflow.org/versions/r0.12/api_docs/python/nn.html#relu),
|
||||
[`tf.nn.relu`](https://www.tensorflow.org/versions/r1.0/api_docs/python/nn.html#relu),
|
||||
which means they take in tensors, produce tensors, and are evaluated from within
|
||||
a TensorFlow graph. However, summary ops have a twist: the Tensors they produce
|
||||
contain serialized protobufs, which are written to disk and sent to TensorBoard.
|
||||
To visualize the summary data in TensorBoard, you should evaluate the summary
|
||||
op, retrieve the result, and then write that result to disk using a
|
||||
summary.FileWriter. A full explanation, with examples, is in [the
|
||||
tutorial](https://www.tensorflow.org/versions/r0.12/how_tos/summaries_and_tensorboard/index.html).
|
||||
tutorial](https://www.tensorflow.org/versions/r1.0/how_tos/summaries_and_tensorboard/index.html).
|
||||
|
||||
### Tags: Giving names to data
|
||||
|
||||
@ -184,7 +184,7 @@ TensorFlow model. To get best use of the graph visualizer, you should use name
|
||||
scopes to hierarchically group the ops in your graph - otherwise, the graph may
|
||||
be difficult to decipher. For more information, including examples, see [the
|
||||
graph visualizer
|
||||
tutorial](https://www.tensorflow.org/versions/r0.12/how_tos/graph_viz/index.html#tensorboard-graph-visualization).
|
||||
tutorial](https://www.tensorflow.org/versions/r1.0/how_tos/graph_viz/index.html#tensorboard-graph-visualization).
|
||||
|
||||
# Frequently Asked Questions
|
||||
|
||||
|
@ -20,6 +20,11 @@ load(
|
||||
"cuda_default_copts"
|
||||
)
|
||||
|
||||
#load(
|
||||
# "//third_party/mkl:build_defs.bzl",
|
||||
# "if_mkl",
|
||||
#)
|
||||
|
||||
# List of proto files for android builds
|
||||
def tf_android_core_proto_sources(core_proto_sources_relative):
|
||||
return ["//tensorflow/core:" + p
|
||||
@ -377,6 +382,10 @@ def tf_cc_tests(srcs, deps, name='', linkstatic=0, tags=[], size="medium",
|
||||
args=args,
|
||||
linkopts=linkopts)
|
||||
|
||||
#def tf_cc_test_mkl(srcs, deps, name='', linkstatic=0, tags=[], size="medium",
|
||||
# args=None):
|
||||
# tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args)
|
||||
|
||||
def tf_cc_tests_gpu(srcs, deps, name='', linkstatic=0, tags=[], size="medium",
|
||||
args=None):
|
||||
tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args)
|
||||
|
@ -39,7 +39,7 @@ $adb shell "/data/local/tmp/benchmark_model \
|
||||
### On desktop:
|
||||
(1) build the binary
|
||||
```bash
|
||||
$bazel build -c opt tensorflow/tools/benchmark:benchmark_model
|
||||
$bazel build --config opt tensorflow/tools/benchmark:benchmark_model
|
||||
```
|
||||
|
||||
(2) Run on your compute graph, similar to the Android case but without the need of adb shell.
|
||||
@ -54,4 +54,4 @@ $bazel-bin/tensorflow/tools/benchmark/benchmark_model \
|
||||
```
|
||||
|
||||
The Inception graph used as an example here may be downloaded from
|
||||
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
|
||||
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
|
||||
|
@ -17,7 +17,7 @@
|
||||
# Install protobuf3.
|
||||
|
||||
# Select protobuf version.
|
||||
PROTOBUF_VERSION="3.1.0"
|
||||
PROTOBUF_VERSION="3.2.0"
|
||||
protobuf_ver_flat=$(echo $PROTOBUF_VERSION | sed 's/\.//g' | sed 's/^0*//g')
|
||||
local_protobuf_ver=$(protoc --version | awk '{print $2}')
|
||||
local_protobuf_ver_flat=$(echo $local_protobuf_ver | sed 's/\.//g' | sed 's/^0*//g')
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
PROTOBUF_VERSION="3.1.0"
|
||||
PROTOBUF_VERSION="3.2.0"
|
||||
PYTHON_BIN=${PYTHON_BIN:-python}
|
||||
DIR=${PWD}/protobuf
|
||||
|
||||
|
@ -111,6 +111,7 @@ function get_failing_cpu_py_tests() {
|
||||
//$1/tensorflow/python:framework_ops_test + \
|
||||
//$1/tensorflow/python:framework_tensor_util_test + \
|
||||
//$1/tensorflow/python:framework_test_util_test + \
|
||||
//$1/tensorflow/python:gradients_test + \
|
||||
//$1/tensorflow/python:image_ops_test + \
|
||||
//$1/tensorflow/python:localhost_cluster_performance_test + \
|
||||
//$1/tensorflow/python:monitored_session_test + \
|
||||
|
@ -36,6 +36,9 @@ particular, functions that have had reordered arguments like `tf.concat`,
|
||||
`tf.split` will cause the script to incorrectly add keyword arguments that
|
||||
mismap arguments.
|
||||
|
||||
- This script wouldn't actually reorder arguments. Instead, the script will add
|
||||
keyword arguments to functions that had their arguments reordered.
|
||||
|
||||
- This script is not able to upgrade all functions. One notable example is
|
||||
`tf.reverse()` which has been changed to take a list of indices rather than
|
||||
a tensor of bools. If the script detects this, it will report this to stdout
|
||||
@ -43,6 +46,12 @@ a tensor of bools. If the script detects this, it will report this to stdout
|
||||
`tf.reverse(a, [False, True, True])` you will need to manually change it to
|
||||
`tf.reverse(a, [1, 2])`.
|
||||
|
||||
|
||||
|
||||
|
||||
- There are some syntaxes that are not handleable with this script as this
|
||||
script was designed to use only standard python packages. If the script fails
|
||||
with "A necessary keyword argument failed to be inserted." or
|
||||
"Failed to find keyword lexicographically. Fix manually.", you can try
|
||||
[@machrisaa's fork of this script](https://github.com/machrisaa/tf0to1).
|
||||
[@machrisaa](https://github.com/machrisaa) has used the
|
||||
[RedBaron Python refactoring engine](https://redbaron.readthedocs.io/en/latest/)
|
||||
which is able to localize syntactic elements more reliably than the built-in
|
||||
`ast` module this script is based upon.
|
||||
|
@ -95,11 +95,15 @@ class APIChangeSpec(object):
|
||||
"tf.split": {
|
||||
"split_dim": "axis",
|
||||
"num_split": "num_or_size_splits"
|
||||
}
|
||||
},
|
||||
"tf.concat": {
|
||||
"concat_dim": "axis"
|
||||
},
|
||||
}
|
||||
|
||||
# Mapping from function to the new name of the function
|
||||
self.function_renames = {
|
||||
"tf.inv": "tf.reciprocal",
|
||||
"tf.contrib.deprecated.scalar_summary": "tf.summary.scalar",
|
||||
"tf.contrib.deprecated.histogram_summary": "tf.summary.histogram",
|
||||
"tf.listdiff": "tf.setdiff1d",
|
||||
@ -142,6 +146,13 @@ class APIChangeSpec(object):
|
||||
"tf.select": "tf.where",
|
||||
"tf.complex_abs": "tf.abs",
|
||||
"tf.batch_matmul": "tf.matmul",
|
||||
"tf.pack": "tf.stack",
|
||||
"tf.unpack": "tf.unstack",
|
||||
}
|
||||
|
||||
self.change_to_function = {
|
||||
"tf.ones_initializer",
|
||||
"tf.zeros_initializer",
|
||||
}
|
||||
|
||||
# Functions that were reordered should be changed to the new keyword args
|
||||
@ -149,6 +160,7 @@ class APIChangeSpec(object):
|
||||
# positional arguments yourself, this could do the wrong thing.
|
||||
self.function_reorders = {
|
||||
"tf.split": ["axis", "num_or_size_splits", "value", "name"],
|
||||
"tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"],
|
||||
"tf.concat": ["concat_dim", "values", "name"],
|
||||
"tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
|
||||
"tf.nn.softmax_cross_entropy_with_logits": [
|
||||
@ -335,6 +347,62 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
||||
items.append(curr.id)
|
||||
return ".".join(reversed(items))
|
||||
|
||||
def _find_true_position(self, node):
|
||||
"""Return correct line number and column offset for a given node.
|
||||
|
||||
This is necessary mainly because ListComp's location reporting reports
|
||||
the next token after the list comprehension list opening.
|
||||
|
||||
Args:
|
||||
node: Node for which we wish to know the lineno and col_offset
|
||||
"""
|
||||
import re
|
||||
find_open = re.compile("^\s*(\\[).*$")
|
||||
find_string_chars = re.compile("['\"]")
|
||||
|
||||
if isinstance(node, ast.ListComp):
|
||||
# Strangely, ast.ListComp returns the col_offset of the first token
|
||||
# after the '[' token which appears to be a bug. Workaround by
|
||||
# explicitly finding the real start of the list comprehension.
|
||||
line = node.lineno
|
||||
col = node.col_offset
|
||||
# loop over lines
|
||||
while 1:
|
||||
# Reverse the text to and regular expression search for whitespace
|
||||
text = self._lines[line-1]
|
||||
reversed_preceding_text = text[:col][::-1]
|
||||
# First find if a [ can be found with only whitespace between it and
|
||||
# col.
|
||||
m = find_open.match(reversed_preceding_text)
|
||||
if m:
|
||||
new_col_offset = col - m.start(1) - 1
|
||||
return line, new_col_offset
|
||||
else:
|
||||
if (reversed_preceding_text=="" or
|
||||
reversed_preceding_text.isspace()):
|
||||
line = line - 1
|
||||
prev_line = self._lines[line - 1]
|
||||
# TODO(aselle):
|
||||
# this is poor comment detection, but it is good enough for
|
||||
# cases where the comment does not contain string literal starting/
|
||||
# ending characters. If ast gave us start and end locations of the
|
||||
# ast nodes rather than just start, we could use string literal
|
||||
# node ranges to filter out spurious #'s that appear in string
|
||||
# literals.
|
||||
comment_start = prev_line.find("#")
|
||||
if comment_start == -1:
|
||||
col = len(prev_line) -1
|
||||
elif find_string_chars.search(prev_line[comment_start:]) is None:
|
||||
col = comment_start
|
||||
else:
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
# Most other nodes return proper locations (with notably does not), but
|
||||
# it is not possible to use that in an argument.
|
||||
return node.lineno, node.col_offset
|
||||
|
||||
|
||||
def visit_Call(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting a call node in the AST.
|
||||
|
||||
@ -342,11 +410,13 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
||||
node: Current Node
|
||||
"""
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
# Find a simple attribute name path e.g. "tf.foo.bar"
|
||||
full_name = self._get_attribute_full_path(node.func)
|
||||
|
||||
# Make sure the func is marked as being part of a call
|
||||
node.func.is_function_for_call = True
|
||||
|
||||
if full_name and full_name.startswith("tf."):
|
||||
# Call special handlers
|
||||
function_handles = self._api_change_spec.function_handle
|
||||
@ -356,27 +426,60 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
||||
# Examine any non-keyword argument and make it into a keyword argument
|
||||
# if reordering required.
|
||||
function_reorders = self._api_change_spec.function_reorders
|
||||
function_keyword_renames = (
|
||||
self._api_change_spec.function_keyword_renames)
|
||||
|
||||
if full_name in function_reorders:
|
||||
reordered = function_reorders[full_name]
|
||||
for idx, arg in enumerate(node.args):
|
||||
self._file_edit.add("Added keyword %r to reordered function %r"
|
||||
% (reordered[idx], full_name), arg.lineno,
|
||||
arg.col_offset, "", reordered[idx] + "=")
|
||||
lineno, col_offset = self._find_true_position(arg)
|
||||
if lineno is None or col_offset is None:
|
||||
self._file_edit.add(
|
||||
"Failed to add keyword %r to reordered function %r"
|
||||
% (reordered[idx], full_name), arg.lineno, arg.col_offset,
|
||||
"", "",
|
||||
error="A necessary keyword argument failed to be inserted.")
|
||||
else:
|
||||
keyword_arg = reordered[idx]
|
||||
if (full_name in function_keyword_renames and
|
||||
keyword_arg in function_keyword_renames[full_name]):
|
||||
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
||||
self._file_edit.add("Added keyword %r to reordered function %r"
|
||||
% (reordered[idx], full_name), lineno,
|
||||
col_offset, "", keyword_arg + "=")
|
||||
|
||||
# Examine each keyword argument and convert it to the final renamed form
|
||||
function_keyword_renames = (
|
||||
self._api_change_spec.function_keyword_renames)
|
||||
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
||||
function_keyword_renames[full_name])
|
||||
for keyword in node.keywords:
|
||||
argkey = keyword.arg
|
||||
argval = keyword.value
|
||||
|
||||
if argkey in renamed_keywords:
|
||||
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
||||
argval_lineno, argval_col_offset = self._find_true_position(argval)
|
||||
if (argval_lineno is not None and argval_col_offset is not None):
|
||||
# TODO(aselle): We should scan backward to find the start of the
|
||||
# keyword key. Unfortunately ast does not give you the location of
|
||||
# keyword keys, so we are forced to infer it from the keyword arg
|
||||
# value.
|
||||
key_start = argval_col_offset - len(argkey) - 1
|
||||
key_end = key_start + len(argkey) + 1
|
||||
if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=":
|
||||
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
||||
(argkey, renamed_keywords[argkey]),
|
||||
argval.lineno,
|
||||
argval.col_offset - len(argkey) - 1,
|
||||
argval_lineno,
|
||||
argval_col_offset - len(argkey) - 1,
|
||||
argkey + "=", renamed_keywords[argkey] + "=")
|
||||
continue
|
||||
self._file_edit.add(
|
||||
"Failed to rename keyword argument from %r to %r" %
|
||||
(argkey, renamed_keywords[argkey]),
|
||||
argval.lineno,
|
||||
argval.col_offset - len(argkey) - 1,
|
||||
"", "",
|
||||
error="Failed to find keyword lexographically. Fix manually.")
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
||||
"""Handle bare Attributes i.e. [tf.foo, tf.bar].
|
||||
@ -387,6 +490,11 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
||||
full_name = self._get_attribute_full_path(node)
|
||||
if full_name and full_name.startswith("tf."):
|
||||
self._rename_functions(node, full_name)
|
||||
if full_name in self._api_change_spec.change_to_function:
|
||||
if not hasattr(node, "is_function_for_call"):
|
||||
new_text = full_name + "()"
|
||||
self._file_edit.add("Changed %r to %r"%(full_name, new_text),
|
||||
node.lineno, node.col_offset, full_name, new_text)
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
|
@ -59,12 +59,45 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n")
|
||||
|
||||
def testRenamePack(self):
|
||||
text = "tf.pack(a)\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, "tf.stack(a)\n")
|
||||
text = "tf.unpack(a)\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, "tf.unstack(a)\n")
|
||||
|
||||
def testReorder(self):
|
||||
text = "tf.concat(a, b)\ntf.split(a, b, c)\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, "tf.concat(concat_dim=a, values=b)\n"
|
||||
self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n"
|
||||
"tf.split(axis=a, num_or_size_splits=b, value=c)\n")
|
||||
|
||||
def testConcatReorderWithKeywordArgs(self):
|
||||
text = "tf.concat(concat_dim=a, values=b)\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n")
|
||||
text = "tf.concat(values=b, concat_dim=a)\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n")
|
||||
text = "tf.concat(a, values=b)\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n")
|
||||
|
||||
def testConcatReorderNested(self):
|
||||
text = "tf.concat(a, tf.concat(c, d))\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(
|
||||
new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n")
|
||||
|
||||
def testInitializers(self):
|
||||
text = ("tf.zeros_initializer;tf.zeros_initializer ()\n"
|
||||
"tf.ones_initializer;tf.ones_initializer ()\n")
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(
|
||||
new_text, "tf.zeros_initializer();tf.zeros_initializer ()\n"
|
||||
"tf.ones_initializer();tf.ones_initializer ()\n")
|
||||
|
||||
def testKeyword(self):
|
||||
text = "tf.reduce_any(a, reduction_indices=[1, 2])\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
@ -80,6 +113,19 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(new_text, new_text)
|
||||
self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."])
|
||||
|
||||
def testListComprehension(self):
|
||||
def _test(input, output):
|
||||
_, unused_report, errors, new_text = self._upgrade(input)
|
||||
self.assertEqual(new_text, output)
|
||||
_test("tf.concat(0, \t[x for x in y])\n",
|
||||
"tf.concat(axis=0, \tvalues=[x for x in y])\n")
|
||||
_test("tf.concat(0,[x for x in y])\n",
|
||||
"tf.concat(axis=0,values=[x for x in y])\n")
|
||||
_test("tf.concat(0,[\nx for x in y])\n",
|
||||
"tf.concat(axis=0,values=[\nx for x in y])\n")
|
||||
_test("tf.concat(0,[\n \tx for x in y])\n",
|
||||
"tf.concat(axis=0,values=[\n \tx for x in y])\n")
|
||||
|
||||
# TODO(aselle): Explicitly not testing command line interface and process_tree
|
||||
# for now, since this is a one off utility.
|
||||
|
||||
|
@ -30,6 +30,7 @@ RUN pip --no-cache-dir install \
|
||||
numpy \
|
||||
scipy \
|
||||
sklearn \
|
||||
pandas \
|
||||
Pillow \
|
||||
&& \
|
||||
python -m ipykernel.kernelspec
|
||||
|
@ -32,6 +32,7 @@ RUN pip --no-cache-dir install \
|
||||
numpy \
|
||||
scipy \
|
||||
sklearn \
|
||||
pandas \
|
||||
&& \
|
||||
python -m ipykernel.kernelspec
|
||||
|
||||
@ -82,7 +83,7 @@ RUN mkdir /bazel && \
|
||||
|
||||
RUN git clone https://github.com/tensorflow/tensorflow.git && \
|
||||
cd tensorflow && \
|
||||
git checkout r0.12
|
||||
git checkout r1.0
|
||||
WORKDIR /tensorflow
|
||||
|
||||
# TODO(craigcitro): Don't install the pip package, since it makes it
|
||||
|
@ -32,6 +32,7 @@ RUN pip --no-cache-dir install \
|
||||
numpy \
|
||||
scipy \
|
||||
sklearn \
|
||||
pandas \
|
||||
&& \
|
||||
python -m ipykernel.kernelspec
|
||||
|
||||
@ -82,7 +83,7 @@ RUN mkdir /bazel && \
|
||||
|
||||
RUN git clone https://github.com/tensorflow/tensorflow.git && \
|
||||
cd tensorflow && \
|
||||
git checkout r0.12
|
||||
git checkout r1.0
|
||||
WORKDIR /tensorflow
|
||||
|
||||
# Configure the build for our CUDA configuration.
|
||||
|
@ -30,6 +30,7 @@ RUN pip --no-cache-dir install \
|
||||
numpy \
|
||||
scipy \
|
||||
sklearn \
|
||||
pandas \
|
||||
Pillow \
|
||||
&& \
|
||||
python -m ipykernel.kernelspec
|
||||
@ -58,6 +59,9 @@ COPY notebooks /notebooks
|
||||
# We just add a little wrapper script.
|
||||
COPY run_jupyter.sh /
|
||||
|
||||
# For CUDA profiling, TensorFlow requires CUPTI.
|
||||
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# TensorBoard
|
||||
EXPOSE 6006
|
||||
# IPython
|
||||
|
@ -328,7 +328,9 @@ def _generate_signature(func, reverse_index):
|
||||
len(argspec.args or []) - len(argspec.defaults or []))
|
||||
|
||||
# Python documentation skips `self` when printing method signatures.
|
||||
first_arg = 1 if inspect.ismethod(func) and 'self' in argspec.args[:1] else 0
|
||||
# Note we cannot test for ismethod here since unbound methods do not register
|
||||
# as methods (in Python 3).
|
||||
first_arg = 1 if 'self' in argspec.args[:1] else 0
|
||||
|
||||
# Add all args without defaults.
|
||||
for arg in argspec.args[first_arg:first_arg_with_default]:
|
||||
@ -679,6 +681,15 @@ def generate_global_index(library_name, index, duplicate_of):
|
||||
for full_name, py_object in six.iteritems(index):
|
||||
if (inspect.ismodule(py_object) or inspect.isfunction(py_object) or
|
||||
inspect.isclass(py_object)):
|
||||
# In Python 3, unbound methods are functions, so eliminate those.
|
||||
if inspect.isfunction(py_object):
|
||||
if full_name.count('.') == 0:
|
||||
parent_name = ''
|
||||
else:
|
||||
parent_name = full_name[:full_name.rfind('.')]
|
||||
if parent_name in index and inspect.isclass(index[parent_name]):
|
||||
# Skip methods (=functions with class parents).
|
||||
continue
|
||||
symbol_links.append((full_name,
|
||||
_markdown_link(full_name, full_name,
|
||||
'.', duplicate_of)))
|
||||
|
@ -13,7 +13,7 @@ and [Rust](https://github.com/tensorflow/rust).
|
||||
The command:
|
||||
|
||||
```sh
|
||||
bazel build -c opt //tensorflow/tools/lib_package:libtensorflow
|
||||
bazel build --config opt //tensorflow/tools/lib_package:libtensorflow
|
||||
```
|
||||
|
||||
produces `bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz`, which
|
||||
|
@ -4,6 +4,7 @@
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
|
||||
load("//third_party/mkl:build_defs.bzl", "if_mkl")
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
|
||||
|
||||
# This returns a list of headers of all public header libraries (e.g.,
|
||||
@ -131,5 +132,5 @@ sh_binary(
|
||||
"//tensorflow/python/tools:all_files",
|
||||
"//tensorflow/tensorboard",
|
||||
],
|
||||
}),
|
||||
}) + if_mkl(["//third_party/mkl:intel_binary_blob"]),
|
||||
)
|
||||
|
@ -89,6 +89,13 @@ function main() {
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/external \
|
||||
"${TMPDIR}/external"
|
||||
RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles
|
||||
# Copy MKL libs over so they can be loaded at runtime
|
||||
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl ]; then
|
||||
mkdir "${TMPDIR}/_solib_k8"
|
||||
cp -R \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl \
|
||||
"${TMPDIR}/_solib_k8"
|
||||
fi
|
||||
else
|
||||
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external ]; then
|
||||
# Old-style runfiles structure (--legacy_external_runfiles).
|
||||
@ -99,6 +106,13 @@ function main() {
|
||||
cp_external \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external \
|
||||
"${TMPDIR}/external"
|
||||
# Copy MKL libs over so they can be loaded at runtime
|
||||
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl ]; then
|
||||
mkdir "${TMPDIR}/_solib_k8"
|
||||
cp -R \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl \
|
||||
"${TMPDIR}/_solib_k8"
|
||||
fi
|
||||
else
|
||||
# New-style runfiles structure (--nolegacy_external_runfiles).
|
||||
cp -R \
|
||||
@ -109,6 +123,13 @@ function main() {
|
||||
cp_external \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \
|
||||
"${TMPDIR}/external"
|
||||
# Copy MKL libs over so they can be loaded at runtime
|
||||
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl ]; then
|
||||
mkdir "${TMPDIR}/_solib_k8"
|
||||
cp -R \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl \
|
||||
"${TMPDIR}/_solib_k8"
|
||||
fi
|
||||
fi
|
||||
RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow
|
||||
fi
|
||||
|
@ -29,7 +29,7 @@ from setuptools.dist import Distribution
|
||||
# This version string is semver compatible, but incompatible with pip.
|
||||
# For pip, we will remove all '-' characters from this string, and use the
|
||||
# result for pip.
|
||||
_VERSION = '0.12.1'
|
||||
_VERSION = '1.0.0-rc1'
|
||||
|
||||
REQUIRED_PACKAGES = [
|
||||
'numpy >= 1.11.0',
|
||||
@ -151,6 +151,7 @@ def find_files(pattern, root):
|
||||
|
||||
|
||||
matches = ['../' + x for x in find_files('*', 'external') if '.py' not in x]
|
||||
matches += ['../' + x for x in find_files('*', '_solib_k8') if '.py' not in x]
|
||||
|
||||
if os.name == 'nt':
|
||||
EXTENSION_NAME = 'python/_pywrap_tensorflow.pyd'
|
||||
|
@ -98,7 +98,7 @@ TODO(xpan): Provide graph.pbtxt, model.ckpt, tfprof_log and run_meta download.
|
||||
|
||||
```shell
|
||||
# Build the tool.
|
||||
bazel build -c opt tensorflow/tools/tfprof/...
|
||||
bazel build --config opt tensorflow/tools/tfprof/...
|
||||
|
||||
# Help information, including detail 'option' instructions.
|
||||
bazel-bin/tensorflow/tools/tfprof/tfprof help
|
||||
|
@ -78,11 +78,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
native.new_http_archive(
|
||||
name = "libxsmm_archive",
|
||||
urls = [
|
||||
"http://bazel-mirror.storage.googleapis.com/github.com/hfp/libxsmm/archive/1.6.6.tar.gz",
|
||||
"https://github.com/hfp/libxsmm/archive/1.6.6.tar.gz",
|
||||
"http://bazel-mirror.storage.googleapis.com/github.com/hfp/libxsmm/archive/1.7.tar.gz",
|
||||
"https://github.com/hfp/libxsmm/archive/1.7.tar.gz",
|
||||
],
|
||||
sha256 = "7c048a48e17f7f14a475be7b83e6e941289e03debb42ce9e02a06353412f9f2a",
|
||||
strip_prefix = "libxsmm-1.6.6",
|
||||
sha256 = "2eea65624a697e74b939511cd2a686b4c957e90c99be168fe134d96771e811ad",
|
||||
strip_prefix = "libxsmm-1.7",
|
||||
build_file = str(Label("//third_party:libxsmm.BUILD")),
|
||||
)
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user