From b0bdff4827f867a67f572ed99d85f9a847788326 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Aug 2016 12:50:48 -0800 Subject: [PATCH] Merge changes from github. Change: 131437429 --- README.md | 4 +- configure | 10 +- tensorflow/BUILD | 9 + tensorflow/cc/framework/cc_op_gen.cc | 2 +- tensorflow/contrib/cmake/CMakeLists.txt | 3 +- .../contrib/cmake/tf_core_kernels.cmake | 2 + tensorflow/contrib/ios_examples/README.md | 3 + .../benchmark.xcodeproj/project.pbxproj | 4 + .../camera_example.xcodeproj/project.pbxproj | 4 + .../project.pbxproj | 4 + tensorflow/contrib/layers/__init__.py | 1 + .../contrib/layers/python/layers/layers.py | 109 ++- .../layers/python/layers/layers_test.py | 104 +++ tensorflow/contrib/learn/BUILD | 3 + tensorflow/contrib/makefile/Makefile | 10 + tensorflow/contrib/makefile/README.md | 4 +- tensorflow/contrib/makefile/tf_op_files.txt | 1 + .../python/ops/confusion_matrix_ops.py | 2 +- .../core/common_runtime/gpu/gpu_device.cc | 3 +- tensorflow/core/debug/debug_graph_utils.cc | 2 +- tensorflow/core/framework/log_memory.cc | 2 +- tensorflow/core/framework/op.cc | 4 +- tensorflow/core/framework/shape_inference.h | 2 +- tensorflow/core/framework/tensor.cc | 2 +- tensorflow/core/kernels/conv_ops.cc | 4 + .../core/kernels/conv_ops_using_gemm.cc | 622 ++++++++++++++++++ tensorflow/core/kernels/reduce_join_op.cc | 2 +- .../core/kernels/reverse_sequence_op.cc | 86 ++- tensorflow/core/kernels/reverse_sequence_op.h | 15 +- .../kernels/reverse_sequence_op_gpu.cu.cc | 18 +- .../kernels/sparse_dense_binary_op_shared.cc | 2 +- tensorflow/core/ops/array_ops.cc | 3 +- tensorflow/core/ops/ops.pbtxt | 36 + tensorflow/core/util/example_proto_helper.cc | 12 +- tensorflow/core/util/sparse/sparse_tensor.h | 2 +- .../examples/image_retraining/retrain.py | 2 +- .../tutorials/mnist/mnist_with_summaries.py | 2 +- .../shard5/tf.train.Optimizer.md | 10 +- .../tf.contrib.metrics.confusion_matrix.md | 2 +- tensorflow/g3doc/api_docs/python/train.md | 10 +- tensorflow/g3doc/get_started/os_setup.md | 13 +- .../summaries_and_tensorboard/index.md | 2 +- tensorflow/python/BUILD | 28 + .../python/kernel_tests/clip_ops_test.py | 40 ++ .../kernel_tests/reverse_sequence_op_test.py | 7 +- .../python/kernel_tests/sparse_ops_test.py | 23 + tensorflow/python/ops/clip_ops.py | 4 +- tensorflow/python/ops/io_ops.py | 1 + tensorflow/python/ops/nn_ops.py | 4 +- tensorflow/python/ops/sparse_ops.py | 47 ++ tensorflow/python/platform/app.py | 4 +- tensorflow/python/platform/app_test.py | 45 ++ tensorflow/python/platform/flags.py | 3 +- tensorflow/python/training/device_setter.py | 2 +- .../python/training/device_setter_test.py | 17 + .../tf-dashboard-common/dashboard-style.html | 10 +- .../tf-event-dashboard.html | 20 + .../vz-line-chart/vz-line-chart.html | 16 +- .../components/vz-line-chart/vz-line-chart.ts | 18 +- tensorflow/tensorboard/tensorboard.py | 7 +- tensorflow/tools/ci_build/Dockerfile.gpu | 2 +- .../docker/parameterized_docker_build.sh | 16 +- tensorflow/workspace.bzl | 40 +- third_party/gpus/cuda/BUILD.tpl | 2 +- third_party/gpus/cuda_configure.bzl | 40 +- 65 files changed, 1374 insertions(+), 159 deletions(-) create mode 100644 tensorflow/core/kernels/conv_ops_using_gemm.cc create mode 100644 tensorflow/python/platform/app_test.py diff --git a/README.md b/README.md index 3de86a21654..e356ff931ce 100644 --- a/README.md +++ b/README.md @@ -34,9 +34,9 @@ and discussion.** People who are a little more adventurous can also try our nightly binaries: * Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/140/artifact/pip_test/whl/tensorflow-0.8.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) +* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/140/artifact/pip_test/whl/tensorflow-0.8.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) * Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/)) -* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/)) +* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/)) * [Android](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/)) #### *Try your first TensorFlow program* diff --git a/configure b/configure index bcef37bd26b..f4b772c55ef 100755 --- a/configure +++ b/configure @@ -98,7 +98,7 @@ while true; do fi fi if [ -e "$GCC_HOST_COMPILER_PATH" ]; then - export CC=$GCC_HOST_COMPILER_PATH + export GCC_HOST_COMPILER_PATH break fi echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2 @@ -142,7 +142,7 @@ while true; do if [ -e "${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH}" ]; then export CUDA_TOOLKIT_PATH - export CUDA_VERSION=$TF_CUDA_VERSION + export TF_CUDA_VERSION break fi echo "Invalid path to CUDA $TF_CUDA_VERSION toolkit. ${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH} cannot be found" @@ -203,7 +203,7 @@ while true; do fi if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" -o -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then - export CUDNN_VERSION=$TF_CUDNN_VERSION + export TF_CUDNN_VERSION export CUDNN_INSTALL_PATH break fi @@ -211,7 +211,7 @@ while true; do if [ "$OSNAME" == "Linux" ]; then CUDNN_PATH_FROM_LDCONFIG="$(ldconfig -p | sed -n 's/.*libcudnn.so .* => \(.*\)/\1/p')" if [ -e "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}" ]; then - export CUDNN_VERSION=$TF_CUDNN_VERSION + export TF_CUDNN_VERSION export CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})" break fi @@ -263,7 +263,7 @@ EOF exit 1 fi else - export CUDA_COMPUTE_CAPABILITIES=$TF_CUDA_COMPUTE_CAPABILITIES + export TF_CUDA_COMPUTE_CAPABILITIES break fi TF_CUDA_COMPUTE_CAPABILITIES="" diff --git a/tensorflow/BUILD b/tensorflow/BUILD index add5d97169a..0a73364471f 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -174,6 +174,15 @@ cc_binary( ], ) +cc_binary( + name = "libtensorflow_c.so", + linkshared = 1, + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/core:tensorflow", + ], +) + cc_binary( name = "libtensorflow_cc.so", linkshared = 1, diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 659fdfb153e..8ed53404f1e 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -232,7 +232,7 @@ string PrintAttrValue(string op, const AttrValue& attr_value) { string ToCamelCase(const string& str) { string result; const char joiner = '_'; - int i = 0; + size_t i = 0; bool cap = true; while (i < str.size()) { const char c = str[i++]; diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index e5a0790ff4a..4fe5960d3b7 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -51,7 +51,8 @@ include(highwayhash) # Let's get to work! include(tf_core_framework.cmake) -include(tf_stream_executor.cmake) +# NOTE: Disabled until issue #3996 is fixed. +# include(tf_stream_executor.cmake) include(tf_core_cpu.cmake) include(tf_models.cmake) include(tf_core_ops.cmake) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 2fff5c2dd37..8f911e3cb10 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -13,6 +13,8 @@ file(GLOB_RECURSE tf_core_kernels_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*testutil.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/*main.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc" + "${tensorflow_source_dir}/tensorflow/core/kernels/debug_ops.h" + "${tensorflow_source_dir}/tensorflow/core/kernels/debug_ops.cc" ) list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs}) diff --git a/tensorflow/contrib/ios_examples/README.md b/tensorflow/contrib/ios_examples/README.md index 1c29c74e51e..83b1d7868a0 100644 --- a/tensorflow/contrib/ios_examples/README.md +++ b/tensorflow/contrib/ios_examples/README.md @@ -71,6 +71,9 @@ rundown: inside the library are not stripped out. To the linker, they can appear unused because no other code references the variables, but in fact their constructors have the important side effect of registering the class. + + - You'll need to include the Accelerate framework in the "Link Binary with + Libraries" build phase of your project. - C++11 support (or later) should be enabled by setting `C++ Language Dialect` to `GNU++11` (or `GNU++14`), and `C++ Standard Library` to `libc++`. diff --git a/tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj index 934bafdeb94..a7266987475 100644 --- a/tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj +++ b/tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj @@ -9,6 +9,7 @@ /* Begin PBXBuildFile section */ 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */; }; 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D871D02091F00DF5523 /* libprotobuf.a */; }; + 5993C7701D5D4E7F0048CE6A /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5993C76F1D5D4E7F0048CE6A /* Accelerate.framework */; }; 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; }; @@ -25,6 +26,7 @@ 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libprotobuf-lite.a"; path = "../../makefile/gen/protobuf_ios/lib/libprotobuf-lite.a"; sourceTree = ""; }; 590E7D871D02091F00DF5523 /* libprotobuf.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = libprotobuf.a; path = ../../makefile/gen/protobuf_ios/lib/libprotobuf.a; sourceTree = ""; }; 5911579B1CF4011C00C31E3A /* benchmark.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = benchmark.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 5993C76F1D5D4E7F0048CE6A /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; 59A3CFF41CF4E68100C4259F /* cropped_panda.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = cropped_panda.jpg; sourceTree = ""; }; @@ -50,6 +52,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( + 5993C7701D5D4E7F0048CE6A /* Accelerate.framework in Frameworks */, 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */, 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */, 59A3D0181CF4E86100C4259F /* UIKit.framework in Frameworks */, @@ -63,6 +66,7 @@ 591157921CF4011C00C31E3A = { isa = PBXGroup; children = ( + 5993C76F1D5D4E7F0048CE6A /* Accelerate.framework */, 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */, 590E7D871D02091F00DF5523 /* libprotobuf.a */, 59A3D0171CF4E86100C4259F /* UIKit.framework */, diff --git a/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj index 14785cc708e..451f536a7c4 100644 --- a/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj +++ b/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj @@ -24,6 +24,7 @@ 592FF90D18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 592FF90A18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard */; }; 592FF92518EE240200C164F8 /* CameraExampleAppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 592FF92218EE240200C164F8 /* CameraExampleAppDelegate.m */; }; 592FF92618EE240200C164F8 /* CameraExampleViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 592FF92418EE240200C164F8 /* CameraExampleViewController.mm */; }; + 5993C7721D5D4E980048CE6A /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5993C7711D5D4E980048CE6A /* Accelerate.framework */; }; /* End PBXBuildFile section */ /* Begin PBXFileReference section */ @@ -52,6 +53,7 @@ 592FF92218EE240200C164F8 /* CameraExampleAppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CameraExampleAppDelegate.m; sourceTree = SOURCE_ROOT; }; 592FF92318EE240200C164F8 /* CameraExampleViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleViewController.h; sourceTree = SOURCE_ROOT; }; 592FF92418EE240200C164F8 /* CameraExampleViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CameraExampleViewController.mm; sourceTree = SOURCE_ROOT; }; + 5993C7711D5D4E980048CE6A /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.3.sdk/System/Library/Frameworks/Accelerate.framework; sourceTree = DEVELOPER_DIR; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -59,6 +61,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( + 5993C7721D5D4E980048CE6A /* Accelerate.framework in Frameworks */, 591D3EDF1CFFAD230059011C /* libprotobuf-lite.a in Frameworks */, 591D3EE01CFFAD230059011C /* libprotobuf.a in Frameworks */, 591D3ECF1CFF7FCE0059011C /* ImageIO.framework in Frameworks */, @@ -103,6 +106,7 @@ 592FF8B718ECBD7600C164F8 /* Frameworks */ = { isa = PBXGroup; children = ( + 5993C7711D5D4E980048CE6A /* Accelerate.framework */, 591D3EDD1CFFAD230059011C /* libprotobuf-lite.a */, 591D3EDE1CFFAD230059011C /* libprotobuf.a */, 591D3ECE1CFF7FCE0059011C /* ImageIO.framework */, diff --git a/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj index a81286a5b19..9e058be2814 100644 --- a/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj +++ b/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj @@ -9,6 +9,7 @@ /* Begin PBXBuildFile section */ 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */; }; 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D871D02091F00DF5523 /* libprotobuf.a */; }; + 5993C7741D5D4EAF0048CE6A /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5993C7731D5D4EAF0048CE6A /* Accelerate.framework */; }; 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; }; @@ -25,6 +26,7 @@ 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libprotobuf-lite.a"; path = "../../makefile/gen/protobuf_ios/lib/libprotobuf-lite.a"; sourceTree = ""; }; 590E7D871D02091F00DF5523 /* libprotobuf.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = libprotobuf.a; path = ../../makefile/gen/protobuf_ios/lib/libprotobuf.a; sourceTree = ""; }; 5911579B1CF4011C00C31E3A /* tf_ios_makefile_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_ios_makefile_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 5993C7731D5D4EAF0048CE6A /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; 59A3CFF41CF4E68100C4259F /* cropped_panda.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = cropped_panda.jpg; sourceTree = ""; }; @@ -50,6 +52,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( + 5993C7741D5D4EAF0048CE6A /* Accelerate.framework in Frameworks */, 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */, 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */, 59A3D0181CF4E86100C4259F /* UIKit.framework in Frameworks */, @@ -63,6 +66,7 @@ 591157921CF4011C00C31E3A = { isa = PBXGroup; children = ( + 5993C7731D5D4EAF0048CE6A /* Accelerate.framework */, 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */, 590E7D871D02091F00DF5523 /* libprotobuf.a */, 59A3D0171CF4E86100C4259F /* UIKit.framework */, diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 20a82a235e2..4bec7c2661a 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -27,6 +27,7 @@ common machine learning algorithms. @@convolution2d_transpose @@flatten @@fully_connected +@@layer_norm @@max_pool2d @@one_hot_encoding @@repeat diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 3290d925515..0947480d7c3 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -52,6 +52,7 @@ __all__ = ['avg_pool2d', 'dropout', 'flatten', 'fully_connected', + 'layer_norm', 'linear', 'max_pool2d', 'one_hot_encoding', @@ -276,7 +277,8 @@ def batch_norm(inputs, outputs.set_shape(inputs_shape) if activation_fn: outputs = activation_fn(outputs) - return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) @add_arg_scope @@ -328,7 +330,8 @@ def bias_add(inputs, outputs = nn.bias_add(inputs, biases) if activation_fn: outputs = activation_fn(outputs) - return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) @add_arg_scope @@ -441,7 +444,8 @@ def convolution2d(inputs, outputs = nn.bias_add(outputs, biases) if activation_fn: outputs = activation_fn(outputs) - return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) @add_arg_scope @@ -541,7 +545,8 @@ def convolution2d_in_plane( if activation_fn: outputs = activation_fn(outputs) - return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) @add_arg_scope @@ -668,7 +673,8 @@ def convolution2d_transpose( if activation_fn: outputs = activation_fn(outputs) - return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) @add_arg_scope @@ -845,7 +851,95 @@ def fully_connected(inputs, # Reshape back outputs outputs = array_ops.reshape(outputs, array_ops.pack(out_shape)) outputs.set_shape(static_shape) - return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) + + +@add_arg_scope +def layer_norm(inputs, + center=True, + scale=True, + activation_fn=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + """Adds a Layer Normalization layer from https://arxiv.org/abs/1607.06450. + + "Layer Normalization" + + Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton + + Can be used as a normalizer function for conv2d and fully_connected. + + Args: + inputs: a tensor with 2 or more dimensions. The normalization + occurs over all but the first dimension. + center: If True, subtract `beta`. If False, `beta` is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is + not used. When the next layer is linear (also e.g. `nn.relu`), this can be + disabled since the scaling can be done by the next layer. + activation_fn: Optional activation function. + reuse: whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: optional collections for the variables. + outputs_collections: collections to add the outputs. + trainable: If `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + scope: Optional scope for `variable_op_scope`. + + Returns: + A `Tensor` representing the output of the operation. + + Raises: + ValueError: if rank or last dimension of `inputs` is undefined. + """ + with variable_scope.variable_scope(scope, 'LayerNorm', [inputs], + reuse=reuse) as sc: + inputs = ops.convert_to_tensor(inputs) + inputs_shape = inputs.get_shape() + inputs_rank = inputs_shape.ndims + if inputs_rank is None: + raise ValueError('Inputs %s has undefined rank.' % inputs.name) + dtype = inputs.dtype.base_dtype + axis = list(range(1, inputs_rank)) + params_shape = inputs_shape[-1:] + if not params_shape.is_fully_defined(): + raise ValueError('Inputs %s has undefined last dimension %s.' % ( + inputs.name, params_shape)) + # Allocate parameters for the beta and gamma of the normalization. + beta, gamma = None, None + if center: + beta_collections = utils.get_variable_collections(variables_collections, + 'beta') + beta = variables.model_variable('beta', + shape=params_shape, + dtype=dtype, + initializer=init_ops.zeros_initializer, + collections=beta_collections, + trainable=trainable) + if scale: + gamma_collections = utils.get_variable_collections(variables_collections, + 'gamma') + gamma = variables.model_variable('gamma', + shape=params_shape, + dtype=dtype, + initializer=init_ops.ones_initializer, + collections=gamma_collections, + trainable=trainable) + # Calculate the moments on the last axis (layer activations). + mean, variance = nn.moments(inputs, axis, keep_dims=True) + # Compute layer normalization using the batch_normalization function. + variance_epsilon = 1E-12 + outputs = nn.batch_normalization( + inputs, mean, variance, beta, gamma, variance_epsilon) + outputs.set_shape(inputs_shape) + if activation_fn: + outputs = activation_fn(outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, + outputs) @add_arg_scope @@ -1094,7 +1188,8 @@ def separable_convolution2d( if activation_fn: outputs = activation_fn(outputs) - return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) @add_arg_scope diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 711d17cb291..4fc2e07b9bc 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -203,6 +203,16 @@ class Convolution2dTest(tf.test.TestCase): scope='conv1') self.assertEquals(output.op.name, 'conv1/Relu') + def testCreateConvWithCollection(self): + height, width = 3, 3 + images = tf.random_uniform((5, height, width, 3), seed=1) + with tf.name_scope('fe'): + conv = tf.contrib.layers.convolution2d( + images, 32, [3, 3], outputs_collections='outputs', + scope='Conv') + namedOutputs = tf.get_collection('outputs')[0] + self.assertEquals(namedOutputs.name, 'fe/Conv') + def testCreateConvWithoutActivation(self): height, width = 3, 3 with self.test_session(): @@ -989,6 +999,16 @@ class FCTest(tf.test.TestCase): output = tf.contrib.layers.fully_connected(inputs, 32, scope='fc1') self.assertEquals(output.op.name, 'fc1/Relu') + def testCreateFCWithCollection(self): + height, width = 3, 3 + inputs = tf.random_uniform((5, height * width * 3), seed=1) + with tf.name_scope('fe'): + fc = tf.contrib.layers.fully_connected( + inputs, 7, outputs_collections='outputs', + scope='fc') + namedOutputs = tf.get_collection('outputs')[0] + self.assertEquals(namedOutputs.name, 'fe/fc') + def testCreateFcCreatesWeightsAndBiasesVars(self): height, width = 3, 3 inputs = tf.random_uniform((5, height * width * 3), seed=1) @@ -1542,6 +1562,90 @@ class BatchNormTest(tf.test.TestCase): self.assertAllClose(moving_variance.eval(), expected_var) +class LayerNormTest(tf.test.TestCase): + + def testUnknownShape(self): + with tf.Graph().as_default() as g, self.test_session(g): + inputs = tf.placeholder(dtype=tf.float32) + with self.assertRaisesRegexp(ValueError, 'undefined rank'): + tf.contrib.layers.layer_norm(inputs) + + def testUnknownLastDim(self): + with tf.Graph().as_default() as g, self.test_session(g): + inputs = tf.placeholder(dtype=tf.float32) + inputs.set_shape(tf.TensorShape((5, 3, 3, None))) + with self.assertRaisesRegexp(ValueError, 'undefined last dimension'): + tf.contrib.layers.layer_norm(inputs) + + def testCreateOp(self): + height, width = 3, 3 + with self.test_session(): + images = np.random.uniform(size=(5, height, width, 3)) + output = tf.contrib.layers.layer_norm(images) + self.assertTrue(output.op.name.startswith('LayerNorm/batchnorm')) + self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3]) + + def testCreateVariables(self): + height, width = 3, 3 + with self.test_session(): + images = tf.random_uniform((5, height, width, 3), seed=1) + tf.contrib.layers.layer_norm(images) + beta = tf.contrib.framework.get_variables_by_name('beta')[0] + gamma = tf.contrib.framework.get_variables_by_name('gamma')[0] + self.assertEquals(beta.op.name, 'LayerNorm/beta') + self.assertEquals(gamma.op.name, 'LayerNorm/gamma') + + def testReuseVariables(self): + height, width = 3, 3 + with self.test_session(): + images = tf.random_uniform((5, height, width, 3), seed=1) + tf.contrib.layers.layer_norm(images, scope='ln') + tf.contrib.layers.layer_norm(images, scope='ln', reuse=True) + beta = tf.contrib.framework.get_variables_by_name('beta') + gamma = tf.contrib.framework.get_variables_by_name('gamma') + self.assertEquals(len(beta), 1) + self.assertEquals(len(gamma), 1) + + def testReuseVars(self): + height, width = 3, 3 + with self.test_session() as sess: + image_shape = (10, height, width, 3) + image_values = np.random.rand(*image_shape) + images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) + output_train = tf.contrib.layers.layer_norm(images, scope='LN') + output_eval = tf.contrib.layers.layer_norm(images, + scope='LN', + reuse=True) + # Initialize all variables + sess.run(tf.initialize_all_variables()) + # output_train and output_eval should be the same. + self.assertAllClose(sess.run([output_train]), sess.run([output_eval])) + + def doOutputTest(self, input_shape): + with self.test_session() as sess: + input_values = np.random.rand(*input_shape) + inputs = tf.constant(input_values, shape=input_shape, dtype=tf.float32) + output_op = tf.contrib.layers.layer_norm(inputs, scope='LN') + # Initialize all variables + sess.run(tf.initialize_all_variables()) + # The mean and variance of the output should be close to 0 and 1 + # respectively. + moments_axis = tuple([i for i in range(1, len(input_shape))]) + outputs = sess.run(output_op) + expected_mean = np.zeros(input_shape[0]) + expected_var = np.ones(input_shape[0]) + mean = np.mean(outputs, axis=moments_axis) + var = np.var(outputs, axis=moments_axis) + tol = 1e-5 + self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) + self.assertAllClose(var, expected_var, rtol=tol, atol=tol) + + def testOutput2DInput(self): + self.doOutputTest((10, 300)) + + def testOutput4DInput(self): + self.doOutputTest((100, 10, 10, 3)) + class MaxPool2DTest(tf.test.TestCase): def testCreateMaxPool(self): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 8278ab89a4b..ad1706a6f9d 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -684,6 +684,9 @@ py_test( size = "small", srcs = ["python/learn/utils/export_test.py"], srcs_version = "PY2AND3", + tags = [ + "manual", # http://b/31032996 + ], deps = [ ":learn", "//tensorflow:tensorflow_py", diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index e68e8dc1bbb..69cb84c2bb2 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -285,6 +285,7 @@ ifeq ($(TARGET),IOS) CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -arch armv7 \ -D__thread= \ + -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -mno-thumb \ -DTF_LEAN_BINARY \ @@ -295,6 +296,7 @@ ifeq ($(TARGET),IOS) ${IPHONEOS_SYSROOT} LDFLAGS := -arch armv7 \ -miphoneos-version-min=${MIN_SDK_VERSION} \ + -framework Accelerate \ -Xlinker -S \ -Xlinker -x \ -Xlinker -dead_strip \ @@ -306,6 +308,7 @@ ifeq ($(TARGET),IOS) CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -arch armv7s \ -D__thread= \ + -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -mno-thumb \ -DTF_LEAN_BINARY \ @@ -316,6 +319,7 @@ ifeq ($(TARGET),IOS) ${IPHONEOS_SYSROOT} LDFLAGS := -arch armv7s \ -miphoneos-version-min=${MIN_SDK_VERSION} \ + -framework Accelerate \ -Xlinker -S \ -Xlinker -x \ -Xlinker -dead_strip \ @@ -327,6 +331,7 @@ ifeq ($(TARGET),IOS) CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -arch arm64 \ -D__thread= \ + -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ -D__ANDROID_TYPES_SLIM__ \ @@ -336,6 +341,7 @@ ifeq ($(TARGET),IOS) ${IPHONEOS_SYSROOT} LDFLAGS := -arch arm64 \ -miphoneos-version-min=${MIN_SDK_VERSION} \ + -framework Accelerate \ -Xlinker -S \ -Xlinker -x \ -Xlinker -dead_strip \ @@ -347,6 +353,7 @@ ifeq ($(TARGET),IOS) CXXFLAGS += -mios-simulator-version-min=$(MIN_SDK_VERSION) \ -arch i386 \ -D__thread= \ + -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ -D__ANDROID_TYPES_SLIM__ \ @@ -356,6 +363,7 @@ ifeq ($(TARGET),IOS) ${IPHONESIMULATOR_SYSROOT} LDFLAGS := -arch i386 \ -mios-simulator-version-min=${MIN_SDK_VERSION} \ + -framework Accelerate \ -Xlinker -S \ -Xlinker -x \ -Xlinker -dead_strip \ @@ -367,6 +375,7 @@ ifeq ($(TARGET),IOS) CXXFLAGS += -mios-simulator-version-min=$(MIN_SDK_VERSION) \ -arch x86_64 \ -D__thread= \ + -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ -D__ANDROID_TYPES_SLIM__ \ @@ -376,6 +385,7 @@ ifeq ($(TARGET),IOS) ${IPHONESIMULATOR_SYSROOT} LDFLAGS := -arch x86_64 \ -mios-simulator-version-min=${MIN_SDK_VERSION} \ + -framework Accelerate \ -Xlinker -S \ -Xlinker -x \ -Xlinker -dead_strip \ diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md index acfcf8f220f..c0a68470ba9 100644 --- a/tensorflow/contrib/makefile/README.md +++ b/tensorflow/contrib/makefile/README.md @@ -260,15 +260,17 @@ For other variations of valid optimization flags, see [clang optimization levels ## Raspberry Pi Building on the Raspberry Pi is similar to a normal Linux system. First -download the dependencies and build protobuf: +download the dependencies, install the required packages and build protobuf: ```bash tensorflow/contrib/makefile/download_dependencies.sh +sudo apt-get install autoconf automake libtool cd tensorflow/contrib/makefile/downloads/protobuf/ ./autogen.sh ./configure make sudo make install +sudo ldconfig # refresh shared library cache cd ../../../../.. ``` diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 5c0128db787..4f5c3176dee 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -99,6 +99,7 @@ tensorflow/core/kernels/cwise_op_equal_to.cc tensorflow/core/kernels/cwise_op_div.cc tensorflow/core/kernels/cwise_op_add.cc tensorflow/core/kernels/ctc_decoder_ops.cc +tensorflow/core/kernels/conv_ops_using_gemm.cc tensorflow/core/kernels/conv_ops.cc tensorflow/core/kernels/conv_grad_ops.cc tensorflow/core/kernels/control_flow_ops.cc diff --git a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py index cf000f06d3a..5620001f1a7 100644 --- a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py +++ b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py @@ -59,7 +59,7 @@ def confusion_matrix(predictions, labels, num_classes=None, name: Scope name. Returns: - A l X l matrix represeting the confusion matrix, where l in the number of + A k X k matrix represeting the confusion matrix, where k is the number of possible labels in the classification task. Raises: diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 38a759dbb74..033d67772ce 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -787,7 +787,8 @@ struct CudaVersion { }; std::vector supported_cuda_compute_capabilities = { - TF_CUDA_CAPABILITIES,}; + TF_CUDA_CAPABILITIES, +}; std::vector GetSupportedCudaComputeCapabilities() { auto cuda_caps = supported_cuda_compute_capabilities; diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc index 118847686d3..c2e50f4d1b1 100644 --- a/tensorflow/core/debug/debug_graph_utils.cc +++ b/tensorflow/core/debug/debug_graph_utils.cc @@ -154,7 +154,7 @@ Status DebugNodeInserter::InsertNodes( // Create all requested debug nodes and their edges to the Copy node. std::vector node_added_debug_nodes; - for (int i = 0; i < tensor_watches[tensor_name].size(); ++i) { + for (size_t i = 0; i < tensor_watches[tensor_name].size(); ++i) { const string& debug_op_name = tensor_watches[tensor_name][i]; Node* debug_node; diff --git a/tensorflow/core/framework/log_memory.cc b/tensorflow/core/framework/log_memory.cc index ba221179cea..dc82504d43d 100644 --- a/tensorflow/core/framework/log_memory.cc +++ b/tensorflow/core/framework/log_memory.cc @@ -30,7 +30,7 @@ namespace { template void OutputToLog(const T& proto) { string type_name = proto.GetTypeName(); - const int index = type_name.find_last_of("."); + const size_t index = type_name.find_last_of("."); if (index != string::npos) type_name = type_name.substr(index + 1); LOG(INFO) << LogMemory::kLogMemoryLabel << " " << type_name << " { " << ProtoShortDebugString(proto) << " }"; diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index 65bfd47433f..6bff192b1ec 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -156,7 +156,7 @@ string OpRegistry::DebugString(bool include_internal) const { bool OpRegistry::MustCallDeferred() const { if (initialized_) return false; initialized_ = true; - for (int i = 0; i < deferred_.size(); ++i) { + for (size_t i = 0; i < deferred_.size(); ++i) { TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i])); } deferred_.clear(); @@ -166,7 +166,7 @@ bool OpRegistry::MustCallDeferred() const { Status OpRegistry::CallDeferred() const { if (initialized_) return Status::OK(); initialized_ = true; - for (int i = 0; i < deferred_.size(); ++i) { + for (size_t i = 0; i < deferred_.size(); ++i) { Status s = RegisterAlreadyLocked(deferred_[i]); if (!s.ok()) { return s; diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index f65dac16860..74cfbba72ea 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -57,7 +57,7 @@ class DimensionHandle { const Dimension* ptr_ = nullptr; - friend class DimensionOrConstant; + friend struct DimensionOrConstant; friend class InferenceContext; friend class ShapeInferenceTest; friend class ShapeInferenceTestutil; diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 6361791cb00..c6c5aab4764 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -740,7 +740,7 @@ string Tensor::SummarizeValue(int64 max_entries) const { string ret; // TODO(irving): Don't call flat every time around this // loop. - for (int64 i = 0; i < limit; ++i) { + for (size_t i = 0; i < limit; ++i) { if (i > 0) strings::StrAppend(&ret, " "); switch (dtype()) { case DT_STRING: diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index e2182d0ec05..86d498fe4a5 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -242,8 +242,12 @@ class Conv2DOp : public BinaryOp { Name("Conv2D").Device(DEVICE_CPU).TypeConstraint("T"), \ Conv2DOp); +// If we're using the alternative GEMM-based implementation of Conv2D for the +// CPU implementation, don't register this EigenTensor-based version. +#if !defined(USE_GEMM_FOR_CONV) TF_CALL_half(REGISTER_CPU); TF_CALL_float(REGISTER_CPU); +#endif // USE_GEMM_FOR_CONV // To be used inside depthwise_conv_op.cc. template class LaunchConv2DOp; diff --git a/tensorflow/core/kernels/conv_ops_using_gemm.cc b/tensorflow/core/kernels/conv_ops_using_gemm.cc new file mode 100644 index 00000000000..c39510a11a2 --- /dev/null +++ b/tensorflow/core/kernels/conv_ops_using_gemm.cc @@ -0,0 +1,622 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains a set of different implementations of the two-dimensional +// convolution operation. The standard TensorFlow Conv2d kernel uses EigenTensor +// to implement the computation, but this module has a variety of different ways +// of producing the same result. These methods are designed to be easier to +// understand and connect to other libraries, so that we can take advantage of +// platforms that have specialized implementations of GEMM for example. +// +// The basic interface is a Conv functor object that's templated by the types +// of the data it will be operating on, and is passed in the arguments needed to +// calculate the convolution. The simplest implementation of this functor is +// ReferenceConvFunctor, which is a readable but slow reference version. +// +// A faster version uses the approach of packing image patches into a matrix +// before calling a matrix multiply, the Im2ColConvFunctor. In turn, this can +// use a variety of different methods to calculate the matrix multiplication, +// or GEMM. The simplest but slowest is the ReferenceGemmFunctor, but the +// FastGemmFunctor will use whatever optimized libraries are available. By +// default it uses Eigen, but on Apple platforms it will take advantage of the +// system's Accelerate BLAS library to get better performance than the standard +// TensorFlow convolution kernel. +// +// The version actually used is defined at the bottom of this file using the +// REGISTER_KERNEL_BUILDER() macro. To try out different implementations (for +// example to switch to a reference one for easier debugging) you can swap out +// the default functors in that call. +// +// The registration itself is guarded with the USE_GEMM_FOR_CONV macro. The iOS +// makefile build defines this, but if you want to enable this implementation +// and disable the standard EigenTensor one in other build setups, you'll need +// to define it there too. + +#include +#include +#include +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +#if defined(__APPLE__) +#include +#define USE_ACCELERATE_GEMM +#endif // __APPLE__ + +namespace tensorflow { + +namespace { +// This function implements the convolution operation in as simple a form as +// possible. It won't give great performance, but it is very useful for +// stepping through and instrumenting for debugging, creating minimal benchmarks +// to prototype with, and sharing with teams that want to run this outside of +// our environment. +// With that in mind, I've avoided using anything except pretty standard C++ +// types. This is especially noticeable in the data access through raw array +// indexing. It's deliberate in this case though, since it makes the underlying +// memory order very explicit, which is important for both inspecting memory +// contents during debugging and for specifying what we expect to others. +// The memory layout of the data is, from biggest stride to smallest: +// input_data = [input_batches, input_height, input_width, input_depth] +// filter_data = [filter_height, filter_width, input_depth, filter_count] +// output_data = [input_batches, output_height, output_width, filter_count] +template +class ReferenceConvFunctor { + public: + void operator()(OpKernelContext* context, const T1* input_data, + int input_batches, int input_height, int input_width, + int input_depth, const T2* filter_data, int filter_height, + int filter_width, int filter_count, int stride_rows, + int stride_cols, Padding padding, T3* output_data, + int output_height, int output_width) { + // The two different padding modes we support can be a bit confusing. SAME + // means we're trying to produce an output image that's the same size as the + // input. It's complicated by stride, which shrinks the output image by a + // a factor, but it means we end up sampling from outside the borders of the + // input. These out-of-bounds values are read as zeroes. VALID means only + // produce output values where the filters can read all their values from + // within the input image. It effectively removes the margins of the output + // image compared to the one produced by SAME. Stride complicates this + // definition though, because it can result in the right and bottom filter + // patches sampling from outside the borders if it's greater than 1. + // Most of the logic for sorting this all out is done before this function, + // when we calculate the output size, but the positioning of the origin of + // the filters is different between the two modes, since SAME positions the + // first filter off the edge of the input. + int filter_left_offset; + int filter_top_offset; + if (padding == VALID) { + filter_left_offset = + ((output_width - 1) * stride_cols + filter_width - input_width + 1) / + 2; + filter_top_offset = ((output_height - 1) * stride_rows + filter_height - + input_height + 1) / + 2; + } else { + filter_left_offset = + ((output_width - 1) * stride_cols + filter_width - input_width) / 2; + filter_top_offset = + ((output_height - 1) * stride_rows + filter_height - input_height) / + 2; + } + + // If we've got multiple images in our input, work through each of them. + for (int batch = 0; batch < input_batches; ++batch) { + // Walk through all the output image values, sliding the filter to + // different positions in the input. + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + // Each filter kernel produces one output channel. + for (int out_channel = 0; out_channel < filter_count; ++out_channel) { + // We're going to calculate a single output value, which means we + // need to multiply a three dimensional kernel of weights against + // the current location within the input image. + /* + *-------------------------------... + |\ ^ + | \in_depth + | \ v + | *-------------------------------... + | | ^ + | | in_y_origin + | | v \ + | |*---*^ + | | \| |filter_height + . | *---*v + . | <---> + . filter_width + . + */ + const int in_x_origin = (out_x * stride_cols) - filter_left_offset; + const int in_y_origin = (out_y * stride_rows) - filter_top_offset; + T3 total(0); + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; + ++in_channel) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + T1 input_value; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + input_value = + input_data[(batch * input_height * input_width * + input_depth) + + (in_y * input_width * input_depth) + + (in_x * input_depth) + in_channel]; + } else { + input_value = T1(0); + } + const T2 filter_value = + filter_data[(filter_y * filter_width * input_depth * + filter_count) + + (filter_x * input_depth * filter_count) + + (in_channel * filter_count) + out_channel]; + total += (input_value * filter_value); + } + } + } + output_data[(batch * output_height * output_width * filter_count) + + (out_y * output_width * filter_count) + + (out_x * filter_count) + out_channel] = total; + } + } + } + } + } +}; + +// A readable but slow implementation of matrix multiplication, useful for +// debugging and understanding the algorithm. Use instead of FastGemmFunctor in +// the Im2ColConvFunctor template definition inside the op registration to +// enable. Assumes row-major ordering of the values in memory. +template +class ReferenceGemmFunctor { + public: + void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda, + const T2* b, size_t ldb, T3* c, size_t ldc) { + const size_t a_i_stride = lda; + const size_t a_l_stride = 1; + const size_t b_j_stride = 1; + const size_t b_l_stride = ldb; + const size_t c_i_stride = ldc; + const size_t c_j_stride = 1; + size_t i, j, l; + for (j = 0; j < n; j++) { + for (i = 0; i < m; i++) { + T3 total(0); + for (l = 0; l < k; l++) { + const size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); + const T1 a_value = a[a_index]; + const size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); + const T2 b_value = b[b_index]; + total += (a_value * b_value); + } + const size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); + c[c_index] = total; + } + } + } +}; + +// Uses the optimized Eigen library to implement the matrix multiplication +// required by the Im2ColConvFunctor class. We supply the two input and one +// output types so that the accumulator can potentially be higher-precision than +// the inputs, even though we don't currently take advantage of this. +template +class FastGemmFunctor { + public: + // Convenience wrappers for the Eigen matrix types we'll be using. + typedef Eigen::Map< + const Eigen::Matrix> + ConstMatrixT1; + typedef Eigen::Map< + const Eigen::Matrix> + ConstMatrixT2; + typedef Eigen::Map< + Eigen::Matrix> + MatrixT3; + void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda, + const T2* b, size_t ldb, T3* c, size_t ldc) { + ConstMatrixT1 a_matrix(a, m, k); + ConstMatrixT2 b_matrix(b, k, n); + MatrixT3 c_matrix(c, m, n); + c_matrix.noalias() = a_matrix * b_matrix; + } +}; + +// If we have Apple's Accelerate framework, use their implementation of GEMM to +// get a performance boost for float. +#if defined(USE_ACCELERATE_GEMM) +template <> +class FastGemmFunctor { + public: + void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda, + const float* b, size_t ldb, float* c, size_t ldc) { + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a, + lda, b, ldb, 0.0f, c, ldc); + } +}; +#endif // USE_ACCELERATE_GEMM + +// Used to keep track of persistent memory buffers used within the op. +template +struct Im2ColBufferResource : public ResourceBase { + mutex mu; + T data[size]; + string DebugString() { return "Im2ColBufferResource"; } +}; + +// Implements convolution as a two stage process, first packing the patches of +// the input image into columns (im2col) and then running GEMM to produce the +// final result. +template +class Im2ColConvFunctor { + public: + void operator()(OpKernelContext* context, const T1* input_data, + int input_batches, int input_height, int input_width, + int input_depth, const T2* filter_data, int filter_height, + int filter_width, int filter_count, int stride_rows, + int stride_cols, Padding padding, T3* output_data, + int output_height, int output_width) { + if ((input_batches <= 0) || (input_width <= 0) || (input_height <= 0) || + (input_depth <= 0)) { + LOG(WARNING) << "Conv2D was called with bad input dimensions: " + << input_batches << ", " << input_height << ", " + << input_width << ", " << input_depth; + return; + } + if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) { + LOG(WARNING) << "Conv2D was called with bad filter dimensions: " + << filter_width << ", " << filter_height << ", " + << filter_count; + return; + } + if ((output_width <= 0) || (output_height <= 0)) { + LOG(WARNING) << "Conv2D was called with bad output width or height: " + << output_width << ", " << output_height; + return; + } + + // These calculations define how the patches will be positioned within the + // input image. The actual definitions are quite complex, and rely on the + // previously-calculated output size. + int filter_left_offset; + int filter_top_offset; + if (padding == VALID) { + filter_left_offset = + ((output_width - 1) * stride_cols + filter_width - input_width + 1) / + 2; + filter_top_offset = ((output_height - 1) * stride_rows + filter_height - + input_height + 1) / + 2; + } else { + filter_left_offset = + ((output_width - 1) * stride_cols + filter_width - input_width) / 2; + filter_top_offset = + ((output_height - 1) * stride_rows + filter_height - input_height) / + 2; + } + + // The im2col buffer has # of patches rows, and # of filters cols. + // It's laid out like this, in row major order in memory: + // < filter value count > + // ^ +---------------------+ + // patch | | + // count | | + // v +---------------------+ + // Each patch row contains a filter_width x filter_height patch of the + // input, with the depth channel as the most contiguous in memory, followed + // by the width, then the height. This is the standard memory order in the + // image world if it helps to visualize it. + const int filter_value_count = filter_width * filter_height * input_depth; + + // We don't want to allocate a buffer to hold all the patches if the size is + // going to be extremely large, so break it into chunks if it's bigger than + // a limit. Each chunk will be processed serially, so we can refill the + // buffer for the next chunk and reuse it, keeping maximum memory size down. + // In this case, we've picked 16 megabytes as a reasonable limit. + const size_t max_chunk_size = (16 * 1024 * 1024); + OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= max_chunk_size, + errors::InvalidArgument("Im2Col patch too large for buffer")); + const size_t patches_per_chunk = + max_chunk_size / (filter_value_count * sizeof(T1)); + + // Because memory allocation is very expensive on mobile platforms, try to + // allocate a persistent buffer that will be kept around between calls. We + // use TensorFlow's resource management to ensure that the memory will be + // released when the session is over. + Im2ColBufferResource* im2col_buffer_resource; + std::function**)> creator = + [](Im2ColBufferResource** resource) { + *resource = new Im2ColBufferResource(); + return Status::OK(); + }; + OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( + "Conv2d", "im2col_buffer", + &im2col_buffer_resource, creator)); + // This means that multiple ops can't be run simultaneously on different + // threads, because we have a single shared resource. The platforms this is + // aimed at have intra-op parallelism as their focus though, so it shouldn't + // be an issue. + mutex_lock lock_buffer(im2col_buffer_resource->mu); + core::ScopedUnref unref_buffer(im2col_buffer_resource); + T1* im2col_buffer = im2col_buffer_resource->data; + + for (int batch = 0; batch < input_batches; ++batch) { + const T1* input_batch_start = + input_data + (batch * input_height * input_width * input_depth); + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_rows) - filter_top_offset; + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_cols) - filter_left_offset; + const int patch_index = (batch * output_width * output_height) + + (out_y * output_width) + out_x; + const int patch_index_within_chunk = patch_index % patches_per_chunk; + T1* im2col_patch_start = + im2col_buffer + (patch_index_within_chunk * filter_value_count); + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int in_y = in_y_origin + filter_y; + T1* im2col_row_start = + im2col_patch_start + (filter_y * filter_width * input_depth); + // If we're off the top or the bottom of the input, fill the whole + // row with zeroes. + if ((in_y < 0) || (in_y >= input_height)) { + T1* im2col_row_end = + im2col_row_start + (filter_width * input_depth); + std::fill(im2col_row_start, im2col_row_end, T1(0)); + } else { + // What we're doing here is trying to copy and fill the im2col + // buffer as efficiently as possible, using functions to set or + // duplicate values en masse. We know we don't have to worry about + // vertical edges because we dealt with that case above, so we + // just need to handle filters that overlap the left or right + // edges. Here's what that looks like: + // + // < left_zero_count > < center_copy_count > < right_zero_count > + // +------------------+---------------------+--------------------+ + // | (filter) | (image) | (filter) | + // +------------------+---------------------+--------------------+ + // in_x_origin 0 input_width in_x_end + // + // In reality it's unlikely that a filter patch will be wider + // than an input, but this shows all the edge cases. + // We use std::fill() to set the left and right sections to zeroes + // and std::copy() to copy over the input data for the center. + const int in_x_end = in_x_origin + filter_width; + const int left_zero_count = std::max(0, 0 - in_x_origin); + const int right_zero_count = std::max(0, in_x_end - input_width); + const int center_copy_count = + filter_width - (left_zero_count + right_zero_count); + if (left_zero_count > 0) { + T1* im2col_left_start = im2col_row_start; + T1* im2col_left_end = + im2col_left_start + (left_zero_count * input_depth); + std::fill(im2col_left_start, im2col_left_end, T1(0)); + } + if (center_copy_count > 0) { + const T1* input_row_start = + input_batch_start + (in_y * input_width * input_depth) + + (std::max(0, in_x_origin) * input_depth); + const T1* input_row_end = + input_row_start + (center_copy_count * input_depth); + T1* im2col_center_start = + im2col_row_start + (left_zero_count * input_depth); + std::copy(input_row_start, input_row_end, im2col_center_start); + } + if (right_zero_count > 0) { + T1* im2col_right_start = + im2col_row_start + + ((left_zero_count + center_copy_count) * input_depth); + T1* im2col_right_end = + im2col_right_start + (right_zero_count * input_depth); + std::fill(im2col_right_start, im2col_right_end, T1(0)); + } + } + } + const bool is_last_in_chunk = + (patch_index_within_chunk == (patches_per_chunk - 1)); + const bool is_last_overall = + ((batch == (input_batches - 1)) && + (out_y == (output_height - 1)) && (out_x == (output_width - 1))); + if (is_last_in_chunk || is_last_overall) { + // Now we've assembled a set of image patches into a matrix, apply a + // GEMM matrix multiply of the patches as rows, times the filter + // weights in columns, to get partial results in the output matrix. + const int how_many_patches = patch_index_within_chunk + 1; + const int m = how_many_patches; + const int n = filter_count; + const int k = filter_value_count; + const int lda = filter_value_count; + const int ldb = filter_count; + const int ldc = filter_count; + const size_t start_patch_index = + patch_index - (how_many_patches - 1); + T3* chunk_output_data = + output_data + (start_patch_index * filter_count); + TGemmFunctor gemm_functor; + gemm_functor(m, n, k, im2col_buffer, lda, filter_data, ldb, + chunk_output_data, ldc); + } + } + } + } + } +}; + +} // namespace + +// This TensorFlow kernel class handles all of the IO and housekeeping for the +// functors that actually implement the underlying algorithm. To swap in +// different implementations of the main calculations, use a different +// TConvFunctor parameter when instantiating the template. +template +class Conv2DUsingGemmOp : public BinaryOp { + public: + explicit Conv2DUsingGemmOp(OpKernelConstruction* context) + : BinaryOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES(context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "Data format not supported by this kernel", data_format)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); + const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); + OP_REQUIRES( + context, stride_n == 1 && stride_c == 1, + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + // Input tensor is of the following dimensions: + // [ batch, in_rows, in_cols, in_depth ] + const Tensor& input = context->input(0); + + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, in_depth, out_depth] + const Tensor& filter = context->input(1); + + // For 2D convolution, there should be 4 dimensions. + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().DebugString())); + OP_REQUIRES(context, filter.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional: ", + filter.shape().DebugString())); + + for (int i = 0; i < 3; i++) { + OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i), + std::numeric_limits::max()), + errors::InvalidArgument("filter too large")); + } + + // The last dimension for input is in_depth. It must be the same as the + // filter's in_depth. + const int64 in_depth = GetTensorDim(input, data_format_, 'C'); + OP_REQUIRES( + context, in_depth == filter.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + in_depth, " vs ", filter.dim_size(2))); + + // The last dimension for filter is out_depth. + const int out_depth = static_cast(filter.dim_size(3)); + + // The second dimension for input is rows/height. + // The first dimension for filter is rows/height. + const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H'); + OP_REQUIRES(context, FastBoundsCheck(input_rows_raw, + std::numeric_limits::max()), + errors::InvalidArgument("Input rows too large")); + const int input_rows = static_cast(input_rows_raw); + const int filter_rows = static_cast(filter.dim_size(0)); + + // The third dimension for input is columns/width. + // The second dimension for filter is columns/width. + const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W'); + OP_REQUIRES(context, FastBoundsCheck(input_cols_raw, + std::numeric_limits::max()), + errors::InvalidArgument("Input cols too large")); + const int input_cols = static_cast(input_cols_raw); + const int filter_cols = static_cast(filter.dim_size(1)); + + // The first dimension for input is batch. + const int64 batch_raw = GetTensorDim(input, data_format_, 'N'); + OP_REQUIRES(context, + FastBoundsCheck(batch_raw, std::numeric_limits::max()), + errors::InvalidArgument("batch is too large")); + const int batch = static_cast(batch_raw); + + // For now we take the stride from the second and third dimensions only (we + // do not support striding on the batch or depth dimension). + const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); + const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); + + int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + OP_REQUIRES_OK(context, + GetWindowedOutputSize(input_rows, filter_rows, stride_rows, + padding_, &out_rows, &pad_rows)); + OP_REQUIRES_OK(context, + GetWindowedOutputSize(input_cols, filter_cols, stride_cols, + padding_, &out_cols, &pad_cols)); + TensorShape out_shape = + ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); + + // Output tensor is of the following dimensions: + // [ in_batch, out_rows, out_cols, out_depth ] + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + VLOG(2) << "Conv2D: in_depth = " << in_depth + << ", input_cols = " << input_cols + << ", filter_cols = " << filter_cols + << ", input_rows = " << input_rows + << ", filter_rows = " << filter_rows + << ", stride_rows = " << stride_rows + << ", stride_cols = " << stride_cols + << ", out_depth = " << out_depth; + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + TConvFunctor conv_functor; + conv_functor(context, input.flat().data(), batch, input_rows, input_cols, + in_depth, filter.flat().data(), filter_rows, filter_cols, + out_depth, stride_rows, stride_cols, padding_, + output->flat().data(), out_rows, out_cols); + } + + private: + std::vector strides_; + Padding padding_; + TensorFormat data_format_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DUsingGemmOp); +}; + +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Conv2D").Device(DEVICE_CPU).TypeConstraint("T"), \ + Conv2DUsingGemmOp< \ + T, Im2ColConvFunctor>>); + +// Only register this GEMM-based implementation of Conv2d if the compiler flags +// request the implementation explicitly, since otherwise it will clash with the +// default EigenTensor-based kernel. +#if defined(USE_GEMM_FOR_CONV) +TF_CALL_half(REGISTER_CPU); +TF_CALL_float(REGISTER_CPU); +#endif // USE_GEMM_FOR_CONV + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduce_join_op.cc b/tensorflow/core/kernels/reduce_join_op.cc index 1bd415c0d12..b72bfeb15ff 100644 --- a/tensorflow/core/kernels/reduce_join_op.cc +++ b/tensorflow/core/kernels/reduce_join_op.cc @@ -105,7 +105,7 @@ void MakeUnreducedIndices(gtl::InlinedVector index_is_reduced, TensorShape GetOutputShape(gtl::InlinedVector index_is_reduced, const TensorShape& input_shape, bool keep_dims) { TensorShape output_shape; - for (int32 index = 0; index < index_is_reduced.size(); ++index) { + for (size_t index = 0; index < index_is_reduced.size(); ++index) { if (index_is_reduced[index]) { if (keep_dims) output_shape.AddDim(1); } else { diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc index 5825b1cd4d1..505c512cc42 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.cc +++ b/tensorflow/core/kernels/reverse_sequence_op.cc @@ -40,19 +40,19 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -template +template void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { const Tensor& input = context->input(0); const Tensor& seq_lens = context->input(1); - auto seq_lens_t = seq_lens.vec(); + auto seq_lens_t = seq_lens.vec(); - std::vector seq_lens_vec(seq_lens_t.size()); + std::vector seq_lens_vec(seq_lens_t.size()); // Copy seq_len info down for validity checks context->eigen_device().memcpyDeviceToHost( seq_lens_vec.data(), seq_lens_t.data(), - sizeof(int64) * seq_lens_t.size()); + sizeof(Tlen) * seq_lens_t.size()); OP_REQUIRES(context, batch_dim != seq_dim, errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); @@ -76,8 +76,7 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { } } -template <> -void CheckErrors(OpKernelContext* context, int batch_dim, +void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) { const Tensor& input = context->input(0); const Tensor& seq_lens = context->input(1); @@ -97,7 +96,19 @@ void CheckErrors(OpKernelContext* context, int batch_dim, " vs. ", input.dim_size(batch_dim))); } -template +template <> +void CheckErrors(OpKernelContext* context, int batch_dim, + int seq_dim) { + CheckErrorsGPU(context, batch_dim, seq_dim); +} + +template <> +void CheckErrors(OpKernelContext* context, int batch_dim, + int seq_dim) { + CheckErrorsGPU(context, batch_dim, seq_dim); +} + +template class ReverseSequenceOp : public OpKernel { public: explicit ReverseSequenceOp(OpKernelConstruction* context) @@ -115,9 +126,9 @@ class ReverseSequenceOp : public OpKernel { errors::InvalidArgument("seq_lens input must be 1-dim, not ", seq_lens.dims())); - auto seq_lens_t = seq_lens.vec(); + auto seq_lens_t = seq_lens.vec(); - CheckErrors(context, batch_dim_, seq_dim_); + CheckErrors(context, batch_dim_, seq_dim_); const int input_dims = input.dims(); @@ -127,7 +138,7 @@ class ReverseSequenceOp : public OpKernel { #define HANDLE_DIM(NDIM) \ case NDIM: \ - functor::ReverseSequence::Compute( \ + functor::ReverseSequence::Compute( \ context->eigen_device(), input.tensor(), batch_dim_, \ seq_dim_, seq_lens_t, output->tensor()); \ break; @@ -153,42 +164,57 @@ class ReverseSequenceOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp); }; -#define REGISTER_REVERSE_SEQUENCE(type) \ +#define REGISTER_REVERSE_SEQUENCE(type, len_type) \ REGISTER_KERNEL_BUILDER( \ - Name("ReverseSequence").Device(DEVICE_CPU).TypeConstraint("T"), \ - ReverseSequenceOp); + Name("ReverseSequence").Device(DEVICE_CPU).TypeConstraint("T"). \ + TypeConstraint("Tlen"), \ + ReverseSequenceOp); -TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE); +#define REGISTER_REVERSE_SEQUENCE_LEN(type) \ + REGISTER_REVERSE_SEQUENCE(type, int32); \ + REGISTER_REVERSE_SEQUENCE(type, int64); + +TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_LEN); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T, Dims) \ - template <> \ - void ReverseSequence::Compute( \ - const GPUDevice& d, typename TTypes::ConstTensor input, \ - int32 batch_dim, int32 seq_dim, TTypes::ConstVec seq_lens, \ - typename TTypes::Tensor output); \ - extern template struct ReverseSequence; +#define DECLARE_GPU_SPEC(T, Tlen, Dims) \ + template <> \ + void ReverseSequence::Compute( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + int32 batch_dim, int32 seq_dim, \ + typename TTypes::ConstVec seq_lens, \ + typename TTypes::Tensor output); \ + extern template struct ReverseSequence; -#define DECLARE_GPU_SPECS(T) \ - DECLARE_GPU_SPEC(T, 2); \ - DECLARE_GPU_SPEC(T, 3); \ - DECLARE_GPU_SPEC(T, 4); \ - DECLARE_GPU_SPEC(T, 5); +#define DECLARE_GPU_SPEC_LEN(T, Dims) \ + DECLARE_GPU_SPEC(T, int32, Dims); \ + DECLARE_GPU_SPEC(T, int64, Dims); + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPEC_LEN(T, 2); \ + DECLARE_GPU_SPEC_LEN(T, 3); \ + DECLARE_GPU_SPEC_LEN(T, 4); \ + DECLARE_GPU_SPEC_LEN(T, 5); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); } // namespace functor // Registration of the GPU implementations. -#define REGISTER_REVERSE_SEQUENCE_GPU(type) \ +#define REGISTER_REVERSE_SEQUENCE_GPU(type, len_type) \ REGISTER_KERNEL_BUILDER( \ - Name("ReverseSequence").Device(DEVICE_GPU).TypeConstraint("T"), \ - ReverseSequenceOp); + Name("ReverseSequence").Device(DEVICE_GPU).TypeConstraint("T"). \ + TypeConstraint("Tlen"), \ + ReverseSequenceOp); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU); +#define REGISTER_REVERSE_SEQUENCE_GPU_LEN(type) \ + REGISTER_REVERSE_SEQUENCE_GPU(type, int32); \ + REGISTER_REVERSE_SEQUENCE_GPU(type, int64); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU_LEN); #undef REGISTER_REVERSE_SEQUENCE_GPU diff --git a/tensorflow/core/kernels/reverse_sequence_op.h b/tensorflow/core/kernels/reverse_sequence_op.h index 72c59d59aad..8ccd32ea160 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.h +++ b/tensorflow/core/kernels/reverse_sequence_op.h @@ -25,12 +25,12 @@ namespace tensorflow { namespace generator { -template +template class ReverseGenerator { public: EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ReverseGenerator(typename TTypes::ConstTensor input, int32 batch_dim, - int32 seq_dim, TTypes::ConstVec seq_lengths) + int32 seq_dim, typename TTypes::ConstVec seq_lengths) : input_(input), batch_dim_(batch_dim), seq_dim_(seq_dim), @@ -51,21 +51,22 @@ class ReverseGenerator { typename TTypes::ConstTensor input_; int32 batch_dim_; int32 seq_dim_; - TTypes::ConstVec seq_lengths_; + typename TTypes::ConstVec seq_lengths_; }; } // namespace generator namespace functor { -template +template struct ReverseSequence { EIGEN_ALWAYS_INLINE static void Compute( const Device& d, typename TTypes::ConstTensor input, - int32 batch_dim, int32 seq_dim, TTypes::ConstVec seq_lengths, + int32 batch_dim, int32 seq_dim, + typename TTypes::ConstVec seq_lengths, typename TTypes::Tensor output) { - generator::ReverseGenerator generator(input, batch_dim, seq_dim, - seq_lengths); + generator::ReverseGenerator generator(input, batch_dim, + seq_dim, seq_lengths); output.device(d) = input.generate(generator); } }; diff --git a/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc index bcc265c624d..373fd606873 100644 --- a/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc +++ b/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc @@ -24,15 +24,19 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -#define DEFINE_GPU_SPEC(T, dims) \ - template class generator::ReverseGenerator; \ - template struct functor::ReverseSequence; +#define DEFINE_GPU_SPEC(T, Tlen, dims) \ + template class generator::ReverseGenerator; \ + template struct functor::ReverseSequence; + +#define DEFINE_GPU_SPEC_LEN(T, dims) \ + DEFINE_GPU_SPEC(T, int32, dims); \ + DEFINE_GPU_SPEC(T, int64, dims); #define DEFINE_GPU_SPECS(T) \ - DEFINE_GPU_SPEC(T, 2); \ - DEFINE_GPU_SPEC(T, 3); \ - DEFINE_GPU_SPEC(T, 4); \ - DEFINE_GPU_SPEC(T, 5); + DEFINE_GPU_SPEC_LEN(T, 2); \ + DEFINE_GPU_SPEC_LEN(T, 3); \ + DEFINE_GPU_SPEC_LEN(T, 4); \ + DEFINE_GPU_SPEC_LEN(T, 5); TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); diff --git a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc index dbd1d9466c0..2d1539fb9d9 100644 --- a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc +++ b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc @@ -91,7 +91,7 @@ class SparseDenseBinaryOpShared : public OpKernel { auto VecGreaterEq = [](ArraySlice lhs, ArraySlice rhs) { if (lhs.size() > rhs.size()) return true; if (lhs.size() < rhs.size()) return false; - for (int i = 0; i < lhs.size(); ++i) { + for (size_t i = 0; i < lhs.size(); ++i) { if (lhs[i] < rhs[i]) return false; } return true; diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 65052e565cc..82d9924665c 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1471,11 +1471,12 @@ This operation returns N 1-D integer tensors representing shape of `input[i]s`. // -------------------------------------------------------------------------- REGISTER_OP("ReverseSequence") .Input("input: T") - .Input("seq_lengths: int64") + .Input("seq_lengths: Tlen") .Output("output: T") .Attr("seq_dim: int") .Attr("batch_dim: int = 0") .Attr("T: type") + .Attr("Tlen: {int32, int64} = DT_INT64") .SetShapeFn([](InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle seq_lens_shape; diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index fab9001ff56..3495f6ea14f 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -4711,6 +4711,42 @@ op { summary: "Decode a PNG-encoded image to a uint8 or uint16 tensor." description: "The attr `channels` indicates the desired number of color channels for the\ndecoded image.\n\nAccepted values are:\n\n* 0: Use the number of channels in the PNG-encoded image.\n* 1: output a grayscale image.\n* 3: output an RGB image.\n* 4: output an RGBA image.\n\nIf needed, the PNG-encoded image is transformed to match the requested number\nof color channels." } +op { + name: "DecodeGif" + input_arg { + name: "contents" + description: "0-D. The GIF-encoded image." + type: DT_STRING + } + output_arg { + name: "image" + description: "3-D with shape `[height, width, channels]`." + type_attr: "dtype" + } + attr { + name: "channels" + type: "int" + default_value { + i: 0 + } + description: "Number of color channels for the decoded image." + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_UINT8 + } + allowed_values { + list { + type: DT_UINT8 + type: DT_UINT16 + } + } + } + summary: "Decode a GIF-encoded image to a uint8 or uint16 tensor." + description: "The attr `channels` indicates the desired number of color channels for the\ndecoded image.\n\nAccepted values are:\n\n* 0: Use the number of channels in the GIF-encoded image.\n* 1: output a grayscale image.\n* 3: output an RGB image.\n* 4: output an RGBA image.\n\nIf needed, the GIF-encoded image is transformed to match the requested number\nof color channels." +} op { name: "DecodeRaw" input_arg { diff --git a/tensorflow/core/util/example_proto_helper.cc b/tensorflow/core/util/example_proto_helper.cc index acab74732e4..658b3588e6b 100644 --- a/tensorflow/core/util/example_proto_helper.cc +++ b/tensorflow/core/util/example_proto_helper.cc @@ -222,7 +222,7 @@ Status SingleExampleProtoToTensors( const auto& feature_dict = features.feature(); // Handle dense features. - for (int d = 0; d < fixed_len_features.size(); ++d) { + for (size_t d = 0; d < fixed_len_features.size(); ++d) { const FixedLenFeature& feature_config = fixed_len_features[d]; const string& key = feature_config.key; const DataType& dtype = feature_config.dtype; @@ -263,7 +263,7 @@ Status SingleExampleProtoToTensors( } // Handle sparse features. - for (int d = 0; d < var_len_features.size(); ++d) { + for (size_t d = 0; d < var_len_features.size(); ++d) { const VarLenFeature& feature_config = var_len_features[d]; const string& key = feature_config.key; const DataType& dtype = feature_config.dtype; @@ -338,7 +338,7 @@ Status BatchExampleProtoToTensors( fixed_len_features.size()); // Preallocate dense_values, since we know their sizes. - for (int d = 0; d < fixed_len_features.size(); ++d) { + for (size_t d = 0; d < fixed_len_features.size(); ++d) { const FixedLenFeature& config = fixed_len_features[d]; TensorShape out_shape; out_shape.AddDim(batch_size); @@ -352,11 +352,11 @@ Status BatchExampleProtoToTensors( // Temporary vector to hold sparse values. std::vector> sparse_values_tmp(var_len_features.size()); - for (int d = 0; d < var_len_features.size(); ++d) { + for (size_t d = 0; d < var_len_features.size(); ++d) { sparse_values_tmp[d] = std::vector(batch_size); } - for (int b = 0; b < examples.size(); ++b) { + for (size_t b = 0; b < examples.size(); ++b) { const Example& ex = *(examples[b]); const string& example_name = (has_names) ? names[b] : ""; SingleExampleProtoToTensors( @@ -364,7 +364,7 @@ Status BatchExampleProtoToTensors( &output_dense_values_tensor_ptrs, &sparse_values_tmp); } - for (int d = 0; d < var_len_features.size(); ++d) { + for (size_t d = 0; d < var_len_features.size(); ++d) { const VarLenFeature& feature_config = var_len_features[d]; const DataType& dtype = feature_config.dtype; const std::vector& sparse_values_tensor = sparse_values_tmp[d]; diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h index 79064e9988d..1665446aee4 100644 --- a/tensorflow/core/util/sparse/sparse_tensor.h +++ b/tensorflow/core/util/sparse/sparse_tensor.h @@ -283,7 +283,7 @@ void SparseTensor::Reorder(const VarDimArray& order) { // permutation (the inverse). This can be calculated with O(1) // additional // and O(n) time (INVPERM) but we just do the simple thing here. - std::vector permutation(reorder.size()); + std::vector permutation(reorder.size()); for (std::size_t n = 0; n < reorder.size(); ++n) { permutation[reorder[n]] = n; } diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 056664778d7..acd3005c483 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -703,7 +703,7 @@ def variable_summaries(var, name): mean = tf.reduce_mean(var) tf.scalar_summary('mean/' + name, mean) with tf.name_scope('stddev'): - stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean))) + stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) tf.scalar_summary('sttdev/' + name, stddev) tf.scalar_summary('max/' + name, tf.reduce_max(var)) tf.scalar_summary('min/' + name, tf.reduce_min(var)) diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py index c83239217ae..23492e5122c 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py +++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py @@ -75,7 +75,7 @@ def train(): mean = tf.reduce_mean(var) tf.scalar_summary('mean/' + name, mean) with tf.name_scope('stddev'): - stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean))) + stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) tf.scalar_summary('sttdev/' + name, stddev) tf.scalar_summary('max/' + name, tf.reduce_max(var)) tf.scalar_summary('min/' + name, tf.reduce_min(var)) diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.Optimizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.Optimizer.md index cc3b53d1193..cfd86f5bee2 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.Optimizer.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.Optimizer.md @@ -184,9 +184,9 @@ applies gradients. ### Gating Gradients -Both `minimize()` and `compute_gradients()` accept a `gate_gradients` -argument that controls the degree of parallelism during the application of -the gradients. +Both `minimize()` and `compute_gradients()` accept a `gate_gradients` argument +that controls the degree of parallelism during the application of the +gradients. The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. @@ -258,7 +258,3 @@ Use `get_slot_names()` to get the list of slot names created by the - - - #### `tf.train.Optimizer.get_name()` {#Optimizer.get_name} - - - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.metrics.confusion_matrix.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.metrics.confusion_matrix.md index f06295f0811..50d7917926c 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.metrics.confusion_matrix.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.metrics.confusion_matrix.md @@ -35,7 +35,7 @@ the same shape in order for this function to work. ##### Returns: - A l X l matrix represeting the confusion matrix, where l in the number of + A k X k matrix represeting the confusion matrix, where k is the number of possible labels in the classification task. ##### Raises: diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index e9c9ad63e3a..d408a404890 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -204,9 +204,9 @@ applies gradients. ### Gating Gradients -Both `minimize()` and `compute_gradients()` accept a `gate_gradients` -argument that controls the degree of parallelism during the application of -the gradients. +Both `minimize()` and `compute_gradients()` accept a `gate_gradients` argument +that controls the degree of parallelism during the application of the +gradients. The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. @@ -311,7 +311,7 @@ Construct a new gradient descent optimizer. ### `class tf.train.AdadeltaOptimizer` {#AdadeltaOptimizer} -Optimizer that implements the Adadelta algorithm. +Optimizer that implements the Adadelta algorithm. See [M. D. Zeiler](http://arxiv.org/abs/1212.5701) ([pdf](http://arxiv.org/pdf/1212.5701v1.pdf)) @@ -3771,5 +3771,3 @@ Generates a checkpoint state proto. CheckpointState proto with model_checkpoint_path and all_model_checkpoint_paths updated to either absolute paths or relative paths to the current save_dir. - - diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index 53afca4ccbf..fe8e0137676 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -8,9 +8,10 @@ github source. The TensorFlow Python API supports Python 2.7 and Python 3.3+. The GPU version (Linux & Mac OS X only) works best with Cuda Toolkit 7.5 and -cuDNN v4. other versions are supported (Cuda toolkit >= 7.0 and cuDNN 6.5(v2), -7.0(v3), v5) only when installing from sources. Please see [Cuda installation] -(#optional-install-cuda-gpus-on-linux) for details. +cuDNN v4. other versions are supported (Cuda toolkit >= 7.0 and +cuDNN 6.5(v2), 7.0(v3), v5) only when installing from sources. +Please see [Cuda installation](#optional-install-cuda-gpus-on-linux) +for details. ## Overview @@ -239,7 +240,7 @@ packages needed by TensorFlow. * Activate the conda environment and install TensorFlow in it. * After the install you will activate the conda environment each time you want to use TensorFlow. -* Optionally install ipython and other packages into the conda environment +* Optionally install ipython and other packages into the conda environment Install Anaconda: @@ -357,7 +358,7 @@ $ source activate tensorflow ### Install IPython -To use tensorflow with IPython it may be necessary to install IPython into the tensorflow environment: +To use tensorflow with IPython it may be necessary to install IPython into the tensorflow environment: ```bash $ source activate tensorflow @@ -365,7 +366,7 @@ $ source activate tensorflow ``` Similarly, other Python packages like pandas may need to get installed into the tensorflow environment -before they can be used together with tensorflow. +before they can be used together with tensorflow. ## Docker installation diff --git a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md index 8183cdf0247..e487d8f04cf 100644 --- a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md +++ b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md @@ -86,7 +86,7 @@ def variable_summaries(var, name): mean = tf.reduce_mean(var) tf.scalar_summary('mean/' + name, mean) with tf.name_scope('stddev'): - stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean))) + stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) tf.scalar_summary('sttdev/' + name, stddev) tf.scalar_summary('max/' + name, tf.reduce_max(var)) tf.scalar_summary('min/' + name, tf.reduce_min(var)) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e1f62a81f93..4766a673a4d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -71,6 +71,34 @@ tf_py_test( ], ) +tf_py_test( + name = "flags_test", + size = "small", + srcs = ["platform/flags_test.py"], + additional_deps = [ + ":platform", + ":platform_test", + ], + tags = [ + "manual", + "notap", + ], +) + +tf_py_test( + name = "app_test", + size = "small", + srcs = ["platform/app_test.py"], + additional_deps = [ + ":platform", + ":platform_test", + ], + tags = [ + "manual", + "notap", + ], +) + cc_library( name = "numpy_lib", srcs = ["lib/core/numpy.cc"], diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py index 1f9e69713a9..bb842a2998b 100644 --- a/tensorflow/python/kernel_tests/clip_ops_test.py +++ b/tensorflow/python/kernel_tests/clip_ops_test.py @@ -57,8 +57,13 @@ class ClipTest(tf.test.TestCase): clip_norm = 4.0 ans = tf.clip_by_norm(x, clip_norm) tf_ans = ans.eval() + + clip_tensor = tf.constant(4.0) + ans = tf.clip_by_norm(x, clip_norm) + tf_ans_tensor = ans.eval() self.assertAllClose(np_ans, tf_ans) + self.assertAllClose(np_ans, tf_ans_tensor) def testClipByNormNotClipped(self): # No norm clipping when clip_norm >= 5 @@ -148,6 +153,28 @@ class ClipTest(tf.test.TestCase): self.assertAllClose(np_ans_0, tf_ans_1) self.assertAllClose(np_ans_1, tf_ans_2) + def testClipByGlobalNormClippedTensor(self): + # Norm clipping when clip_norm < 5 + with self.test_session(): + x0 = tf.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3]) + x1 = tf.constant([1.0, -2.0]) + # Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5 + clip_norm = tf.constant(4.0) + + # Answers are the original tensors scaled by 4.0/5.0 + np_ans_0 = [[-1.6, 0.0, 0.0], + [3.2, 0.0, 0.0]] + np_ans_1 = [0.8, -1.6] + + ans, norm = tf.clip_by_global_norm((x0, x1), clip_norm) + tf_ans_1 = ans[0].eval() + tf_ans_2 = ans[1].eval() + tf_norm = norm.eval() + + self.assertAllClose(tf_norm, 5.0) + self.assertAllClose(np_ans_0, tf_ans_1) + self.assertAllClose(np_ans_1, tf_ans_2) + def testClipByGlobalNormSupportsNone(self): # Norm clipping when clip_norm < 5 with self.test_session(): @@ -259,6 +286,19 @@ class ClipTest(tf.test.TestCase): self.assertAllClose(np_ans, tf_ans) + def testClipByAverageNormClippedTensor(self): + # Norm clipping when average clip_norm < 0.83333333 + with self.test_session(): + x = tf.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3]) + # Average norm of x = sqrt(3^2 + 4^2) / 6 = 0.83333333 + np_ans = [[-2.88, 0.0, 0.0], + [3.84, 0.0, 0.0]] + clip_norm = tf.constant(0.8) + ans = tf.clip_by_average_norm(x, clip_norm) + tf_ans = ans.eval() + + self.assertAllClose(np_ans, tf_ans) + def testClipByAverageNormNotClipped(self): # No norm clipping when average clip_norm >= 0.83333333 with self.test_session(): diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py index 48fbccd686a..cea9711d322 100644 --- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py +++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py @@ -47,7 +47,7 @@ class ReverseSequenceTest(tf.test.TestCase): self._testReverseSequence(x, batch_dim, seq_dim, seq_lengths, truth, False, expected_err_re) - def _testBasic(self, dtype): + def _testBasic(self, dtype, len_dtype=np.int64): x = np.asarray([ [[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]], @@ -56,7 +56,7 @@ class ReverseSequenceTest(tf.test.TestCase): x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 # reverse dim 2 up to (0:3, none, 0:4) along dim=0 - seq_lengths = np.asarray([3, 0, 4], dtype=np.int64) + seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype) truth_orig = np.asarray( [[[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3 @@ -70,6 +70,9 @@ class ReverseSequenceTest(tf.test.TestCase): batch_dim = 2 self._testBothReverseSequence(x, batch_dim, seq_dim, seq_lengths, truth) + def testSeqLenghtInt32(self): + self._testBasic(np.float32, np.int32) + def testFloatBasic(self): self._testBasic(np.float32) diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py index 29b57e80944..d945af0081c 100644 --- a/tensorflow/python/kernel_tests/sparse_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops_test.py @@ -743,5 +743,28 @@ class SparseMinimumMaximumTest(test_util.TensorFlowTestCase): tf.sparse_maximum(sp_zero, sp_one).eval() +class SparseTransposeTest(tf.test.TestCase): + + def _SparseTensorPlaceholder(self): + return tf.SparseTensor( + tf.placeholder(tf.int64), + tf.placeholder(tf.float64), + tf.placeholder(tf.int64)) + + def testTranspose(self): + with self.test_session(use_gpu=False) as sess: + np.random.seed(1618) + shapes = [np.random.randint(1, 10, size=rank) for rank in range(1, 6)] + for shape in shapes: + for dtype in [np.int32, np.int64, np.float32, np.float64]: + dn_input = np.random.randn(*shape).astype(dtype) + rank = tf.rank(dn_input).eval() + perm = np.random.choice(rank, rank, False) + sp_input, unused_a_nnz = _sparsify(dn_input) + sp_trans = tf.sparse_transpose(sp_input, perm=perm) + dn_trans = tf.sparse_tensor_to_dense(sp_trans).eval() + expected_trans = tf.transpose(dn_input, perm=perm).eval() + self.assertAllEqual(dn_trans, expected_trans) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index 9f2f972f7db..65e05626f04 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -206,7 +206,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm scale = clip_norm * math_ops.minimum( 1.0 / use_norm, - constant_op.constant(1.0 / clip_norm, dtype=use_norm.dtype)) + constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm) values = [ ops.convert_to_tensor( @@ -268,7 +268,7 @@ def clip_by_average_norm(t, clip_norm, name=None): math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t)))) tclip = array_ops.identity( t * clip_norm * math_ops.minimum( - l2norm_inv * n_element, constant_op.constant(1.0 / clip_norm)), + l2norm_inv * n_element, constant_op.constant(1.0) / clip_norm), name=name) return tclip diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index db8f043607e..7990dba3b63 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -80,6 +80,7 @@ Queues](../../how_tos/threading_and_queues/index.md). @@FIFOQueue @@PaddingFIFOQueue @@RandomShuffleQueue +@@PriorityQueue ## Dealing with the filesystem diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 1ac5314ad2b..11ca3ff0d72 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numbers + import numpy as np from tensorflow.python.framework import common_shapes @@ -1131,7 +1133,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): """ with ops.name_scope(name, "dropout", [x]) as name: x = ops.convert_to_tensor(x, name="x") - if isinstance(keep_prob, float) and not 0 < keep_prob <= 1: + if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: raise ValueError("keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) keep_prob = ops.convert_to_tensor(keep_prob, diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index c42da9af5d7..198ccab0211 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -40,6 +40,7 @@ dimension, and dense along all other dimensions. @@sparse_retain @@sparse_reset_shape @@sparse_fill_empty_rows +@@sparse_transpose ## Reduction @@sparse_reduce_sum @@ -1582,3 +1583,49 @@ def _SparseSparseMaximumMinimumShape(op): # pylint: disable=invalid-name op.inputs[4].get_shape().assert_has_rank(1) # b_values op.inputs[5].get_shape().assert_has_rank(1) # b_shape return [tensor_shape.unknown_shape(2), tensor_shape.unknown_shape(1)] + + +def sparse_transpose(sp_input, perm=None, name=None): + """Transposes a `SparseTensor` + + The returned tensor's dimension i will correspond to the input dimension + `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is + the rank of the input tensor. Hence by default, this operation performs a + regular matrix transpose on 2-D input Tensors. + + For example, if `sp_input` has shape `[4, 5]` and `indices` / `values`: + + [0, 3]: b + [0, 1]: a + [3, 1]: d + [2, 0]: c + + then the output will be a `SparseTensor` of shape `[5, 4]` and + `indices` / `values`: + + [0, 2]: c + [1, 0]: a + [1, 3]: d + [3, 0]: b + + Args: + sp_input: The input `SparseTensor`. + perm: A permutation of the dimensions of `sp_input`. + name: A name prefix for the returned tensors (optional) + Returns: + A transposed `SparseTensor`. + + Raises: + TypeError: If `sp_input` is not a `SparseTensor`. + """ + with ops.op_scope([sp_input], name, "SparseTranspose") as name: + if perm is None: + rank = array_ops.rank(sp_input) + perm = (rank - 1) - math_ops.range(0, rank, 1) + indices = sp_input.indices + transposed_indices = array_ops.transpose(array_ops.gather(array_ops.transpose(indices), perm)) + dense_shape = sp_input.shape + transposed_dense_shape = array_ops.gather(dense_shape, perm) + transposed_st = ops.SparseTensor(transposed_indices, sp_input.values, transposed_dense_shape) + transposed_st = sparse_reorder(transposed_st) + return transposed_st diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py index ebca21e075a..7419d844ffd 100644 --- a/tensorflow/python/platform/app.py +++ b/tensorflow/python/platform/app.py @@ -25,6 +25,6 @@ from tensorflow.python.platform import flags def run(main=None): f = flags.FLAGS - f._parse_flags() + flags_passthrough = f._parse_flags() main = main or sys.modules['__main__'].main - sys.exit(main(sys.argv)) + sys.exit(main(sys.argv[:1] + flags_passthrough)) diff --git a/tensorflow/python/platform/app_test.py b/tensorflow/python/platform/app_test.py new file mode 100644 index 00000000000..b7a58dd9fb4 --- /dev/null +++ b/tensorflow/python/platform/app_test.py @@ -0,0 +1,45 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for our flags implementation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.python.platform import app +from tensorflow.python.platform import flags + +FLAGS = flags.FLAGS +flags.DEFINE_boolean('myflag', False, '') + +def main(argv): + if (len(argv) != 3): + print("Length of argv was not 3: ", argv) + sys.exit(-1) + + if argv[1] != "--passthrough": + print("--passthrough argument not in argv") + sys.exit(-1) + + if argv[2] != "extra": + print("'extra' argument not in argv") + sys.exit(-1) + + +if __name__ == '__main__': + sys.argv.extend(["--myflag", "--passthrough", "extra"]) + app.run() diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py index 85f9e2cb860..645f8acecb0 100644 --- a/tensorflow/python/platform/flags.py +++ b/tensorflow/python/platform/flags.py @@ -30,10 +30,11 @@ class _FlagValues(object): self.__dict__['__parsed'] = False def _parse_flags(self): - result, _ = _global_parser.parse_known_args() + result, unparsed = _global_parser.parse_known_args() for flag_name, val in vars(result).items(): self.__dict__['__flags'][flag_name] = val self.__dict__['__parsed'] = True + return unparsed def __getattr__(self, name): """Retrieves the 'value' attribute of the flag --name.""" diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py index c7a8af85d1c..a414116a741 100644 --- a/tensorflow/python/training/device_setter.py +++ b/tensorflow/python/training/device_setter.py @@ -150,7 +150,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps", else: cluster_spec = server_lib.ClusterSpec(cluster).as_dict() # Get ps_job_name from ps_device by striping "/job:". - ps_job_name = ps_device.lstrip("/job:") + ps_job_name = pydev.DeviceSpec.from_string(ps_device).job if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None: return None ps_tasks = len(cluster_spec[ps_job_name]) diff --git a/tensorflow/python/training/device_setter_test.py b/tensorflow/python/training/device_setter_test.py index 8fb2768c5f7..607cf2d20db 100644 --- a/tensorflow/python/training/device_setter_test.py +++ b/tensorflow/python/training/device_setter_test.py @@ -77,6 +77,23 @@ class DeviceSetterTest(tf.test.TestCase): self.assertDeviceEqual("/job:moon/task:1", w.device) self.assertDeviceEqual("/job:moon/task:1", w.initializer.device) self.assertDeviceEqual("/job:sun", a.device) + + def testPS2TasksWithCPUConstraint(self): + cluster_spec = tf.train.ClusterSpec({ + "sun": ["sun0:2222", "sun1:2222", "sun2:2222"], + "moon": ["moon0:2222", "moon1:2222"]}) + + with tf.device(tf.train.replica_device_setter( + ps_device="/job:moon/cpu:0", worker_device="/job:sun", + cluster=cluster_spec.as_cluster_def())): + v = tf.Variable([1, 2]) + w = tf.Variable([2, 1]) + a = v + w + self.assertDeviceEqual("/job:moon/task:0/cpu:0", v.device) + self.assertDeviceEqual("/job:moon/task:0/cpu:0", v.initializer.device) + self.assertDeviceEqual("/job:moon/task:1/cpu:0", w.device) + self.assertDeviceEqual("/job:moon/task:1/cpu:0", w.initializer.device) + self.assertDeviceEqual("/job:sun", a.device) if __name__ == "__main__": diff --git a/tensorflow/tensorboard/components/tf-dashboard-common/dashboard-style.html b/tensorflow/tensorboard/components/tf-dashboard-common/dashboard-style.html index 15919942923..4152472cd50 100644 --- a/tensorflow/tensorboard/components/tf-dashboard-common/dashboard-style.html +++ b/tensorflow/tensorboard/components/tf-dashboard-common/dashboard-style.html @@ -53,7 +53,7 @@ limitations under the License. } .card .card-bottom-row { position: absolute; - left: 50px; + left: 75px; bottom: 0; padding-right: 10px; } @@ -71,6 +71,14 @@ limitations under the License. display: block; } + .log-option-button { + position: absolute; + left: 25px; + bottom: 0px; + color: #2196F3; + display: block; + } + #content-container{ display: block; } diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/tf-event-dashboard.html b/tensorflow/tensorboard/components/tf-event-dashboard/tf-event-dashboard.html index 417f64b42b3..bbbfdc1393a 100644 --- a/tensorflow/tensorboard/components/tf-event-dashboard/tf-event-dashboard.html +++ b/tensorflow/tensorboard/components/tf-event-dashboard/tf-event-dashboard.html @@ -145,6 +145,13 @@ The #center div contains tf-line-charts embedded inside tf-collapsable-panes. icon="fullscreen" on-tap="toggleSelected" > + +