Merge changes from github.
Change: 142074581
This commit is contained in:
parent
811629aed4
commit
2e4869af1a
@ -33,10 +33,10 @@ and discussion.**
|
||||
|
||||
People who are a little more adventurous can also try our nightly binaries:
|
||||
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.0rc0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.0rc0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.0rc0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.0rc0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.0rc0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.0rc0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.0rc0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.0rc0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.12.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-0.12.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||
* [Android](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
|
||||
|
||||
#### *Try your first TensorFlow program*
|
||||
|
@ -61,9 +61,7 @@
|
||||
acceleration. Known limitations include: It is not currently possible to load
|
||||
a custom op library. The GCS and HDFS file systems are not currently
|
||||
supported. The following ops are not currently implemented:
|
||||
DepthwiseConv2dNative, DepthwiseConv2dNativeBackpropFilter,
|
||||
DepthwiseConv2dNativeBackpropInput, Dequantize, Digamma, Erf, Erfc, Igamma,
|
||||
Igammac, Lgamma, Polygamma, QuantizeAndDequantize, QuantizedAvgPool,
|
||||
Dequantize, QuantizeAndDequantize, QuantizedAvgPool,
|
||||
QuantizedBatchNomWithGlobalNormalization, QuantizedBiasAdd, QuantizedConcat,
|
||||
QuantizedConv2D, QuantizedMatmul, QuantizedMaxPool,
|
||||
QuantizeDownAndShrinkRange, QuantizedRelu, QuantizedRelu6, QuantizedReshape,
|
||||
|
7
configure
vendored
7
configure
vendored
@ -24,7 +24,8 @@ function bazel_clean_and_fetch() {
|
||||
if ! is_windows; then
|
||||
bazel clean --expunge
|
||||
fi
|
||||
bazel fetch //tensorflow/...
|
||||
# TODO(https://github.com/bazelbuild/bazel/issues/2220) Remove the nested `bazel query`.
|
||||
bazel fetch $(bazel query "//tensorflow/... -//tensorflow/examples/android/...")
|
||||
}
|
||||
|
||||
## Set up python-related environment settings
|
||||
@ -279,7 +280,7 @@ while true; do
|
||||
TF_CUDNN_VERSION=${BASH_REMATCH[1]}
|
||||
echo "libcudnn.so resolves to libcudnn${TF_CUDNN_EXT}"
|
||||
elif [[ "$REALVAL" =~ ([0-9]*).dylib ]]; then
|
||||
TF_CUDNN_EXT=${BASH_REMATCH[1]}".dylib"
|
||||
TF_CUDNN_EXT="."${BASH_REMATCH[1]}".dylib"
|
||||
TF_CUDNN_VERSION=${BASH_REMATCH[1]}
|
||||
echo "libcudnn.dylib resolves to libcudnn${TF_CUDNN_EXT}"
|
||||
fi
|
||||
@ -435,7 +436,7 @@ while true; do
|
||||
# Point to ComputeCpp root
|
||||
if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then
|
||||
default_computecpp_toolkit_path=/usr/local/computecpp
|
||||
read -p "Please specify the location where ComputeCpp $TF_OPENCL_VERSION is installed. Refer to README.md for more details. [Default is $default_computecpp_toolkit_path]: " COMPUTECPP_TOOLKIT_PATH
|
||||
read -p "Please specify the location where ComputeCpp for SYCL $TF_OPENCL_VERSION is installed. [Default is $default_computecpp_toolkit_path]: " COMPUTECPP_TOOLKIT_PATH
|
||||
fromuser="1"
|
||||
if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then
|
||||
COMPUTECPP_TOOLKIT_PATH=$default_computecpp_toolkit_path
|
||||
|
@ -50,6 +50,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "no_tensorflow_py_deps",
|
||||
values = {"define": "no_tensorflow_py_deps=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "ios",
|
||||
values = {
|
||||
|
@ -13,12 +13,36 @@ Linux.
|
||||
Current Status
|
||||
--------------
|
||||
|
||||
CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/get_started/os_setup.html#pip-installation-on-windows)
|
||||
for instructions on how to install a pre-built TensorFlow package on Windows.
|
||||
|
||||
### Current known limitations
|
||||
* It is not possible to load a custom Op library.
|
||||
* GCS and HDFS file systems are not supported.
|
||||
* The following Ops are not currently implemented:
|
||||
- Dequantize
|
||||
- QuantizeAndDequantize
|
||||
- QuantizedAvgPool
|
||||
- QuantizedBatchNomWithGlobalNormalization
|
||||
- QuantizedBiasAdd
|
||||
- QuantizedConcat
|
||||
- QuantizedConv2D
|
||||
- QuantizedMatmul
|
||||
- QuantizedMaxPoo
|
||||
- QuantizeDownAndShrinkRange
|
||||
- QuantizedRelu
|
||||
- QuantizedRelu6
|
||||
- QuantizedReshape
|
||||
- QuantizeV2
|
||||
- RequantizationRange
|
||||
- Requantize
|
||||
|
||||
## Building with CMake
|
||||
|
||||
The CMake files in this directory can build the core TensorFlow runtime, an
|
||||
example C++ binary, and a PIP package containing the runtime and Python
|
||||
bindings.
|
||||
|
||||
Note: Windows support is in an **alpha** state, and we welcome your feedback.
|
||||
|
||||
### Pre-requisites
|
||||
|
||||
* CMake version 3.5 up to 3.6
|
||||
@ -46,6 +70,8 @@ Note: Windows support is in an **alpha** state, and we welcome your feedback.
|
||||
- [swigwin-3.0.10](http://www.swig.org/download.html)
|
||||
- [NVidia CUDA Toolkit 8.0] (https://developer.nvidia.com/cuda-downloads)
|
||||
- [NVidia CUDNN 5.1] (https://developer.nvidia.com/cudnn)
|
||||
- [CMake 3.6](https://cmake.org/files/v3.6/cmake-3.6.3-win64-x64.msi)
|
||||
|
||||
* Ubuntu 14.04
|
||||
- Makefile generator
|
||||
- Docker 1.9.1 (for automated testing)
|
||||
|
@ -26,7 +26,7 @@ from setuptools import find_packages, setup, Command
|
||||
from setuptools.command.install import install as InstallCommandBase
|
||||
from setuptools.dist import Distribution
|
||||
|
||||
_VERSION = '0.12.0-rc0-cmake-experimental'
|
||||
_VERSION = '0.12.0-rc1-cmake-experimental'
|
||||
|
||||
REQUIRED_PACKAGES = [
|
||||
'numpy >= 1.11.0',
|
||||
|
@ -62,7 +62,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc"
|
||||
)
|
||||
)
|
||||
list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs})
|
||||
endif(tensorflow_BUILD_CONTRIB_KERNELS)
|
||||
|
||||
|
@ -48,7 +48,7 @@ GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contr
|
||||
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc")
|
||||
|
||||
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
|
||||
|
||||
########################################################
|
||||
# tf_user_ops library
|
||||
|
@ -188,6 +188,7 @@ add_python_module("tensorflow/python/lib")
|
||||
add_python_module("tensorflow/python/lib/core")
|
||||
add_python_module("tensorflow/python/lib/io")
|
||||
add_python_module("tensorflow/python/ops")
|
||||
add_python_module("tensorflow/python/ops/losses")
|
||||
add_python_module("tensorflow/python/platform")
|
||||
add_python_module("tensorflow/python/platform/default")
|
||||
add_python_module("tensorflow/python/platform/summary")
|
||||
@ -220,6 +221,7 @@ add_python_module("tensorflow/contrib/bayesflow/examples/reinforce_simple")
|
||||
add_python_module("tensorflow/contrib/bayesflow/python")
|
||||
add_python_module("tensorflow/contrib/bayesflow/python/kernel_tests")
|
||||
add_python_module("tensorflow/contrib/bayesflow/python/ops")
|
||||
add_python_module("tensorflow/contrib/compiler")
|
||||
add_python_module("tensorflow/contrib/copy_graph")
|
||||
add_python_module("tensorflow/contrib/copy_graph/python")
|
||||
add_python_module("tensorflow/contrib/copy_graph/python/util")
|
||||
@ -261,6 +263,12 @@ add_python_module("tensorflow/contrib/grid_rnn")
|
||||
add_python_module("tensorflow/contrib/grid_rnn/python")
|
||||
add_python_module("tensorflow/contrib/grid_rnn/python/kernel_tests")
|
||||
add_python_module("tensorflow/contrib/grid_rnn/python/ops")
|
||||
add_python_module("tensorflow/contrib/image")
|
||||
add_python_module("tensorflow/contrib/image/python")
|
||||
add_python_module("tensorflow/contrib/image/python/ops")
|
||||
add_python_module("tensorflow/contrib/input_pipeline")
|
||||
add_python_module("tensorflow/contrib/input_pipeline/python")
|
||||
add_python_module("tensorflow/contrib/input_pipeline/python/ops")
|
||||
add_python_module("tensorflow/contrib/integrate")
|
||||
add_python_module("tensorflow/contrib/integrate/python")
|
||||
add_python_module("tensorflow/contrib/integrate/python/ops")
|
||||
@ -301,6 +309,7 @@ add_python_module("tensorflow/contrib/learn/python/learn/preprocessing/tests")
|
||||
add_python_module("tensorflow/contrib/learn/python/learn/tests")
|
||||
add_python_module("tensorflow/contrib/learn/python/learn/tests/dataframe")
|
||||
add_python_module("tensorflow/contrib/learn/python/learn/utils")
|
||||
add_python_module("tensorflow/contrib/legacy_seq2seq")
|
||||
add_python_module("tensorflow/contrib/linalg")
|
||||
add_python_module("tensorflow/contrib/linalg/python")
|
||||
add_python_module("tensorflow/contrib/linalg/python/ops")
|
||||
@ -392,6 +401,7 @@ add_python_module("tensorflow/contrib/training/python")
|
||||
add_python_module("tensorflow/contrib/training/python/training")
|
||||
add_python_module("tensorflow/contrib/util")
|
||||
|
||||
|
||||
# Additional directories with no Python sources.
|
||||
add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tensorboard/dist")
|
||||
@ -423,6 +433,7 @@ set(tf_python_op_lib_names
|
||||
)
|
||||
|
||||
function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name)
|
||||
set(options SHAPE_FUNCTIONS_NOT_REQUIRED)
|
||||
set(oneValueArgs DESTINATION)
|
||||
set(multiValueArgs ADDITIONAL_LIBRARIES)
|
||||
cmake_parse_arguments(GENERATE_PYTHON_OP_LIB
|
||||
@ -432,7 +443,12 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name)
|
||||
set(GENERATE_PYTHON_OP_LIB_DESTINATION
|
||||
"${python_ops_target_dir}/gen_${tf_python_op_lib_name}.py")
|
||||
endif()
|
||||
|
||||
if(GENERATE_PYTHON_OP_LIB_SHAPE_FUNCTIONS_NOT_REQUIRED)
|
||||
set(require_shape_fn 0)
|
||||
else()
|
||||
set(require_shape_fn 1)
|
||||
endif()
|
||||
|
||||
# Create a C++ executable that links in the appropriate op
|
||||
# registrations and generates Python wrapper code based on the
|
||||
# registered ops.
|
||||
@ -453,7 +469,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name)
|
||||
# containing the wrappers.
|
||||
add_custom_command(
|
||||
OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION}
|
||||
COMMAND ${tf_python_op_lib_name}_gen_python @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt 1 > ${GENERATE_PYTHON_OP_LIB_DESTINATION}
|
||||
COMMAND ${tf_python_op_lib_name}_gen_python @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION}
|
||||
DEPENDS ${tf_python_op_lib_name}_gen_python
|
||||
)
|
||||
|
||||
@ -496,6 +512,8 @@ GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_factorization_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_framework_variable_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_variable_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_tensor_forest_ops.py)
|
||||
|
||||
add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES})
|
||||
add_dependencies(tf_python_ops tf_python_op_gen_main)
|
||||
|
@ -126,8 +126,9 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
set(tf_test_src_py_exclude
|
||||
# generally not working
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/resource_variable_ops_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py"
|
||||
)
|
||||
if (WIN32)
|
||||
set(tf_test_src_py_exclude
|
||||
@ -144,7 +145,6 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/string_to_number_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/clip_ops_test.py"
|
||||
# misc
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/cwise_ops_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/reshape_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/server_test.py"
|
||||
|
@ -532,7 +532,7 @@ class ConvolutionTest(tf.test.TestCase):
|
||||
self.assertListEqual(list(eval_output.shape), expected_size_dynamic)
|
||||
|
||||
def testDynamicOutputSizeWithRateOneValidPaddingNCHW(self):
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
num_filters = 32
|
||||
input_size = [5, 3, 9, 11]
|
||||
expected_size = [None, num_filters, None, None]
|
||||
@ -624,7 +624,7 @@ class Convolution2dTransposeTests(tf.test.TestCase):
|
||||
|
||||
def testOutputSizeWithStrideOneSamePaddingNCHW(self):
|
||||
# `NCHW` data fomat is only supported for `GPU` device.
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
num_filters = 32
|
||||
input_size = [5, 3, 10, 12]
|
||||
@ -1782,7 +1782,7 @@ class BatchNormTest(tf.test.TestCase):
|
||||
self._testNoneUpdatesCollections(False, data_format='NCHW')
|
||||
|
||||
def testNoneUpdatesCollectionsFusedNCHW(self):
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._testNoneUpdatesCollections(True, data_format='NCHW')
|
||||
|
||||
def testNoneUpdatesCollectionsFusedNHWC(self):
|
||||
@ -1849,7 +1849,7 @@ class BatchNormTest(tf.test.TestCase):
|
||||
self._testDelayedUpdateMovingVars(False, data_format='NCHW')
|
||||
|
||||
def testDelayedUpdateMovingVarsFusedNCHW(self):
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._testDelayedUpdateMovingVars(True, data_format='NCHW')
|
||||
|
||||
def testDelayedUpdateMovingVarsFusedNHWC(self):
|
||||
@ -2037,7 +2037,7 @@ class BatchNormTest(tf.test.TestCase):
|
||||
self._testIsTrainingVariable(False, data_format='NCHW')
|
||||
|
||||
def testIsTrainingVariableFusedNCHW(self):
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._testIsTrainingVariable(True, data_format='NCHW')
|
||||
|
||||
def testIsTrainingVariableFusedNHWC(self):
|
||||
@ -2176,7 +2176,7 @@ class BatchNormTest(tf.test.TestCase):
|
||||
self._testNoneUpdatesCollectionIsTrainingVariable(False, data_format='NCHW')
|
||||
|
||||
def testNoneUpdatesCollectionIsTrainingVariableFusedNCHW(self):
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._testNoneUpdatesCollectionIsTrainingVariable(
|
||||
True, data_format='NCHW')
|
||||
|
||||
@ -2255,7 +2255,7 @@ class BatchNormTest(tf.test.TestCase):
|
||||
self._testTrainMovingVars(False, data_format='NCHW')
|
||||
|
||||
def testTrainMovingVarsFusedNCHW(self):
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._testTrainMovingVars(True, data_format='NCHW')
|
||||
|
||||
def testTrainMovingVarsFusedNHWC(self):
|
||||
|
@ -90,7 +90,6 @@ def boston_eval_fn():
|
||||
return tf.concat_v2([features, features], 0), tf.concat_v2([labels, labels],
|
||||
0)
|
||||
|
||||
|
||||
def extract(data, key):
|
||||
if isinstance(data, dict):
|
||||
assert key in data
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Experiment class collecting information needed for a single training run."""
|
||||
"""ExportStrategy class that provides strategies to export model so later it
|
||||
can be used for TensorFlow serving."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -1,51 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tools to work with checkpoints."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
||||
|
||||
|
||||
@deprecated('2016-08-22', 'Please use tf.contrib.framework.load_checkpoint '
|
||||
'instead')
|
||||
def load_checkpoint(filepattern):
|
||||
"""See `tf.contrib.framework.load_checkpoint`."""
|
||||
return checkpoint_utils.load_checkpoint(filepattern)
|
||||
|
||||
|
||||
@deprecated('2016-08-22', 'Please use tf.contrib.framework.load_variable '
|
||||
'instead')
|
||||
def load_variable(checkpoint_dir, name):
|
||||
"""See `tf.contrib.framework.load_variable`."""
|
||||
return checkpoint_utils.load_variable(checkpoint_dir, name)
|
||||
|
||||
|
||||
@deprecated('2016-08-22', 'Please use tf.contrib.framework.list_variables '
|
||||
'instead')
|
||||
def list_variables(checkpoint_dir):
|
||||
"""See `tf.contrib.framework.list_variables`."""
|
||||
return checkpoint_utils.list_variables(checkpoint_dir)
|
||||
|
||||
|
||||
@deprecated('2016-08-22', 'Please use tf.contrib.framework.init_from_checkpoint'
|
||||
' instead')
|
||||
def init_from_checkpoint(checkpoint_dir, assignment_map):
|
||||
"""See `tf.contrib.framework.init_from_checkpoint`."""
|
||||
checkpoint_utils.init_from_checkpoint(checkpoint_dir, assignment_map)
|
@ -106,6 +106,7 @@ tensorflow/core/kernels/deep_conv2d.cc
|
||||
tensorflow/core/kernels/xsmm_conv2d.cc
|
||||
tensorflow/core/kernels/cwise_ops_common.cc
|
||||
tensorflow/core/kernels/cwise_op_tanh.cc
|
||||
tensorflow/core/kernels/cwise_op_pow.cc
|
||||
tensorflow/core/kernels/cwise_op_sub.cc
|
||||
tensorflow/core/kernels/cwise_op_squared_difference.cc
|
||||
tensorflow/core/kernels/cwise_op_square.cc
|
||||
|
@ -8,8 +8,6 @@ exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
py_library(
|
||||
name = "metrics_py",
|
||||
srcs = [
|
||||
|
@ -99,7 +99,7 @@ normal distribution, regularize it with an `l2_loss` and place it on the `CPU`,
|
||||
one need only declare the following:
|
||||
|
||||
```python
|
||||
weights = variables.variable('weights',
|
||||
weights = slim.variable('weights',
|
||||
shape=[10, 10, 3 , 3],
|
||||
initializer=tf.truncated_normal_initializer(stddev=0.1),
|
||||
regularizer=slim.l2_regularizer(0.05),
|
||||
@ -361,11 +361,11 @@ One can also nest `arg_scopes` and use multiple operations in the same scope.
|
||||
For example:
|
||||
|
||||
```python
|
||||
with slim.arg_scope([slim.conv2d, slim.fully_connected],
|
||||
with slim.arg_scope([slim.conv2d, slim.fully_connected],
|
||||
activation_fn=tf.nn.relu,
|
||||
weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
|
||||
weights_regularizer=slim.l2_regularizer(0.0005)):
|
||||
with arg_scope([slim.conv2d], stride=1, padding='SAME'):
|
||||
with slim.arg_scope([slim.conv2d], stride=1, padding='SAME'):
|
||||
net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
|
||||
net = slim.conv2d(net, 256, [5, 5],
|
||||
weights_initializer=tf.truncated_normal_initializer(stddev=0.03),
|
||||
@ -450,7 +450,7 @@ images, labels = ...
|
||||
predictions = vgg.vgg16(images)
|
||||
|
||||
# Define the loss functions and get the total loss.
|
||||
loss = losses.softmax_cross_entropy(predictions, labels)
|
||||
loss = slim.losses.softmax_cross_entropy(predictions, labels)
|
||||
```
|
||||
|
||||
In this example, we start by creating the model (using TF-Slim's VGG
|
||||
@ -477,7 +477,7 @@ total_loss = slim.losses.get_total_loss(add_regularization_losses=False)
|
||||
In this example, we have two losses which we add by calling
|
||||
`slim.losses.softmax_cross_entropy` and `slim.losses.sum_of_squares`. We can
|
||||
obtain the total loss by adding them together (`total_loss`) or by calling
|
||||
`slim.losses.GetTotalLoss()`. How did this work?
|
||||
`slim.losses.get_total_loss()`. How did this work?
|
||||
When you create a loss function via TF-Slim, TF-Slim adds the loss to a
|
||||
special TensorFlow collection of loss functions. This enables you to either
|
||||
manage the total loss manually, or allow TF-Slim to manage them for you.
|
||||
@ -566,11 +566,10 @@ vgg = tf.contrib.slim.nets.vgg
|
||||
...
|
||||
|
||||
train_log_dir = ...
|
||||
if not gfile.Exists(train_log_dir):
|
||||
gfile.MakeDirs(train_log_dir)
|
||||
if not tf.gfile.Exists(train_log_dir):
|
||||
tf.gfile.MakeDirs(train_log_dir)
|
||||
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
with tf.Graph().as_default():
|
||||
# Set up the data loading:
|
||||
images, labels = ...
|
||||
|
||||
@ -581,7 +580,7 @@ with g.as_default():
|
||||
slim.losses.softmax_cross_entropy(predictions, labels)
|
||||
|
||||
total_loss = slim.losses.get_total_loss()
|
||||
tf.summary.scalar('losses/total loss', total_loss)
|
||||
tf.summary.scalar('losses/total_loss', total_loss)
|
||||
|
||||
# Specify the optimization scheme:
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=.001)
|
||||
@ -638,8 +637,8 @@ helper functions to select a subset of variables to restore:
|
||||
|
||||
```python
|
||||
# Create some variables.
|
||||
v1 = slim.variables.variable(name="v1", ...)
|
||||
v2 = slim.variables.variable(name="nested/v2", ...)
|
||||
v1 = slim.variable(name="v1", ...)
|
||||
v2 = slim.variable(name="nested/v2", ...)
|
||||
...
|
||||
|
||||
# Get list of variables to restore (which contains only 'v2'). These are all
|
||||
@ -748,7 +747,7 @@ We define a metric to be a performance measure that is not a loss function
|
||||
(losses are directly optimized during training), but which we are still
|
||||
interested in for the purpose of evaluating our model.
|
||||
For example, we might want to minimize log loss, but our metrics of interest
|
||||
might be F1 score, or Intersection Over Union score (which are not
|
||||
might be F1 score (test accuracy), or Intersection Over Union score (which are not
|
||||
differentiable, and therefore cannot be used as losses).
|
||||
|
||||
TF-Slim provides a set of metric operations that makes evaluating models
|
||||
@ -775,8 +774,8 @@ set (upon which the loss is computed), we'll assume we're using test data:
|
||||
images, labels = LoadTestData(...)
|
||||
predictions = MyModel(images)
|
||||
|
||||
mae_value_op, mae_update_op = slim.metrics.mean_absolute_error(predictions, labels)
|
||||
mre_value_op, mre_update_op = slim.metrics.mean_relative_error(predictions, labels, labels)
|
||||
mae_value_op, mae_update_op = slim.metrics.streaming_mean_absolute_error(predictions, labels)
|
||||
mre_value_op, mre_update_op = slim.metrics.streaming_mean_relative_error(predictions, labels, labels)
|
||||
pl_value_op, pl_update_op = slim.metrics.percentage_less(mean_relative_errors, 0.3)
|
||||
```
|
||||
|
||||
@ -793,13 +792,13 @@ this, TF-Slim provides two convenience functions:
|
||||
|
||||
# Aggregates the value and update ops in two lists:
|
||||
value_ops, update_ops = slim.metrics.aggregate_metrics(
|
||||
slim.metrics.mean_absolute_error(predictions, labels),
|
||||
slim.metrics.mean_squared_error(predictions, labels))
|
||||
slim.metrics.streaming_mean_absolute_error(predictions, labels),
|
||||
slim.metrics.streaming_mean_squared_error(predictions, labels))
|
||||
|
||||
# Aggregates the value and update ops in two dictionaries:
|
||||
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
|
||||
"eval/mean_absolute_error": slim.metrics.mean_absolute_error(predictions, labels),
|
||||
"eval/mean_squared_error": slim.metrics.mean_squared_error(predictions, labels),
|
||||
"eval/mean_absolute_error": slim.metrics.streaming_mean_absolute_error(predictions, labels),
|
||||
"eval/mean_squared_error": slim.metrics.streaming_mean_squared_error(predictions, labels),
|
||||
})
|
||||
|
||||
```
|
||||
@ -823,8 +822,8 @@ predictions = vgg.vgg_16(images)
|
||||
|
||||
# Choose the metrics to compute:
|
||||
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
|
||||
"eval/mean_absolute_error": slim.metrics.mean_absolute_error(predictions, labels),
|
||||
"eval/mean_squared_error": slim.metrics.mean_squared_error(predictions, labels),
|
||||
"eval/mean_absolute_error": slim.metrics.streaming_mean_absolute_error(predictions, labels),
|
||||
"eval/mean_squared_error": slim.metrics.streaming_mean_squared_error(predictions, labels),
|
||||
})
|
||||
|
||||
# Evaluate the model using 1000 batches of data:
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for nets.inception_v1."""
|
||||
"""Tests for nets.inception_v3."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -2352,6 +2352,7 @@ filegroup(
|
||||
# GIF data with optimization
|
||||
"lib/gif/testdata/optimized.gif",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -36,8 +36,10 @@ void SYCLAllocator::DeallocateRaw(void *ptr) {
|
||||
}
|
||||
|
||||
void SYCLAllocator::EnterLameDuckMode() {
|
||||
device_->deallocate_all();
|
||||
device_ = nullptr;
|
||||
if (device_) {
|
||||
device_->deallocate_all();
|
||||
device_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#define EIGEN_USE_SYCL
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -23,11 +23,38 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static std::unordered_set<SYCLDevice*> live_devices;
|
||||
|
||||
void ShutdownSycl() {
|
||||
for (auto device : live_devices) {
|
||||
device->EnterLameDuckMode();
|
||||
}
|
||||
live_devices.clear();
|
||||
}
|
||||
bool first_time = true;
|
||||
|
||||
void SYCLDevice::RegisterDevice() {
|
||||
if (first_time) {
|
||||
first_time = false;
|
||||
atexit(ShutdownSycl);
|
||||
}
|
||||
live_devices.insert(this);
|
||||
}
|
||||
|
||||
SYCLDevice::~SYCLDevice() {
|
||||
device_context_->Unref();
|
||||
sycl_allocator_->EnterLameDuckMode();
|
||||
delete sycl_device_;
|
||||
delete sycl_queue_;
|
||||
live_devices.erase(this);
|
||||
}
|
||||
|
||||
void SYCLDevice::EnterLameDuckMode() {
|
||||
sycl_allocator_->EnterLameDuckMode();
|
||||
delete sycl_device_;
|
||||
sycl_device_ = nullptr;
|
||||
delete sycl_queue_;
|
||||
sycl_queue_ = nullptr;
|
||||
}
|
||||
|
||||
void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) {
|
||||
@ -63,8 +90,8 @@ Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto,
|
||||
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
|
||||
device_context_->CopyCPUTensorToDevice(&parsed, this, ©,
|
||||
[&status](const Status &s) {
|
||||
status = s;
|
||||
});
|
||||
status = s;
|
||||
});
|
||||
*tensor = copy;
|
||||
}
|
||||
return status;
|
||||
|
@ -20,8 +20,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
|
||||
|
||||
#define EIGEN_USE_SYCL
|
||||
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_allocator.h"
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
|
||||
@ -45,10 +43,13 @@ public:
|
||||
sycl_allocator_(new SYCLAllocator(sycl_queue_)),
|
||||
device_context_(new SYCLDeviceContext()) {
|
||||
set_eigen_sycl_device(sycl_device_);
|
||||
RegisterDevice();
|
||||
}
|
||||
|
||||
~SYCLDevice() override;
|
||||
|
||||
void EnterLameDuckMode();
|
||||
|
||||
void Compute(OpKernel *op_kernel, OpKernelContext *context) override;
|
||||
Allocator *GetAllocator(AllocatorAttributes attr) override;
|
||||
Status MakeTensorFromProto(const TensorProto &tensor_proto,
|
||||
@ -65,6 +66,8 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
void RegisterDevice();
|
||||
|
||||
Allocator *cpu_allocator_; // owned
|
||||
Eigen::QueueInterface* sycl_queue_; // owned
|
||||
Eigen::SyclDevice* sycl_device_; // owned
|
||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
|
||||
#define EIGEN_USE_SYCL
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
|
||||
|
@ -438,7 +438,7 @@ static bool CopyIfNeeded(TensorProto* in, TensorProto* out) {
|
||||
} else {
|
||||
Tensor t(in->dtype());
|
||||
if (!t.FromProto(cpu_allocator(), *in)) return false;
|
||||
t.AsProtoField(out);
|
||||
t.AsProtoTensorContent(out);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -377,7 +377,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
recv->set_key(key);
|
||||
// TODO(zhifengc): Deal with gpu -> cpu copy.
|
||||
TensorProto* proto = recv->mutable_val();
|
||||
val.AsProtoField(proto);
|
||||
val.AsProtoTensorContent(proto);
|
||||
}
|
||||
}
|
||||
delete collector;
|
||||
|
@ -28,8 +28,11 @@ message TensorProto {
|
||||
// to represent a constant Tensor with a single value.
|
||||
int32 version_number = 3;
|
||||
|
||||
// Serialized content from Tensor::AsProtoTensorContent(). This representation
|
||||
// can be used for all tensor types.
|
||||
// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
|
||||
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
|
||||
// can be used for all tensor types. The purpose of this representation is to
|
||||
// reduce serialization overhead during RPC call by avoiding serialization of
|
||||
// many repeated small items.
|
||||
bytes tensor_content = 4;
|
||||
|
||||
// Type specific representations that make it easy to create tensor protos in
|
||||
|
@ -57,19 +57,21 @@ class ArgOp : public OpKernel {
|
||||
const int32 dim = internal::SubtleMustCopy(dimension.scalar<int32>()());
|
||||
const int input_dims = input.dims();
|
||||
|
||||
OP_REQUIRES(context, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
|
||||
OP_REQUIRES(context, dim < input_dims,
|
||||
errors::InvalidArgument("Minimum tensor rank: ", dim + 1,
|
||||
" but got: ", input_dims));
|
||||
int axis = dim < 0 ? dim + input_dims : dim;
|
||||
|
||||
OP_REQUIRES(context, axis >= 0 && axis < input_dims,
|
||||
errors::InvalidArgument(
|
||||
"Expected dimension in the range [", -input_dims, ", ",
|
||||
input_dims, "), but got ", dim));
|
||||
OP_REQUIRES(
|
||||
context, input.dim_size(dim) > 0,
|
||||
context, input.dim_size(axis) > 0,
|
||||
errors::InvalidArgument("Reduction axis ", dim, " is empty in shape ",
|
||||
input.shape().DebugString()));
|
||||
|
||||
TensorShape output_shape;
|
||||
const TensorShape& input_shape = input.shape();
|
||||
for (int d = 0; d < input_dims - 1; ++d) {
|
||||
output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
|
||||
output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
|
||||
}
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
|
||||
@ -77,7 +79,7 @@ class ArgOp : public OpKernel {
|
||||
#define HANDLE_DIM(NDIM) \
|
||||
case NDIM: \
|
||||
ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \
|
||||
input.tensor<T, NDIM>(), dim, \
|
||||
input.tensor<T, NDIM>(), axis, \
|
||||
output->tensor<int64, NDIM - 1>()); \
|
||||
break;
|
||||
|
||||
|
@ -16,9 +16,6 @@ limitations under the License.
|
||||
// See docs in ../ops/array_ops.cc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define EIGEN_USE_SYCL
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/kernels/constant_op.h"
|
||||
|
||||
@ -116,6 +113,9 @@ REGISTER_KERNEL_BUILDER(Name("Const")
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
#endif //TENSORFLOW_USE_SYCL
|
||||
|
||||
namespace functor {
|
||||
|
||||
@ -128,6 +128,17 @@ struct FillFunctor<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
// Partial specialization of FillFunctor<Device=SYCLDevice, T>.
|
||||
template <typename T>
|
||||
struct FillFunctor<SYCLDevice, T> {
|
||||
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
|
||||
typename TTypes<T>::ConstScalar in) {
|
||||
To32Bit(out).device(d) = To32Bit(out).constant(in());
|
||||
}
|
||||
};
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
template <typename Device, typename T>
|
||||
@ -172,6 +183,17 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
|
||||
REGISTER_KERNEL(CPU, quint8);
|
||||
#undef REGISTER_CPU_KERNEL
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL(SYCL, float)
|
||||
REGISTER_KERNEL_BUILDER(Name("Fill")
|
||||
.Device(DEVICE_SYCL)
|
||||
.TypeConstraint<int32>("T")
|
||||
.HostMemory("dims")
|
||||
.HostMemory("value")
|
||||
.HostMemory("output"),
|
||||
FillOp<CPUDevice, int32>);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_KERNEL(GPU, Eigen::half);
|
||||
REGISTER_KERNEL(GPU, float);
|
||||
@ -220,6 +242,15 @@ class ZerosLikeOp : public OpKernel {
|
||||
TF_CALL_POD_STRING_TYPES(REGISTER_CPU);
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL(float, SYCL);
|
||||
REGISTER_KERNEL_BUILDER(Name("ZerosLike")
|
||||
.Device(DEVICE_SYCL)
|
||||
.TypeConstraint<int32>("T")
|
||||
.HostMemory("y"),
|
||||
ZerosLikeOp<CPUDevice, int32>);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_KERNEL(bool, GPU);
|
||||
REGISTER_KERNEL(Eigen::half, GPU);
|
||||
|
@ -16,12 +16,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
#if EIGEN_HAS_C99_MATH
|
||||
REGISTER3(UnaryOp, CPU, "Digamma", functor::digamma, float, Eigen::half,
|
||||
double);
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER3(UnaryOp, GPU, "Digamma", functor::digamma, float, Eigen::half,
|
||||
double);
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
} // namespace tensorflow
|
||||
|
@ -37,6 +37,7 @@ REGISTER5(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
|
||||
.TypeConstraint<TYPE>("T"), \
|
||||
BinaryOp<SYCLDevice, functor::div<TYPE>>);
|
||||
REGISTER_SYCL_KERNEL(float)
|
||||
REGISTER_SYCL_KERNEL(int32)
|
||||
#undef REGISTER_SYCL_KERNEL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
#if GOOGLE_CUDA
|
||||
|
@ -16,10 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
#if EIGEN_HAS_C99_MATH
|
||||
REGISTER3(UnaryOp, CPU, "Erf", functor::erf, float, Eigen::half, double);
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER3(UnaryOp, GPU, "Erf", functor::erf, float, Eigen::half, double);
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
} // namespace tensorflow
|
||||
|
@ -16,10 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
#if EIGEN_HAS_C99_MATH
|
||||
REGISTER3(UnaryOp, CPU, "Erfc", functor::erfc, float, Eigen::half, double);
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER3(UnaryOp, GPU, "Erfc", functor::erfc, float, Eigen::half, double);
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
} // namespace tensorflow
|
||||
|
@ -16,8 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
#if EIGEN_HAS_C99_MATH
|
||||
REGISTER2(BinaryOp, CPU, "Igamma", functor::igamma, float, double);
|
||||
REGISTER2(BinaryOp, CPU, "Igammac", functor::igammac, float, double);
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
} // namespace tensorflow
|
||||
|
@ -16,11 +16,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
#if EIGEN_HAS_C99_MATH
|
||||
REGISTER3(UnaryOp, CPU, "Lgamma", functor::lgamma, float, Eigen::half, double);
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER3(UnaryOp, GPU, "Lgamma", functor::lgamma, float, Eigen::half, double);
|
||||
#endif
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -17,7 +17,5 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER2(BinaryOp, CPU, "Zeta", functor::zeta, float, double);
|
||||
#if EIGEN_HAS_C99_MATH
|
||||
REGISTER2(BinaryOp, CPU, "Polygamma", functor::polygamma, float, double);
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
} // namespace tensorflow
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_SYCL_COMMON_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_SYCL_COMMON_H_
|
||||
|
||||
#define EIGEN_USE_SYCL
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
@ -14,9 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define EIGEN_USE_SYCL
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/kernels/dense_update_ops.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
@ -165,6 +165,20 @@ class DynamicStitchOp : public OpKernel {
|
||||
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
|
||||
#undef REGISTER_DYNAMIC_STITCH
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_DYNAMIC_STITCH_SYCL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("indices") \
|
||||
.HostMemory("data") \
|
||||
.HostMemory("merged"), \
|
||||
DynamicStitchOp<type>)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH_SYCL);
|
||||
#undef REGISTER_DYNAMIC_STITCH_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_DYNAMIC_STITCH_GPU(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
|
||||
|
@ -52,5 +52,18 @@ DEFINE_SETZERO_CPU(complex64);
|
||||
DEFINE_SETZERO_CPU(complex128);
|
||||
#undef DEFINE_SETZERO_CPU
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <typename T>
|
||||
void SetZeroFunctor<Eigen::SyclDevice, T>::operator()(
|
||||
const Eigen::SyclDevice& d, typename TTypes<T>::Flat out) {
|
||||
out.device(d) = out.constant(T(0));
|
||||
}
|
||||
|
||||
#define DEFINE_SETZERO_SYCL(T) \
|
||||
template struct SetZeroFunctor<Eigen::SyclDevice, T>;
|
||||
DEFINE_SETZERO_SYCL(float);
|
||||
#undef DEFINE_SETZERO_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
@ -45,6 +45,15 @@ struct SetZeroFunctor<Eigen::ThreadPoolDevice, T> {
|
||||
typename TTypes<T>::Flat out);
|
||||
};
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
// Partial specialization of SetZeroFunctor<Device=Eigen::SyclDevice, T>.
|
||||
template <typename T>
|
||||
struct SetZeroFunctor<Eigen::SyclDevice, T> {
|
||||
void operator()(const Eigen::SyclDevice& d,
|
||||
typename TTypes<T>::Flat out);
|
||||
};
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
template <>
|
||||
struct SetZeroFunctor<Eigen::ThreadPoolDevice, string> {
|
||||
void operator()(const Eigen::ThreadPoolDevice& d,
|
||||
|
@ -93,7 +93,7 @@ REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
|
||||
Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
|
||||
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
|
||||
.Device(DEVICE_GPU)
|
||||
.Device(DEVICE_SYCL)
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T"),
|
||||
ArgOp);
|
||||
@ -104,7 +104,7 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
|
||||
RetvalOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
|
||||
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
|
||||
.Device(DEVICE_GPU)
|
||||
.Device(DEVICE_SYCL)
|
||||
.HostMemory("input")
|
||||
.TypeConstraint<int32>("T"),
|
||||
RetvalOp);
|
||||
@ -238,5 +238,9 @@ REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
|
||||
SymbolicGradientOp);
|
||||
REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU),
|
||||
SymbolicGradientOp);
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL),
|
||||
SymbolicGradientOp);
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
} // namespace tensorflow
|
||||
|
@ -54,10 +54,29 @@ REGISTER_KERNEL_BUILDER(Name("RefIdentity").Device(DEVICE_CPU), IdentityOp);
|
||||
IdentityOp)
|
||||
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
|
||||
REGISTER_SYCL_KERNEL(bfloat16);
|
||||
|
||||
#undef REGISTER_SYCL_KERNEL
|
||||
#endif
|
||||
|
||||
#define REGISTER_SYCL_HOST_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Identity") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.HostMemory("input") \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("RefIdentity") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.HostMemory("input") \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
IdentityOp)
|
||||
|
||||
REGISTER_SYCL_HOST_KERNEL(int32);
|
||||
REGISTER_SYCL_HOST_KERNEL(bool);
|
||||
|
||||
#undef REGISTER_SYCL_HOST_KERNEL
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
|
@ -31,6 +31,27 @@ REGISTER_KERNEL_BUILDER(Name("Reshape").Device(DEVICE_CPU).HostMemory("shape"),
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Reshape") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.HostMemory("shape") \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tshape"), \
|
||||
ReshapeOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
|
||||
#undef REGISTER_SYCL_KERNEL
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Reshape")
|
||||
.Device(DEVICE_SYCL)
|
||||
.HostMemory("tensor")
|
||||
.HostMemory("shape")
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T")
|
||||
.TypeConstraint<int32>("Tshape"),
|
||||
ReshapeOp);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
|
@ -89,6 +89,12 @@ class RangeOp : public OpKernel {
|
||||
|
||||
#define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, T)
|
||||
#define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, T)
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL(DEVICE_SYCL, T)
|
||||
TF_CALL_float(REGISTER_SYCL_KERNEL);
|
||||
TF_CALL_int32(REGISTER_SYCL_KERNEL);
|
||||
TF_CALL_int64(REGISTER_SYCL_KERNEL);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
TF_CALL_float(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_double(REGISTER_CPU_KERNEL);
|
||||
|
@ -63,6 +63,40 @@ REGISTER_KERNEL_BUILDER(Name("Shape")
|
||||
.TypeConstraint<int64>("out_type"),
|
||||
ShapeOp<int64>);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Shape") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<int32>("out_type") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
ShapeOp<int32>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Shape") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<int64>("out_type") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
ShapeOp<int64>);
|
||||
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
|
||||
#undef REGISTER_SYCL_KERNEL
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Shape")
|
||||
.Device(DEVICE_SYCL)
|
||||
.HostMemory("input")
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T")
|
||||
.TypeConstraint<int32>("out_type"),
|
||||
ShapeOp<int32>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Shape")
|
||||
.Device(DEVICE_SYCL)
|
||||
.HostMemory("input")
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T")
|
||||
.TypeConstraint<int64>("out_type"),
|
||||
ShapeOp<int64>);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Shape") \
|
||||
@ -193,6 +227,34 @@ class RankOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("Rank").Device(DEVICE_CPU).HostMemory("output"),
|
||||
RankOp);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Rank") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("output"), \
|
||||
RankOp);
|
||||
REGISTER_SYCL_KERNEL(float);
|
||||
#undef REGISTER_SYCL_KERNEL
|
||||
|
||||
// A special GPU kernel for int32 and bool.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
// registration requires all int32 inputs and outputs to be in host memory.
|
||||
REGISTER_KERNEL_BUILDER(Name("Rank")
|
||||
.Device(DEVICE_SYCL)
|
||||
.TypeConstraint<int32>("T")
|
||||
.HostMemory("input")
|
||||
.HostMemory("output"),
|
||||
RankOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Rank")
|
||||
.Device(DEVICE_SYCL)
|
||||
.TypeConstraint<bool>("T")
|
||||
.HostMemory("input")
|
||||
.HostMemory("output"),
|
||||
RankOp);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Rank") \
|
||||
|
@ -25,6 +25,7 @@ namespace tensorflow {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
using SYCLDevice = Eigen::SyclDevice;
|
||||
|
||||
namespace {
|
||||
template <class T>
|
||||
@ -220,9 +221,9 @@ struct ApplyMomentum<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdam<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
|
||||
template <typename Device, typename T>
|
||||
struct ApplyAdamNonCuda {
|
||||
void operator()(const Device& d, typename TTypes<T>::Flat var,
|
||||
typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
|
||||
typename TTypes<T>::ConstScalar beta1_power,
|
||||
typename TTypes<T>::ConstScalar beta2_power,
|
||||
@ -239,6 +240,12 @@ struct ApplyAdam<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {};
|
||||
template <typename T>
|
||||
struct ApplyAdam<SYCLDevice, T> : ApplyAdamNonCuda<SYCLDevice, T> {};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct ApplyRMSProp<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
|
||||
@ -2139,6 +2146,12 @@ TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T);
|
||||
|
||||
TF_CALL_float(REGISTER_SYCL_KERNELS);
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
|
@ -388,6 +388,9 @@ class TestOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override { ctx->set_output(0, Tensor()); }
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("TestOpWithNoGrad").Device(DEVICE_CPU), TestOp);
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("TestOpWithNoGrad").Device(DEVICE_SYCL), TestOp);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
TEST_F(MathGradTest, Error_Reporting) {
|
||||
auto x = test::AsTensor<float>({-3.f});
|
||||
|
@ -1268,17 +1268,18 @@ Status ArgOpShape(shape_inference::InferenceContext* c) {
|
||||
dimension_val = dim_t->scalar<int64>()();
|
||||
}
|
||||
|
||||
if (dimension_val < 0 || dimension_val >= input_rank) {
|
||||
int64 axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val;
|
||||
if (axis < 0 || axis >= input_rank) {
|
||||
return errors::InvalidArgument("Dimension (", dimension_val,
|
||||
") must be in the range [0, ", input_rank,
|
||||
"), where ", input_rank, " is the ",
|
||||
"number of dimensions in the input.");
|
||||
") must be in the range [", -input_rank,
|
||||
", ", input_rank, "), where ", input_rank,
|
||||
" is the number of dimensions in the input.");
|
||||
}
|
||||
|
||||
// Return the input shape without the dimension being reduced.
|
||||
std::vector<DimensionHandle> dims;
|
||||
for (int i = 0; i < input_rank; ++i) {
|
||||
if (dimension_val != i) {
|
||||
if (axis != i) {
|
||||
dims.emplace_back(c->Dim(input_shape, i));
|
||||
}
|
||||
}
|
||||
|
@ -450,11 +450,15 @@ TEST(MathOpsTest, ArgOps_ShapeFn) {
|
||||
// Dimension value out of bounds
|
||||
dimension = test::AsScalar(10);
|
||||
op.input_tensors[1] = &dimension;
|
||||
INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]");
|
||||
INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
|
||||
|
||||
dimension = test::AsScalar(-10);
|
||||
op.input_tensors[1] = &dimension;
|
||||
INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]");
|
||||
INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
|
||||
|
||||
dimension = test::AsScalar(-1);
|
||||
op.input_tensors[1] = &dimension;
|
||||
INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
|
||||
}
|
||||
|
||||
TEST(MathOpsTest, Betainc_ShapeFn) {
|
||||
|
@ -16,6 +16,10 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_PLATFORM_CPU_INFO_H_
|
||||
#define TENSORFLOW_PLATFORM_CPU_INFO_H_
|
||||
|
||||
#if defined(PLATFORM_WINDOWS)
|
||||
#include "tensorflow/core/platform/windows/cpu_info.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace port {
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,19 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// SWIG test helper for lib::tensorflow::Status
|
||||
#ifndef TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_
|
||||
#define TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
%import(module="tensorflow.python.pywrap_tensorflow") "tensorflow/python/lib/core/status.i"
|
||||
// Byte order defines provided by gcc. MSVC doesn't define those so
|
||||
// we define them here.
|
||||
// We assume that all windows platform out there are little endian.
|
||||
#define __ORDER_LITTLE_ENDIAN__ 0x4d2
|
||||
#define __ORDER_BIG_ENDIAN__ 0x10e1
|
||||
#define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__
|
||||
|
||||
%inline %{
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
tensorflow::Status NotOkay() {
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, "Testing 1 2 3");
|
||||
}
|
||||
|
||||
tensorflow::Status Okay() {
|
||||
return tensorflow::Status();
|
||||
}
|
||||
%}
|
||||
#endif // TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_
|
@ -21,6 +21,7 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
@ -65,6 +66,26 @@ def main(unused_argv):
|
||||
# Specify that all features have real-value data
|
||||
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
|
||||
|
||||
validation_metrics = {
|
||||
"accuracy": MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_accuracy,
|
||||
prediction_key="classes"),
|
||||
"recall": MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_recall,
|
||||
prediction_key="classes"),
|
||||
"precision": MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_precision,
|
||||
prediction_key="classes")
|
||||
}
|
||||
validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
|
||||
test_set.data,
|
||||
test_set.target,
|
||||
every_n_steps=50,
|
||||
metrics=validation_metrics,
|
||||
early_stopping_metric="loss",
|
||||
early_stopping_metric_minimize=True,
|
||||
early_stopping_rounds=200)
|
||||
|
||||
# Build 3 layer DNN with 10, 20, 10 units respectively.
|
||||
classifier = tf.contrib.learn.DNNClassifier(
|
||||
feature_columns=feature_columns,
|
||||
|
@ -186,13 +186,13 @@ Renames file src to target. If target already exists, it will be replaced.
|
||||
|
||||
#### `virtual uint64 tensorflow::Env::NowMicros()=0` {#virtual_uint64_tensorflow_Env_NowMicros}
|
||||
|
||||
Returns the number of micro-seconds since some fixed point in time. Only useful for computing deltas of time.
|
||||
Returns the number of micro-seconds since the Unix epoch.
|
||||
|
||||
|
||||
|
||||
#### `virtual uint64 tensorflow::Env::NowSeconds()` {#virtual_uint64_tensorflow_Env_NowSeconds}
|
||||
|
||||
Returns the number of seconds since some fixed point in time. Only useful for computing deltas of time.
|
||||
Returns the number of seconds since the Unix epoch.
|
||||
|
||||
|
||||
|
||||
|
@ -50,7 +50,7 @@ Returns true if the path matches the given pattern. The wildcards allowed in pat
|
||||
|
||||
#### `uint64 tensorflow::EnvWrapper::NowMicros() override` {#uint64_tensorflow_EnvWrapper_NowMicros}
|
||||
|
||||
Returns the number of micro-seconds since some fixed point in time. Only useful for computing deltas of time.
|
||||
Returns the number of micro-seconds since the Unix epoch.
|
||||
|
||||
|
||||
|
||||
|
@ -14,7 +14,6 @@ python/sparse_ops.md
|
||||
python/io_ops.md
|
||||
python/python_io.md
|
||||
python/nn.md
|
||||
python/rnn_cell.md
|
||||
python/client.md
|
||||
python/train.md
|
||||
python/histogram_ops.md
|
||||
|
@ -1,4 +1,185 @@
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ByteSize()` {#TaggedRunMetadata.ByteSize}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.Clear()` {#TaggedRunMetadata.Clear}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ClearExtension(extension_handle)` {#TaggedRunMetadata.ClearExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ClearField(field_name)` {#TaggedRunMetadata.ClearField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.CopyFrom(other_msg)` {#TaggedRunMetadata.CopyFrom}
|
||||
|
||||
Copies the content of the specified message into the current message.
|
||||
|
||||
The method clears the current message and then merges the specified
|
||||
message using MergeFrom.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.DiscardUnknownFields()` {#TaggedRunMetadata.DiscardUnknownFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.FindInitializationErrors()` {#TaggedRunMetadata.FindInitializationErrors}
|
||||
|
||||
Finds required fields which are not initialized.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A list of strings. Each string is a path to an uninitialized field from
|
||||
the top-level message, e.g. "foo.bar[5].baz".
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.HasExtension(extension_handle)` {#TaggedRunMetadata.HasExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.HasField(field_name)` {#TaggedRunMetadata.HasField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.IsInitialized(errors=None)` {#TaggedRunMetadata.IsInitialized}
|
||||
|
||||
Checks if all required fields of a message are set.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||
paths of all missing required fields.
|
||||
|
||||
##### Returns:
|
||||
|
||||
True iff the specified message has all required fields set.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ListFields()` {#TaggedRunMetadata.ListFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.MergeFrom(msg)` {#TaggedRunMetadata.MergeFrom}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.MergeFromString(serialized)` {#TaggedRunMetadata.MergeFromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ParseFromString(serialized)` {#TaggedRunMetadata.ParseFromString}
|
||||
|
||||
Parse serialized protocol buffer data into this message.
|
||||
|
||||
Like MergeFromString(), except we clear the object first and
|
||||
do not return the value that MergeFromString returns.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SerializePartialToString()` {#TaggedRunMetadata.SerializePartialToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SerializeToString()` {#TaggedRunMetadata.SerializeToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SetInParent()` {#TaggedRunMetadata.SetInParent}
|
||||
|
||||
Sets the _cached_byte_size_dirty bit to true,
|
||||
and propagates this to our listener iff this was a state change.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.WhichOneof(oneof_name)` {#TaggedRunMetadata.WhichOneof}
|
||||
|
||||
Returns the name of the currently set field inside a oneof, or None.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__deepcopy__(memo=None)` {#TaggedRunMetadata.__deepcopy__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__eq__(other)` {#TaggedRunMetadata.__eq__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__getstate__()` {#TaggedRunMetadata.__getstate__}
|
||||
@ -6,3 +187,66 @@
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__hash__()` {#TaggedRunMetadata.__hash__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__init__(**kwargs)` {#TaggedRunMetadata.__init__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__ne__(other_msg)` {#TaggedRunMetadata.__ne__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__repr__()` {#TaggedRunMetadata.__repr__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__setstate__(state)` {#TaggedRunMetadata.__setstate__}
|
||||
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__str__()` {#TaggedRunMetadata.__str__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__unicode__()` {#TaggedRunMetadata.__unicode__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.run_metadata` {#TaggedRunMetadata.run_metadata}
|
||||
|
||||
Magic attribute generated for "run_metadata" proto field.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.tag` {#TaggedRunMetadata.tag}
|
||||
|
||||
Magic attribute generated for "tag" proto field.
|
||||
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
See `tf.global_variables`. (deprecated)
|
||||
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2016-03-02.
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2017-03-02.
|
||||
Instructions for updating:
|
||||
Please use tf.global_variables instead.
|
||||
|
||||
|
@ -0,0 +1,17 @@
|
||||
### `tf.merge_all_summaries(*args, **kwargs)` {#merge_all_summaries}
|
||||
|
||||
Merges all summaries collected in the default graph. (deprecated)
|
||||
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30.
|
||||
Instructions for updating:
|
||||
Please switch to tf.summary.merge_all.
|
||||
|
||||
Args:
|
||||
key: `GraphKey` used to collect the summaries. Defaults to
|
||||
`GraphKeys.SUMMARIES`.
|
||||
|
||||
Returns:
|
||||
If no summaries were collected, returns None. Otherwise returns a scalar
|
||||
`Tensor` of type `string` containing the serialized `Summary` protocol
|
||||
buffer resulting from the merging.
|
||||
|
@ -0,0 +1,49 @@
|
||||
### `tf.image_summary(*args, **kwargs)` {#image_summary}
|
||||
|
||||
Outputs a `Summary` protocol buffer with images. (deprecated)
|
||||
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30.
|
||||
Instructions for updating:
|
||||
Please switch to tf.summary.image. Note that tf.summary.histogram uses the node name instead of the tag. This means that TensorFlow will automatically de-duplicate summary names based on the scope they are created in. Also, the max_images argument was renamed to max_outputs.
|
||||
|
||||
The summary has up to `max_images` summary values containing images. The
|
||||
images are built from `tensor` which must be 4-D with shape `[batch_size,
|
||||
height, width, channels]` and where `channels` can be:
|
||||
|
||||
* 1: `tensor` is interpreted as Grayscale.
|
||||
* 3: `tensor` is interpreted as RGB.
|
||||
* 4: `tensor` is interpreted as RGBA.
|
||||
|
||||
The images have the same number of channels as the input tensor. For float
|
||||
input, the values are normalized one image at a time to fit in the range
|
||||
`[0, 255]`. `uint8` values are unchanged. The op uses two different
|
||||
normalization algorithms:
|
||||
|
||||
* If the input values are all positive, they are rescaled so the largest one
|
||||
is 255.
|
||||
|
||||
* If any input value is negative, the values are shifted so input value 0.0
|
||||
is at 127. They are then rescaled so that either the smallest value is 0,
|
||||
or the largest one is 255.
|
||||
|
||||
The `tag` argument is a scalar `Tensor` of type `string`. It is used to
|
||||
build the `tag` of the summary values:
|
||||
|
||||
* If `max_images` is 1, the summary value tag is '*tag*/image'.
|
||||
* If `max_images` is greater than 1, the summary value tags are
|
||||
generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
|
||||
|
||||
Args:
|
||||
tag: A scalar `Tensor` of type `string`. Used to build the `tag`
|
||||
of the summary values.
|
||||
tensor: A 4-D `uint8` or `float32` `Tensor` of shape `[batch_size, height,
|
||||
width, channels]` where `channels` is 1, 3, or 4.
|
||||
max_images: Max number of batch elements to generate images for.
|
||||
collections: Optional list of ops.GraphKeys. The collections to add the
|
||||
summary to. Defaults to [ops.GraphKeys.SUMMARIES]
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||||
buffer.
|
||||
|
@ -1,4 +1,185 @@
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ByteSize()` {#SummaryDescription.ByteSize}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.Clear()` {#SummaryDescription.Clear}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ClearExtension(extension_handle)` {#SummaryDescription.ClearExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ClearField(field_name)` {#SummaryDescription.ClearField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.CopyFrom(other_msg)` {#SummaryDescription.CopyFrom}
|
||||
|
||||
Copies the content of the specified message into the current message.
|
||||
|
||||
The method clears the current message and then merges the specified
|
||||
message using MergeFrom.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.DiscardUnknownFields()` {#SummaryDescription.DiscardUnknownFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.FindInitializationErrors()` {#SummaryDescription.FindInitializationErrors}
|
||||
|
||||
Finds required fields which are not initialized.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A list of strings. Each string is a path to an uninitialized field from
|
||||
the top-level message, e.g. "foo.bar[5].baz".
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.HasExtension(extension_handle)` {#SummaryDescription.HasExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.HasField(field_name)` {#SummaryDescription.HasField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.IsInitialized(errors=None)` {#SummaryDescription.IsInitialized}
|
||||
|
||||
Checks if all required fields of a message are set.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||
paths of all missing required fields.
|
||||
|
||||
##### Returns:
|
||||
|
||||
True iff the specified message has all required fields set.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ListFields()` {#SummaryDescription.ListFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.MergeFrom(msg)` {#SummaryDescription.MergeFrom}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.MergeFromString(serialized)` {#SummaryDescription.MergeFromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ParseFromString(serialized)` {#SummaryDescription.ParseFromString}
|
||||
|
||||
Parse serialized protocol buffer data into this message.
|
||||
|
||||
Like MergeFromString(), except we clear the object first and
|
||||
do not return the value that MergeFromString returns.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SerializePartialToString()` {#SummaryDescription.SerializePartialToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SerializeToString()` {#SummaryDescription.SerializeToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SetInParent()` {#SummaryDescription.SetInParent}
|
||||
|
||||
Sets the _cached_byte_size_dirty bit to true,
|
||||
and propagates this to our listener iff this was a state change.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.WhichOneof(oneof_name)` {#SummaryDescription.WhichOneof}
|
||||
|
||||
Returns the name of the currently set field inside a oneof, or None.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__deepcopy__(memo=None)` {#SummaryDescription.__deepcopy__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__eq__(other)` {#SummaryDescription.__eq__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__getstate__()` {#SummaryDescription.__getstate__}
|
||||
@ -6,3 +187,59 @@
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__hash__()` {#SummaryDescription.__hash__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__init__(**kwargs)` {#SummaryDescription.__init__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__ne__(other_msg)` {#SummaryDescription.__ne__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__repr__()` {#SummaryDescription.__repr__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__setstate__(state)` {#SummaryDescription.__setstate__}
|
||||
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__str__()` {#SummaryDescription.__str__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__unicode__()` {#SummaryDescription.__unicode__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.type_hint` {#SummaryDescription.type_hint}
|
||||
|
||||
Magic attribute generated for "type_hint" proto field.
|
||||
|
||||
|
||||
|
@ -173,125 +173,6 @@ Checks that for all elements of farray1 and farray2
|
||||
* <b>`err`</b>: a float value.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertBetween(value, minv, maxv, msg=None)` {#TestCase.assertBetween}
|
||||
|
||||
Asserts that value is between minv and maxv (inclusive).
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCommandFails(command, regexes, env=None, close_fds=True, msg=None)` {#TestCase.assertCommandFails}
|
||||
|
||||
Asserts a shell command fails and the error matches a regex in a list.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`command`</b>: List or string representing the command to run.
|
||||
* <b>`regexes`</b>: the list of regular expression strings.
|
||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
||||
forking.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCommandSucceeds(command, regexes=('',), env=None, close_fds=True, msg=None)` {#TestCase.assertCommandSucceeds}
|
||||
|
||||
Asserts that a shell command succeeds (i.e. exits with code 0).
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`command`</b>: List or string representing the command to run.
|
||||
* <b>`regexes`</b>: List of regular expression byte strings that match success.
|
||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
||||
forking.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsExactSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsExactSubsequence}
|
||||
|
||||
Assert that "container" contains "subsequence" as an exact subsequence.
|
||||
|
||||
Asserts that "container" contains all the elements of "subsequence", in
|
||||
order, and without other elements interspersed. For example, [1, 2, 3] is an
|
||||
exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
||||
* <b>`subsequence`</b>: the list we hope will be an exact subsequence of container.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsInOrder(strings, target, msg=None)` {#TestCase.assertContainsInOrder}
|
||||
|
||||
Asserts that the strings provided are found in the target in order.
|
||||
|
||||
This may be useful for checking HTML output.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`strings`</b>: A list of strings, such as [ 'fox', 'dog' ]
|
||||
* <b>`target`</b>: A target string in which to look for the strings, such as
|
||||
'The quick brown fox jumped over the lazy dog'.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsSubsequence}
|
||||
|
||||
Assert that "container" contains "subsequence" as a subsequence.
|
||||
|
||||
Asserts that "container" contains all the elements of "subsequence", in
|
||||
order, but possibly with other elements interspersed. For example, [1, 2, 3]
|
||||
is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
||||
* <b>`subsequence`</b>: the list we hope will be a subsequence of container.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsSubset(expected_subset, actual_set, msg=None)` {#TestCase.assertContainsSubset}
|
||||
|
||||
Checks whether actual iterable is a superset of expected iterable.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCountEqual(*args, **kwargs)` {#TestCase.assertCountEqual}
|
||||
|
||||
An unordered sequence specific comparison.
|
||||
|
||||
Equivalent to assertItemsEqual(). This method is a compatibility layer
|
||||
for Python 3k, since 2to3 does not convert assertItemsEqual() calls into
|
||||
assertCountEqual() calls.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertDeviceEqual(device1, device2)` {#TestCase.assertDeviceEqual}
|
||||
@ -314,48 +195,9 @@ Checks whether actual is a superset of expected.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertDictEqual(a, b, msg=None)` {#TestCase.assertDictEqual}
|
||||
|
||||
Raises AssertionError if a and b are not equal dictionaries.
|
||||
|
||||
##### Args:
|
||||
#### `tf.test.TestCase.assertDictEqual(d1, d2, msg=None)` {#TestCase.assertDictEqual}
|
||||
|
||||
|
||||
* <b>`a`</b>: A dict, the expected value.
|
||||
* <b>`b`</b>: A dict, the actual value.
|
||||
* <b>`msg`</b>: An optional str, the associated message.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`AssertionError`</b>: if the dictionaries are not equal.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertEmpty(container, msg=None)` {#TestCase.assertEmpty}
|
||||
|
||||
Assert that an object has zero length.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertEndsWith(actual, expected_end, msg=None)` {#TestCase.assertEndsWith}
|
||||
|
||||
Assert that actual.endswith(expected_end) is True.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`expected_end`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
@ -440,11 +282,10 @@ Included for symmetry with assertIsNone.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertItemsEqual(*args, **kwargs)` {#TestCase.assertItemsEqual}
|
||||
#### `tf.test.TestCase.assertItemsEqual(expected_seq, actual_seq, msg=None)` {#TestCase.assertItemsEqual}
|
||||
|
||||
An unordered sequence specific comparison.
|
||||
|
||||
It asserts that actual_seq and expected_seq have the same element counts.
|
||||
An unordered sequence specific comparison. It asserts that
|
||||
actual_seq and expected_seq have the same element counts.
|
||||
Equivalent to::
|
||||
|
||||
self.assertEqual(Counter(iter(actual_seq)),
|
||||
@ -457,30 +298,6 @@ Asserts that each element has the same count in both sequences.
|
||||
- [0, 1, 1] and [1, 0, 1] compare equal.
|
||||
- [0, 0, 1] and [0, 1] compare unequal.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertJsonEqual(first, second, msg=None)` {#TestCase.assertJsonEqual}
|
||||
|
||||
Asserts that the JSON objects defined in two strings are equal.
|
||||
|
||||
A summary of the differences will be included in the failure message
|
||||
using assertSameStructure.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`first`</b>: A string contining JSON to decode and compare to second.
|
||||
* <b>`second`</b>: A string contining JSON to decode and compare to first.
|
||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
@ -550,13 +367,6 @@ if not.
|
||||
* <b>`msg`</b>: An optional string message to append to the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNoCommonElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertNoCommonElements}
|
||||
|
||||
Checks whether actual iterable and expected iterable are disjoint.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)` {#TestCase.assertNotAlmostEqual}
|
||||
@ -587,33 +397,6 @@ as significant digits (measured from the most signficant digit).
|
||||
Objects that are equal automatically fail.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEmpty(container, msg=None)` {#TestCase.assertNotEmpty}
|
||||
|
||||
Assert that an object has non-zero length.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEndsWith(actual, unexpected_end, msg=None)` {#TestCase.assertNotEndsWith}
|
||||
|
||||
Assert that actual.endswith(unexpected_end) is False.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`unexpected_end`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEqual(first, second, msg=None)` {#TestCase.assertNotEqual}
|
||||
@ -651,20 +434,6 @@ Included for symmetry with assertIsInstance.
|
||||
Fail the test if the text matches the regular expression.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotStartsWith(actual, unexpected_start, msg=None)` {#TestCase.assertNotStartsWith}
|
||||
|
||||
Assert that actual.startswith(unexpected_start) is False.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`unexpected_start`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertProtoEquals(expected_message_maybe_ascii, message)` {#TestCase.assertProtoEquals}
|
||||
@ -739,38 +508,6 @@ Asserts that the message in a raised exception matches a regexp.
|
||||
* <b>`kwargs`</b>: Extra kwargs.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithLiteralMatch(expected_exception, expected_exception_message, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithLiteralMatch}
|
||||
|
||||
Asserts that the message in a raised exception equals the given string.
|
||||
|
||||
Unlike assertRaisesRegexp, this method takes a literal string, not
|
||||
a regular expression.
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ExType, 'message'):
|
||||
DoSomething()
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
||||
* <b>`expected_exception_message`</b>: String message expected in the raised
|
||||
exception. For a raise exception e, expected_exception_message must
|
||||
equal str(e).
|
||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
||||
* <b>`args`</b>: Extra args.
|
||||
* <b>`kwargs`</b>: Extra kwargs.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A context manager if callable_obj is None. Otherwise, None.
|
||||
|
||||
##### Raises:
|
||||
|
||||
self.failureException if callable_obj does not raise a macthing exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithPredicateMatch(exception_type, expected_err_re_or_predicate)` {#TestCase.assertRaisesWithPredicateMatch}
|
||||
@ -795,71 +532,6 @@ predicate search.
|
||||
exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithRegexpMatch(expected_exception, expected_regexp, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithRegexpMatch}
|
||||
|
||||
Asserts that the message in a raised exception matches the given regexp.
|
||||
|
||||
This is just a wrapper around assertRaisesRegexp. Please use
|
||||
assertRaisesRegexp instead of assertRaisesWithRegexpMatch.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
||||
* <b>`expected_regexp`</b>: Regexp (re pattern object or string) expected to be
|
||||
found in error message.
|
||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
||||
* <b>`args`</b>: Extra args.
|
||||
* <b>`kwargs`</b>: Extra keyword args.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A context manager if callable_obj is None. Otherwise, None.
|
||||
|
||||
##### Raises:
|
||||
|
||||
self.failureException if callable_obj does not raise a macthing exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRegexMatch(actual_str, regexes, message=None)` {#TestCase.assertRegexMatch}
|
||||
|
||||
Asserts that at least one regex in regexes matches str.
|
||||
|
||||
If possible you should use assertRegexpMatches, which is a simpler
|
||||
version of this method. assertRegexpMatches takes a single regular
|
||||
expression (a string or re compiled object) instead of a list.
|
||||
|
||||
Notes:
|
||||
1. This function uses substring matching, i.e. the matching
|
||||
succeeds if *any* substring of the error message matches *any*
|
||||
regex in the list. This is more convenient for the user than
|
||||
full-string matching.
|
||||
|
||||
2. If regexes is the empty list, the matching will always fail.
|
||||
|
||||
3. Use regexes=[''] for a regex that will always pass.
|
||||
|
||||
4. '.' matches any single character *except* the newline. To
|
||||
match any character, use '(.|
|
||||
)'.
|
||||
|
||||
5. '^' matches the beginning of each line, not just the beginning
|
||||
of the string. Similarly, '$' matches the end of each line.
|
||||
|
||||
6. An exception will be thrown if regexes contains an invalid
|
||||
regex.
|
||||
|
||||
Args:
|
||||
actual_str: The string we try to match with the items in regexes.
|
||||
regexes: The regular expressions we want to match against str.
|
||||
See "Notes" above for detailed notes on how this is interpreted.
|
||||
message: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRegexpMatches(text, expected_regexp, msg=None)` {#TestCase.assertRegexpMatches}
|
||||
@ -867,79 +539,6 @@ Asserts that at least one regex in regexes matches str.
|
||||
Fail the test unless the text matches the regular expression.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSameElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertSameElements}
|
||||
|
||||
Assert that two sequences have the same elements (in any order).
|
||||
|
||||
This method, unlike assertItemsEqual, doesn't care about any
|
||||
duplicates in the expected and actual sequences.
|
||||
|
||||
>> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
|
||||
# Doesn't raise an AssertionError
|
||||
|
||||
If possible, you should use assertItemsEqual instead of
|
||||
assertSameElements.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSameStructure(a, b, aname='a', bname='b', msg=None)` {#TestCase.assertSameStructure}
|
||||
|
||||
Asserts that two values contain the same structural content.
|
||||
|
||||
The two arguments should be data trees consisting of trees of dicts and
|
||||
lists. They will be deeply compared by walking into the contents of dicts
|
||||
and lists; other items will be compared using the == operator.
|
||||
If the two structures differ in content, the failure message will indicate
|
||||
the location within the structures where the first difference is found.
|
||||
This may be helpful when comparing large structures.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`a`</b>: The first structure to compare.
|
||||
* <b>`b`</b>: The second structure to compare.
|
||||
* <b>`aname`</b>: Variable name to use for the first structure in assertion messages.
|
||||
* <b>`bname`</b>: Variable name to use for the second structure.
|
||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceAlmostEqual(expected_seq, actual_seq, places=None, msg=None, delta=None)` {#TestCase.assertSequenceAlmostEqual}
|
||||
|
||||
An approximate equality assertion for ordered sequences.
|
||||
|
||||
Fail if the two sequences are unequal as determined by their value
|
||||
differences rounded to the given number of decimal places (default 7) and
|
||||
comparing to zero, or by comparing that the difference between each value
|
||||
in the two sequences is more than the given delta.
|
||||
|
||||
Note that decimal places (from zero) are usually not the same as significant
|
||||
digits (measured from the most signficant digit).
|
||||
|
||||
If the two sequences compare equal then they will automatically compare
|
||||
almost equal.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`places`</b>: The number of decimal places to compare.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
* <b>`delta`</b>: The OK difference between compared values.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)` {#TestCase.assertSequenceEqual}
|
||||
@ -960,26 +559,6 @@ which can be indexed, has a length, and has an equality operator.
|
||||
differences.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceStartsWith(prefix, whole, msg=None)` {#TestCase.assertSequenceStartsWith}
|
||||
|
||||
An equality assertion for the beginning of ordered sequences.
|
||||
|
||||
If prefix is an empty sequence, it will raise an error unless whole is also
|
||||
an empty sequence.
|
||||
|
||||
If prefix is not a sequence, it will raise an error if the first element of
|
||||
whole does not match.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`prefix`</b>: A sequence expected at the beginning of the whole parameter.
|
||||
* <b>`whole`</b>: The sequence in which to look for prefix.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSetEqual(set1, set2, msg=None)` {#TestCase.assertSetEqual}
|
||||
@ -1031,51 +610,6 @@ Assert that actual.startswith(expected_start) is True.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertTotallyOrdered(*groups, **kwargs)` {#TestCase.assertTotallyOrdered}
|
||||
|
||||
Asserts that total ordering has been implemented correctly.
|
||||
|
||||
For example, say you have a class A that compares only on its attribute x.
|
||||
Comparators other than __lt__ are omitted for brevity.
|
||||
|
||||
class A(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.x)
|
||||
|
||||
def __lt__(self, other):
|
||||
try:
|
||||
return self.x < other.x
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
|
||||
assertTotallyOrdered will check that instances can be ordered correctly.
|
||||
For example,
|
||||
|
||||
self.assertTotallyOrdered(
|
||||
[None], # None should come before everything else.
|
||||
[1], # Integers sort earlier.
|
||||
[A(1, 'a')],
|
||||
[A(2, 'b')], # 2 is after 1.
|
||||
[A(3, 'c'), A(3, 'd')], # The second argument is irrelevant.
|
||||
[A(4, 'z')],
|
||||
['foo']) # Strings sort last.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`*groups`</b>: A list of groups of elements. Each group of elements is a list
|
||||
of objects that are equal. The elements in each group must be less than
|
||||
the elements in the group after it. For example, these groups are
|
||||
totally ordered: [None], [1], [2, 2], [3].
|
||||
* <b>`**kwargs`</b>: optional msg keyword argument can be passed.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertTrue(expr, msg=None)` {#TestCase.assertTrue}
|
||||
@ -1098,13 +632,6 @@ A tuple-specific equality assertion.
|
||||
differences.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertUrlEqual(a, b, msg=None)` {#TestCase.assertUrlEqual}
|
||||
|
||||
Asserts that urls are equal, ignoring ordering of query params.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assert_(expr, msg=None)` {#TestCase.assert_}
|
||||
@ -1166,9 +693,9 @@ tearDown.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.fail(msg=None, prefix=None)` {#TestCase.fail}
|
||||
#### `tf.test.TestCase.fail(msg=None)` {#TestCase.fail}
|
||||
|
||||
Fail immediately with the given message, optionally prefixed.
|
||||
Fail immediately, with the given message.
|
||||
|
||||
|
||||
- - -
|
||||
@ -1220,13 +747,6 @@ Fail immediately with the given message, optionally prefixed.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.getRecordedProperties()` {#TestCase.getRecordedProperties}
|
||||
|
||||
Return any properties that the user has recorded.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.get_temp_dir()` {#TestCase.get_temp_dir}
|
||||
@ -1241,20 +761,6 @@ Return any properties that the user has recorded.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.recordProperty(property_name, property_value)` {#TestCase.recordProperty}
|
||||
|
||||
Record an arbitrary property for later use.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`property_name`</b>: str, name of property to record; must be a valid XML
|
||||
attribute name
|
||||
* <b>`property_value`</b>: value of property; must be valid XML attribute value
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.run(result=None)` {#TestCase.run}
|
||||
@ -1280,18 +786,11 @@ Hook method for setting up class fixture before running tests in the class.
|
||||
|
||||
#### `tf.test.TestCase.shortDescription()` {#TestCase.shortDescription}
|
||||
|
||||
Format both the test method name and the first line of its docstring.
|
||||
Returns a one-line description of the test, or None if no
|
||||
description has been provided.
|
||||
|
||||
If no docstring is given, only returns the method name.
|
||||
|
||||
This method overrides unittest.TestCase.shortDescription(), which
|
||||
only returns the first line of the docstring, obscuring the name
|
||||
of the test upon failure.
|
||||
|
||||
##### Returns:
|
||||
|
||||
|
||||
* <b>`desc`</b>: A short description of a test method.
|
||||
The default implementation of this method returns the first line of
|
||||
the specified test method's docstring.
|
||||
|
||||
|
||||
- - -
|
||||
|
@ -0,0 +1,22 @@
|
||||
### `tf.scalar_summary(*args, **kwargs)` {#scalar_summary}
|
||||
|
||||
Outputs a `Summary` protocol buffer with scalar values. (deprecated)
|
||||
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30.
|
||||
Instructions for updating:
|
||||
Please switch to tf.summary.scalar. Note that tf.summary.scalar uses the node name instead of the tag. This means that TensorFlow will automatically de-duplicate summary names based on the scope they are created in. Also, passing a tensor or list of tags to a scalar summary op is no longer supported.
|
||||
|
||||
The input `tags` and `values` must have the same shape. The generated
|
||||
summary has a summary value for each tag-value pair in `tags` and `values`.
|
||||
|
||||
Args:
|
||||
tags: A `string` `Tensor`. Tags for the summaries.
|
||||
values: A real numeric Tensor. Values for the summaries.
|
||||
collections: Optional list of graph collections keys. The new summary op is
|
||||
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||||
buffer.
|
||||
|
@ -0,0 +1,4 @@
|
||||
#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension}
|
||||
|
||||
|
||||
|
@ -0,0 +1,26 @@
|
||||
### `tf.histogram_summary(*args, **kwargs)` {#histogram_summary}
|
||||
|
||||
Outputs a `Summary` protocol buffer with a histogram. (deprecated)
|
||||
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30.
|
||||
Instructions for updating:
|
||||
Please switch to tf.summary.histogram. Note that tf.summary.histogram uses the node name instead of the tag. This means that TensorFlow will automatically de-duplicate summary names based on their scope.
|
||||
|
||||
The generated
|
||||
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
|
||||
has one summary value containing a histogram for `values`.
|
||||
|
||||
This op reports an `InvalidArgument` error if any value is not finite.
|
||||
|
||||
Args:
|
||||
tag: A `string` `Tensor`. 0-D. Tag to use for the summary value.
|
||||
values: A real numeric `Tensor`. Any shape. Values to use to
|
||||
build the histogram.
|
||||
collections: Optional list of graph collections keys. The new summary op is
|
||||
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||||
buffer.
|
||||
|
@ -0,0 +1,27 @@
|
||||
### `tf.merge_summary(*args, **kwargs)` {#merge_summary}
|
||||
|
||||
Merges summaries. (deprecated)
|
||||
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30.
|
||||
Instructions for updating:
|
||||
Please switch to tf.summary.merge.
|
||||
|
||||
This op creates a
|
||||
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
|
||||
protocol buffer that contains the union of all the values in the input
|
||||
summaries.
|
||||
|
||||
When the Op is run, it reports an `InvalidArgument` error if multiple values
|
||||
in the summaries to merge use the same tag.
|
||||
|
||||
Args:
|
||||
inputs: A list of `string` `Tensor` objects containing serialized `Summary`
|
||||
protocol buffers.
|
||||
collections: Optional list of graph collections keys. The new summary op is
|
||||
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||||
buffer resulting from the merging.
|
||||
|
@ -0,0 +1,4 @@
|
||||
#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString}
|
||||
|
||||
|
||||
|
@ -0,0 +1,4 @@
|
||||
#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension}
|
||||
|
||||
|
||||
|
@ -0,0 +1,207 @@
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.__init__(*args, **kwargs)` {#SummaryWriter.__init__}
|
||||
|
||||
Creates a `SummaryWriter` and an event file. (deprecated)
|
||||
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30.
|
||||
Instructions for updating:
|
||||
Please switch to tf.summary.FileWriter. The interface and behavior is the same; this is just a rename.
|
||||
|
||||
This class is deprecated, and should be replaced with tf.summary.FileWriter.
|
||||
|
||||
On construction the summary writer creates a new event file in `logdir`.
|
||||
This event file will contain `Event` protocol buffers constructed when you
|
||||
call one of the following functions: `add_summary()`, `add_session_log()`,
|
||||
`add_event()`, or `add_graph()`.
|
||||
|
||||
If you pass a `Graph` to the constructor it is added to
|
||||
the event file. (This is equivalent to calling `add_graph()` later).
|
||||
|
||||
TensorBoard will pick the graph from the file and display it graphically so
|
||||
you can interactively explore the graph you built. You will usually pass
|
||||
the graph from the session in which you launched it:
|
||||
|
||||
```python
|
||||
...create a graph...
|
||||
# Launch the graph in a session.
|
||||
sess = tf.Session()
|
||||
# Create a summary writer, add the 'graph' to the event file.
|
||||
writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
|
||||
```
|
||||
|
||||
The other arguments to the constructor control the asynchronous writes to
|
||||
the event file:
|
||||
|
||||
* `flush_secs`: How often, in seconds, to flush the added summaries
|
||||
and events to disk.
|
||||
* `max_queue`: Maximum number of summaries or events pending to be
|
||||
written to disk before one of the 'add' calls block.
|
||||
|
||||
Args:
|
||||
logdir: A string. Directory where event file will be written.
|
||||
graph: A `Graph` object, such as `sess.graph`.
|
||||
max_queue: Integer. Size of the queue for pending events and summaries.
|
||||
flush_secs: Number. How often, in seconds, to flush the
|
||||
pending events and summaries to disk.
|
||||
graph_def: DEPRECATED: Use the `graph` argument instead.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.add_event(event)` {#SummaryWriter.add_event}
|
||||
|
||||
Adds an event to the event file.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`event`</b>: An `Event` protocol buffer.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.add_graph(graph, global_step=None, graph_def=None)` {#SummaryWriter.add_graph}
|
||||
|
||||
Adds a `Graph` to the event file.
|
||||
|
||||
The graph described by the protocol buffer will be displayed by
|
||||
TensorBoard. Most users pass a graph in the constructor instead.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`graph`</b>: A `Graph` object, such as `sess.graph`.
|
||||
* <b>`global_step`</b>: Number. Optional global step counter to record with the
|
||||
graph.
|
||||
* <b>`graph_def`</b>: DEPRECATED. Use the `graph` parameter instead.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`ValueError`</b>: If both graph and graph_def are passed to the method.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.add_meta_graph(meta_graph_def, global_step=None)` {#SummaryWriter.add_meta_graph}
|
||||
|
||||
Adds a `MetaGraphDef` to the event file.
|
||||
|
||||
The `MetaGraphDef` allows running the given graph via
|
||||
`saver.import_meta_graph()`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`meta_graph_def`</b>: A `MetaGraphDef` object, often as retured by
|
||||
`saver.export_meta_graph()`.
|
||||
* <b>`global_step`</b>: Number. Optional global step counter to record with the
|
||||
graph.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: If both `meta_graph_def` is not an instance of `MetaGraphDef`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.add_run_metadata(run_metadata, tag, global_step=None)` {#SummaryWriter.add_run_metadata}
|
||||
|
||||
Adds a metadata information for a single session.run() call.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`run_metadata`</b>: A `RunMetadata` protobuf object.
|
||||
* <b>`tag`</b>: The tag name for this metadata.
|
||||
* <b>`global_step`</b>: Number. Optional global step counter to record with the
|
||||
StepStats.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`ValueError`</b>: If the provided tag was already used for this type of event.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.add_session_log(session_log, global_step=None)` {#SummaryWriter.add_session_log}
|
||||
|
||||
Adds a `SessionLog` protocol buffer to the event file.
|
||||
|
||||
This method wraps the provided session in an `Event` protocol buffer
|
||||
and adds it to the event file.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`session_log`</b>: A `SessionLog` protocol buffer.
|
||||
* <b>`global_step`</b>: Number. Optional global step value to record with the
|
||||
summary.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.add_summary(summary, global_step=None)` {#SummaryWriter.add_summary}
|
||||
|
||||
Adds a `Summary` protocol buffer to the event file.
|
||||
|
||||
This method wraps the provided summary in an `Event` protocol buffer
|
||||
and adds it to the event file.
|
||||
|
||||
You can pass the result of evaluating any summary op, using
|
||||
[`Session.run()`](client.md#Session.run) or
|
||||
[`Tensor.eval()`](framework.md#Tensor.eval), to this
|
||||
function. Alternatively, you can pass a `tf.Summary` protocol
|
||||
buffer that you populate with your own data. The latter is
|
||||
commonly done to report evaluation results in event files.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`summary`</b>: A `Summary` protocol buffer, optionally serialized as a string.
|
||||
* <b>`global_step`</b>: Number. Optional global step value to record with the
|
||||
summary.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.close()` {#SummaryWriter.close}
|
||||
|
||||
Flushes the event file to disk and close the file.
|
||||
|
||||
Call this method when you do not need the summary writer anymore.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.flush()` {#SummaryWriter.flush}
|
||||
|
||||
Flushes the event file to disk.
|
||||
|
||||
Call this method to make sure that all pending events have been written to
|
||||
disk.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.get_logdir()` {#SummaryWriter.get_logdir}
|
||||
|
||||
Returns the directory where event file will be written.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.SummaryWriter.reopen()` {#SummaryWriter.reopen}
|
||||
|
||||
Reopens the EventFileWriter.
|
||||
|
||||
Can be called after `close()` to add more events in the same directory.
|
||||
The events will go into a new events file.
|
||||
|
||||
Does nothing if the EventFileWriter was not closed.
|
||||
|
||||
|
@ -16,7 +16,7 @@ in the summaries to merge use the same tag.
|
||||
* <b>`inputs`</b>: A list of `string` `Tensor` objects containing serialized `Summary`
|
||||
protocol buffers.
|
||||
* <b>`collections`</b>: Optional list of graph collections keys. The new summary op is
|
||||
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
|
||||
added to these collections. Defaults to `[]`.
|
||||
* <b>`name`</b>: A name for the operation (optional).
|
||||
|
||||
##### Returns:
|
||||
|
@ -0,0 +1,37 @@
|
||||
### `tf.audio_summary(*args, **kwargs)` {#audio_summary}
|
||||
|
||||
Outputs a `Summary` protocol buffer with audio. (deprecated)
|
||||
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30.
|
||||
Instructions for updating:
|
||||
Please switch to tf.summary.audio. Note that tf.summary.histogram uses the node name instead of the tag. This means that TensorFlow will automatically de-duplicate summary names based on the scope they are created in.
|
||||
|
||||
The summary has up to `max_outputs` summary values containing audio. The
|
||||
audio is built from `tensor` which must be 3-D with shape `[batch_size,
|
||||
frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are
|
||||
assumed to be in the range of `[-1.0, 1.0]` with a sample rate of
|
||||
`sample_rate`.
|
||||
|
||||
The `tag` argument is a scalar `Tensor` of type `string`. It is used to
|
||||
build the `tag` of the summary values:
|
||||
|
||||
* If `max_outputs` is 1, the summary value tag is '*tag*/audio'.
|
||||
* If `max_outputs` is greater than 1, the summary value tags are
|
||||
generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
|
||||
|
||||
Args:
|
||||
tag: A scalar `Tensor` of type `string`. Used to build the `tag`
|
||||
of the summary values.
|
||||
tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]`
|
||||
or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`.
|
||||
sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the
|
||||
signal in hertz.
|
||||
max_outputs: Max number of batch elements to generate audio for.
|
||||
collections: Optional list of ops.GraphKeys. The collections to add the
|
||||
summary to. Defaults to [ops.GraphKeys.SUMMARIES]
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||||
buffer.
|
||||
|
@ -0,0 +1,4 @@
|
||||
#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString}
|
||||
|
||||
|
||||
|
@ -3444,7 +3444,7 @@ device assignments have not changed.
|
||||
|
||||
See `tf.global_variables`. (deprecated)
|
||||
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2016-03-02.
|
||||
THIS FUNCTION IS DEPRECATED. It will be removed after 2017-03-02.
|
||||
Instructions for updating:
|
||||
Please use tf.global_variables instead.
|
||||
|
||||
|
@ -424,7 +424,7 @@ in the summaries to merge use the same tag.
|
||||
* <b>`inputs`</b>: A list of `string` `Tensor` objects containing serialized `Summary`
|
||||
protocol buffers.
|
||||
* <b>`collections`</b>: Optional list of graph collections keys. The new summary op is
|
||||
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
|
||||
added to these collections. Defaults to `[]`.
|
||||
* <b>`name`</b>: A name for the operation (optional).
|
||||
|
||||
##### Returns:
|
||||
@ -485,6 +485,187 @@ metadata is stored in its NodeDef. This method retrieves the description.
|
||||
### `class tf.summary.SummaryDescription` {#SummaryDescription}
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ByteSize()` {#SummaryDescription.ByteSize}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.Clear()` {#SummaryDescription.Clear}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ClearExtension(extension_handle)` {#SummaryDescription.ClearExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ClearField(field_name)` {#SummaryDescription.ClearField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.CopyFrom(other_msg)` {#SummaryDescription.CopyFrom}
|
||||
|
||||
Copies the content of the specified message into the current message.
|
||||
|
||||
The method clears the current message and then merges the specified
|
||||
message using MergeFrom.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.DiscardUnknownFields()` {#SummaryDescription.DiscardUnknownFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.FindInitializationErrors()` {#SummaryDescription.FindInitializationErrors}
|
||||
|
||||
Finds required fields which are not initialized.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A list of strings. Each string is a path to an uninitialized field from
|
||||
the top-level message, e.g. "foo.bar[5].baz".
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.HasExtension(extension_handle)` {#SummaryDescription.HasExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.HasField(field_name)` {#SummaryDescription.HasField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.IsInitialized(errors=None)` {#SummaryDescription.IsInitialized}
|
||||
|
||||
Checks if all required fields of a message are set.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||
paths of all missing required fields.
|
||||
|
||||
##### Returns:
|
||||
|
||||
True iff the specified message has all required fields set.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ListFields()` {#SummaryDescription.ListFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.MergeFrom(msg)` {#SummaryDescription.MergeFrom}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.MergeFromString(serialized)` {#SummaryDescription.MergeFromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.ParseFromString(serialized)` {#SummaryDescription.ParseFromString}
|
||||
|
||||
Parse serialized protocol buffer data into this message.
|
||||
|
||||
Like MergeFromString(), except we clear the object first and
|
||||
do not return the value that MergeFromString returns.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SerializePartialToString()` {#SummaryDescription.SerializePartialToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SerializeToString()` {#SummaryDescription.SerializeToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.SetInParent()` {#SummaryDescription.SetInParent}
|
||||
|
||||
Sets the _cached_byte_size_dirty bit to true,
|
||||
and propagates this to our listener iff this was a state change.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.WhichOneof(oneof_name)` {#SummaryDescription.WhichOneof}
|
||||
|
||||
Returns the name of the currently set field inside a oneof, or None.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__deepcopy__(memo=None)` {#SummaryDescription.__deepcopy__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__eq__(other)` {#SummaryDescription.__eq__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__getstate__()` {#SummaryDescription.__getstate__}
|
||||
@ -492,12 +673,249 @@ metadata is stored in its NodeDef. This method retrieves the description.
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__hash__()` {#SummaryDescription.__hash__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__init__(**kwargs)` {#SummaryDescription.__init__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__ne__(other_msg)` {#SummaryDescription.__ne__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__repr__()` {#SummaryDescription.__repr__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__setstate__(state)` {#SummaryDescription.__setstate__}
|
||||
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__str__()` {#SummaryDescription.__str__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.__unicode__()` {#SummaryDescription.__unicode__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.SummaryDescription.type_hint` {#SummaryDescription.type_hint}
|
||||
|
||||
Magic attribute generated for "type_hint" proto field.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.summary.TaggedRunMetadata` {#TaggedRunMetadata}
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ByteSize()` {#TaggedRunMetadata.ByteSize}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.Clear()` {#TaggedRunMetadata.Clear}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ClearExtension(extension_handle)` {#TaggedRunMetadata.ClearExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ClearField(field_name)` {#TaggedRunMetadata.ClearField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.CopyFrom(other_msg)` {#TaggedRunMetadata.CopyFrom}
|
||||
|
||||
Copies the content of the specified message into the current message.
|
||||
|
||||
The method clears the current message and then merges the specified
|
||||
message using MergeFrom.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`other_msg`</b>: Message to copy into the current one.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.DiscardUnknownFields()` {#TaggedRunMetadata.DiscardUnknownFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.FindInitializationErrors()` {#TaggedRunMetadata.FindInitializationErrors}
|
||||
|
||||
Finds required fields which are not initialized.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A list of strings. Each string is a path to an uninitialized field from
|
||||
the top-level message, e.g. "foo.bar[5].baz".
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.HasExtension(extension_handle)` {#TaggedRunMetadata.HasExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.HasField(field_name)` {#TaggedRunMetadata.HasField}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.IsInitialized(errors=None)` {#TaggedRunMetadata.IsInitialized}
|
||||
|
||||
Checks if all required fields of a message are set.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`errors`</b>: A list which, if provided, will be populated with the field
|
||||
paths of all missing required fields.
|
||||
|
||||
##### Returns:
|
||||
|
||||
True iff the specified message has all required fields set.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ListFields()` {#TaggedRunMetadata.ListFields}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.MergeFrom(msg)` {#TaggedRunMetadata.MergeFrom}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.MergeFromString(serialized)` {#TaggedRunMetadata.MergeFromString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.ParseFromString(serialized)` {#TaggedRunMetadata.ParseFromString}
|
||||
|
||||
Parse serialized protocol buffer data into this message.
|
||||
|
||||
Like MergeFromString(), except we clear the object first and
|
||||
do not return the value that MergeFromString returns.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SerializePartialToString()` {#TaggedRunMetadata.SerializePartialToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SerializeToString()` {#TaggedRunMetadata.SerializeToString}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.SetInParent()` {#TaggedRunMetadata.SetInParent}
|
||||
|
||||
Sets the _cached_byte_size_dirty bit to true,
|
||||
and propagates this to our listener iff this was a state change.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.WhichOneof(oneof_name)` {#TaggedRunMetadata.WhichOneof}
|
||||
|
||||
Returns the name of the currently set field inside a oneof, or None.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__deepcopy__(memo=None)` {#TaggedRunMetadata.__deepcopy__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__eq__(other)` {#TaggedRunMetadata.__eq__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__getstate__()` {#TaggedRunMetadata.__getstate__}
|
||||
@ -505,4 +923,67 @@ Support the pickle protocol.
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__hash__()` {#TaggedRunMetadata.__hash__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__init__(**kwargs)` {#TaggedRunMetadata.__init__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__ne__(other_msg)` {#TaggedRunMetadata.__ne__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__repr__()` {#TaggedRunMetadata.__repr__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__setstate__(state)` {#TaggedRunMetadata.__setstate__}
|
||||
|
||||
Support the pickle protocol.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__str__()` {#TaggedRunMetadata.__str__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.__unicode__()` {#TaggedRunMetadata.__unicode__}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.run_metadata` {#TaggedRunMetadata.run_metadata}
|
||||
|
||||
Magic attribute generated for "run_metadata" proto field.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.summary.TaggedRunMetadata.tag` {#TaggedRunMetadata.tag}
|
||||
|
||||
Magic attribute generated for "tag" proto field.
|
||||
|
||||
|
||||
|
||||
|
@ -213,125 +213,6 @@ Checks that for all elements of farray1 and farray2
|
||||
* <b>`err`</b>: a float value.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertBetween(value, minv, maxv, msg=None)` {#TestCase.assertBetween}
|
||||
|
||||
Asserts that value is between minv and maxv (inclusive).
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCommandFails(command, regexes, env=None, close_fds=True, msg=None)` {#TestCase.assertCommandFails}
|
||||
|
||||
Asserts a shell command fails and the error matches a regex in a list.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`command`</b>: List or string representing the command to run.
|
||||
* <b>`regexes`</b>: the list of regular expression strings.
|
||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
||||
forking.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCommandSucceeds(command, regexes=('',), env=None, close_fds=True, msg=None)` {#TestCase.assertCommandSucceeds}
|
||||
|
||||
Asserts that a shell command succeeds (i.e. exits with code 0).
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`command`</b>: List or string representing the command to run.
|
||||
* <b>`regexes`</b>: List of regular expression byte strings that match success.
|
||||
* <b>`env`</b>: Dictionary of environment variable settings.
|
||||
* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after
|
||||
forking.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsExactSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsExactSubsequence}
|
||||
|
||||
Assert that "container" contains "subsequence" as an exact subsequence.
|
||||
|
||||
Asserts that "container" contains all the elements of "subsequence", in
|
||||
order, and without other elements interspersed. For example, [1, 2, 3] is an
|
||||
exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
||||
* <b>`subsequence`</b>: the list we hope will be an exact subsequence of container.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsInOrder(strings, target, msg=None)` {#TestCase.assertContainsInOrder}
|
||||
|
||||
Asserts that the strings provided are found in the target in order.
|
||||
|
||||
This may be useful for checking HTML output.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`strings`</b>: A list of strings, such as [ 'fox', 'dog' ]
|
||||
* <b>`target`</b>: A target string in which to look for the strings, such as
|
||||
'The quick brown fox jumped over the lazy dog'.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsSubsequence}
|
||||
|
||||
Assert that "container" contains "subsequence" as a subsequence.
|
||||
|
||||
Asserts that "container" contains all the elements of "subsequence", in
|
||||
order, but possibly with other elements interspersed. For example, [1, 2, 3]
|
||||
is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: the list we're testing for subsequence inclusion.
|
||||
* <b>`subsequence`</b>: the list we hope will be a subsequence of container.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertContainsSubset(expected_subset, actual_set, msg=None)` {#TestCase.assertContainsSubset}
|
||||
|
||||
Checks whether actual iterable is a superset of expected iterable.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertCountEqual(*args, **kwargs)` {#TestCase.assertCountEqual}
|
||||
|
||||
An unordered sequence specific comparison.
|
||||
|
||||
Equivalent to assertItemsEqual(). This method is a compatibility layer
|
||||
for Python 3k, since 2to3 does not convert assertItemsEqual() calls into
|
||||
assertCountEqual() calls.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertDeviceEqual(device1, device2)` {#TestCase.assertDeviceEqual}
|
||||
@ -354,48 +235,9 @@ Checks whether actual is a superset of expected.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertDictEqual(a, b, msg=None)` {#TestCase.assertDictEqual}
|
||||
|
||||
Raises AssertionError if a and b are not equal dictionaries.
|
||||
|
||||
##### Args:
|
||||
#### `tf.test.TestCase.assertDictEqual(d1, d2, msg=None)` {#TestCase.assertDictEqual}
|
||||
|
||||
|
||||
* <b>`a`</b>: A dict, the expected value.
|
||||
* <b>`b`</b>: A dict, the actual value.
|
||||
* <b>`msg`</b>: An optional str, the associated message.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`AssertionError`</b>: if the dictionaries are not equal.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertEmpty(container, msg=None)` {#TestCase.assertEmpty}
|
||||
|
||||
Assert that an object has zero length.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertEndsWith(actual, expected_end, msg=None)` {#TestCase.assertEndsWith}
|
||||
|
||||
Assert that actual.endswith(expected_end) is True.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`expected_end`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
@ -480,11 +322,10 @@ Included for symmetry with assertIsNone.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertItemsEqual(*args, **kwargs)` {#TestCase.assertItemsEqual}
|
||||
#### `tf.test.TestCase.assertItemsEqual(expected_seq, actual_seq, msg=None)` {#TestCase.assertItemsEqual}
|
||||
|
||||
An unordered sequence specific comparison.
|
||||
|
||||
It asserts that actual_seq and expected_seq have the same element counts.
|
||||
An unordered sequence specific comparison. It asserts that
|
||||
actual_seq and expected_seq have the same element counts.
|
||||
Equivalent to::
|
||||
|
||||
self.assertEqual(Counter(iter(actual_seq)),
|
||||
@ -497,30 +338,6 @@ Asserts that each element has the same count in both sequences.
|
||||
- [0, 1, 1] and [1, 0, 1] compare equal.
|
||||
- [0, 0, 1] and [0, 1] compare unequal.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertJsonEqual(first, second, msg=None)` {#TestCase.assertJsonEqual}
|
||||
|
||||
Asserts that the JSON objects defined in two strings are equal.
|
||||
|
||||
A summary of the differences will be included in the failure message
|
||||
using assertSameStructure.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`first`</b>: A string contining JSON to decode and compare to second.
|
||||
* <b>`second`</b>: A string contining JSON to decode and compare to first.
|
||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
@ -590,13 +407,6 @@ if not.
|
||||
* <b>`msg`</b>: An optional string message to append to the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNoCommonElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertNoCommonElements}
|
||||
|
||||
Checks whether actual iterable and expected iterable are disjoint.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)` {#TestCase.assertNotAlmostEqual}
|
||||
@ -627,33 +437,6 @@ as significant digits (measured from the most signficant digit).
|
||||
Objects that are equal automatically fail.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEmpty(container, msg=None)` {#TestCase.assertNotEmpty}
|
||||
|
||||
Assert that an object has non-zero length.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`container`</b>: Anything that implements the collections.Sized interface.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEndsWith(actual, unexpected_end, msg=None)` {#TestCase.assertNotEndsWith}
|
||||
|
||||
Assert that actual.endswith(unexpected_end) is False.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`unexpected_end`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotEqual(first, second, msg=None)` {#TestCase.assertNotEqual}
|
||||
@ -691,20 +474,6 @@ Included for symmetry with assertIsInstance.
|
||||
Fail the test if the text matches the regular expression.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertNotStartsWith(actual, unexpected_start, msg=None)` {#TestCase.assertNotStartsWith}
|
||||
|
||||
Assert that actual.startswith(unexpected_start) is False.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`actual`</b>: str
|
||||
* <b>`unexpected_start`</b>: str
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertProtoEquals(expected_message_maybe_ascii, message)` {#TestCase.assertProtoEquals}
|
||||
@ -779,38 +548,6 @@ Asserts that the message in a raised exception matches a regexp.
|
||||
* <b>`kwargs`</b>: Extra kwargs.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithLiteralMatch(expected_exception, expected_exception_message, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithLiteralMatch}
|
||||
|
||||
Asserts that the message in a raised exception equals the given string.
|
||||
|
||||
Unlike assertRaisesRegexp, this method takes a literal string, not
|
||||
a regular expression.
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ExType, 'message'):
|
||||
DoSomething()
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
||||
* <b>`expected_exception_message`</b>: String message expected in the raised
|
||||
exception. For a raise exception e, expected_exception_message must
|
||||
equal str(e).
|
||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
||||
* <b>`args`</b>: Extra args.
|
||||
* <b>`kwargs`</b>: Extra kwargs.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A context manager if callable_obj is None. Otherwise, None.
|
||||
|
||||
##### Raises:
|
||||
|
||||
self.failureException if callable_obj does not raise a macthing exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithPredicateMatch(exception_type, expected_err_re_or_predicate)` {#TestCase.assertRaisesWithPredicateMatch}
|
||||
@ -835,71 +572,6 @@ predicate search.
|
||||
exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRaisesWithRegexpMatch(expected_exception, expected_regexp, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithRegexpMatch}
|
||||
|
||||
Asserts that the message in a raised exception matches the given regexp.
|
||||
|
||||
This is just a wrapper around assertRaisesRegexp. Please use
|
||||
assertRaisesRegexp instead of assertRaisesWithRegexpMatch.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_exception`</b>: Exception class expected to be raised.
|
||||
* <b>`expected_regexp`</b>: Regexp (re pattern object or string) expected to be
|
||||
found in error message.
|
||||
* <b>`callable_obj`</b>: Function to be called, or None to return a context.
|
||||
* <b>`args`</b>: Extra args.
|
||||
* <b>`kwargs`</b>: Extra keyword args.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A context manager if callable_obj is None. Otherwise, None.
|
||||
|
||||
##### Raises:
|
||||
|
||||
self.failureException if callable_obj does not raise a macthing exception.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRegexMatch(actual_str, regexes, message=None)` {#TestCase.assertRegexMatch}
|
||||
|
||||
Asserts that at least one regex in regexes matches str.
|
||||
|
||||
If possible you should use assertRegexpMatches, which is a simpler
|
||||
version of this method. assertRegexpMatches takes a single regular
|
||||
expression (a string or re compiled object) instead of a list.
|
||||
|
||||
Notes:
|
||||
1. This function uses substring matching, i.e. the matching
|
||||
succeeds if *any* substring of the error message matches *any*
|
||||
regex in the list. This is more convenient for the user than
|
||||
full-string matching.
|
||||
|
||||
2. If regexes is the empty list, the matching will always fail.
|
||||
|
||||
3. Use regexes=[''] for a regex that will always pass.
|
||||
|
||||
4. '.' matches any single character *except* the newline. To
|
||||
match any character, use '(.|
|
||||
)'.
|
||||
|
||||
5. '^' matches the beginning of each line, not just the beginning
|
||||
of the string. Similarly, '$' matches the end of each line.
|
||||
|
||||
6. An exception will be thrown if regexes contains an invalid
|
||||
regex.
|
||||
|
||||
Args:
|
||||
actual_str: The string we try to match with the items in regexes.
|
||||
regexes: The regular expressions we want to match against str.
|
||||
See "Notes" above for detailed notes on how this is interpreted.
|
||||
message: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertRegexpMatches(text, expected_regexp, msg=None)` {#TestCase.assertRegexpMatches}
|
||||
@ -907,79 +579,6 @@ Asserts that at least one regex in regexes matches str.
|
||||
Fail the test unless the text matches the regular expression.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSameElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertSameElements}
|
||||
|
||||
Assert that two sequences have the same elements (in any order).
|
||||
|
||||
This method, unlike assertItemsEqual, doesn't care about any
|
||||
duplicates in the expected and actual sequences.
|
||||
|
||||
>> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
|
||||
# Doesn't raise an AssertionError
|
||||
|
||||
If possible, you should use assertItemsEqual instead of
|
||||
assertSameElements.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSameStructure(a, b, aname='a', bname='b', msg=None)` {#TestCase.assertSameStructure}
|
||||
|
||||
Asserts that two values contain the same structural content.
|
||||
|
||||
The two arguments should be data trees consisting of trees of dicts and
|
||||
lists. They will be deeply compared by walking into the contents of dicts
|
||||
and lists; other items will be compared using the == operator.
|
||||
If the two structures differ in content, the failure message will indicate
|
||||
the location within the structures where the first difference is found.
|
||||
This may be helpful when comparing large structures.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`a`</b>: The first structure to compare.
|
||||
* <b>`b`</b>: The second structure to compare.
|
||||
* <b>`aname`</b>: Variable name to use for the first structure in assertion messages.
|
||||
* <b>`bname`</b>: Variable name to use for the second structure.
|
||||
* <b>`msg`</b>: Additional text to include in the failure message.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceAlmostEqual(expected_seq, actual_seq, places=None, msg=None, delta=None)` {#TestCase.assertSequenceAlmostEqual}
|
||||
|
||||
An approximate equality assertion for ordered sequences.
|
||||
|
||||
Fail if the two sequences are unequal as determined by their value
|
||||
differences rounded to the given number of decimal places (default 7) and
|
||||
comparing to zero, or by comparing that the difference between each value
|
||||
in the two sequences is more than the given delta.
|
||||
|
||||
Note that decimal places (from zero) are usually not the same as significant
|
||||
digits (measured from the most signficant digit).
|
||||
|
||||
If the two sequences compare equal then they will automatically compare
|
||||
almost equal.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`expected_seq`</b>: A sequence containing elements we are expecting.
|
||||
* <b>`actual_seq`</b>: The sequence that we are testing.
|
||||
* <b>`places`</b>: The number of decimal places to compare.
|
||||
* <b>`msg`</b>: The message to be printed if the test fails.
|
||||
* <b>`delta`</b>: The OK difference between compared values.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)` {#TestCase.assertSequenceEqual}
|
||||
@ -1000,26 +599,6 @@ which can be indexed, has a length, and has an equality operator.
|
||||
differences.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSequenceStartsWith(prefix, whole, msg=None)` {#TestCase.assertSequenceStartsWith}
|
||||
|
||||
An equality assertion for the beginning of ordered sequences.
|
||||
|
||||
If prefix is an empty sequence, it will raise an error unless whole is also
|
||||
an empty sequence.
|
||||
|
||||
If prefix is not a sequence, it will raise an error if the first element of
|
||||
whole does not match.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`prefix`</b>: A sequence expected at the beginning of the whole parameter.
|
||||
* <b>`whole`</b>: The sequence in which to look for prefix.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertSetEqual(set1, set2, msg=None)` {#TestCase.assertSetEqual}
|
||||
@ -1071,51 +650,6 @@ Assert that actual.startswith(expected_start) is True.
|
||||
* <b>`msg`</b>: Optional message to report on failure.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertTotallyOrdered(*groups, **kwargs)` {#TestCase.assertTotallyOrdered}
|
||||
|
||||
Asserts that total ordering has been implemented correctly.
|
||||
|
||||
For example, say you have a class A that compares only on its attribute x.
|
||||
Comparators other than __lt__ are omitted for brevity.
|
||||
|
||||
class A(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.x)
|
||||
|
||||
def __lt__(self, other):
|
||||
try:
|
||||
return self.x < other.x
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
|
||||
assertTotallyOrdered will check that instances can be ordered correctly.
|
||||
For example,
|
||||
|
||||
self.assertTotallyOrdered(
|
||||
[None], # None should come before everything else.
|
||||
[1], # Integers sort earlier.
|
||||
[A(1, 'a')],
|
||||
[A(2, 'b')], # 2 is after 1.
|
||||
[A(3, 'c'), A(3, 'd')], # The second argument is irrelevant.
|
||||
[A(4, 'z')],
|
||||
['foo']) # Strings sort last.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`*groups`</b>: A list of groups of elements. Each group of elements is a list
|
||||
of objects that are equal. The elements in each group must be less than
|
||||
the elements in the group after it. For example, these groups are
|
||||
totally ordered: [None], [1], [2, 2], [3].
|
||||
* <b>`**kwargs`</b>: optional msg keyword argument can be passed.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertTrue(expr, msg=None)` {#TestCase.assertTrue}
|
||||
@ -1138,13 +672,6 @@ A tuple-specific equality assertion.
|
||||
differences.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assertUrlEqual(a, b, msg=None)` {#TestCase.assertUrlEqual}
|
||||
|
||||
Asserts that urls are equal, ignoring ordering of query params.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.assert_(expr, msg=None)` {#TestCase.assert_}
|
||||
@ -1206,9 +733,9 @@ tearDown.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.fail(msg=None, prefix=None)` {#TestCase.fail}
|
||||
#### `tf.test.TestCase.fail(msg=None)` {#TestCase.fail}
|
||||
|
||||
Fail immediately with the given message, optionally prefixed.
|
||||
Fail immediately, with the given message.
|
||||
|
||||
|
||||
- - -
|
||||
@ -1260,13 +787,6 @@ Fail immediately with the given message, optionally prefixed.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.getRecordedProperties()` {#TestCase.getRecordedProperties}
|
||||
|
||||
Return any properties that the user has recorded.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.get_temp_dir()` {#TestCase.get_temp_dir}
|
||||
@ -1281,20 +801,6 @@ Return any properties that the user has recorded.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.recordProperty(property_name, property_value)` {#TestCase.recordProperty}
|
||||
|
||||
Record an arbitrary property for later use.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`property_name`</b>: str, name of property to record; must be a valid XML
|
||||
attribute name
|
||||
* <b>`property_value`</b>: value of property; must be valid XML attribute value
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.test.TestCase.run(result=None)` {#TestCase.run}
|
||||
@ -1320,18 +826,11 @@ Hook method for setting up class fixture before running tests in the class.
|
||||
|
||||
#### `tf.test.TestCase.shortDescription()` {#TestCase.shortDescription}
|
||||
|
||||
Format both the test method name and the first line of its docstring.
|
||||
Returns a one-line description of the test, or None if no
|
||||
description has been provided.
|
||||
|
||||
If no docstring is given, only returns the method name.
|
||||
|
||||
This method overrides unittest.TestCase.shortDescription(), which
|
||||
only returns the first line of the docstring, obscuring the name
|
||||
of the test upon failure.
|
||||
|
||||
##### Returns:
|
||||
|
||||
|
||||
* <b>`desc`</b>: A short description of a test method.
|
||||
The default implementation of this method returns the first line of
|
||||
the specified test method's docstring.
|
||||
|
||||
|
||||
- - -
|
||||
|
@ -82,37 +82,37 @@ of the binary on Linux or Mac, you can follow these instructions:
|
||||
|
||||
```bash
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc0-cp27-none-linux_x86_64.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc0-cp27-none-linux_x86_64.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 2.7:
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc0-py2-none-any.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc1-py2-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 2.7:
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc0-py2-none-any.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc1-py2-none-any.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc0-cp34-cp34m-linux_x86_64.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc0-cp34-cp34m-linux_x86_64.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc0-cp35-cp35m-linux_x86_64.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc0-cp35-cp35m-linux_x86_64.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc0-py3-none-any.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc1-py3-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc0-py3-none-any.whl
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc1-py3-none-any.whl
|
||||
```
|
||||
|
||||
Install TensorFlow:
|
||||
@ -136,24 +136,32 @@ You can now [test your installation](#test-the-tensorflow-installation).
|
||||
|
||||
### Pip installation on Windows
|
||||
|
||||
TensorFlow supports only 64-bit Python 3.5 on Windows. We have tested
|
||||
the pip packages with the following distributions of Python:
|
||||
TensorFlow supports only 64-bit Python 3.5 on Windows. We have tested the pip packages
|
||||
with the following distributions of Python:
|
||||
|
||||
* [Python 3.5 from python.org](https://www.python.org/downloads/release/python-352/)
|
||||
* [Python 3.5 from Anaconda](https://www.continuum.io/downloads#windows)
|
||||
|
||||
* [Python 3.5 from python.org](https://www.python.org/downloads/release/python-352/).
|
||||
|
||||
NOTE: TensorFlow requires `MSVCP140.DLL`, which may not be installed on your system.
|
||||
If, when you `import tensorflow as tf`, you see an error about `No module named
|
||||
"_pywrap_tensorflow"` and/or `DLL load failed`, check whether `MSVCP140.DLL` is in
|
||||
your `%PATH%` and, if not, you should install the [Visual C++ 2015
|
||||
redistributable](https://www.microsoft.com/en-us/download/details.aspx?id=53587)
|
||||
(x64 version).
|
||||
|
||||
Both distributions include pip. To install the CPU-only version of
|
||||
TensorFlow, enter the following command at a command prompt:
|
||||
|
||||
```bat
|
||||
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-0.12.0rc0-cp35-cp35m-win_amd64.whl
|
||||
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-0.12.0rc1-cp35-cp35m-win_amd64.whl
|
||||
```
|
||||
|
||||
To install the GPU version of TensorFlow, enter the following command
|
||||
at a command prompt:
|
||||
|
||||
```bat
|
||||
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-0.12.0rc0-cp35-cp35m-win_amd64.whl
|
||||
C:\> pip install --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-0.12.0rc1-cp35-cp35m-win_amd64.whl
|
||||
```
|
||||
|
||||
You can now [test your installation](#test-the-tensorflow-installation).
|
||||
@ -208,37 +216,37 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
|
||||
|
||||
```bash
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc0-cp27-none-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc0-cp27-none-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 2.7:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc0-py2-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc1-py2-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 2.7:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc0-py2-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc1-py2-none-any.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc0-cp34-cp34m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc0-cp34-cp34m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc0-cp35-cp35m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc0-cp35-cp35m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc0-py3-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc1-py3-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc0-py3-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc1-py3-none-any.whl
|
||||
```
|
||||
|
||||
Finally install TensorFlow:
|
||||
@ -360,37 +368,37 @@ select the correct binary to install:
|
||||
|
||||
```bash
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc0-cp27-none-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc0-cp27-none-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc1-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 2.7:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc0-py2-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc1-py2-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 2.7:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc0-py2-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc1-py2-none-any.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc0-cp34-cp34m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc0-cp34-cp34m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc0-cp35-cp35m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc0-cp35-cp35m-linux_x86_64.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only, Python 3.4 or 3.5:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc0-py3-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.12.0rc1-py3-none-any.whl
|
||||
|
||||
# Mac OS X, GPU enabled, Python 3.4 or 3.5:
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc0-py3-none-any.whl
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-0.12.0rc1-py3-none-any.whl
|
||||
```
|
||||
|
||||
Finally install TensorFlow:
|
||||
@ -458,7 +466,7 @@ code.
|
||||
code.
|
||||
|
||||
We also have tags with `latest` replaced by a released version (e.g.,
|
||||
`0.12.0-rc0-gpu`).
|
||||
`0.12.0-rc1-gpu`).
|
||||
|
||||
With Docker the installation is as follows:
|
||||
|
||||
@ -860,7 +868,7 @@ $ bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_pack
|
||||
$ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
|
||||
|
||||
# The name of the .whl file will depend on your platform.
|
||||
$ sudo pip install /tmp/tensorflow_pkg/tensorflow-0.12.0rc0-py2-none-any.whl
|
||||
$ sudo pip install /tmp/tensorflow_pkg/tensorflow-0.12.0rc1-py2-none-any.whl
|
||||
```
|
||||
|
||||
## Optimizing CPU performance
|
||||
@ -1083,6 +1091,27 @@ Exception:
|
||||
|
||||
Solution: Add an `--ignore-installed` flag to the pip command.
|
||||
|
||||
#### Cannot remove entries from nonexistent file: easy-install.pth
|
||||
|
||||
If during a `pip` installation using an Anaconda Python distribution you encounter the error:
|
||||
|
||||
```
|
||||
Cannot remove entries from nonexistent file <path-to-anaconda-instalation>/anaconda[version]/lib/site-packages/easy-install.pth
|
||||
```
|
||||
|
||||
1. Upgrade setuptools:
|
||||
`pip install --upgrade -I setuptools`
|
||||
|
||||
2. Install TensorFlow again adding `--ignore-installed` flag:
|
||||
`pip install --ignore-installed --upgrade <tensorflow_url>`
|
||||
|
||||
|
||||
|
||||
Step #1 might already solve the problem, however if it still persists, execute step #2.
|
||||
|
||||
This issue occurs with new Anaconda installations when `pip` tries to remove `easy-install.pth`.
|
||||
This file is not included in Anaconda packages, which causes the `pip` installation to fail.
|
||||
|
||||
|
||||
### Linux issues
|
||||
|
||||
|
@ -34,11 +34,7 @@ definitions. If you see a standalone TensorFlow file representing a model, it's
|
||||
likely to contain a serialized version of one of these `GraphDef` objects
|
||||
saved out by the protobuf code.
|
||||
|
||||
This generated code is used to save and load the GraphDef files from disk. A
|
||||
good example to look at as we dig into this is
|
||||
[graph_metrics.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/graph_metrics.py). This Python script takes a saved graph
|
||||
definition, and analyzes the model to estimate performance and resource
|
||||
statistics. The code that actually loads the model looks like this:
|
||||
This generated code is used to save and load the GraphDef files from disk. The code that actually loads the model looks like this:
|
||||
|
||||
```python
|
||||
graph_def = graph_pb2.GraphDef()
|
||||
@ -67,7 +63,7 @@ There are actually two different formats that a ProtoBuf can be saved in.
|
||||
TextFormat is a human-readable form, which makes it nice for debugging and
|
||||
editing, but can get large when there's numerical data like weights stored in
|
||||
it. You can see a small example of that in
|
||||
[graph_run_run2.pbtxt](https://github.com/tensorflow/tensorflow/blob/ae3c8479f88da1cd5636b974f653f27755cb0034/tensorflow/tensorboard/components/tf_tensorboard/test/data/graph_run_run2.pbtxt).
|
||||
[graph_run_run2.pbtxt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/components/tf_tensorboard/test/data/graph_run_run2.pbtxt).
|
||||
|
||||
Binary format files are a lot smaller than their text equivalents, even though
|
||||
they're not as readable for us. In this script, we ask the user to supply a
|
||||
|
@ -69,3 +69,19 @@ Yuan Yu, and Xiaoqiang Zheng.
|
||||
TensorFlow: Large-scale machine learning on heterogeneous systems,
|
||||
2015. Software available from tensorflow.org.
|
||||
```
|
||||
|
||||
If you use [TF.Learn](https://www.tensorflow.org/tutorials/tflearn/) in your research and would like to cite it, we suggest you cite the [whitepaper](https://arxiv.org/abs/1612.04251):
|
||||
|
||||
```
|
||||
@article{tang2016tflearn,
|
||||
title={TF.Learn: TensorFlow's High-level Module for Distributed Machine Learning},
|
||||
author={Tang, Yuan},
|
||||
journal={arXiv preprint arXiv:1612.04251},
|
||||
year={2016}
|
||||
}
|
||||
```
|
||||
|
||||
In textual form:
|
||||
```
|
||||
Tang, Yuan. "TF.Learn: TensorFlow's High-level Module for Distributed Machine Learning." arXiv preprint arXiv:1612.04251 (2016).
|
||||
```
|
@ -1,12 +1,17 @@
|
||||
# Additional Resources
|
||||
|
||||
## TensorFlow WhitePaper
|
||||
## TensorFlow WhitePapers
|
||||
|
||||
Additional details about the TensorFlow programming model and the underlying
|
||||
implementation can be found in our white paper:
|
||||
|
||||
* [TensorFlow: Large-scale machine learning on heterogeneous systems](http://download.tensorflow.org/paper/whitepaper2015.pdf)
|
||||
|
||||
|
||||
A white paper is also available for [TF.Learn](https://www.tensorflow.org/tutorials/tflearn/):
|
||||
|
||||
* [TF.Learn: TensorFlow's High-level Module for Distributed Machine Learning](https://arxiv.org/abs/1612.04251)
|
||||
|
||||
### Citation
|
||||
|
||||
If you use TensorFlow in your research and would like to cite the TensorFlow
|
||||
|
@ -152,6 +152,8 @@ def maybe_download():
|
||||
print("Training data is downloaded to %s" % train_file_name)
|
||||
|
||||
if FLAGS.test_data:
|
||||
test_file_name = FLAGS.test_data
|
||||
else:
|
||||
test_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
urllib.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name)
|
||||
test_file_name = test_file.name
|
||||
@ -379,7 +381,7 @@ tf.contrib.layers provides the following convenience functions for constructing
|
||||
fully connected layers:
|
||||
|
||||
* `relu(inputs, num_outputs)`. Create a layer of `num_outputs` nodes fully
|
||||
connected to the previous layer `inputs` with a [ReLu activation
|
||||
connected to the previous layer `inputs` with a [ReLU activation
|
||||
function](https://en.wikipedia.org/wiki/Rectifier_\(neural_networks\))
|
||||
([tf.nn.relu](../../api_docs/python/nn.md#relu)):
|
||||
|
||||
@ -388,7 +390,7 @@ fully connected layers:
|
||||
```
|
||||
|
||||
* `relu6(inputs, num_outputs)`. Create a layer of `num_outputs` nodes fully
|
||||
connected to the previous layer `hidden_layer` with a ReLu 6 activation
|
||||
connected to the previous layer `hidden_layer` with a ReLU 6 activation
|
||||
function ([tf.nn.relu6](../../api_docs/python/nn.md#relu6)):
|
||||
|
||||
```python
|
||||
@ -448,7 +450,7 @@ def model_fn(features, targets, mode, params):
|
||||
Here, because you'll be passing the abalone `Datasets` directly to `fit()`,
|
||||
`evaluate()`, and `predict()` via `x` and `y` arguments, the input layer is the
|
||||
`features` `Tensor` passed to the `model_fn`. The network contains two hidden
|
||||
layers, each with 10 nodes and a ReLu activation function. The output layer
|
||||
layers, each with 10 nodes and a ReLU activation function. The output layer
|
||||
contains no activation function, and is
|
||||
[reshaped](../../api_docs/python/array_ops.md#reshape) to a one-dimensional
|
||||
tensor to capture the model's predictions, which are stored in
|
||||
|
@ -98,7 +98,7 @@ The `inference()` function builds the graph as far as needed to
|
||||
return the tensor that would contain the output predictions.
|
||||
|
||||
It takes the images placeholder as input and builds on top
|
||||
of it a pair of fully connected layers with [ReLu](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) activation followed by a ten
|
||||
of it a pair of fully connected layers with [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) activation followed by a ten
|
||||
node linear layer specifying the output logits.
|
||||
|
||||
Each layer is created beneath a unique [`tf.name_scope`](../../../api_docs/python/framework.md#name_scope)
|
||||
|
@ -13,6 +13,7 @@ exports_files(["LICENSE"])
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "py_tests")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
|
@ -37,7 +37,6 @@ static std::vector<string> ListDevices(TF_Status* out_status) {
|
||||
|
||||
std::vector<std::unique_ptr<Device>> device_holder(devices.begin(), devices.end());
|
||||
|
||||
|
||||
for (const Device* device : devices) {
|
||||
const DeviceAttributes& attr = device->attributes();
|
||||
string attr_serialized;
|
||||
|
@ -14,6 +14,7 @@ licenses(["notice"]) # Apache 2.0
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_library(
|
||||
name = "debug_py",
|
||||
|
@ -59,6 +59,7 @@ class DType(object):
|
||||
@@name
|
||||
@@base_dtype
|
||||
@@real_dtype
|
||||
@@is_bool
|
||||
@@is_floating
|
||||
@@is_complex
|
||||
@@is_integer
|
||||
@ -141,6 +142,11 @@ class DType(object):
|
||||
"""Returns a `types_pb2.DataType` enum value based on this `DType`."""
|
||||
return self._type_enum
|
||||
|
||||
@property
|
||||
def is_bool(self):
|
||||
"""Returns whether this is a boolean data type"""
|
||||
return self.base_dtype == bool
|
||||
|
||||
@property
|
||||
def is_integer(self):
|
||||
"""Returns whether this is a (non-quantized) integer type."""
|
||||
|
@ -54,7 +54,7 @@ class ArgMaxTest(tf.test.TestCase):
|
||||
x = np.asarray(100*np.random.randn(3, 2, 4, 5, 6), dtype=dtype)
|
||||
|
||||
# Check that argmin and argmax match numpy along all dimensions
|
||||
for dim in range(5):
|
||||
for dim in range(-5, 5):
|
||||
self._testBothArg(tf.argmax, x, dim, x.argmax(dim))
|
||||
self._testBothArg(tf.argmin, x, dim, x.argmin(dim))
|
||||
|
||||
|
@ -100,8 +100,8 @@ class AtrousConvolutionTest(tf.test.TestCase):
|
||||
dilation_rate=[rate])
|
||||
|
||||
def testAtrousConvolutionNC(self):
|
||||
if tf.test.is_gpu_available():
|
||||
# "NCW" and "NCHW" formats are not currently supported on CPU.
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
# "NCW" and "NCHW" formats are currently supported only on CUDA.
|
||||
with self.test_session(use_gpu=True):
|
||||
for padding in ["SAME", "VALID"]:
|
||||
self._test_atrous_convolution(
|
||||
|
@ -29,8 +29,8 @@ def GetTestConfigs():
|
||||
all the valid test configs as tuples of data_format and use_gpu.
|
||||
"""
|
||||
test_configs = [("NHWC", False), ("NHWC", True)]
|
||||
if tf.test.is_gpu_available():
|
||||
# "NCHW" format is not currently supported on CPU.
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
# "NCHW" format is currently only supported on CUDA.
|
||||
test_configs += [("NCHW", True)]
|
||||
return test_configs
|
||||
|
||||
@ -89,7 +89,7 @@ class BiasAddTest(tf.test.TestCase):
|
||||
self._testBias(np_inputs, np_bias, use_gpu=False)
|
||||
if np_inputs.dtype in [np.float16, np.float32, np.float64]:
|
||||
self._testBias(np_inputs, np_bias, use_gpu=True)
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._testBiasNCHW(np_inputs, np_bias, use_gpu=True)
|
||||
|
||||
def testInputDims(self):
|
||||
|
@ -159,8 +159,8 @@ class Conv2DTransposeTest(tf.test.TestCase):
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
def testConv2DTransposeSingleStrideNCHW(self):
|
||||
# `NCHW` data fomat is only supported for `GPU` device.
|
||||
if tf.test.is_gpu_available():
|
||||
# `NCHW` data fomat is only supported for CUDA device.
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
with self.test_session(use_gpu=True):
|
||||
strides = [1, 1, 1, 1]
|
||||
|
||||
@ -192,8 +192,8 @@ class Conv2DTransposeTest(tf.test.TestCase):
|
||||
self.assertAllClose(target, value[n, k, h, w])
|
||||
|
||||
def testConv2DTransposeSameNCHW(self):
|
||||
# `NCHW` data fomat is only supported for `GPU` device.
|
||||
if tf.test.is_gpu_available():
|
||||
# `NCHW` data fomat is only supported for CUDA device.
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
with self.test_session(use_gpu=True):
|
||||
strides = [1, 1, 2, 2]
|
||||
|
||||
@ -226,8 +226,8 @@ class Conv2DTransposeTest(tf.test.TestCase):
|
||||
self.assertAllClose(target, value[n, k, h, w])
|
||||
|
||||
def testConv2DTransposeValidNCHW(self):
|
||||
# `NCHW` data fomat is only supported for `GPU` device.
|
||||
if tf.test.is_gpu_available():
|
||||
# `NCHW` data fomat is only supported for CUDA device.
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
with self.test_session(use_gpu=True):
|
||||
strides = [1, 1, 2, 2]
|
||||
|
||||
|
@ -232,8 +232,8 @@ class PoolingTest(tf.test.TestCase):
|
||||
strides=strides)
|
||||
|
||||
def testPoolNC(self):
|
||||
if tf.test.is_gpu_available():
|
||||
# "NC*" format is not currently supported on CPU.
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
# "NC*" format is currently only supported on CUDA.
|
||||
with self.test_session(use_gpu=True):
|
||||
for padding in ["SAME", "VALID"]:
|
||||
self._test(input_shape=[2, 2, 9],
|
||||
|
@ -1,770 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for functional style sequence-to-sequence models."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class Seq2SeqTest(tf.test.TestCase):
|
||||
|
||||
def testRNNDecoder(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
inp = [tf.constant(0.5, shape=[2, 2])] * 2
|
||||
_, enc_state = tf.nn.rnn(
|
||||
tf.nn.rnn_cell.GRUCell(2), inp, dtype=tf.float32)
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
|
||||
cell = tf.nn.rnn_cell.OutputProjectionWrapper(
|
||||
tf.nn.rnn_cell.GRUCell(2), 4)
|
||||
dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_state, cell)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
def testBasicRNNSeq2Seq(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
inp = [tf.constant(0.5, shape=[2, 2])] * 2
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
|
||||
cell = tf.nn.rnn_cell.OutputProjectionWrapper(
|
||||
tf.nn.rnn_cell.GRUCell(2), 4)
|
||||
dec, mem = tf.nn.seq2seq.basic_rnn_seq2seq(inp, dec_inp, cell)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
def testTiedRNNSeq2Seq(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
inp = [tf.constant(0.5, shape=[2, 2])] * 2
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
|
||||
cell = tf.nn.rnn_cell.OutputProjectionWrapper(
|
||||
tf.nn.rnn_cell.GRUCell(2), 4)
|
||||
dec, mem = tf.nn.seq2seq.tied_rnn_seq2seq(inp, dec_inp, cell)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(1, len(res))
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
def testEmbeddingRNNDecoder(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
inp = [tf.constant(0.5, shape=[2, 2])] * 2
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
_, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)]
|
||||
dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(
|
||||
dec_inp, enc_state, cell, num_symbols=4, embedding_size=2)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(1, len(res))
|
||||
self.assertEqual((2, 2), res[0].c.shape)
|
||||
self.assertEqual((2, 2), res[0].h.shape)
|
||||
|
||||
def testEmbeddingRNNSeq2Seq(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in range(2)]
|
||||
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)]
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
dec, mem = tf.nn.seq2seq.embedding_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 5), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 2), res[0].c.shape)
|
||||
self.assertEqual((2, 2), res[0].h.shape)
|
||||
|
||||
# Test with state_is_tuple=False.
|
||||
with tf.variable_scope("no_tuple"):
|
||||
cell1 = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||
dec, mem = tf.nn.seq2seq.embedding_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell1, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 5), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
# Test externally provided output projection.
|
||||
w = tf.get_variable("proj_w", [2, 5])
|
||||
b = tf.get_variable("proj_b", [5])
|
||||
with tf.variable_scope("proj_seq2seq"):
|
||||
dec, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2, output_projection=(w, b))
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
# Test that previous-feeding model ignores inputs after the first.
|
||||
dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in range(3)]
|
||||
with tf.variable_scope("other"):
|
||||
d3, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
|
||||
enc_inp, dec_inp2, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2,
|
||||
feed_previous=tf.constant(True))
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
d1, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2, feed_previous=True)
|
||||
d2, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
|
||||
enc_inp, dec_inp2, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2, feed_previous=True)
|
||||
res1 = sess.run(d1)
|
||||
res2 = sess.run(d2)
|
||||
res3 = sess.run(d3)
|
||||
self.assertAllClose(res1, res2)
|
||||
self.assertAllClose(res1, res3)
|
||||
|
||||
def testEmbeddingTiedRNNSeq2Seq(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in range(2)]
|
||||
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)]
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
dec, mem = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_symbols=5, embedding_size=2)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 5), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 2), res[0].c.shape)
|
||||
self.assertEqual((2, 2), res[0].h.shape)
|
||||
|
||||
# Test when num_decoder_symbols is provided, the size of decoder output
|
||||
# is num_decoder_symbols.
|
||||
with tf.variable_scope("decoder_symbols_seq2seq"):
|
||||
dec, mem = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_symbols=5, num_decoder_symbols=3,
|
||||
embedding_size=2)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 3), res[0].shape)
|
||||
|
||||
# Test externally provided output projection.
|
||||
w = tf.get_variable("proj_w", [2, 5])
|
||||
b = tf.get_variable("proj_b", [5])
|
||||
with tf.variable_scope("proj_seq2seq"):
|
||||
dec, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_symbols=5, embedding_size=2,
|
||||
output_projection=(w, b))
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
# Test that previous-feeding model ignores inputs after the first.
|
||||
dec_inp2 = [tf.constant(0, tf.int32, shape=[2])] * 3
|
||||
with tf.variable_scope("other"):
|
||||
d3, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
|
||||
enc_inp, dec_inp2, cell, num_symbols=5, embedding_size=2,
|
||||
feed_previous=tf.constant(True))
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
d1, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_symbols=5, embedding_size=2,
|
||||
feed_previous=True)
|
||||
d2, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
|
||||
enc_inp, dec_inp2, cell, num_symbols=5, embedding_size=2,
|
||||
feed_previous=True)
|
||||
res1 = sess.run(d1)
|
||||
res2 = sess.run(d2)
|
||||
res3 = sess.run(d3)
|
||||
self.assertAllClose(res1, res2)
|
||||
self.assertAllClose(res1, res3)
|
||||
|
||||
def testAttentionDecoder1(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
cell = tf.nn.rnn_cell.GRUCell(2)
|
||||
inp = [tf.constant(0.5, shape=[2, 2])] * 2
|
||||
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
|
||||
for e in enc_outputs])
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
|
||||
dec, mem = tf.nn.seq2seq.attention_decoder(
|
||||
dec_inp, enc_state,
|
||||
attn_states, cell, output_size=4)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
def testAttentionDecoder2(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
cell = tf.nn.rnn_cell.GRUCell(2)
|
||||
inp = [tf.constant(0.5, shape=[2, 2])] * 2
|
||||
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
|
||||
for e in enc_outputs])
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
|
||||
dec, mem = tf.nn.seq2seq.attention_decoder(
|
||||
dec_inp, enc_state,
|
||||
attn_states, cell, output_size=4,
|
||||
num_heads=2)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
def testDynamicAttentionDecoder1(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
cell = tf.nn.rnn_cell.GRUCell(2)
|
||||
inp = tf.constant(0.5, shape=[2, 2, 2])
|
||||
enc_outputs, enc_state = tf.nn.dynamic_rnn(cell, inp, dtype=tf.float32)
|
||||
attn_states = enc_outputs
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
|
||||
dec, mem = tf.nn.seq2seq.attention_decoder(
|
||||
dec_inp, enc_state,
|
||||
attn_states, cell, output_size=4)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
def testDynamicAttentionDecoder2(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
cell = tf.nn.rnn_cell.GRUCell(2)
|
||||
inp = tf.constant(0.5, shape=[2, 2, 2])
|
||||
enc_outputs, enc_state = tf.nn.dynamic_rnn(cell, inp, dtype=tf.float32)
|
||||
attn_states = enc_outputs
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
|
||||
dec, mem = tf.nn.seq2seq.attention_decoder(
|
||||
dec_inp, enc_state,
|
||||
attn_states, cell, output_size=4,
|
||||
num_heads=2)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
def testAttentionDecoderStateIsTuple(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
cell = tf.nn.rnn_cell.MultiRNNCell(cells=[cell] * 2,
|
||||
state_is_tuple=True)
|
||||
inp = [tf.constant(0.5, shape=[2, 2])] * 2
|
||||
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
|
||||
for e in enc_outputs])
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
|
||||
dec, mem = tf.nn.seq2seq.attention_decoder(
|
||||
dec_inp, enc_state,
|
||||
attn_states, cell, output_size=4)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(2, len(res[0]))
|
||||
self.assertEqual((2, 2), res[0][0].c.shape)
|
||||
self.assertEqual((2, 2), res[0][0].h.shape)
|
||||
self.assertEqual((2, 2), res[0][1].c.shape)
|
||||
self.assertEqual((2, 2), res[0][1].h.shape)
|
||||
|
||||
# pylint: disable=unused-variable,invalid-name
|
||||
def testDynamicAttentionDecoderStateIsTuple(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope(
|
||||
"root", initializer=tf.constant_initializer(0.5)):
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
cell = tf.nn.rnn_cell.MultiRNNCell(cells=[cell] * 2,
|
||||
state_is_tuple=True)
|
||||
inp = tf.constant(0.5, shape=[2, 2, 2])
|
||||
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
|
||||
for e in enc_outputs])
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
|
||||
dec, mem = tf.nn.seq2seq.attention_decoder(
|
||||
dec_inp, enc_state,
|
||||
attn_states, cell, output_size=4)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(2, len(res[0]))
|
||||
self.assertEqual((2, 2), res[0][0].c.shape)
|
||||
self.assertEqual((2, 2), res[0][0].h.shape)
|
||||
self.assertEqual((2, 2), res[0][1].c.shape)
|
||||
self.assertEqual((2, 2), res[0][1].h.shape)
|
||||
|
||||
def testEmbeddingAttentionDecoder(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
inp = [tf.constant(0.5, shape=[2, 2])] * 2
|
||||
cell = tf.nn.rnn_cell.GRUCell(2)
|
||||
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
|
||||
for e in enc_outputs])
|
||||
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)]
|
||||
dec, mem = tf.nn.seq2seq.embedding_attention_decoder(
|
||||
dec_inp, enc_state, attn_states, cell, num_symbols=4,
|
||||
embedding_size=2, output_size=3)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 3), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
def testEmbeddingAttentionSeq2Seq(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in range(2)]
|
||||
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)]
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
dec, mem = tf.nn.seq2seq.embedding_attention_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 5), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 2), res[0].c.shape)
|
||||
self.assertEqual((2, 2), res[0].h.shape)
|
||||
|
||||
# Test with state_is_tuple=False.
|
||||
with tf.variable_scope("no_tuple"):
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||
dec, mem = tf.nn.seq2seq.embedding_attention_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 5), res[0].shape)
|
||||
|
||||
res = sess.run([mem])
|
||||
self.assertEqual((2, 4), res[0].shape)
|
||||
|
||||
# Test externally provided output projection.
|
||||
w = tf.get_variable("proj_w", [2, 5])
|
||||
b = tf.get_variable("proj_b", [5])
|
||||
with tf.variable_scope("proj_seq2seq"):
|
||||
dec, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2, output_projection=(w, b))
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 2), res[0].shape)
|
||||
|
||||
# Test that previous-feeding model ignores inputs after the first.
|
||||
dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in range(3)]
|
||||
with tf.variable_scope("other"):
|
||||
d3, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
|
||||
enc_inp, dec_inp2, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2,
|
||||
feed_previous=tf.constant(True))
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
d1, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2, feed_previous=True)
|
||||
d2, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
|
||||
enc_inp, dec_inp2, cell, num_encoder_symbols=2,
|
||||
num_decoder_symbols=5, embedding_size=2, feed_previous=True)
|
||||
res1 = sess.run(d1)
|
||||
res2 = sess.run(d2)
|
||||
res3 = sess.run(d3)
|
||||
self.assertAllClose(res1, res2)
|
||||
self.assertAllClose(res1, res3)
|
||||
|
||||
def testOne2ManyRNNSeq2Seq(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in range(2)]
|
||||
dec_inp_dict = {}
|
||||
dec_inp_dict["0"] = [
|
||||
tf.constant(i, tf.int32, shape=[2]) for i in range(3)]
|
||||
dec_inp_dict["1"] = [
|
||||
tf.constant(i, tf.int32, shape=[2]) for i in range(4)]
|
||||
dec_symbols_dict = {"0": 5, "1": 6}
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
outputs_dict, state_dict = tf.nn.seq2seq.one2many_rnn_seq2seq(
|
||||
enc_inp, dec_inp_dict, cell, 2, dec_symbols_dict, embedding_size=2)
|
||||
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
res = sess.run(outputs_dict["0"])
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual((2, 5), res[0].shape)
|
||||
res = sess.run(outputs_dict["1"])
|
||||
self.assertEqual(4, len(res))
|
||||
self.assertEqual((2, 6), res[0].shape)
|
||||
res = sess.run([state_dict["0"]])
|
||||
self.assertEqual((2, 2), res[0].c.shape)
|
||||
self.assertEqual((2, 2), res[0].h.shape)
|
||||
res = sess.run([state_dict["1"]])
|
||||
self.assertEqual((2, 2), res[0].c.shape)
|
||||
self.assertEqual((2, 2), res[0].h.shape)
|
||||
|
||||
# Test that previous-feeding model ignores inputs after the first, i.e.
|
||||
# dec_inp_dict2 has different inputs from dec_inp_dict after the first
|
||||
# time-step.
|
||||
dec_inp_dict2 = {}
|
||||
dec_inp_dict2["0"] = [
|
||||
tf.constant(0, tf.int32, shape=[2]) for _ in range(3)]
|
||||
dec_inp_dict2["1"] = [
|
||||
tf.constant(0, tf.int32, shape=[2]) for _ in range(4)]
|
||||
with tf.variable_scope("other"):
|
||||
outputs_dict3, _ = tf.nn.seq2seq.one2many_rnn_seq2seq(
|
||||
enc_inp, dec_inp_dict2, cell, 2, dec_symbols_dict,
|
||||
embedding_size=2, feed_previous=tf.constant(True))
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
outputs_dict1, _ = tf.nn.seq2seq.one2many_rnn_seq2seq(
|
||||
enc_inp, dec_inp_dict, cell, 2, dec_symbols_dict,
|
||||
embedding_size=2, feed_previous=True)
|
||||
outputs_dict2, _ = tf.nn.seq2seq.one2many_rnn_seq2seq(
|
||||
enc_inp, dec_inp_dict2, cell, 2, dec_symbols_dict,
|
||||
embedding_size=2, feed_previous=True)
|
||||
res1 = sess.run(outputs_dict1["0"])
|
||||
res2 = sess.run(outputs_dict2["0"])
|
||||
res3 = sess.run(outputs_dict3["0"])
|
||||
self.assertAllClose(res1, res2)
|
||||
self.assertAllClose(res1, res3)
|
||||
|
||||
def testSequenceLoss(self):
|
||||
with self.test_session() as sess:
|
||||
logits = [tf.constant(i + 0.5, shape=[2, 5]) for i in range(3)]
|
||||
targets = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)]
|
||||
weights = [tf.constant(1.0, shape=[2]) for i in range(3)]
|
||||
|
||||
average_loss_per_example = tf.nn.seq2seq.sequence_loss(
|
||||
logits, targets, weights,
|
||||
average_across_timesteps=True,
|
||||
average_across_batch=True)
|
||||
res = sess.run(average_loss_per_example)
|
||||
self.assertAllClose(1.60944, res)
|
||||
|
||||
average_loss_per_sequence = tf.nn.seq2seq.sequence_loss(
|
||||
logits, targets, weights,
|
||||
average_across_timesteps=False,
|
||||
average_across_batch=True)
|
||||
res = sess.run(average_loss_per_sequence)
|
||||
self.assertAllClose(4.828314, res)
|
||||
|
||||
total_loss = tf.nn.seq2seq.sequence_loss(
|
||||
logits, targets, weights,
|
||||
average_across_timesteps=False,
|
||||
average_across_batch=False)
|
||||
res = sess.run(total_loss)
|
||||
self.assertAllClose(9.656628, res)
|
||||
|
||||
def testSequenceLossByExample(self):
|
||||
with self.test_session() as sess:
|
||||
output_classes = 5
|
||||
logits = [tf.constant(i + 0.5, shape=[2, output_classes])
|
||||
for i in range(3)]
|
||||
targets = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)]
|
||||
weights = [tf.constant(1.0, shape=[2]) for i in range(3)]
|
||||
|
||||
average_loss_per_example = tf.nn.seq2seq.sequence_loss_by_example(
|
||||
logits, targets, weights,
|
||||
average_across_timesteps=True)
|
||||
res = sess.run(average_loss_per_example)
|
||||
self.assertAllClose(np.asarray([1.609438, 1.609438]), res)
|
||||
|
||||
loss_per_sequence = tf.nn.seq2seq.sequence_loss_by_example(
|
||||
logits, targets, weights,
|
||||
average_across_timesteps=False)
|
||||
res = sess.run(loss_per_sequence)
|
||||
self.assertAllClose(np.asarray([4.828314, 4.828314]), res)
|
||||
|
||||
def testModelWithBucketsScopeAndLoss(self):
|
||||
"""Test that variable scope reuse is not reset after model_with_buckets."""
|
||||
classes = 10
|
||||
buckets = [(4, 4), (8, 8)]
|
||||
|
||||
with self.test_session():
|
||||
# Here comes a sample Seq2Seq model using GRU cells.
|
||||
def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss):
|
||||
"""Example sequence-to-sequence model that uses GRU cells."""
|
||||
def GRUSeq2Seq(enc_inp, dec_inp):
|
||||
cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.GRUCell(24)] * 2,
|
||||
state_is_tuple=True)
|
||||
return tf.nn.seq2seq.embedding_attention_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols=classes,
|
||||
num_decoder_symbols=classes, embedding_size=24)
|
||||
targets = [dec_inp[i+1] for i in range(len(dec_inp) - 1)] + [0]
|
||||
return tf.nn.seq2seq.model_with_buckets(
|
||||
enc_inp, dec_inp, targets, weights, buckets, GRUSeq2Seq,
|
||||
per_example_loss=per_example_loss)
|
||||
|
||||
# Now we construct the copy model.
|
||||
inp = [tf.placeholder(tf.int32, shape=[None]) for _ in range(8)]
|
||||
out = [tf.placeholder(tf.int32, shape=[None]) for _ in range(8)]
|
||||
weights = [tf.ones_like(inp[0], dtype=tf.float32) for _ in range(8)]
|
||||
with tf.variable_scope("root"):
|
||||
_, losses1 = SampleGRUSeq2Seq(inp, out, weights, per_example_loss=False)
|
||||
# Now check that we did not accidentally set reuse.
|
||||
self.assertEqual(False, tf.get_variable_scope().reuse)
|
||||
# Construct one more model with per-example loss.
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
_, losses2 = SampleGRUSeq2Seq(inp, out, weights, per_example_loss=True)
|
||||
# First loss is scalar, the second one is a 1-dimensinal tensor.
|
||||
self.assertEqual([], losses1[0].get_shape().as_list())
|
||||
self.assertEqual([None], losses2[0].get_shape().as_list())
|
||||
|
||||
def testModelWithBuckets(self):
|
||||
"""Larger tests that does full sequence-to-sequence model training."""
|
||||
# We learn to copy 10 symbols in 2 buckets: length 4 and length 8.
|
||||
classes = 10
|
||||
buckets = [(4, 4), (8, 8)]
|
||||
perplexities = [[], []] # Results for each bucket.
|
||||
tf.set_random_seed(111)
|
||||
random.seed(111)
|
||||
np.random.seed(111)
|
||||
|
||||
with self.test_session() as sess:
|
||||
# We use sampled softmax so we keep output projection separate.
|
||||
w = tf.get_variable("proj_w", [24, classes])
|
||||
w_t = tf.transpose(w)
|
||||
b = tf.get_variable("proj_b", [classes])
|
||||
# Here comes a sample Seq2Seq model using GRU cells.
|
||||
def SampleGRUSeq2Seq(enc_inp, dec_inp, weights):
|
||||
"""Example sequence-to-sequence model that uses GRU cells."""
|
||||
def GRUSeq2Seq(enc_inp, dec_inp):
|
||||
cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.GRUCell(24)] * 2,
|
||||
state_is_tuple=True)
|
||||
return tf.nn.seq2seq.embedding_attention_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols=classes,
|
||||
num_decoder_symbols=classes, embedding_size=24,
|
||||
output_projection=(w, b))
|
||||
targets = [dec_inp[i+1] for i in range(len(dec_inp) - 1)] + [0]
|
||||
def SampledLoss(labels, inputs):
|
||||
labels = tf.reshape(labels, [-1, 1])
|
||||
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, 8, classes)
|
||||
return tf.nn.seq2seq.model_with_buckets(
|
||||
enc_inp, dec_inp, targets, weights, buckets, GRUSeq2Seq,
|
||||
softmax_loss_function=SampledLoss)
|
||||
|
||||
# Now we construct the copy model.
|
||||
batch_size = 8
|
||||
inp = [tf.placeholder(tf.int32, shape=[None]) for _ in range(8)]
|
||||
out = [tf.placeholder(tf.int32, shape=[None]) for _ in range(8)]
|
||||
weights = [tf.ones_like(inp[0], dtype=tf.float32) for _ in range(8)]
|
||||
with tf.variable_scope("root"):
|
||||
_, losses = SampleGRUSeq2Seq(inp, out, weights)
|
||||
updates = []
|
||||
params = tf.global_variables()
|
||||
optimizer = tf.train.AdamOptimizer(0.03, epsilon=1e-5)
|
||||
for i in range(len(buckets)):
|
||||
full_grads = tf.gradients(losses[i], params)
|
||||
grads, _ = tf.clip_by_global_norm(full_grads, 30.0)
|
||||
update = optimizer.apply_gradients(zip(grads, params))
|
||||
updates.append(update)
|
||||
sess.run([tf.global_variables_initializer()])
|
||||
steps = 6
|
||||
for _ in range(steps):
|
||||
bucket = random.choice(np.arange(len(buckets)))
|
||||
length = buckets[bucket][0]
|
||||
i = [np.array([np.random.randint(9) + 1 for _ in range(batch_size)],
|
||||
dtype=np.int32) for _ in range(length)]
|
||||
# 0 is our "GO" symbol here.
|
||||
o = [np.array([0] * batch_size, dtype=np.int32)] + i
|
||||
feed = {}
|
||||
for i1, i2, o1, o2 in zip(inp[:length], i[:length],
|
||||
out[:length], o[:length]):
|
||||
feed[i1.name] = i2
|
||||
feed[o1.name] = o2
|
||||
if length < 8: # For the 4-bucket, we need the 5th as target.
|
||||
feed[out[length].name] = o[length]
|
||||
res = sess.run([updates[bucket], losses[bucket]], feed)
|
||||
perplexities[bucket].append(math.exp(float(res[1])))
|
||||
for bucket in range(len(buckets)):
|
||||
if len(perplexities[bucket]) > 1: # Assert that perplexity went down.
|
||||
self.assertLess(perplexities[bucket][-1], perplexities[bucket][0])
|
||||
|
||||
def testModelWithBooleanFeedPrevious(self):
|
||||
"""Test the model behavior when feed_previous is True.
|
||||
|
||||
For example, the following two cases have the same effect:
|
||||
- Train `embedding_rnn_seq2seq` with `feed_previous=True`, which contains
|
||||
a `embedding_rnn_decoder` with `feed_previous=True` and
|
||||
`update_embedding_for_previous=True`. The decoder is fed with "<Go>"
|
||||
and outputs "A, B, C".
|
||||
- Train `embedding_rnn_seq2seq` with `feed_previous=False`. The decoder
|
||||
is fed with "<Go>, A, B".
|
||||
"""
|
||||
num_encoder_symbols = 3
|
||||
num_decoder_symbols = 5
|
||||
batch_size = 2
|
||||
num_enc_timesteps = 2
|
||||
num_dec_timesteps = 3
|
||||
|
||||
def TestModel(seq2seq):
|
||||
with self.test_session(graph=tf.Graph()) as sess:
|
||||
tf.set_random_seed(111)
|
||||
random.seed(111)
|
||||
np.random.seed(111)
|
||||
|
||||
enc_inp = [tf.constant(i + 1, tf.int32, shape=[batch_size])
|
||||
for i in range(num_enc_timesteps)]
|
||||
dec_inp_fp_true = [tf.constant(i, tf.int32, shape=[batch_size])
|
||||
for i in range(num_dec_timesteps)]
|
||||
dec_inp_holder_fp_false = [tf.placeholder(tf.int32, shape=[batch_size])
|
||||
for _ in range(num_dec_timesteps)]
|
||||
targets = [tf.constant(i + 1, tf.int32, shape=[batch_size])
|
||||
for i in range(num_dec_timesteps)]
|
||||
weights = [tf.constant(1.0, shape=[batch_size])
|
||||
for i in range(num_dec_timesteps)]
|
||||
|
||||
def ForwardBackward(enc_inp, dec_inp, feed_previous):
|
||||
scope_name = "fp_{}".format(feed_previous)
|
||||
with tf.variable_scope(scope_name):
|
||||
dec_op, _ = seq2seq(enc_inp, dec_inp, feed_previous=feed_previous)
|
||||
net_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
|
||||
scope_name)
|
||||
optimizer = tf.train.AdamOptimizer(0.03, epsilon=1e-5)
|
||||
update_op = optimizer.minimize(
|
||||
tf.nn.seq2seq.sequence_loss(dec_op, targets, weights),
|
||||
var_list=net_variables)
|
||||
return dec_op, update_op, net_variables
|
||||
|
||||
dec_op_fp_true, update_fp_true, variables_fp_true = ForwardBackward(
|
||||
enc_inp, dec_inp_fp_true, feed_previous=True)
|
||||
_, update_fp_false, variables_fp_false = ForwardBackward(
|
||||
enc_inp, dec_inp_holder_fp_false, feed_previous=False)
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# We only check consistencies between the variables existing in both
|
||||
# the models with True and False feed_previous. Variables created by
|
||||
# the loop_function in the model with True feed_previous are ignored.
|
||||
v_false_name_dict = {v.name.split("/", 1)[-1]: v
|
||||
for v in variables_fp_false}
|
||||
matched_variables = [(v, v_false_name_dict[v.name.split("/", 1)[-1]])
|
||||
for v in variables_fp_true]
|
||||
for v_true, v_false in matched_variables:
|
||||
sess.run(tf.assign(v_false, v_true))
|
||||
|
||||
# Take the symbols generated by the decoder with feed_previous=True as
|
||||
# the true input symbols for the decoder with feed_previous=False.
|
||||
dec_fp_true = sess.run(dec_op_fp_true)
|
||||
output_symbols_fp_true = np.argmax(dec_fp_true, axis=2)
|
||||
dec_inp_fp_false = np.vstack((dec_inp_fp_true[0].eval(),
|
||||
output_symbols_fp_true[:-1]))
|
||||
sess.run(update_fp_true)
|
||||
sess.run(update_fp_false,
|
||||
{holder: inp for holder, inp in zip(dec_inp_holder_fp_false,
|
||||
dec_inp_fp_false)})
|
||||
|
||||
for v_true, v_false in matched_variables:
|
||||
self.assertAllClose(v_true.eval(), v_false.eval())
|
||||
|
||||
def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous):
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
return tf.nn.seq2seq.embedding_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols,
|
||||
num_decoder_symbols, embedding_size=2, feed_previous=feed_previous)
|
||||
|
||||
def EmbeddingRNNSeq2SeqNoTupleF(enc_inp, dec_inp, feed_previous):
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||
return tf.nn.seq2seq.embedding_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols,
|
||||
num_decoder_symbols, embedding_size=2, feed_previous=feed_previous)
|
||||
|
||||
def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous):
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
return tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_decoder_symbols, embedding_size=2,
|
||||
feed_previous=feed_previous)
|
||||
|
||||
def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||
return tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_decoder_symbols, embedding_size=2,
|
||||
feed_previous=feed_previous)
|
||||
|
||||
def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous):
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
return tf.nn.seq2seq.embedding_attention_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols,
|
||||
num_decoder_symbols, embedding_size=2, feed_previous=feed_previous)
|
||||
|
||||
def EmbeddingAttentionSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||
return tf.nn.seq2seq.embedding_attention_seq2seq(
|
||||
enc_inp, dec_inp, cell, num_encoder_symbols,
|
||||
num_decoder_symbols, embedding_size=2, feed_previous=feed_previous)
|
||||
|
||||
for model in (EmbeddingRNNSeq2SeqF, EmbeddingRNNSeq2SeqNoTupleF,
|
||||
EmbeddingTiedRNNSeq2Seq, EmbeddingTiedRNNSeq2SeqNoTuple,
|
||||
EmbeddingAttentionSeq2Seq, EmbeddingAttentionSeq2SeqNoTuple):
|
||||
TestModel(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -22,6 +22,7 @@ import numpy
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
@ -95,6 +96,21 @@ class VariableScopeTest(tf.test.TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
tf.get_variable("x", initializer={})
|
||||
|
||||
def testInitFromNonInitializer(self):
|
||||
with self.test_session() as sess:
|
||||
# Test various dtypes with zeros initializer as following:
|
||||
types = [tf.int8, tf.uint8, tf.int16, tf.uint16, tf.int32, tf.int64,
|
||||
tf.bool]
|
||||
|
||||
# Use different varibale_name to distinguish various dtypes
|
||||
for (i, dtype) in enumerate(types):
|
||||
x = tf.get_variable(name='x%d' % i, shape=(3, 4), dtype=dtype)
|
||||
y = tf.get_variable(name='y%d' % i, shape=(3, 4), dtype=dtype,
|
||||
initializer=init_ops.zeros_initializer(dtype=dtype))
|
||||
|
||||
tf.global_variables_initializer().run()
|
||||
self.assertAllEqual(x.eval(), y.eval())
|
||||
|
||||
def testVarScopeCachingDevice(self):
|
||||
with self.test_session():
|
||||
caching_device = "/job:moo"
|
||||
@ -672,6 +688,27 @@ def axis0_into3_partitioner(shape=None, **unused_kwargs):
|
||||
|
||||
class VariableScopeWithPartitioningTest(tf.test.TestCase):
|
||||
|
||||
def testInitFromNonInitializer(self):
|
||||
with self.test_session() as sess:
|
||||
# Test various dtypes with zeros initializer as following:
|
||||
types = [tf.int8, tf.uint8, tf.int16, tf.uint16, tf.int32, tf.int64,
|
||||
tf.bool]
|
||||
|
||||
# Use different varibale_name to distinguish various dtypes
|
||||
for (i, dtype) in enumerate(types):
|
||||
x = tf.get_variable(name='x%d' % i, shape=(3, 4), dtype=dtype,
|
||||
partitioner=axis0_into2_partitioner)
|
||||
y = tf.get_variable(name='y%d' % i, shape=(6, 4), dtype=dtype,
|
||||
partitioner=axis0_into2_partitioner,
|
||||
initializer=init_ops.zeros_initializer(dtype=dtype))
|
||||
|
||||
tf.global_variables_initializer().run()
|
||||
# x and y would become var list after partition
|
||||
val_x = sess.run(list(x))
|
||||
val_y = sess.run(list(y))
|
||||
|
||||
self.assertAllEqual(val_x, val_y)
|
||||
|
||||
def testResultNameMatchesRequested(self):
|
||||
with tf.variable_scope("scope0", partitioner=axis0_into2_partitioner):
|
||||
v = tf.get_variable("name0", shape=(3, 1, 1))
|
||||
|
@ -366,6 +366,14 @@ class VariablesTestCase(tf.test.TestCase):
|
||||
for i in v2.initializer.inputs:
|
||||
self.assertEqual(expected_group_v2, i.op.colocation_groups())
|
||||
|
||||
def testLoad(self):
|
||||
with self.test_session():
|
||||
var = tf.Variable(np.zeros((5,5), np.float32))
|
||||
tf.global_variables_initializer().run()
|
||||
var.load(np.ones((5, 5), np.float32))
|
||||
|
||||
self.assertAllClose(np.ones((5, 5), np.float32), var.eval())
|
||||
|
||||
|
||||
class IsInitializedTest(tf.test.TestCase):
|
||||
|
||||
|
@ -66,7 +66,8 @@ def zeros_initializer(dtype=dtypes.float32):
|
||||
"""Returns an initializer that generates tensors initialized to 0."""
|
||||
|
||||
def _initializer(shape, dtype=dtype, partition_info=None):
|
||||
return constant_op.constant(0, dtype=dtype, shape=shape)
|
||||
return constant_op.constant(False if dtype is dtypes.bool else 0,
|
||||
dtype=dtype, shape=shape)
|
||||
|
||||
return _initializer
|
||||
|
||||
|
@ -139,66 +139,66 @@ class BatchNormalizationTest(tf.test.TestCase):
|
||||
|
||||
def testInference(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(x_shape, [1], use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(x_shape, [1], use_gpu=True, data_format='NCHW')
|
||||
self._test_inference(x_shape, [1], use_gpu=False, data_format='NHWC')
|
||||
|
||||
x_shape = [1, 1, 6, 2]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(x_shape, [2], use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(x_shape, [2], use_gpu=False, data_format='NHWC')
|
||||
|
||||
x_shape = [1, 2, 1, 6]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(x_shape, [2], use_gpu=True, data_format='NCHW')
|
||||
|
||||
x_shape = [27, 131, 127, 6]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(x_shape, [131], use_gpu=True, data_format='NCHW')
|
||||
self._test_inference(x_shape, [6], use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(x_shape, [6], use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testTraining(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(x_shape, [1], use_gpu=True, data_format='NHWC')
|
||||
self._test_training(x_shape, [1], use_gpu=True, data_format='NCHW')
|
||||
self._test_training(x_shape, [1], use_gpu=False, data_format='NHWC')
|
||||
|
||||
x_shape = [1, 1, 6, 2]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(x_shape, [2], use_gpu=True, data_format='NHWC')
|
||||
self._test_training(x_shape, [2], use_gpu=False, data_format='NHWC')
|
||||
|
||||
x_shape = [1, 2, 1, 6]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(x_shape, [2], use_gpu=True, data_format='NCHW')
|
||||
|
||||
x_shape = [27, 131, 127, 6]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(x_shape, [131], use_gpu=True, data_format='NCHW')
|
||||
self._test_training(x_shape, [6], use_gpu=True, data_format='NHWC')
|
||||
self._test_training(x_shape, [6], use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testBatchNormGrad(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(x_shape, [1], use_gpu=True, data_format='NHWC')
|
||||
self._test_gradient(x_shape, [1], use_gpu=True, data_format='NCHW')
|
||||
self._test_gradient(x_shape, [1], use_gpu=False, data_format='NHWC')
|
||||
|
||||
x_shape = [1, 1, 6, 2]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(x_shape, [2], use_gpu=True, data_format='NHWC')
|
||||
self._test_gradient(x_shape, [2], use_gpu=False, data_format='NHWC')
|
||||
|
||||
x_shape = [1, 2, 1, 6]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(x_shape, [2], use_gpu=True, data_format='NCHW')
|
||||
|
||||
x_shape = [7, 9, 13, 6]
|
||||
if tf.test.is_gpu_available():
|
||||
if tf.test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(x_shape, [9], use_gpu=True, data_format='NCHW')
|
||||
self._test_gradient(x_shape, [6], use_gpu=True, data_format='NHWC')
|
||||
self._test_gradient(x_shape, [6], use_gpu=False, data_format='NHWC')
|
||||
|
@ -526,9 +526,14 @@ class _VariableStore(object):
|
||||
|
||||
var_full_name = "%s/part_%d" % (name, i)
|
||||
with ops.name_scope(var_full_name + "/PartitionedInitializer"):
|
||||
# Create the tensor to initialize the variable with default value.
|
||||
if initializer is None:
|
||||
init = init_ops.uniform_unit_scaling_initializer()
|
||||
init_shape = var_shape
|
||||
init, initializing_from_value = self._get_default_initializer(
|
||||
name=name, shape=shape, dtype=dtype)
|
||||
if initializing_from_value:
|
||||
init_shape = None
|
||||
else:
|
||||
init_shape = var_shape
|
||||
elif callable(initializer):
|
||||
init = initializer
|
||||
init_shape = var_shape
|
||||
@ -653,9 +658,10 @@ class _VariableStore(object):
|
||||
raise ValueError("Shape of a new variable (%s) must be fully defined, "
|
||||
"but instead was %s." % (name, shape))
|
||||
|
||||
# Create the tensor to initialize the variable.
|
||||
# Create the tensor to initialize the variable with default value.
|
||||
if initializer is None:
|
||||
initializer = init_ops.uniform_unit_scaling_initializer()
|
||||
initializer, initializing_from_value = self._get_default_initializer(
|
||||
name=name, shape=shape, dtype=dtype)
|
||||
# Clear control dependencies while creating the initializer.
|
||||
with ops.control_dependencies(None):
|
||||
if initializing_from_value:
|
||||
@ -692,6 +698,39 @@ class _VariableStore(object):
|
||||
return v
|
||||
|
||||
|
||||
# Initialize variable when no initializer provided
|
||||
def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
|
||||
"""Provide a default initializer and a corresponding value.
|
||||
|
||||
Args:
|
||||
name: see get_variable.
|
||||
shape: see get_variable.
|
||||
dtype: see get_variable.
|
||||
|
||||
Returns:
|
||||
initializer and initializing_from_value. See get_variable above.
|
||||
|
||||
Raises:
|
||||
ValueError: When giving unsupported dtype.
|
||||
"""
|
||||
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
|
||||
if dtype.is_floating:
|
||||
initializer = init_ops.uniform_unit_scaling_initializer()
|
||||
initializing_from_value = False
|
||||
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
|
||||
# If dtype is DT_BOOL, provide a default value `FALSE`
|
||||
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
|
||||
initializer = init_ops.zeros_initializer()(
|
||||
shape=shape, dtype=dtype.base_dtype)
|
||||
initializing_from_value = True
|
||||
# NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
|
||||
else:
|
||||
raise ValueError("An initializer for variable %s of %s is required"
|
||||
% (name, dtype.base_dtype))
|
||||
|
||||
return initializer, initializing_from_value
|
||||
|
||||
|
||||
# To stop regularization, use this regularizer
|
||||
def no_regularizer(_):
|
||||
"""Use this function to prevent regularization of variables."""
|
||||
|
@ -650,6 +650,46 @@ class Variable(object):
|
||||
"""
|
||||
return state_ops.count_up_to(self._variable, limit=limit)
|
||||
|
||||
def load(self, value, session=None):
|
||||
"""Load new value into this variable
|
||||
|
||||
Writes new value to variable's memory. Doesn't add ops to the graph.
|
||||
|
||||
This convenience method requires a session where the graph containing this
|
||||
variable has been launched. If no session is passed, the default session is
|
||||
used. See the [Session class](../../api_docs/python/client.md#Session) for
|
||||
more information on launching a graph and on sessions.
|
||||
|
||||
```python
|
||||
v = tf.Variable([1, 2])
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(init)
|
||||
# Usage passing the session explicitly.
|
||||
v.load([2, 3], sess)
|
||||
print(v.eval(sess)) # prints [2 3]
|
||||
# Usage with the default session. The 'with' block
|
||||
# above makes 'sess' the default session.
|
||||
v.load([3, 4], sess)
|
||||
print(v.eval()) # prints [3 4]
|
||||
```
|
||||
|
||||
Args:
|
||||
value: New variable value
|
||||
session: The session to use to evaluate this variable. If
|
||||
none, the default session is used.
|
||||
|
||||
Raises:
|
||||
ValueError: Session is not passed and no default session
|
||||
"""
|
||||
session = session or ops.get_default_session()
|
||||
if session is None:
|
||||
raise ValueError(
|
||||
"Either session argument should be provided or default session "
|
||||
"should be established")
|
||||
session.run(self._initializer_op, {self._initializer_op.inputs[1]: value})
|
||||
|
||||
# Conversion to tensor.
|
||||
@staticmethod
|
||||
def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name
|
||||
@ -1070,7 +1110,7 @@ def global_variables():
|
||||
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
|
||||
|
||||
@deprecated("2016-03-02", "Please use tf.global_variables instead.")
|
||||
@deprecated("2017-03-02", "Please use tf.global_variables instead.")
|
||||
def all_variables():
|
||||
"""See `tf.global_variables`."""
|
||||
return global_variables()
|
||||
|
@ -9,6 +9,8 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_library(
|
||||
name = "constants",
|
||||
srcs = ["constants.py"],
|
||||
|
@ -273,7 +273,7 @@ def merge(inputs, collections=None, name=None):
|
||||
inputs: A list of `string` `Tensor` objects containing serialized `Summary`
|
||||
protocol buffers.
|
||||
collections: Optional list of graph collections keys. The new summary op is
|
||||
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
|
||||
added to these collections. Defaults to `[]`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
|
@ -7,6 +7,8 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_library(
|
||||
name = "freeze_graph_lib",
|
||||
srcs = ["freeze_graph.py"],
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user