Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
1f094bccdb
11
.mention-bot
Normal file
11
.mention-bot
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"maxReviewers": 2,
|
||||||
|
"numFilesToCheck": 10, // Number of files to check against, default is 5
|
||||||
|
"userBlacklist": ["tensorflower-gardener"], // users in this list will never be mentioned by mention-bot
|
||||||
|
"requiredOrgs": ["tensorflow"], // mention-bot will only mention user who are a member of one of these organizations
|
||||||
|
"skipAlreadyAssignedPR": true, // mention-bot will ignore already assigned PR's
|
||||||
|
"skipAlreadyMentionedPR": true, // mention-bot will ignore if there is already existing an exact mention
|
||||||
|
"skipTitle": "Branch", // mention-bot will ignore PR that includes text in the title,
|
||||||
|
"delayed": true, // mention-bot will wait to comment until specified time in `delayedUntil` value
|
||||||
|
"delayedUntil": "10m",
|
||||||
|
}
|
@ -33,10 +33,10 @@ and discussion.**
|
|||||||
|
|
||||||
People who are a little more adventurous can also try our nightly binaries:
|
People who are a little more adventurous can also try our nightly binaries:
|
||||||
|
|
||||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/))
|
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/))
|
||||||
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||||
* [Android](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
|
* [Android](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
|
||||||
|
|
||||||
#### *Try your first TensorFlow program*
|
#### *Try your first TensorFlow program*
|
||||||
|
@ -10,6 +10,7 @@ BUS_ANY was used.
|
|||||||
|
|
||||||
## Major Features and Improvements
|
## Major Features and Improvements
|
||||||
|
|
||||||
|
* CUDA 8 support.
|
||||||
* cuDNN 5 support.
|
* cuDNN 5 support.
|
||||||
* HDFS Support.
|
* HDFS Support.
|
||||||
* Adds Fused LSTM support via cuDNN 5 in `tensorflow/contrib/cudnn_rnn`.
|
* Adds Fused LSTM support via cuDNN 5 in `tensorflow/contrib/cudnn_rnn`.
|
||||||
|
19
WORKSPACE
19
WORKSPACE
@ -153,8 +153,8 @@ new_http_archive(
|
|||||||
new_http_archive(
|
new_http_archive(
|
||||||
name = "iron_iconset_svg",
|
name = "iron_iconset_svg",
|
||||||
build_file = "bower.BUILD",
|
build_file = "bower.BUILD",
|
||||||
url = "https://github.com/polymerelements/iron-iconset-svg/archive/v1.0.10.tar.gz",
|
url = "https://github.com/polymerelements/iron-iconset-svg/archive/v1.1.0.tar.gz",
|
||||||
strip_prefix = "iron-iconset-svg-1.0.10",
|
strip_prefix = "iron-iconset-svg-1.1.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
new_http_archive(
|
new_http_archive(
|
||||||
@ -188,8 +188,8 @@ new_http_archive(
|
|||||||
new_http_archive(
|
new_http_archive(
|
||||||
name = "iron_overlay_behavior",
|
name = "iron_overlay_behavior",
|
||||||
build_file = "bower.BUILD",
|
build_file = "bower.BUILD",
|
||||||
url = "https://github.com/polymerelements/iron-overlay-behavior/archive/v1.9.0.tar.gz",
|
url = "https://github.com/polymerelements/iron-overlay-behavior/archive/v1.10.1.tar.gz",
|
||||||
strip_prefix = "iron-overlay-behavior-1.9.0",
|
strip_prefix = "iron-overlay-behavior-1.10.1",
|
||||||
)
|
)
|
||||||
|
|
||||||
new_http_archive(
|
new_http_archive(
|
||||||
@ -206,6 +206,13 @@ new_http_archive(
|
|||||||
strip_prefix = "iron-resizable-behavior-1.0.3",
|
strip_prefix = "iron-resizable-behavior-1.0.3",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
new_http_archive(
|
||||||
|
name = "iron_scroll_target_behavior",
|
||||||
|
build_file = "bower.BUILD",
|
||||||
|
url = "https://github.com/polymerelements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz",
|
||||||
|
strip_prefix = "iron-scroll-target-behavior-1.0.3",
|
||||||
|
)
|
||||||
|
|
||||||
new_http_archive(
|
new_http_archive(
|
||||||
name = "iron_selector",
|
name = "iron_selector",
|
||||||
build_file = "bower.BUILD",
|
build_file = "bower.BUILD",
|
||||||
@ -291,8 +298,8 @@ new_http_archive(
|
|||||||
new_http_archive(
|
new_http_archive(
|
||||||
name = "paper_icon_button",
|
name = "paper_icon_button",
|
||||||
build_file = "bower.BUILD",
|
build_file = "bower.BUILD",
|
||||||
url = "https://github.com/polymerelements/paper-icon-button/archive/v1.1.2.tar.gz",
|
url = "https://github.com/polymerelements/paper-icon-button/archive/v1.1.3.tar.gz",
|
||||||
strip_prefix = "paper-icon-button-1.1.2",
|
strip_prefix = "paper-icon-button-1.1.3",
|
||||||
)
|
)
|
||||||
|
|
||||||
new_http_archive(
|
new_http_archive(
|
||||||
|
@ -209,6 +209,7 @@ filegroup(
|
|||||||
name = "iron_overlay_behavior",
|
name = "iron_overlay_behavior",
|
||||||
srcs = [
|
srcs = [
|
||||||
"index.html",
|
"index.html",
|
||||||
|
"iron-focusables-helper.html",
|
||||||
"iron-overlay-backdrop.html",
|
"iron-overlay-backdrop.html",
|
||||||
"iron-overlay-behavior.html",
|
"iron-overlay-behavior.html",
|
||||||
"iron-overlay-manager.html",
|
"iron-overlay-manager.html",
|
||||||
@ -232,6 +233,14 @@ filegroup(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "iron_scroll_target_behavior",
|
||||||
|
srcs = [
|
||||||
|
"index.html",
|
||||||
|
"iron-scroll-target-behavior.html",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "iron_selector",
|
name = "iron_selector",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -62,8 +62,6 @@ cc_library(
|
|||||||
# This define (mostly) guarantees we don't link any problematic
|
# This define (mostly) guarantees we don't link any problematic
|
||||||
# code. We use it, but we do not rely on it, as evidenced above.
|
# code. We use it, but we do not rely on it, as evidenced above.
|
||||||
"EIGEN_MPL2_ONLY",
|
"EIGEN_MPL2_ONLY",
|
||||||
# TODO(jart): Use EIGEN_USE_NONBLOCKING_THREAD_POOL but first add an
|
|
||||||
# eigen_initialize.cc file and alwayslink=1.
|
|
||||||
],
|
],
|
||||||
includes = ["."],
|
includes = ["."],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
@ -105,6 +105,7 @@ filegroup(
|
|||||||
"//tensorflow/contrib/framework:all_files",
|
"//tensorflow/contrib/framework:all_files",
|
||||||
"//tensorflow/contrib/graph_editor:all_files",
|
"//tensorflow/contrib/graph_editor:all_files",
|
||||||
"//tensorflow/contrib/grid_rnn:all_files",
|
"//tensorflow/contrib/grid_rnn:all_files",
|
||||||
|
"//tensorflow/contrib/integrate:all_files",
|
||||||
"//tensorflow/contrib/layers:all_files",
|
"//tensorflow/contrib/layers:all_files",
|
||||||
"//tensorflow/contrib/layers/kernels:all_files",
|
"//tensorflow/contrib/layers/kernels:all_files",
|
||||||
"//tensorflow/contrib/learn:all_files",
|
"//tensorflow/contrib/learn:all_files",
|
||||||
@ -148,7 +149,6 @@ filegroup(
|
|||||||
"//tensorflow/examples/image_retraining:all_files",
|
"//tensorflow/examples/image_retraining:all_files",
|
||||||
"//tensorflow/examples/label_image:all_files",
|
"//tensorflow/examples/label_image:all_files",
|
||||||
"//tensorflow/examples/learn:all_files",
|
"//tensorflow/examples/learn:all_files",
|
||||||
"//tensorflow/examples/skflow:all_files",
|
|
||||||
"//tensorflow/examples/tutorials/estimators:all_files",
|
"//tensorflow/examples/tutorials/estimators:all_files",
|
||||||
"//tensorflow/examples/tutorials/mnist:all_files",
|
"//tensorflow/examples/tutorials/mnist:all_files",
|
||||||
"//tensorflow/examples/tutorials/word2vec:all_files",
|
"//tensorflow/examples/tutorials/word2vec:all_files",
|
||||||
|
@ -264,6 +264,36 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "nn_grad",
|
||||||
|
srcs = ["gradients/nn_grad.cc"],
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":grad_op_registry",
|
||||||
|
":ops",
|
||||||
|
":scope",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "gradients_nn_grad_test",
|
||||||
|
srcs = ["gradients/nn_grad_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":grad_op_registry",
|
||||||
|
":grad_testutil",
|
||||||
|
":gradient_checker",
|
||||||
|
":nn_grad",
|
||||||
|
":testutil",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_wrappers_cc(
|
tf_gen_op_wrappers_cc(
|
||||||
name = "cc_ops",
|
name = "cc_ops",
|
||||||
op_lib_names = [
|
op_lib_names = [
|
||||||
@ -411,6 +441,7 @@ cc_library(
|
|||||||
srcs = ["training/queue_runner.cc"],
|
srcs = ["training/queue_runner.cc"],
|
||||||
hdrs = ["training/queue_runner.h"],
|
hdrs = ["training/queue_runner.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":coordinator",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -425,6 +456,7 @@ tf_cc_test(
|
|||||||
name = "queue_runner_test",
|
name = "queue_runner_test",
|
||||||
srcs = ["training/queue_runner_test.cc"],
|
srcs = ["training/queue_runner_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"coordinator",
|
||||||
":cc_ops",
|
":cc_ops",
|
||||||
":queue_runner",
|
":queue_runner",
|
||||||
":scope",
|
":scope",
|
||||||
@ -439,3 +471,37 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "coordinator",
|
||||||
|
srcs = ["training/coordinator.cc"],
|
||||||
|
hdrs = ["training/coordinator.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "coordinator_test",
|
||||||
|
srcs = ["training/coordinator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":coordinator",
|
||||||
|
":queue_runner",
|
||||||
|
":scope",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -110,20 +110,15 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const ops::Output& x,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
Status ComputeGradientErrorInternal(const Scope& scope, const ops::Output& x,
|
||||||
const TensorShape& x_shape, const ops::Output& y,
|
const TensorShape& x_shape,
|
||||||
const TensorShape& y_shape, T* max_error) {
|
const ops::Output& y,
|
||||||
|
const TensorShape& y_shape, Tensor* x_data,
|
||||||
|
T* max_error) {
|
||||||
const int64 x_size = x_shape.num_elements();
|
const int64 x_size = x_shape.num_elements();
|
||||||
const int64 y_size = y_shape.num_elements();
|
const int64 y_size = y_shape.num_elements();
|
||||||
|
|
||||||
// Initialize 'x_data' to random values.
|
|
||||||
Tensor x_data(x.type(), x_shape);
|
|
||||||
auto x_data_flat = x_data.flat<T>();
|
|
||||||
x_data_flat.setRandom();
|
|
||||||
|
|
||||||
// Initialize theoretical Jacobian to zeros.
|
// Initialize theoretical Jacobian to zeros.
|
||||||
Tensor jacobian_t(x.type(), {x_size, y_size});
|
Tensor jacobian_t(x.type(), {x_size, y_size});
|
||||||
auto jacobian_t_flat = jacobian_t.flat<T>();
|
auto jacobian_t_flat = jacobian_t.flat<T>();
|
||||||
@ -131,7 +126,7 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
|||||||
|
|
||||||
// Compute theoretical Jacobian.
|
// Compute theoretical Jacobian.
|
||||||
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
|
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
|
||||||
scope, x, x_shape, x_data, y, y_shape, &jacobian_t));
|
scope, x, x_shape, *x_data, y, y_shape, &jacobian_t));
|
||||||
|
|
||||||
// Initialize numeric Jacobian to zeros.
|
// Initialize numeric Jacobian to zeros.
|
||||||
Tensor jacobian_n(x.type(), {x_size, y_size});
|
Tensor jacobian_n(x.type(), {x_size, y_size});
|
||||||
@ -140,7 +135,7 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
|||||||
|
|
||||||
// Compute numeric Jacobian.
|
// Compute numeric Jacobian.
|
||||||
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
|
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
|
||||||
scope, x, x_shape, y, y_shape, 1e-3, &x_data, &jacobian_n));
|
scope, x, x_shape, y, y_shape, 1e-3, x_data, &jacobian_n));
|
||||||
|
|
||||||
// Compute the maximum error between theoretical and numeric Jacobians.
|
// Compute the maximum error between theoretical and numeric Jacobians.
|
||||||
*max_error = 0.0;
|
*max_error = 0.0;
|
||||||
@ -154,10 +149,39 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||||
|
const TensorShape& x_shape, const ops::Output& y,
|
||||||
|
const TensorShape& y_shape, T* max_error) {
|
||||||
|
// Initialize 'x_data' to random values.
|
||||||
|
Tensor x_data(x.type(), x_shape);
|
||||||
|
auto x_data_flat = x_data.flat<T>();
|
||||||
|
x_data_flat.setRandom();
|
||||||
|
// Compute gradient error.
|
||||||
|
return ComputeGradientErrorInternal(scope, x, x_shape, y, y_shape, &x_data,
|
||||||
|
max_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||||
|
const Tensor& x_init_value, const ops::Output& y,
|
||||||
|
const TensorShape& y_shape, T* max_error) {
|
||||||
|
// Initialize 'x_data' from 'x_init_value'.
|
||||||
|
Tensor x_data(x_init_value);
|
||||||
|
// Compute gradient error.
|
||||||
|
return ComputeGradientErrorInternal(scope, x, x_data.shape(), y, y_shape,
|
||||||
|
&x_data, max_error);
|
||||||
|
}
|
||||||
|
|
||||||
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
|
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
|
||||||
template Status ComputeGradientError<T>( \
|
template Status ComputeGradientError<T>( \
|
||||||
const Scope& scope, const ops::Output& x, const TensorShape& x_shape, \
|
const Scope& scope, const ops::Output& x, const TensorShape& x_shape, \
|
||||||
const ops::Output& y, const TensorShape& y_shape, T* max_error)
|
const ops::Output& y, const TensorShape& y_shape, T* max_error); \
|
||||||
|
template Status ComputeGradientError<T>( \
|
||||||
|
const Scope& scope, const ops::Output& x, const Tensor& x_init_value, \
|
||||||
|
const ops::Output& y, const TensorShape& y_shape, T* max_error);
|
||||||
|
|
||||||
INSTANTIATE_GRAD_ERR_TYPE(float);
|
INSTANTIATE_GRAD_ERR_TYPE(float);
|
||||||
INSTANTIATE_GRAD_ERR_TYPE(double);
|
INSTANTIATE_GRAD_ERR_TYPE(double);
|
||||||
|
@ -30,6 +30,12 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
|||||||
const TensorShape& x_shape, const ops::Output& y,
|
const TensorShape& x_shape, const ops::Output& y,
|
||||||
const TensorShape& y_shape, T* max_error);
|
const TensorShape& y_shape, T* max_error);
|
||||||
|
|
||||||
|
// Overload of ComputeGradientError which takes an initial value for 'x'.
|
||||||
|
template <typename T>
|
||||||
|
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||||
|
const Tensor& x_init_value, const ops::Output& y,
|
||||||
|
const TensorShape& y_shape, T* max_error);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
||||||
|
77
tensorflow/cc/gradients/nn_grad.cc
Normal file
77
tensorflow/cc/gradients/nn_grad.cc
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status SoftmaxGrad(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
// Softmax gradient function.
|
||||||
|
// p = softmax(x) maps from [batch, n] to [batch, m]
|
||||||
|
// dp/dx = [dp0/dx0 ... dp0/dxn-1 ]
|
||||||
|
// [ ... ... ]
|
||||||
|
// [dpm-1/dx0 ... dpm-1/dxn-1]
|
||||||
|
// dL/dx = dp/dx * dL/dy
|
||||||
|
//
|
||||||
|
// Using alternative formula:
|
||||||
|
// dL/dx = dL/dy * y - sum(dL/dy * y) * y
|
||||||
|
// = (dL/dy - sum(dL/dy * y)) * y
|
||||||
|
auto y = op.output(0);
|
||||||
|
auto dyy = Mul(scope, grad_inputs[0], y);
|
||||||
|
auto sum = Reshape(scope, Sum(scope, dyy, {1}), {-1, 1});
|
||||||
|
auto sub = Sub(scope, grad_inputs[0], sum);
|
||||||
|
auto dx = Mul(scope, sub, y);
|
||||||
|
grad_outputs->push_back(dx);
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
|
||||||
|
|
||||||
|
Status ReluGradHelper(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
auto dx = ReluGrad(scope, grad_inputs[0], op.input(0));
|
||||||
|
grad_outputs->push_back(dx);
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("Relu", ReluGradHelper);
|
||||||
|
|
||||||
|
Status Relu6GradHelper(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
auto dx = Relu6Grad(scope, grad_inputs[0], op.input(0));
|
||||||
|
grad_outputs->push_back(dx);
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
|
||||||
|
|
||||||
|
Status EluGradHelper(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
auto dx = EluGrad(scope, grad_inputs[0], op.output(0));
|
||||||
|
grad_outputs->push_back(dx);
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("Elu", EluGradHelper);
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tensorflow
|
91
tensorflow/cc/gradients/nn_grad_test.cc
Normal file
91
tensorflow/cc/gradients/nn_grad_test.cc
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||||
|
#include "tensorflow/cc/framework/testutil.h"
|
||||||
|
#include "tensorflow/cc/gradients/grad_testutil.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
using namespace ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class NNGradTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
NNGradTest() : scope_(Scope::NewRootScope()) {}
|
||||||
|
|
||||||
|
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
|
||||||
|
const TensorShape& y_shape) {
|
||||||
|
float max_error;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
ComputeGradientError(scope_, x, x_shape, y, y_shape, &max_error));
|
||||||
|
EXPECT_LT(max_error, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
|
||||||
|
const TensorShape& y_shape) {
|
||||||
|
float max_error;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
|
||||||
|
EXPECT_LT(max_error, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
Scope scope_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(NNGradTest, SoftmaxGrad) {
|
||||||
|
TensorShape shape({32, 10});
|
||||||
|
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||||
|
auto y = Softmax(scope_, x);
|
||||||
|
RunTest(x, shape, y, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NNGradTest, ReluGrad) {
|
||||||
|
TensorShape shape({5, 2});
|
||||||
|
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||||
|
auto y = Relu(scope_, x);
|
||||||
|
// Avoid input values where ReLU gradient is not well defined (around zero).
|
||||||
|
Tensor x_init_value = test::AsTensor<float>(
|
||||||
|
{-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9}, {5, 2});
|
||||||
|
RunTest(x, x_init_value, y, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NNGradTest, Relu6Grad) {
|
||||||
|
TensorShape shape({5, 2});
|
||||||
|
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||||
|
auto y = Relu6(scope_, x);
|
||||||
|
// Avoid input values where ReLU gradient is not well defined (around zero
|
||||||
|
// and six).
|
||||||
|
Tensor x_init_value = test::AsTensor<float>(
|
||||||
|
{-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9}, {5, 2});
|
||||||
|
RunTest(x, x_init_value, y, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NNGradTest, EluGrad) {
|
||||||
|
TensorShape shape({5, 2});
|
||||||
|
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||||
|
auto y = Elu(scope_, x);
|
||||||
|
Tensor x_init_value = test::AsTensor<float>(
|
||||||
|
{-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9}, {5, 2});
|
||||||
|
RunTest(x, x_init_value, y, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
90
tensorflow/cc/training/coordinator.cc
Normal file
90
tensorflow/cc/training/coordinator.cc
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/training/coordinator.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Coordinator::Coordinator() : Coordinator(std::vector<error::Code>()) {}
|
||||||
|
|
||||||
|
Coordinator::Coordinator(const std::vector<error::Code>& clean_stop_errors)
|
||||||
|
: should_stop_(false) {
|
||||||
|
if (clean_stop_errors.empty()) {
|
||||||
|
clean_stop_errors_.insert(error::OUT_OF_RANGE);
|
||||||
|
} else {
|
||||||
|
for (const auto& code : clean_stop_errors) {
|
||||||
|
clean_stop_errors_.insert(static_cast<int>(code));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Coordinator::~Coordinator() {
|
||||||
|
RequestStop();
|
||||||
|
Join();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) {
|
||||||
|
runners_.push_back(std::move(runner));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Coordinator::RequestStop() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (should_stop_) {
|
||||||
|
return Status(error::FAILED_PRECONDITION,
|
||||||
|
"The Coordinator is not running.");
|
||||||
|
}
|
||||||
|
should_stop_ = true;
|
||||||
|
wait_for_stop_.notify_all();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Coordinator::ShouldStop() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
return should_stop_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Coordinator::Join() {
|
||||||
|
// TODO(yuefengz): deal with unexpected calls to Join().
|
||||||
|
// TODO(yuefengz): deal with stragglers.
|
||||||
|
for (const auto& t : runners_) {
|
||||||
|
ReportStatus(t->Join());
|
||||||
|
}
|
||||||
|
runners_.clear();
|
||||||
|
return status_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Coordinator::ReportStatus(const Status& status) {
|
||||||
|
mutex_lock l(status_lock_);
|
||||||
|
if (status.ok() || !status_.ok() ||
|
||||||
|
clean_stop_errors_.count(static_cast<int>(status.code())) > 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
status_ = status;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Coordinator::GetStatus() {
|
||||||
|
mutex_lock l(status_lock_);
|
||||||
|
return status_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Coordinator::WaitForStop() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
while (!should_stop_) {
|
||||||
|
wait_for_stop_.wait(l);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
109
tensorflow/cc/training/coordinator.h
Normal file
109
tensorflow/cc/training/coordinator.h
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// The abstract interface for runners which must implement the Join function.
|
||||||
|
class RunnerInterface {
|
||||||
|
public:
|
||||||
|
virtual ~RunnerInterface() {}
|
||||||
|
virtual Status Join() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Coordinator class manages the termination of a collection of QueueRunners.
|
||||||
|
// Without a coordinator, QueueRunners have to be joined in a specific order;
|
||||||
|
// otherwise the QueueRunner::Join() could sometimes hang. The
|
||||||
|
// Coordinator::RequestStop() plays the key role which notifies all running
|
||||||
|
// threads under a coordinator to stop. This function could be called by any
|
||||||
|
// thread or any client.
|
||||||
|
// Usage, in the client:
|
||||||
|
// Coordinator coord;
|
||||||
|
// std::unique_ptr<QueueRunner> qr(&coord, ...);
|
||||||
|
// qr.Start(session);
|
||||||
|
// coord.RegisterRunner(std::move(qr));
|
||||||
|
// // do some work
|
||||||
|
// TF_CHECK_OK(coord.Join());
|
||||||
|
// In each thread of QueueRunner, the coordinator needs to be used as:
|
||||||
|
// void Run() {
|
||||||
|
// while (!coord->ShouldStop()) {
|
||||||
|
// // do some work
|
||||||
|
// if (error) {
|
||||||
|
// coord->RequestStop();
|
||||||
|
// coord->ReportStatus(error_status);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
class Coordinator {
|
||||||
|
public:
|
||||||
|
Coordinator();
|
||||||
|
|
||||||
|
// Constructor with a list of error codes which would not be taken as errors
|
||||||
|
// in status reporting.
|
||||||
|
Coordinator(const std::vector<error::Code>& clean_stop_errors);
|
||||||
|
|
||||||
|
// In the destructor, RequestStop() and Join() would be called.
|
||||||
|
~Coordinator();
|
||||||
|
|
||||||
|
// Registers a runner, i.e. a unit of running threads which is usually a
|
||||||
|
// QueueRunner. It takes the ownership of runner to avoid lifecycle-related
|
||||||
|
// problems. Note, the coordinator would not start these threads; they are
|
||||||
|
// supposed to be in running state when they are registered here.
|
||||||
|
Status RegisterRunner(std::unique_ptr<RunnerInterface> runner);
|
||||||
|
|
||||||
|
// Requests all running threads to stop.
|
||||||
|
Status RequestStop();
|
||||||
|
|
||||||
|
// Returns true if its RequestStop() has been called.
|
||||||
|
bool ShouldStop();
|
||||||
|
|
||||||
|
// Joins all threads, returns OK or the first reported and unexpected status.
|
||||||
|
Status Join();
|
||||||
|
|
||||||
|
// Reports status to the coordinator. This is usually called by threads.
|
||||||
|
void ReportStatus(const Status& status);
|
||||||
|
|
||||||
|
// Returns the latest status.
|
||||||
|
Status GetStatus();
|
||||||
|
|
||||||
|
// Returns immediately if the coordinator is stopped or blocks until
|
||||||
|
// RequestStop() is called.
|
||||||
|
void WaitForStop();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<std::unique_ptr<RunnerInterface>> runners_;
|
||||||
|
std::unordered_set<int> clean_stop_errors_;
|
||||||
|
mutex mu_;
|
||||||
|
bool should_stop_ GUARDED_BY(mu_);
|
||||||
|
mutex status_lock_;
|
||||||
|
Status status_;
|
||||||
|
condition_variable wait_for_stop_;
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
|
183
tensorflow/cc/training/coordinator_test.cc
Normal file
183
tensorflow/cc/training/coordinator_test.cc
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/training/coordinator.h"
|
||||||
|
|
||||||
|
#include "tensorflow/cc/training/queue_runner.h"
|
||||||
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
|
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using error::Code;
|
||||||
|
|
||||||
|
void WaitForStopThread(Coordinator* coord, bool* stopped, Notification* done) {
|
||||||
|
coord->WaitForStop();
|
||||||
|
*stopped = true;
|
||||||
|
done->Notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CoordinatorTest, TestStopAndWaitOnStop) {
|
||||||
|
Coordinator coord;
|
||||||
|
EXPECT_EQ(coord.ShouldStop(), false);
|
||||||
|
|
||||||
|
bool stopped = false;
|
||||||
|
Notification done;
|
||||||
|
Env::Default()->SchedClosure(
|
||||||
|
std::bind(&WaitForStopThread, &coord, &stopped, &done));
|
||||||
|
Env::Default()->SleepForMicroseconds(10000000);
|
||||||
|
EXPECT_EQ(stopped, false);
|
||||||
|
|
||||||
|
coord.RequestStop();
|
||||||
|
done.WaitForNotification();
|
||||||
|
EXPECT_EQ(stopped, true);
|
||||||
|
EXPECT_EQ(coord.ShouldStop(), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
class MockQueueRunner : public RunnerInterface {
|
||||||
|
public:
|
||||||
|
MockQueueRunner(Coordinator* coord) {
|
||||||
|
coord_ = coord;
|
||||||
|
join_counter_ = nullptr;
|
||||||
|
thread_pool_.reset(new thread::ThreadPool(Env::Default(), "test-pool", 10));
|
||||||
|
}
|
||||||
|
|
||||||
|
MockQueueRunner(Coordinator* coord, int* join_counter)
|
||||||
|
: MockQueueRunner(coord) {
|
||||||
|
join_counter_ = join_counter;
|
||||||
|
}
|
||||||
|
|
||||||
|
void StartCounting(std::atomic<int>* counter, int until) {
|
||||||
|
thread_pool_->Schedule(
|
||||||
|
std::bind(&MockQueueRunner::CountThread, this, counter, until));
|
||||||
|
}
|
||||||
|
|
||||||
|
void StartSettingStatus(const Status& status, BlockingCounter* counter) {
|
||||||
|
thread_pool_->Schedule(
|
||||||
|
std::bind(&MockQueueRunner::SetStatusThread, this, status, counter));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Join() {
|
||||||
|
if (join_counter_ != nullptr) {
|
||||||
|
(*join_counter_)++;
|
||||||
|
}
|
||||||
|
thread_pool_.reset();
|
||||||
|
return status_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetStatus() { return status_; }
|
||||||
|
|
||||||
|
void SetStatus(const Status& status) { status_ = status; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void CountThread(std::atomic<int>* counter, int until) {
|
||||||
|
while (!coord_->ShouldStop() && counter->load() < until) {
|
||||||
|
(*counter)++;
|
||||||
|
Env::Default()->SleepForMicroseconds(100000);
|
||||||
|
}
|
||||||
|
coord_->RequestStop();
|
||||||
|
}
|
||||||
|
void SetStatusThread(const Status& status, BlockingCounter* counter) {
|
||||||
|
Env::Default()->SleepForMicroseconds(100000);
|
||||||
|
SetStatus(status);
|
||||||
|
counter->DecrementCount();
|
||||||
|
}
|
||||||
|
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||||
|
Status status_;
|
||||||
|
Coordinator* coord_;
|
||||||
|
int* join_counter_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(CoordinatorTest, TestRealStop) {
|
||||||
|
std::atomic<int> counter(0);
|
||||||
|
Coordinator coord;
|
||||||
|
|
||||||
|
std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
|
||||||
|
qr1->StartCounting(&counter, 100);
|
||||||
|
coord.RegisterRunner(std::move(qr1));
|
||||||
|
|
||||||
|
std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
|
||||||
|
qr2->StartCounting(&counter, 100);
|
||||||
|
coord.RegisterRunner(std::move(qr2));
|
||||||
|
|
||||||
|
// Wait until the counting has started
|
||||||
|
while (counter.load() == 0)
|
||||||
|
;
|
||||||
|
coord.RequestStop();
|
||||||
|
|
||||||
|
int temp_counter = counter.load();
|
||||||
|
Env::Default()->SleepForMicroseconds(10000000);
|
||||||
|
EXPECT_EQ(temp_counter, counter.load());
|
||||||
|
TF_EXPECT_OK(coord.Join());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CoordinatorTest, TestRequestStop) {
|
||||||
|
Coordinator coord;
|
||||||
|
std::atomic<int> counter(0);
|
||||||
|
std::unique_ptr<MockQueueRunner> qr;
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
qr.reset(new MockQueueRunner(&coord));
|
||||||
|
qr->StartCounting(&counter, 10);
|
||||||
|
coord.RegisterRunner(std::move(qr));
|
||||||
|
}
|
||||||
|
|
||||||
|
coord.WaitForStop();
|
||||||
|
EXPECT_EQ(coord.ShouldStop(), true);
|
||||||
|
EXPECT_EQ(counter.load(), 10);
|
||||||
|
TF_EXPECT_OK(coord.Join());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CoordinatorTest, TestJoin) {
|
||||||
|
Coordinator coord;
|
||||||
|
int join_counter = 0;
|
||||||
|
std::unique_ptr<MockQueueRunner> qr1(
|
||||||
|
new MockQueueRunner(&coord, &join_counter));
|
||||||
|
coord.RegisterRunner(std::move(qr1));
|
||||||
|
std::unique_ptr<MockQueueRunner> qr2(
|
||||||
|
new MockQueueRunner(&coord, &join_counter));
|
||||||
|
coord.RegisterRunner(std::move(qr2));
|
||||||
|
|
||||||
|
TF_EXPECT_OK(coord.Join());
|
||||||
|
EXPECT_EQ(join_counter, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CoordinatorTest, StatusReporting) {
|
||||||
|
Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE});
|
||||||
|
BlockingCounter counter(3);
|
||||||
|
|
||||||
|
std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
|
||||||
|
qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter);
|
||||||
|
coord.RegisterRunner(std::move(qr1));
|
||||||
|
|
||||||
|
std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
|
||||||
|
qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter);
|
||||||
|
coord.RegisterRunner(std::move(qr2));
|
||||||
|
|
||||||
|
std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord));
|
||||||
|
qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter);
|
||||||
|
coord.RegisterRunner(std::move(qr3));
|
||||||
|
|
||||||
|
counter.Wait();
|
||||||
|
EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -25,6 +25,14 @@ Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
|
|||||||
return (*result)->Init(queue_runner_def);
|
return (*result)->Init(queue_runner_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
|
||||||
|
Coordinator* coord,
|
||||||
|
std::unique_ptr<QueueRunner>* result) {
|
||||||
|
result->reset(new QueueRunner());
|
||||||
|
(*result)->coord_ = coord;
|
||||||
|
return (*result)->Init(queue_runner_def);
|
||||||
|
}
|
||||||
|
|
||||||
Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
|
Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
|
||||||
queue_name_ = queue_runner_def.queue_name();
|
queue_name_ = queue_runner_def.queue_name();
|
||||||
enqueue_op_names_.clear();
|
enqueue_op_names_.clear();
|
||||||
@ -46,8 +54,8 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
thread_pool_.reset(new thread::ThreadPool(
|
thread_pool_.reset(new thread::ThreadPool(
|
||||||
Env::Default(), SanitizeThreadSuffix(queue_name_), runs_));
|
Env::Default(), SanitizeThreadSuffix(queue_name_), runs_ + 1));
|
||||||
should_stop_ = false;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,63 +65,108 @@ QueueRunner::~QueueRunner() {
|
|||||||
Join();
|
Join();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status QueueRunner::Start(Session* sess) {
|
Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
|
||||||
|
|
||||||
|
Status QueueRunner::Start(Session* sess, int wait_for) {
|
||||||
|
counter_.reset(new BlockingCounter(runs_));
|
||||||
for (const string& enqueue_op : enqueue_op_names_) {
|
for (const string& enqueue_op : enqueue_op_names_) {
|
||||||
thread_pool_->Schedule(
|
thread_pool_->Schedule(
|
||||||
std::bind(&QueueRunner::Run, this, sess, enqueue_op));
|
std::bind(&QueueRunner::Run, this, sess, enqueue_op));
|
||||||
}
|
}
|
||||||
|
if (coord_) {
|
||||||
|
thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess));
|
||||||
|
}
|
||||||
|
// Wait for up to 'wait_for' milliseconds.
|
||||||
|
if (wait_for > 0) {
|
||||||
|
if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) {
|
||||||
|
return Status(error::DEADLINE_EXCEEDED,
|
||||||
|
"Queues not fed before the timeout");
|
||||||
|
}
|
||||||
|
// Check the status of the queue runner as well as the result of the enqueue
|
||||||
|
// operations.
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (!enqueue_status_.ok()) {
|
||||||
|
return enqueue_status_;
|
||||||
|
} else {
|
||||||
|
return status_;
|
||||||
|
}
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status QueueRunner::Stop(Session* sess) {
|
void QueueRunner::Stop(Session* sess) {
|
||||||
should_stop_ = true;
|
|
||||||
if (cancel_op_name_.empty()) {
|
if (cancel_op_name_.empty()) {
|
||||||
return Status::OK();
|
return;
|
||||||
} else {
|
} else {
|
||||||
return sess->Run({}, {}, {cancel_op_name_}, nullptr);
|
CHECK(coord_ != nullptr);
|
||||||
|
coord_->WaitForStop();
|
||||||
|
UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status QueueRunner::Join() {
|
Status QueueRunner::Join() {
|
||||||
thread_pool_.reset();
|
thread_pool_.reset();
|
||||||
|
mutex_lock l(mu_);
|
||||||
return status_;
|
return status_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void QueueRunner::UpdateStatus(const Status& status) {
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (!status_.ok() || status.ok() ||
|
||||||
|
queue_closed_exception_types_.count(static_cast<int>(status.code())) >
|
||||||
|
0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
status_ = status;
|
||||||
|
}
|
||||||
|
if (coord_) {
|
||||||
|
coord_->ReportStatus(status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
|
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
|
||||||
bool decremented = false;
|
bool decremented = false;
|
||||||
while (!should_stop_.load()) {
|
bool first_iteration = true;
|
||||||
|
while (true) {
|
||||||
|
if (coord_ && coord_->ShouldStop()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
auto status = sess->Run({}, {}, {enqueue_op}, nullptr);
|
auto status = sess->Run({}, {}, {enqueue_op}, nullptr);
|
||||||
|
if (first_iteration) {
|
||||||
|
if (!status.ok()) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
enqueue_status_ = status;
|
||||||
|
}
|
||||||
|
counter_->DecrementCount();
|
||||||
|
first_iteration = false;
|
||||||
|
}
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
continue;
|
continue;
|
||||||
} else if (queue_closed_exception_types_.count(
|
} else if (queue_closed_exception_types_.count(
|
||||||
static_cast<int>(status.code())) > 0) {
|
static_cast<int>(status.code())) > 0) {
|
||||||
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
runs_--;
|
runs_--;
|
||||||
decremented = true;
|
decremented = true;
|
||||||
should_stop_ = true;
|
}
|
||||||
|
|
||||||
// If all enqueue ops have finished, run the close op.
|
// If all enqueue ops have finished, run the close op.
|
||||||
if (runs_ == 0 && !close_op_name_.empty()) {
|
if (runs_ == 0) {
|
||||||
|
if (!close_op_name_.empty()) {
|
||||||
auto s = sess->Run({}, {}, {close_op_name_}, nullptr);
|
auto s = sess->Run({}, {}, {close_op_name_}, nullptr);
|
||||||
if (!s.ok() && status_.ok() &&
|
UpdateStatus(status);
|
||||||
queue_closed_exception_types_.count(static_cast<int>(s.code())) ==
|
|
||||||
0) {
|
|
||||||
status_ = s;
|
|
||||||
}
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
{
|
UpdateStatus(status);
|
||||||
mutex_lock l(mu_);
|
if (coord_) {
|
||||||
should_stop_ = true;
|
coord_->RequestStop();
|
||||||
// Only record the first failure status.
|
|
||||||
if (status_.ok()) {
|
|
||||||
status_ = status;
|
|
||||||
}
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
// Stop the queue runner immediately to propagate the error to
|
first_iteration = false;
|
||||||
// subsequent queues.
|
|
||||||
Stop(sess);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!decremented) {
|
if (!decremented) {
|
||||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/training/coordinator.h"
|
||||||
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
@ -32,7 +34,7 @@ namespace tensorflow {
|
|||||||
|
|
||||||
// QueueRunner class imitates the behavior of the python version of QueueRunner
|
// QueueRunner class imitates the behavior of the python version of QueueRunner
|
||||||
// which creates a thread for each enqueue op, runs close op on completion.
|
// which creates a thread for each enqueue op, runs close op on completion.
|
||||||
class QueueRunner {
|
class QueueRunner : public RunnerInterface {
|
||||||
public:
|
public:
|
||||||
// Creates a new QueueRunner from proto.
|
// Creates a new QueueRunner from proto.
|
||||||
// TODO(yuefengz): we may want to initialize from queues and ops in the
|
// TODO(yuefengz): we may want to initialize from queues and ops in the
|
||||||
@ -40,24 +42,29 @@ class QueueRunner {
|
|||||||
static Status New(const QueueRunnerDef& queue_runner_def,
|
static Status New(const QueueRunnerDef& queue_runner_def,
|
||||||
std::unique_ptr<QueueRunner>* result);
|
std::unique_ptr<QueueRunner>* result);
|
||||||
|
|
||||||
|
// Creates a new QueueRunner with a coordinator, see coordinator.h for usage.
|
||||||
|
static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord,
|
||||||
|
std::unique_ptr<QueueRunner>* result);
|
||||||
|
|
||||||
// The destructor would join all the threads.
|
// The destructor would join all the threads.
|
||||||
~QueueRunner();
|
~QueueRunner();
|
||||||
|
|
||||||
// Starts the queue runner with the given session.
|
// Starts the queue runner with the given session.
|
||||||
Status Start(Session* sess);
|
Status Start(Session* sess);
|
||||||
|
|
||||||
// Requests to stop and runs the cancel op.
|
// Starts the queue runner with the given session, and wait for up to the
|
||||||
Status Stop(Session* sess);
|
// specified time (in milliseconds) for the queues to start to fill up.
|
||||||
|
Status Start(Session* sess, int wait_for);
|
||||||
|
|
||||||
// Joins all the threads. Returns okay if all threads run successfully;
|
// Joins all the threads. Returns okay if all threads run successfully;
|
||||||
// otherwise returns the first captured failure status.
|
// otherwise returns the first captured failure status.
|
||||||
Status Join();
|
Status Join() final;
|
||||||
|
|
||||||
// Returns the lastest status.
|
// Returns the lastest status.
|
||||||
Status GetStatus();
|
Status GetStatus();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
QueueRunner() {}
|
QueueRunner() : coord_(nullptr) {}
|
||||||
|
|
||||||
// Initializes the instance with the QueueRunnerDef proto.
|
// Initializes the instance with the QueueRunnerDef proto.
|
||||||
Status Init(const QueueRunnerDef& queue_runner_def);
|
Status Init(const QueueRunnerDef& queue_runner_def);
|
||||||
@ -65,6 +72,14 @@ class QueueRunner {
|
|||||||
// The Run function for each thread.
|
// The Run function for each thread.
|
||||||
void Run(Session* sess, const string& enqueue_op);
|
void Run(Session* sess, const string& enqueue_op);
|
||||||
|
|
||||||
|
// Requests to stop and runs the cancel op. It would be called in a separate
|
||||||
|
// thread when coordinator is set.
|
||||||
|
void Stop(Session* sess);
|
||||||
|
|
||||||
|
// Updates the internal status; it only keeps OK or the first unexpected error
|
||||||
|
// status.
|
||||||
|
void UpdateStatus(const Status& status);
|
||||||
|
|
||||||
string queue_name_;
|
string queue_name_;
|
||||||
std::vector<string> enqueue_op_names_;
|
std::vector<string> enqueue_op_names_;
|
||||||
string close_op_name_;
|
string close_op_name_;
|
||||||
@ -73,12 +88,15 @@ class QueueRunner {
|
|||||||
std::unordered_set<int> queue_closed_exception_types_;
|
std::unordered_set<int> queue_closed_exception_types_;
|
||||||
|
|
||||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||||
std::atomic<bool> should_stop_;
|
|
||||||
condition_variable wait_to_close_;
|
condition_variable wait_to_close_;
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
// TODO(yuefengz): implement c++ coordinator.
|
// TODO(yuefengz): implement c++ coordinator.
|
||||||
int runs_ = 0;
|
int runs_ = 0;
|
||||||
Status status_;
|
Status status_ GUARDED_BY(mu_);
|
||||||
|
Status enqueue_status_ GUARDED_BY(mu_);
|
||||||
|
std::unique_ptr<BlockingCounter> counter_;
|
||||||
|
|
||||||
|
Coordinator* coord_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/cc/framework/scope.h"
|
#include "tensorflow/cc/framework/scope.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/cc/training/coordinator.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
@ -111,7 +112,7 @@ TEST(QueueRunnerTest, BasicTest) {
|
|||||||
auto session = BuildSessionAndInitVariable(graph_def);
|
auto session = BuildSessionAndInitVariable(graph_def);
|
||||||
|
|
||||||
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
|
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
|
||||||
kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {});
|
kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
|
||||||
|
|
||||||
std::unique_ptr<QueueRunner> qr;
|
std::unique_ptr<QueueRunner> qr;
|
||||||
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
|
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
|
||||||
@ -164,7 +165,8 @@ GraphDef BuildDoubleQueueGraph() {
|
|||||||
auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
|
auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
|
||||||
auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
|
auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
|
||||||
QueueClose::CancelPendingEnqueues(true));
|
QueueClose::CancelPendingEnqueues(true));
|
||||||
auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32});
|
auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32},
|
||||||
|
FIFOQueue::Capacity(3));
|
||||||
auto dequeue0 =
|
auto dequeue0 =
|
||||||
QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32});
|
QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32});
|
||||||
auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]});
|
auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]});
|
||||||
@ -252,34 +254,34 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
|
|||||||
EXPECT_EQ(join_succeeded, true);
|
EXPECT_EQ(join_succeeded, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QueueRunnerTest, Stop) {
|
TEST(QueueRunnerTest, EmptyEnqueueOps) {
|
||||||
auto graph_def = BuildDoubleQueueGraph();
|
QueueRunnerDef queue_runner_def =
|
||||||
|
BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
|
||||||
|
|
||||||
|
std::unique_ptr<QueueRunner> qr;
|
||||||
|
EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
|
||||||
|
Code::INVALID_ARGUMENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(QueueRunnerTest, StartTimeout) {
|
||||||
|
GraphDef graph_def = BuildDoubleQueueGraph();
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
std::unique_ptr<Session> session(NewSession(options));
|
std::unique_ptr<Session> session(NewSession(options));
|
||||||
TF_CHECK_OK(session->Create(graph_def));
|
TF_CHECK_OK(session->Create(graph_def));
|
||||||
|
|
||||||
QueueRunnerDef queue_runner_def =
|
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
|
||||||
BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
|
kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
|
||||||
{Code::OUT_OF_RANGE, Code::CANCELLED});
|
|
||||||
std::unique_ptr<QueueRunner> qr;
|
std::unique_ptr<QueueRunner> qr;
|
||||||
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
|
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
|
||||||
TF_CHECK_OK(qr->Start(session.get()));
|
// This will timeout since queue0 is not fed and queue1 is fetching data from
|
||||||
|
// queue0.
|
||||||
TF_EXPECT_OK(qr->Stop(session.get()));
|
EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED);
|
||||||
|
session->Close();
|
||||||
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
|
|
||||||
|
|
||||||
EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
|
|
||||||
Code::OUT_OF_RANGE);
|
|
||||||
|
|
||||||
// qr is already stopped
|
|
||||||
TF_EXPECT_OK(qr->Join());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QueueRunnerTest, StopTwoQueues) {
|
TEST(QueueRunnerTest, TestCoordinatorStop) {
|
||||||
auto graph_def = BuildDoubleQueueGraph();
|
auto graph_def = BuildDoubleQueueGraph();
|
||||||
|
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
std::unique_ptr<Session> session(NewSession(options));
|
std::unique_ptr<Session> session(NewSession(options));
|
||||||
TF_CHECK_OK(session->Create(graph_def));
|
TF_CHECK_OK(session->Create(graph_def));
|
||||||
@ -290,31 +292,24 @@ TEST(QueueRunnerTest, StopTwoQueues) {
|
|||||||
QueueRunnerDef queue_runner1 =
|
QueueRunnerDef queue_runner1 =
|
||||||
BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
|
BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
|
||||||
{Code::OUT_OF_RANGE, Code::CANCELLED});
|
{Code::OUT_OF_RANGE, Code::CANCELLED});
|
||||||
|
|
||||||
|
Coordinator coord;
|
||||||
std::unique_ptr<QueueRunner> qr0;
|
std::unique_ptr<QueueRunner> qr0;
|
||||||
TF_EXPECT_OK(QueueRunner::New(queue_runner0, &qr0));
|
TF_EXPECT_OK(QueueRunner::New(queue_runner0, &coord, &qr0));
|
||||||
TF_CHECK_OK(qr0->Start(session.get()));
|
TF_CHECK_OK(qr0->Start(session.get()));
|
||||||
std::unique_ptr<QueueRunner> qr1;
|
std::unique_ptr<QueueRunner> qr1;
|
||||||
TF_EXPECT_OK(QueueRunner::New(queue_runner1, &qr1));
|
TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1));
|
||||||
TF_CHECK_OK(qr1->Start(session.get()));
|
TF_CHECK_OK(qr1->Start(session.get()));
|
||||||
|
|
||||||
|
coord.RegisterRunner(std::move(qr0));
|
||||||
|
coord.RegisterRunner(std::move(qr1));
|
||||||
|
|
||||||
std::vector<Tensor> dq;
|
std::vector<Tensor> dq;
|
||||||
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
|
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
|
||||||
EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
|
EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
|
||||||
|
|
||||||
TF_EXPECT_OK(qr0->Stop(session.get()));
|
TF_EXPECT_OK(coord.RequestStop());
|
||||||
TF_EXPECT_OK(qr1->Stop(session.get()));
|
TF_EXPECT_OK(coord.Join());
|
||||||
|
|
||||||
TF_EXPECT_OK(qr0->Join());
|
|
||||||
TF_EXPECT_OK(qr1->Join());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(QueueRunnerTest, EmptyEnqueueOps) {
|
|
||||||
QueueRunnerDef queue_runner_def =
|
|
||||||
BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
|
|
||||||
|
|
||||||
std::unique_ptr<QueueRunner> qr;
|
|
||||||
EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
|
|
||||||
Code::INVALID_ARGUMENT);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -23,6 +23,7 @@ py_library(
|
|||||||
"//tensorflow/contrib/framework:framework_py",
|
"//tensorflow/contrib/framework:framework_py",
|
||||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||||
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
||||||
|
"//tensorflow/contrib/integrate:integrate_py",
|
||||||
"//tensorflow/contrib/layers:layers_py",
|
"//tensorflow/contrib/layers:layers_py",
|
||||||
"//tensorflow/contrib/learn",
|
"//tensorflow/contrib/learn",
|
||||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib import factorization
|
|||||||
from tensorflow.contrib import framework
|
from tensorflow.contrib import framework
|
||||||
from tensorflow.contrib import graph_editor
|
from tensorflow.contrib import graph_editor
|
||||||
from tensorflow.contrib import grid_rnn
|
from tensorflow.contrib import grid_rnn
|
||||||
|
from tensorflow.contrib import integrate
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
from tensorflow.contrib import learn
|
from tensorflow.contrib import learn
|
||||||
from tensorflow.contrib import linear_optimizer
|
from tensorflow.contrib import linear_optimizer
|
||||||
|
@ -76,7 +76,7 @@ def build_split_apply_merge_model():
|
|||||||
|
|
||||||
# REINFORCE forward step
|
# REINFORCE forward step
|
||||||
route_selection = st.StochasticTensor(
|
route_selection = st.StochasticTensor(
|
||||||
distributions.Categorical, logits=logits)
|
distributions.Categorical(logits=logits))
|
||||||
|
|
||||||
# Accessing route_selection as a Tensor below forces a sample of
|
# Accessing route_selection as a Tensor below forces a sample of
|
||||||
# the Categorical distribution based on its logits.
|
# the Categorical distribution based on its logits.
|
||||||
|
@ -22,6 +22,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
st = tf.contrib.bayesflow.stochastic_tensor
|
st = tf.contrib.bayesflow.stochastic_tensor
|
||||||
sge = tf.contrib.bayesflow.stochastic_gradient_estimators
|
sge = tf.contrib.bayesflow.stochastic_gradient_estimators
|
||||||
|
dists = tf.contrib.distributions
|
||||||
|
|
||||||
|
|
||||||
class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
||||||
@ -31,7 +32,7 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
|||||||
self._final_loss = tf.constant(3.2)
|
self._final_loss = tf.constant(3.2)
|
||||||
|
|
||||||
def _testScoreFunction(self, loss_fn, expected):
|
def _testScoreFunction(self, loss_fn, expected):
|
||||||
x = st.BernoulliTensor(p=self._p, loss_fn=loss_fn)
|
x = st.StochasticTensor(dists.Bernoulli(p=self._p), loss_fn=loss_fn)
|
||||||
sf = x.loss(self._final_loss)
|
sf = x.loss(self._final_loss)
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.initialize_all_variables())
|
||||||
@ -62,8 +63,8 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
|||||||
def testScoreFunctionWithMeanBaseline(self):
|
def testScoreFunctionWithMeanBaseline(self):
|
||||||
ema_decay = 0.8
|
ema_decay = 0.8
|
||||||
num_steps = 6
|
num_steps = 6
|
||||||
x = st.BernoulliTensor(
|
x = st.StochasticTensor(
|
||||||
p=self._p,
|
dists.Bernoulli(p=self._p),
|
||||||
loss_fn=sge.get_score_function_with_baseline(
|
loss_fn=sge.get_score_function_with_baseline(
|
||||||
sge.get_mean_baseline(ema_decay)))
|
sge.get_mean_baseline(ema_decay)))
|
||||||
sf = x.loss(self._final_loss)
|
sf = x.loss(self._final_loss)
|
||||||
@ -98,12 +99,12 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self):
|
def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self):
|
||||||
ema_decay = 0.8
|
ema_decay = 0.8
|
||||||
x = st.BernoulliTensor(
|
x = st.StochasticTensor(
|
||||||
p=self._p,
|
dists.Bernoulli(p=self._p),
|
||||||
loss_fn=sge.get_score_function_with_baseline(
|
loss_fn=sge.get_score_function_with_baseline(
|
||||||
sge.get_mean_baseline(ema_decay)))
|
sge.get_mean_baseline(ema_decay)))
|
||||||
y = st.BernoulliTensor(
|
y = st.StochasticTensor(
|
||||||
p=self._p,
|
dists.Bernoulli(p=self._p),
|
||||||
loss_fn=sge.get_score_function_with_baseline(
|
loss_fn=sge.get_score_function_with_baseline(
|
||||||
sge.get_mean_baseline(ema_decay)))
|
sge.get_mean_baseline(ema_decay)))
|
||||||
sf_x = x.loss(self._final_loss)
|
sf_x = x.loss(self._final_loss)
|
||||||
|
@ -39,9 +39,9 @@ class TestSurrogateLosses(tf.test.TestCase):
|
|||||||
mu = [0.0, 0.1, 0.2]
|
mu = [0.0, 0.1, 0.2]
|
||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||||
likelihood = st.StochasticTensor(
|
likelihood = st.StochasticTensor(
|
||||||
distributions.Normal, mu=prior, sigma=sigma)
|
distributions.Normal(mu=prior, sigma=sigma))
|
||||||
self.assertTrue(prior.distribution.is_reparameterized)
|
self.assertTrue(prior.distribution.is_reparameterized)
|
||||||
self.assertTrue(likelihood.distribution.is_reparameterized)
|
self.assertTrue(likelihood.distribution.is_reparameterized)
|
||||||
|
|
||||||
@ -77,10 +77,9 @@ class TestSurrogateLosses(tf.test.TestCase):
|
|||||||
mu = tf.constant([0.0, 0.1, 0.2])
|
mu = tf.constant([0.0, 0.1, 0.2])
|
||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
prior = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
prior = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||||
likelihood = st.StochasticTensor(
|
likelihood = st.StochasticTensor(NormalNotParam(mu=prior, sigma=sigma))
|
||||||
NormalNotParam, mu=prior, sigma=sigma)
|
prior_2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||||
prior_2 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
|
||||||
|
|
||||||
loss = tf.square(tf.identity(likelihood) - mu)
|
loss = tf.square(tf.identity(likelihood) - mu)
|
||||||
part_loss = tf.square(tf.identity(prior) - mu)
|
part_loss = tf.square(tf.identity(prior) - mu)
|
||||||
@ -155,9 +154,7 @@ class TestSurrogateLosses(tf.test.TestCase):
|
|||||||
mu = tf.constant([0.0, 0.1, 0.2])
|
mu = tf.constant([0.0, 0.1, 0.2])
|
||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
dt = st.StochasticTensor(NormalNotParam,
|
dt = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma),
|
||||||
mu=mu,
|
|
||||||
sigma=sigma,
|
|
||||||
loss_fn=None)
|
loss_fn=None)
|
||||||
self.assertEqual(None, dt.loss(tf.constant([2.0])))
|
self.assertEqual(None, dt.loss(tf.constant([2.0])))
|
||||||
|
|
||||||
@ -166,8 +163,8 @@ class TestSurrogateLosses(tf.test.TestCase):
|
|||||||
mu = tf.constant([0.0, 0.1, 0.2])
|
mu = tf.constant([0.0, 0.1, 0.2])
|
||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
dt1 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
dt1 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||||
dt2 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
dt2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||||
loss = tf.square(tf.identity(dt1)) + 10. + dt2
|
loss = tf.square(tf.identity(dt1)) + 10. + dt2
|
||||||
|
|
||||||
sl_all = sg.surrogate_loss([loss])
|
sl_all = sg.surrogate_loss([loss])
|
||||||
@ -186,8 +183,8 @@ class TestSurrogateLosses(tf.test.TestCase):
|
|||||||
class StochasticDependenciesMapTest(tf.test.TestCase):
|
class StochasticDependenciesMapTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testBuildsMapOfUpstreamNodes(self):
|
def testBuildsMapOfUpstreamNodes(self):
|
||||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
dt2 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt2 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
out1 = dt1.value() + 1.
|
out1 = dt1.value() + 1.
|
||||||
out2 = dt2.value() + 2.
|
out2 = dt2.value() + 2.
|
||||||
x = out1 + out2
|
x = out1 + out2
|
||||||
@ -197,11 +194,11 @@ class StochasticDependenciesMapTest(tf.test.TestCase):
|
|||||||
self.assertEqual(dep_map[dt2], set([x, y]))
|
self.assertEqual(dep_map[dt2], set([x, y]))
|
||||||
|
|
||||||
def testHandlesStackedStochasticNodes(self):
|
def testHandlesStackedStochasticNodes(self):
|
||||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
out1 = dt1.value() + 1.
|
out1 = dt1.value() + 1.
|
||||||
dt2 = st.StochasticTensor(distributions.Normal, mu=out1, sigma=1.)
|
dt2 = st.StochasticTensor(distributions.Normal(mu=out1, sigma=1.))
|
||||||
x = dt2.value() + 2.
|
x = dt2.value() + 2.
|
||||||
dt3 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
y = dt3.value() * 3.
|
y = dt3.value() * 3.
|
||||||
dep_map = sg._stochastic_dependencies_map([x, y])
|
dep_map = sg._stochastic_dependencies_map([x, y])
|
||||||
self.assertEqual(dep_map[dt1], set([x]))
|
self.assertEqual(dep_map[dt1], set([x]))
|
||||||
@ -209,10 +206,10 @@ class StochasticDependenciesMapTest(tf.test.TestCase):
|
|||||||
self.assertEqual(dep_map[dt3], set([y]))
|
self.assertEqual(dep_map[dt3], set([y]))
|
||||||
|
|
||||||
def testTraversesControlInputs(self):
|
def testTraversesControlInputs(self):
|
||||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
logits = dt1.value() * 3.
|
logits = dt1.value() * 3.
|
||||||
dt2 = st.StochasticTensor(distributions.Bernoulli, logits=logits)
|
dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits))
|
||||||
dt3 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||||
x = dt3.value()
|
x = dt3.value()
|
||||||
y = tf.ones((2, 2)) * 4.
|
y = tf.ones((2, 2)) * 4.
|
||||||
z = tf.ones((2, 2)) * 3.
|
z = tf.ones((2, 2)) * 3.
|
||||||
|
@ -35,19 +35,19 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
sigma2 = tf.constant([0.1, 0.2, 0.3])
|
sigma2 = tf.constant([0.1, 0.2, 0.3])
|
||||||
|
|
||||||
prior_default = st.StochasticTensor(
|
prior_default = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma)
|
distributions.Normal(mu=mu, sigma=sigma))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(prior_default.value_type, st.SampleAndReshapeValue))
|
isinstance(prior_default.value_type, st.SampleAndReshapeValue))
|
||||||
prior_0 = st.StochasticTensor(
|
prior_0 = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma,
|
distributions.Normal(mu=mu, sigma=sigma),
|
||||||
dist_value_type=st.SampleAndReshapeValue())
|
dist_value_type=st.SampleAndReshapeValue())
|
||||||
self.assertTrue(isinstance(prior_0.value_type, st.SampleAndReshapeValue))
|
self.assertTrue(isinstance(prior_0.value_type, st.SampleAndReshapeValue))
|
||||||
|
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||||
self.assertTrue(isinstance(prior.value_type, st.SampleAndReshapeValue))
|
self.assertTrue(isinstance(prior.value_type, st.SampleAndReshapeValue))
|
||||||
likelihood = st.StochasticTensor(
|
likelihood = st.StochasticTensor(
|
||||||
distributions.Normal, mu=prior, sigma=sigma2)
|
distributions.Normal(mu=prior, sigma=sigma2))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(likelihood.value_type, st.SampleAndReshapeValue))
|
isinstance(likelihood.value_type, st.SampleAndReshapeValue))
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
|
|
||||||
with st.value_type(st.MeanValue()):
|
with st.value_type(st.MeanValue()):
|
||||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||||
self.assertTrue(isinstance(prior.value_type, st.MeanValue))
|
self.assertTrue(isinstance(prior.value_type, st.MeanValue))
|
||||||
|
|
||||||
prior_mean = prior.mean()
|
prior_mean = prior.mean()
|
||||||
@ -94,7 +94,8 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
with st.value_type(st.SampleAndReshapeValue()):
|
with st.value_type(st.SampleAndReshapeValue()):
|
||||||
prior_single = st.StochasticTensor(
|
prior_single = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma)
|
distributions.Normal(
|
||||||
|
mu=mu, sigma=sigma))
|
||||||
|
|
||||||
prior_single_value = prior_single.value()
|
prior_single_value = prior_single.value()
|
||||||
self.assertEqual(prior_single_value.get_shape(), (2, 3))
|
self.assertEqual(prior_single_value.get_shape(), (2, 3))
|
||||||
@ -104,7 +105,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
with st.value_type(st.SampleAndReshapeValue(n=2)):
|
with st.value_type(st.SampleAndReshapeValue(n=2)):
|
||||||
prior_double = st.StochasticTensor(
|
prior_double = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma)
|
distributions.Normal(mu=mu, sigma=sigma))
|
||||||
|
|
||||||
prior_double_value = prior_double.value()
|
prior_double_value = prior_double.value()
|
||||||
self.assertEqual(prior_double_value.get_shape(), (4, 3))
|
self.assertEqual(prior_double_value.get_shape(), (4, 3))
|
||||||
@ -119,7 +120,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
with st.value_type(st.SampleValue()):
|
with st.value_type(st.SampleValue()):
|
||||||
prior_single = st.StochasticTensor(
|
prior_single = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma)
|
distributions.Normal(mu=mu, sigma=sigma))
|
||||||
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
|
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
|
||||||
|
|
||||||
prior_single_value = prior_single.value()
|
prior_single_value = prior_single.value()
|
||||||
@ -130,7 +131,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
with st.value_type(st.SampleValue(n=2)):
|
with st.value_type(st.SampleValue(n=2)):
|
||||||
prior_double = st.StochasticTensor(
|
prior_double = st.StochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma)
|
distributions.Normal(mu=mu, sigma=sigma))
|
||||||
|
|
||||||
prior_double_value = prior_double.value()
|
prior_double_value = prior_double.value()
|
||||||
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
|
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
|
||||||
@ -143,9 +144,9 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
mu = [0.0, -1.0, 1.0]
|
mu = [0.0, -1.0, 1.0]
|
||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
with st.value_type(st.MeanValue()):
|
with st.value_type(st.MeanValue()):
|
||||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||||
entropy = prior.entropy()
|
entropy = prior.entropy()
|
||||||
deep_entropy = prior.entropy()
|
deep_entropy = prior.distribution.entropy()
|
||||||
expected_deep_entropy = distributions.Normal(
|
expected_deep_entropy = distributions.Normal(
|
||||||
mu=mu, sigma=sigma).entropy()
|
mu=mu, sigma=sigma).entropy()
|
||||||
entropies = sess.run([entropy, deep_entropy, expected_deep_entropy])
|
entropies = sess.run([entropy, deep_entropy, expected_deep_entropy])
|
||||||
@ -159,17 +160,15 @@ class StochasticTensorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# With default
|
# With default
|
||||||
with st.value_type(st.MeanValue(stop_gradient=True)):
|
with st.value_type(st.MeanValue(stop_gradient=True)):
|
||||||
dt = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
dt = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||||
loss = dt.loss([tf.constant(2.0)])
|
loss = dt.loss([tf.constant(2.0)])
|
||||||
self.assertTrue(loss is not None)
|
self.assertTrue(loss is not None)
|
||||||
self.assertAllClose(dt.distribution.log_prob(mu).eval() * 2.0,
|
self.assertAllClose(
|
||||||
loss.eval())
|
dt.distribution.log_prob(mu).eval() * 2.0, loss.eval())
|
||||||
|
|
||||||
# With passed-in loss_fn.
|
# With passed-in loss_fn.
|
||||||
dt = st.StochasticTensor(
|
dt = st.StochasticTensor(
|
||||||
distributions.Normal,
|
distributions.Normal(mu=mu, sigma=sigma),
|
||||||
mu=mu,
|
|
||||||
sigma=sigma,
|
|
||||||
dist_value_type=st.MeanValue(stop_gradient=True),
|
dist_value_type=st.MeanValue(stop_gradient=True),
|
||||||
loss_fn=sge.get_score_function_with_constant_baseline(
|
loss_fn=sge.get_score_function_with_constant_baseline(
|
||||||
baseline=tf.constant(8.0)))
|
baseline=tf.constant(8.0)))
|
||||||
@ -204,7 +203,7 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
|||||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||||
obs = tf.zeros((2, 3))
|
obs = tf.zeros((2, 3))
|
||||||
z = st.ObservedStochasticTensor(
|
z = st.ObservedStochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma, value=obs)
|
distributions.Normal(mu=mu, sigma=sigma), value=obs)
|
||||||
[obs_val, z_val] = sess.run([obs, z.value()])
|
[obs_val, z_val] = sess.run([obs, z.value()])
|
||||||
self.assertAllEqual(obs_val, z_val)
|
self.assertAllEqual(obs_val, z_val)
|
||||||
|
|
||||||
@ -216,13 +215,13 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
|||||||
sigma = tf.placeholder(tf.float32)
|
sigma = tf.placeholder(tf.float32)
|
||||||
obs = tf.placeholder(tf.float32)
|
obs = tf.placeholder(tf.float32)
|
||||||
z = st.ObservedStochasticTensor(
|
z = st.ObservedStochasticTensor(
|
||||||
distributions.Normal, mu=mu, sigma=sigma, value=obs)
|
distributions.Normal(mu=mu, sigma=sigma), value=obs)
|
||||||
|
|
||||||
mu2 = tf.placeholder(tf.float32, shape=[None])
|
mu2 = tf.placeholder(tf.float32, shape=[None])
|
||||||
sigma2 = tf.placeholder(tf.float32, shape=[None])
|
sigma2 = tf.placeholder(tf.float32, shape=[None])
|
||||||
obs2 = tf.placeholder(tf.float32, shape=[None, None])
|
obs2 = tf.placeholder(tf.float32, shape=[None, None])
|
||||||
z2 = st.ObservedStochasticTensor(
|
z2 = st.ObservedStochasticTensor(
|
||||||
distributions.Normal, mu=mu2, sigma=sigma2, value=obs2)
|
distributions.Normal(mu=mu2, sigma=sigma2), value=obs2)
|
||||||
|
|
||||||
coll = tf.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
|
coll = tf.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
|
||||||
self.assertEqual(coll, [z, z2])
|
self.assertEqual(coll, [z, z2])
|
||||||
@ -230,27 +229,19 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
|||||||
def testConstructionErrors(self):
|
def testConstructionErrors(self):
|
||||||
mu = [0., 0.]
|
mu = [0., 0.]
|
||||||
sigma = [1., 1.]
|
sigma = [1., 1.]
|
||||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
self.assertRaises(
|
||||||
distributions.Normal, mu=mu, sigma=sigma,
|
ValueError,
|
||||||
|
st.ObservedStochasticTensor,
|
||||||
|
distributions.Normal(mu=mu, sigma=sigma),
|
||||||
value=tf.zeros((3,)))
|
value=tf.zeros((3,)))
|
||||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
self.assertRaises(
|
||||||
distributions.Normal, mu=mu, sigma=sigma,
|
ValueError,
|
||||||
|
st.ObservedStochasticTensor,
|
||||||
|
distributions.Normal(mu=mu, sigma=sigma),
|
||||||
value=tf.zeros((3, 1)))
|
value=tf.zeros((3, 1)))
|
||||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
self.assertRaises(
|
||||||
distributions.Normal, mu=mu, sigma=sigma,
|
ValueError,
|
||||||
value=tf.zeros((1, 2), dtype=tf.int32))
|
st.ObservedStochasticTensor,
|
||||||
|
distributions.Normal(mu=mu, sigma=sigma),
|
||||||
|
value=tf.zeros(
|
||||||
class AutomaticDistributionImportTest(tf.test.TestCase):
|
(1, 2), dtype=tf.int32))
|
||||||
|
|
||||||
def testImportNormal(self):
|
|
||||||
self.assertTrue(hasattr(st, "NormalTensor"))
|
|
||||||
self.assertTrue(callable(st.NormalTensor))
|
|
||||||
norm = st.NormalTensor(mu=0.0, sigma=1.0)
|
|
||||||
self.assertEqual(type(norm).__name__, "NormalTensor")
|
|
||||||
self.assertTrue(isinstance(norm, st.NormalTensor))
|
|
||||||
self.assertTrue(isinstance(norm, st.StochasticTensor))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
tf.test.main()
|
|
||||||
|
@ -44,7 +44,7 @@ def mini_vae():
|
|||||||
x = [[-6., 3., 6.], [-8., 4., 8.]]
|
x = [[-6., 3., 6.], [-8., 4., 8.]]
|
||||||
prior = distributions.Normal(mu=0., sigma=1.)
|
prior = distributions.Normal(mu=0., sigma=1.)
|
||||||
variational = st.StochasticTensor(
|
variational = st.StochasticTensor(
|
||||||
distributions.Normal, mu=inference_net(x, 1), sigma=1.)
|
distributions.Normal(mu=inference_net(x, 1), sigma=1.))
|
||||||
vi.register_prior(variational, prior)
|
vi.register_prior(variational, prior)
|
||||||
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
||||||
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
||||||
@ -101,7 +101,7 @@ class VariationalInferenceTest(tf.test.TestCase):
|
|||||||
|
|
||||||
prior = distributions.Bernoulli(0.5)
|
prior = distributions.Bernoulli(0.5)
|
||||||
variational = st.StochasticTensor(
|
variational = st.StochasticTensor(
|
||||||
NormalNoEntropy, mu=inference_net(x, 1), sigma=1.)
|
NormalNoEntropy(mu=inference_net(x, 1), sigma=1.))
|
||||||
vi.register_prior(variational, prior)
|
vi.register_prior(variational, prior)
|
||||||
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
||||||
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
||||||
|
@ -44,7 +44,6 @@ from __future__ import print_function
|
|||||||
import abc
|
import abc
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import inspect
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import six
|
import six
|
||||||
@ -79,10 +78,6 @@ class BaseStochasticTensor(object):
|
|||||||
def graph(self):
|
def graph(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def input_dict(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def value(self, name=None):
|
def value(self, name=None):
|
||||||
pass
|
pass
|
||||||
@ -120,6 +115,7 @@ class BaseStochasticTensor(object):
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
ops.register_tensor_conversion_function(
|
ops.register_tensor_conversion_function(
|
||||||
BaseStochasticTensor, BaseStochasticTensor._tensor_conversion_function)
|
BaseStochasticTensor, BaseStochasticTensor._tensor_conversion_function)
|
||||||
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
@ -223,8 +219,8 @@ class SampleAndReshapeValue(_StochasticValueType):
|
|||||||
st_value = st.value()
|
st_value = st.value()
|
||||||
assertEqual(st_value.get_shape(), (4, 3))
|
assertEqual(st_value.get_shape(), (4, 3))
|
||||||
|
|
||||||
dt_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
|
st_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
|
||||||
assertEqual(dt_value_val.shape, (4, 3))
|
assertEqual(st_value_val.shape, (4, 3))
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -312,17 +308,16 @@ class StochasticTensor(BaseStochasticTensor):
|
|||||||
"""StochasticTensor is a BaseStochasticTensor backed by a distribution."""
|
"""StochasticTensor is a BaseStochasticTensor backed by a distribution."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dist_cls,
|
dist,
|
||||||
name=None,
|
name="StochasticTensor",
|
||||||
dist_value_type=None,
|
dist_value_type=None,
|
||||||
loss_fn=sge.score_function,
|
loss_fn=sge.score_function):
|
||||||
**dist_args):
|
|
||||||
"""Construct a `StochasticTensor`.
|
"""Construct a `StochasticTensor`.
|
||||||
|
|
||||||
`StochasticTensor` will instantiate a distribution from `dist_cls` and
|
`StochasticTensor` is backed by the `dist` distribution and its `value`
|
||||||
`dist_args` and its `value` method will return the same value each time
|
method will return the same value each time it is called. What `value` is
|
||||||
it is called. What `value` is returned is controlled by the
|
returned is controlled by the `dist_value_type` (defaults to
|
||||||
`dist_value_type` (defaults to `SampleAndReshapeValue`).
|
`SampleAndReshapeValue`).
|
||||||
|
|
||||||
Some distributions' sample functions are not differentiable (e.g. a sample
|
Some distributions' sample functions are not differentiable (e.g. a sample
|
||||||
from a discrete distribution like a Bernoulli) and so to differentiate
|
from a discrete distribution like a Bernoulli) and so to differentiate
|
||||||
@ -338,28 +333,25 @@ class StochasticTensor(BaseStochasticTensor):
|
|||||||
`MeanValueType` or if `loss_fn=None`.
|
`MeanValueType` or if `loss_fn=None`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dist_cls: a `Distribution` class.
|
dist: an instance of `Distribution`.
|
||||||
name: a name for this `StochasticTensor` and its ops.
|
name: a name for this `StochasticTensor` and its ops.
|
||||||
dist_value_type: a `_StochasticValueType`, which will determine what the
|
dist_value_type: a `_StochasticValueType`, which will determine what the
|
||||||
`value` of this `StochasticTensor` will be. If not provided, the
|
`value` of this `StochasticTensor` will be. If not provided, the
|
||||||
value type set with the `value_type` context manager will be used.
|
value type set with the `value_type` context manager will be used.
|
||||||
loss_fn: callable that takes `(st, st.value(), influenced_loss)`, where
|
loss_fn: callable that takes
|
||||||
|
`(st, st.value(), influenced_loss)`, where
|
||||||
`st` is this `StochasticTensor`, and returns a `Tensor` loss. By
|
`st` is this `StochasticTensor`, and returns a `Tensor` loss. By
|
||||||
default, `loss_fn` is the `score_function`, or more precisely, the
|
default, `loss_fn` is the `score_function`, or more precisely, the
|
||||||
integral of the score function, such that when the gradient is taken,
|
integral of the score function, such that when the gradient is taken,
|
||||||
the score function results. See the `stochastic_gradient_estimators`
|
the score function results. See the `stochastic_gradient_estimators`
|
||||||
module for additional loss functions and baselines.
|
module for additional loss functions and baselines.
|
||||||
**dist_args: keyword arguments to be passed through to `dist_cls` on
|
|
||||||
construction.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if `dist_cls` is not a `Distribution`.
|
TypeError: if `dist` is not an instance of `Distribution`.
|
||||||
TypeError: if `loss_fn` is not `callable`.
|
TypeError: if `loss_fn` is not `callable`.
|
||||||
"""
|
"""
|
||||||
if not issubclass(dist_cls, distributions.Distribution):
|
if not isinstance(dist, distributions.Distribution):
|
||||||
raise TypeError("dist_cls must be a subclass of Distribution")
|
raise TypeError("dist must be an instance of Distribution")
|
||||||
self._dist_cls = dist_cls
|
|
||||||
self._dist_args = dist_args
|
|
||||||
if dist_value_type is None:
|
if dist_value_type is None:
|
||||||
try:
|
try:
|
||||||
self._value_type = get_current_value_type()
|
self._value_type = get_current_value_type()
|
||||||
@ -371,24 +363,17 @@ class StochasticTensor(BaseStochasticTensor):
|
|||||||
with value_type(dist_value_type):
|
with value_type(dist_value_type):
|
||||||
self._value_type = get_current_value_type()
|
self._value_type = get_current_value_type()
|
||||||
|
|
||||||
self._value_type.declare_inputs(self, dist_args)
|
|
||||||
|
|
||||||
if loss_fn is not None and not callable(loss_fn):
|
if loss_fn is not None and not callable(loss_fn):
|
||||||
raise TypeError("loss_fn must be callable")
|
raise TypeError("loss_fn must be callable")
|
||||||
self._loss_fn = loss_fn
|
self._loss_fn = loss_fn
|
||||||
|
|
||||||
with ops.name_scope(name, "StochasticTensor",
|
with ops.name_scope(name) as scope:
|
||||||
dist_args.values()) as scope:
|
|
||||||
self._name = scope
|
self._name = scope
|
||||||
self._dist = dist_cls(**dist_args)
|
self._dist = dist
|
||||||
self._value = self._create_value()
|
self._value = self._create_value()
|
||||||
|
|
||||||
super(StochasticTensor, self).__init__()
|
super(StochasticTensor, self).__init__()
|
||||||
|
|
||||||
@property
|
|
||||||
def input_dict(self):
|
|
||||||
return self._dist_args
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_type(self):
|
def value_type(self):
|
||||||
return self._value_type
|
return self._value_type
|
||||||
@ -397,9 +382,6 @@ class StochasticTensor(BaseStochasticTensor):
|
|||||||
def distribution(self):
|
def distribution(self):
|
||||||
return self._dist
|
return self._dist
|
||||||
|
|
||||||
def clone(self, name=None, **dist_args):
|
|
||||||
return StochasticTensor(self._dist_cls, name=name, **dist_args)
|
|
||||||
|
|
||||||
def _create_value(self):
|
def _create_value(self):
|
||||||
"""Create the value Tensor based on the value type, store as self._value."""
|
"""Create the value Tensor based on the value type, store as self._value."""
|
||||||
|
|
||||||
@ -494,33 +476,28 @@ class ObservedStochasticTensor(StochasticTensor):
|
|||||||
"""A StochasticTensor with an observed value."""
|
"""A StochasticTensor with an observed value."""
|
||||||
|
|
||||||
# pylint: disable=super-init-not-called
|
# pylint: disable=super-init-not-called
|
||||||
def __init__(self, dist_cls, value, name=None, **dist_args):
|
def __init__(self, dist, value, name=None):
|
||||||
"""Construct an `ObservedStochasticTensor`.
|
"""Construct an `ObservedStochasticTensor`.
|
||||||
|
|
||||||
`ObservedStochasticTensor` will instantiate a distribution from `dist_cls`
|
`ObservedStochasticTensor` is backed by distribution `dist` and uses the
|
||||||
and `dist_args` but use the provided value instead of sampling from the
|
provided value instead of using the current value type to draw a value from
|
||||||
distribution. The provided value argument must be appropriately shaped
|
the distribution. The provided value argument must be appropriately shaped
|
||||||
to have come from the constructed distribution.
|
to have come from the distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dist_cls: a `Distribution` class.
|
dist: an instance of `Distribution`.
|
||||||
value: a Tensor containing the observed value
|
value: a Tensor containing the observed value
|
||||||
name: a name for this `ObservedStochasticTensor` and its ops.
|
name: a name for this `ObservedStochasticTensor` and its ops.
|
||||||
**dist_args: keyword arguments to be passed through to `dist_cls` on
|
|
||||||
construction.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if `dist_cls` is not a `Distribution`.
|
TypeError: if `dist` is not an instance of `Distribution`.
|
||||||
ValueError: if `value` is not compatible with the distribution.
|
ValueError: if `value` is not compatible with the distribution.
|
||||||
"""
|
"""
|
||||||
if not issubclass(dist_cls, distributions.Distribution):
|
if not isinstance(dist, distributions.Distribution):
|
||||||
raise TypeError("dist_cls must be a subclass of Distribution")
|
raise TypeError("dist must be an instance of Distribution")
|
||||||
self._dist_cls = dist_cls
|
with ops.name_scope(name, "ObservedStochasticTensor", [value]) as scope:
|
||||||
self._dist_args = dist_args
|
|
||||||
with ops.name_scope(name, "ObservedStochasticTensor",
|
|
||||||
list(dist_args.values()) + [value]) as scope:
|
|
||||||
self._name = scope
|
self._name = scope
|
||||||
self._dist = dist_cls(**dist_args)
|
self._dist = dist
|
||||||
dist_shape = self._dist.get_batch_shape().concatenate(
|
dist_shape = self._dist.get_batch_shape().concatenate(
|
||||||
self._dist.get_event_shape())
|
self._dist.get_event_shape())
|
||||||
value = ops.convert_to_tensor(value)
|
value = ops.convert_to_tensor(value)
|
||||||
@ -538,7 +515,7 @@ class ObservedStochasticTensor(StochasticTensor):
|
|||||||
"sample from the distribution %s." % (value_shape, dist_shape))
|
"sample from the distribution %s." % (value_shape, dist_shape))
|
||||||
if value.dtype != self._dist.dtype:
|
if value.dtype != self._dist.dtype:
|
||||||
raise ValueError("Type of observed value (%s) does not match type of "
|
raise ValueError("Type of observed value (%s) does not match type of "
|
||||||
"distribuiton (%s)." % (value.dtype, self._dist.dtype))
|
"distribution (%s)." % (value.dtype, self._dist.dtype))
|
||||||
self._value = array_ops.identity(value)
|
self._value = array_ops.identity(value)
|
||||||
# pylint: disable=non-parent-init-called
|
# pylint: disable=non-parent-init-called
|
||||||
BaseStochasticTensor.__init__(self)
|
BaseStochasticTensor.__init__(self)
|
||||||
@ -557,39 +534,3 @@ __all__ = [
|
|||||||
"value_type",
|
"value_type",
|
||||||
"get_current_value_type",
|
"get_current_value_type",
|
||||||
]
|
]
|
||||||
|
|
||||||
_globals = globals()
|
|
||||||
# pylint: disable=redefined-builtin
|
|
||||||
__doc__ += "\n\n## Automatically Generated StochasticTensors\n\n"
|
|
||||||
# pylint: enable=redefined-builtin
|
|
||||||
for _name in sorted(dir(distributions)):
|
|
||||||
_candidate = getattr(distributions, _name)
|
|
||||||
if (inspect.isclass(_candidate)
|
|
||||||
and _candidate != distributions.Distribution
|
|
||||||
and issubclass(_candidate, distributions.Distribution)):
|
|
||||||
_local_name = "%sTensor" % _name
|
|
||||||
|
|
||||||
class _WrapperTensor(StochasticTensor):
|
|
||||||
_my_candidate = _candidate
|
|
||||||
|
|
||||||
def __init__(self, name=None, dist_value_type=None,
|
|
||||||
loss_fn=sge.score_function, **dist_args):
|
|
||||||
StochasticTensor.__init__(
|
|
||||||
self,
|
|
||||||
dist_cls=self._my_candidate,
|
|
||||||
name=name,
|
|
||||||
dist_value_type=dist_value_type,
|
|
||||||
loss_fn=loss_fn, **dist_args)
|
|
||||||
|
|
||||||
_WrapperTensor.__name__ = _local_name
|
|
||||||
_WrapperTensor.__doc__ = (
|
|
||||||
"`%s` is a `StochasticTensor` backed by the distribution `%s`."""
|
|
||||||
% (_local_name, _name))
|
|
||||||
_globals[_local_name] = _WrapperTensor
|
|
||||||
del _WrapperTensor
|
|
||||||
del _candidate
|
|
||||||
|
|
||||||
__all__.append(_local_name)
|
|
||||||
__doc__ += "@@%s\n" % _local_name
|
|
||||||
|
|
||||||
del _local_name
|
|
||||||
|
@ -126,7 +126,7 @@ def get_stochastic_variable(getter,
|
|||||||
|
|
||||||
dist_kwargs = dist_kwargs or {}
|
dist_kwargs = dist_kwargs or {}
|
||||||
dist_kwargs.update(params)
|
dist_kwargs.update(params)
|
||||||
sample = st.StochasticTensor(dist_cls, **dist_kwargs)
|
sample = st.StochasticTensor(dist_cls(**dist_kwargs))
|
||||||
|
|
||||||
if prior is not None:
|
if prior is not None:
|
||||||
if callable(prior):
|
if callable(prior):
|
||||||
|
2
tensorflow/contrib/cmake/external/gif.cmake
vendored
2
tensorflow/contrib/cmake/external/gif.cmake
vendored
@ -1,7 +1,7 @@
|
|||||||
include (ExternalProject)
|
include (ExternalProject)
|
||||||
|
|
||||||
set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/)
|
set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/)
|
||||||
set(gif_URL http://ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
|
set(gif_URL http://cdimage.debian.org/mirror/xbmc.org/build-deps/sources/giflib-5.1.4.tar.gz)
|
||||||
set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1)
|
set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1)
|
||||||
set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install)
|
set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install)
|
||||||
set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif)
|
set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif)
|
||||||
|
@ -26,7 +26,7 @@ from setuptools import find_packages, setup, Command
|
|||||||
from setuptools.command.install import install as InstallCommandBase
|
from setuptools.command.install import install as InstallCommandBase
|
||||||
from setuptools.dist import Distribution
|
from setuptools.dist import Distribution
|
||||||
|
|
||||||
_VERSION = '0.11.0rc1-cmake-experimental'
|
_VERSION = '0.11.0rc2-cmake-experimental'
|
||||||
|
|
||||||
REQUIRED_PACKAGES = [
|
REQUIRED_PACKAGES = [
|
||||||
'numpy >= 1.11.0',
|
'numpy >= 1.11.0',
|
||||||
|
@ -57,7 +57,6 @@ initialized with parameters that define the distributions.
|
|||||||
@@MultivariateNormalCholesky
|
@@MultivariateNormalCholesky
|
||||||
@@MultivariateNormalDiagPlusVDVT
|
@@MultivariateNormalDiagPlusVDVT
|
||||||
@@MultivariateNormalDiagWithSoftplusStDev
|
@@MultivariateNormalDiagWithSoftplusStDev
|
||||||
@@matrix_diag_transform
|
|
||||||
|
|
||||||
### Other multivariate distributions
|
### Other multivariate distributions
|
||||||
|
|
||||||
@ -67,6 +66,10 @@ initialized with parameters that define the distributions.
|
|||||||
@@WishartCholesky
|
@@WishartCholesky
|
||||||
@@WishartFull
|
@@WishartFull
|
||||||
|
|
||||||
|
### Multivariate Utilities
|
||||||
|
|
||||||
|
@@matrix_diag_transform
|
||||||
|
|
||||||
## Transformed distributions
|
## Transformed distributions
|
||||||
|
|
||||||
@@TransformedDistribution
|
@@TransformedDistribution
|
||||||
@ -86,7 +89,7 @@ representing the posterior or posterior predictive.
|
|||||||
@@normal_conjugates_known_sigma_posterior
|
@@normal_conjugates_known_sigma_posterior
|
||||||
@@normal_conjugates_known_sigma_predictive
|
@@normal_conjugates_known_sigma_predictive
|
||||||
|
|
||||||
## Kullback Leibler Divergence
|
## Kullback-Leibler Divergence
|
||||||
|
|
||||||
@@kl
|
@@kl
|
||||||
@@RegisterKL
|
@@RegisterKL
|
||||||
|
@ -25,7 +25,7 @@ import tensorflow as tf
|
|||||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
|
|
||||||
|
|
||||||
class DistributionUtilTest(tf.test.TestCase):
|
class AssertCloseTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testAssertCloseIntegerDtype(self):
|
def testAssertCloseIntegerDtype(self):
|
||||||
x = [1, 5, 10, 15, 20]
|
x = [1, 5, 10, 15, 20]
|
||||||
@ -110,6 +110,9 @@ class DistributionUtilTest(tf.test.TestCase):
|
|||||||
distribution_util.assert_integer_form(w)]):
|
distribution_util.assert_integer_form(w)]):
|
||||||
tf.identity(w).eval()
|
tf.identity(w).eval()
|
||||||
|
|
||||||
|
|
||||||
|
class GetLogitsAndProbTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testGetLogitsAndProbImproperArguments(self):
|
def testGetLogitsAndProbImproperArguments(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -229,6 +232,9 @@ class DistributionUtilTest(tf.test.TestCase):
|
|||||||
p=p4, multidimensional=True, validate_args=False)
|
p=p4, multidimensional=True, validate_args=False)
|
||||||
prob.eval()
|
prob.eval()
|
||||||
|
|
||||||
|
|
||||||
|
class LogCombinationsTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testLogCombinationsBinomial(self):
|
def testLogCombinationsBinomial(self):
|
||||||
n = [2, 5, 12, 15]
|
n = [2, 5, 12, 15]
|
||||||
k = [1, 2, 4, 11]
|
k = [1, 2, 4, 11]
|
||||||
@ -252,6 +258,9 @@ class DistributionUtilTest(tf.test.TestCase):
|
|||||||
log_binom = distribution_util.log_combinations(n, counts)
|
log_binom = distribution_util.log_combinations(n, counts)
|
||||||
self.assertEqual([2, 2], log_binom.get_shape())
|
self.assertEqual([2, 2], log_binom.get_shape())
|
||||||
|
|
||||||
|
|
||||||
|
class RotateTransposeTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _np_rotate_transpose(self, x, shift):
|
def _np_rotate_transpose(self, x, shift):
|
||||||
if not isinstance(x, np.ndarray):
|
if not isinstance(x, np.ndarray):
|
||||||
x = np.array(x)
|
x = np.array(x)
|
||||||
@ -283,7 +292,10 @@ class DistributionUtilTest(tf.test.TestCase):
|
|||||||
sess.run(distribution_util.rotate_transpose(x, shift),
|
sess.run(distribution_util.rotate_transpose(x, shift),
|
||||||
feed_dict={x: x_value, shift: shift_value}))
|
feed_dict={x: x_value, shift: shift_value}))
|
||||||
|
|
||||||
def testChooseVector(self):
|
|
||||||
|
class PickVectorTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testCorrectlyPicksVector(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = np.arange(10, 12)
|
x = np.arange(10, 12)
|
||||||
y = np.arange(15, 18)
|
y = np.arange(15, 18)
|
||||||
@ -301,5 +313,51 @@ class DistributionUtilTest(tf.test.TestCase):
|
|||||||
tf.constant(False), x, y)) # No eval.
|
tf.constant(False), x, y)) # No eval.
|
||||||
|
|
||||||
|
|
||||||
|
class FillLowerTriangularTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testCorrectlyMakes1x1LowerTril(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = np.array([[1.], [2], [3]])
|
||||||
|
expected = np.array([[[1.]], [[2]], [[3]]])
|
||||||
|
actual = distribution_util.fill_lower_triangular(x)
|
||||||
|
self.assertAllEqual(expected.shape, actual.get_shape())
|
||||||
|
self.assertAllEqual(expected, actual.eval())
|
||||||
|
|
||||||
|
def testCorrectlyMakesNoBatchLowerTril(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = tf.convert_to_tensor(np.arange(9, dtype=np.float32))
|
||||||
|
expected = np.array(
|
||||||
|
[[0., 0., 0.],
|
||||||
|
[1., 2., 0.],
|
||||||
|
[3., 4., 5.]])
|
||||||
|
actual = distribution_util.fill_lower_triangular(x)
|
||||||
|
self.assertAllEqual(expected.shape, actual.get_shape())
|
||||||
|
self.assertAllEqual(expected, actual.eval())
|
||||||
|
self.assertAllEqual(
|
||||||
|
np.concatenate([np.ones(6, dtype=np.float32),
|
||||||
|
np.zeros(3, dtype=np.float32)]),
|
||||||
|
tf.gradients(distribution_util.fill_lower_triangular(x), x)[0].eval())
|
||||||
|
|
||||||
|
def testCorrectlyMakesBatchLowerTril(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = np.reshape(np.arange(24), (2, 2, 6))
|
||||||
|
expected = np.array(
|
||||||
|
[[[[0., 0., 0.],
|
||||||
|
[1., 2., 0.],
|
||||||
|
[3., 4., 5.]],
|
||||||
|
[[6., 0., 0.],
|
||||||
|
[7., 8., 0.],
|
||||||
|
[9., 10., 11.]]],
|
||||||
|
[[[12., 0., 0.],
|
||||||
|
[13., 14., 0.],
|
||||||
|
[15., 16., 17.]],
|
||||||
|
[[18., 0., 0.],
|
||||||
|
[19., 20., 0.],
|
||||||
|
[21., 22., 23.]]]])
|
||||||
|
actual = distribution_util.fill_lower_triangular(x)
|
||||||
|
self.assertAllEqual(expected.shape, actual.get_shape())
|
||||||
|
self.assertAllEqual(expected, actual.eval())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -20,11 +20,13 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
@ -376,7 +378,7 @@ def pick_vector(cond,
|
|||||||
TypeError: if `cond` is not a constant and
|
TypeError: if `cond` is not a constant and
|
||||||
`true_vector.dtype != false_vector.dtype`
|
`true_vector.dtype != false_vector.dtype`
|
||||||
"""
|
"""
|
||||||
with ops.op_scope((cond, true_vector, false_vector), name):
|
with ops.name_scope(name, values=(cond, true_vector, false_vector)):
|
||||||
cond = ops.convert_to_tensor(cond, name="cond")
|
cond = ops.convert_to_tensor(cond, name="cond")
|
||||||
if cond.dtype != dtypes.bool:
|
if cond.dtype != dtypes.bool:
|
||||||
raise TypeError("%s.dtype=%s which is not %s" %
|
raise TypeError("%s.dtype=%s which is not %s" %
|
||||||
@ -405,6 +407,105 @@ def gen_new_seed(seed, salt):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def fill_lower_triangular(x, name="fill_lower_triangular"):
|
||||||
|
"""Creates a (batch of) lower triangular matrix from a vector of inputs.
|
||||||
|
|
||||||
|
If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
|
||||||
|
b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
|
||||||
|
`n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.
|
||||||
|
|
||||||
|
Note: This function is very slow; possibly 10x slower than zero-ing out the
|
||||||
|
upper-triangular portion of a full matrix.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
fill_lower_triangular([1, 2, 3, 4, 5, 6])
|
||||||
|
# Returns: [[1, 0, 0],
|
||||||
|
# [2, 3, 0],
|
||||||
|
# [4, 5, 6]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor` representing lower triangular elements.
|
||||||
|
name: `String`. The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tril: `Tensor` with lower triangular elements filled from `x`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, values=(x,)):
|
||||||
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
|
if (x.get_shape().ndims is not None and
|
||||||
|
x.get_shape()[-1].value is not None):
|
||||||
|
d = x.get_shape()[-1].value
|
||||||
|
# d = n^2/2 + n/2 implies n is:
|
||||||
|
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
|
||||||
|
final_shape = x.get_shape()[:-1].concatenate(
|
||||||
|
tensor_shape.TensorShape([n, n]))
|
||||||
|
else:
|
||||||
|
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
|
||||||
|
# d = n^2/2 + n/2 implies n is:
|
||||||
|
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
|
||||||
|
dtype=dtypes.int32)
|
||||||
|
final_shape = x.get_shape()[:-1].concatenate(
|
||||||
|
tensor_shape.TensorShape([None, None]))
|
||||||
|
|
||||||
|
# Make ids for each batch dim.
|
||||||
|
if (x.get_shape().ndims is not None and
|
||||||
|
x.get_shape()[:-1].is_fully_defined()):
|
||||||
|
batch_shape = np.asarray(x.get_shape()[:-1].as_list(), dtype=np.int32)
|
||||||
|
m = np.prod(batch_shape)
|
||||||
|
else:
|
||||||
|
batch_shape = array_ops.shape(x)[:-1]
|
||||||
|
m = array_ops.reduce_prod(batch_shape)
|
||||||
|
|
||||||
|
# Flatten batch dims.
|
||||||
|
y = array_ops.reshape(x, [-1, d])
|
||||||
|
|
||||||
|
# Prepend a zero to each row.
|
||||||
|
y = array_ops.pad(y, paddings=[[0, 0], [1, 0]])
|
||||||
|
|
||||||
|
# Make ids for each batch dim.
|
||||||
|
if x.get_shape()[:-1].is_fully_defined():
|
||||||
|
m = np.asarray(np.prod(x.get_shape()[:-1].as_list()), dtype=np.int32)
|
||||||
|
else:
|
||||||
|
m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
|
||||||
|
batch_ids = math_ops.range(m)
|
||||||
|
|
||||||
|
def make_tril_ids(n):
|
||||||
|
"""Internal helper to create vector of linear indices into y."""
|
||||||
|
cols = array_ops.reshape(array_ops.tile(math_ops.range(n), [n]), [n, n])
|
||||||
|
rows = array_ops.tile(
|
||||||
|
array_ops.expand_dims(math_ops.range(n), -1), [1, n])
|
||||||
|
pred = math_ops.greater(cols, rows)
|
||||||
|
tril_ids = array_ops.tile(array_ops.reshape(
|
||||||
|
math_ops.cumsum(math_ops.range(n)), [n, 1]), [1, n]) + cols
|
||||||
|
tril_ids = math_ops.select(pred,
|
||||||
|
array_ops.zeros([n, n], dtype=dtypes.int32),
|
||||||
|
tril_ids + 1)
|
||||||
|
tril_ids = array_ops.reshape(tril_ids, [-1])
|
||||||
|
return tril_ids
|
||||||
|
tril_ids = make_tril_ids(n)
|
||||||
|
|
||||||
|
# Assemble the ids into pairs.
|
||||||
|
idx = array_ops.pack([
|
||||||
|
array_ops.tile(array_ops.expand_dims(batch_ids, -1), [1, n*n]),
|
||||||
|
array_ops.tile([tril_ids], [m, 1])])
|
||||||
|
idx = array_ops.transpose(idx, [1, 2, 0])
|
||||||
|
|
||||||
|
if x.get_shape().ndims == 1:
|
||||||
|
# Prefer using gather because it has a gradient.
|
||||||
|
# We wrap the result in a list so downstream logic "just works."
|
||||||
|
y = [array_ops.gather(y[0, :], tril_ids)]
|
||||||
|
else:
|
||||||
|
y = array_ops.gather_nd(y, idx)
|
||||||
|
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
|
||||||
|
|
||||||
|
y.set_shape(y.get_shape().merge_with(final_shape))
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
class AppendDocstring(object):
|
class AppendDocstring(object):
|
||||||
"""Helper class to promote private subclass docstring to public counterpart.
|
"""Helper class to promote private subclass docstring to public counterpart.
|
||||||
|
|
||||||
|
@ -571,8 +571,7 @@ class WALSModel(object):
|
|||||||
extras = size % num_shards
|
extras = size % num_shards
|
||||||
assignments = tf.maximum(ids // (ids_per_shard + 1),
|
assignments = tf.maximum(ids // (ids_per_shard + 1),
|
||||||
(ids - extras) // ids_per_shard)
|
(ids - extras) // ids_per_shard)
|
||||||
new_ids = tf.select(assignments < extras,
|
new_ids = tf.where(assignments < extras, ids % (ids_per_shard + 1),
|
||||||
ids % (ids_per_shard + 1),
|
|
||||||
(ids - extras) % ids_per_shard)
|
(ids - extras) % ids_per_shard)
|
||||||
return assignments, new_ids
|
return assignments, new_ids
|
||||||
return func
|
return func
|
||||||
@ -655,7 +654,7 @@ class WALSModel(object):
|
|||||||
update_op: An op that assigns the newly computed values to the row/column
|
update_op: An op that assigns the newly computed values to the row/column
|
||||||
factors.
|
factors.
|
||||||
"""
|
"""
|
||||||
assert isinstance(sp_input, ops.SparseTensor)
|
assert isinstance(sp_input, tf.SparseTensor)
|
||||||
|
|
||||||
if update_row_factors:
|
if update_row_factors:
|
||||||
left = self._row_factors
|
left = self._row_factors
|
||||||
|
@ -18,8 +18,6 @@ py_library(
|
|||||||
"__init__.py",
|
"__init__.py",
|
||||||
"python/framework/__init__.py",
|
"python/framework/__init__.py",
|
||||||
"python/framework/checkpoint_utils.py",
|
"python/framework/checkpoint_utils.py",
|
||||||
"python/framework/decorator_utils.py",
|
|
||||||
"python/framework/deprecation.py",
|
|
||||||
"python/framework/experimental.py",
|
"python/framework/experimental.py",
|
||||||
"python/framework/tensor_util.py",
|
"python/framework/tensor_util.py",
|
||||||
"python/ops/__init__.py",
|
"python/ops/__init__.py",
|
||||||
@ -102,20 +100,6 @@ py_test(
|
|||||||
deps = ["//tensorflow:tensorflow_py"],
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "deprecation_test",
|
|
||||||
srcs = ["python/framework/deprecation_test.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = ["//tensorflow:tensorflow_py"],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "decorator_utils_test",
|
|
||||||
srcs = ["python/framework/decorator_utils_test.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = ["//tensorflow:tensorflow_py"],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "experimental_test",
|
name = "experimental_test",
|
||||||
srcs = ["python/framework/experimental_test.py"],
|
srcs = ["python/framework/experimental_test.py"],
|
||||||
@ -135,6 +119,7 @@ py_test(
|
|||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["python/ops/variables_test.py"],
|
srcs = ["python/ops/variables_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["manual"],
|
||||||
deps = ["//tensorflow:tensorflow_py"],
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,10 +19,10 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
from tensorflow.contrib.framework.python.framework import decorator_utils
|
|
||||||
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
|
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
|
||||||
from tensorflow.contrib.framework.python.framework.deprecation import deprecated
|
|
||||||
from tensorflow.contrib.framework.python.framework.deprecation import deprecated_arg_values
|
|
||||||
from tensorflow.contrib.framework.python.framework.deprecation import deprecated_args
|
|
||||||
from tensorflow.contrib.framework.python.framework.experimental import experimental
|
from tensorflow.contrib.framework.python.framework.experimental import experimental
|
||||||
from tensorflow.contrib.framework.python.framework.tensor_util import *
|
from tensorflow.contrib.framework.python.framework.tensor_util import *
|
||||||
|
from tensorflow.python.util import decorator_utils
|
||||||
|
from tensorflow.python.util.deprecation import deprecated
|
||||||
|
from tensorflow.python.util.deprecation import deprecated_arg_values
|
||||||
|
from tensorflow.python.util.deprecation import deprecated_args
|
||||||
|
@ -20,8 +20,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from tensorflow.contrib.framework.python.framework import decorator_utils
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.util import decorator_utils
|
||||||
|
|
||||||
|
|
||||||
def _add_experimental_function_notice_to_docstring(doc):
|
def _add_experimental_function_notice_to_docstring(doc):
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -283,7 +284,7 @@ def is_tensor(x):
|
|||||||
Returns:
|
Returns:
|
||||||
`True` if `x` is a tensor, `False` if not.
|
`True` if `x` is a tensor, `False` if not.
|
||||||
"""
|
"""
|
||||||
tensor_types = (ops.Tensor, ops.SparseTensor, variables.Variable)
|
tensor_types = (ops.Tensor, sparse_tensor.SparseTensor, variables.Variable)
|
||||||
return isinstance(x, tensor_types)
|
return isinstance(x, tensor_types)
|
||||||
|
|
||||||
|
|
||||||
@ -303,7 +304,7 @@ def with_shape(expected_shape, tensor):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if tensor has an invalid shape.
|
ValueError: if tensor has an invalid shape.
|
||||||
"""
|
"""
|
||||||
if isinstance(tensor, ops.SparseTensor):
|
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||||
raise ValueError('SparseTensor not supported.')
|
raise ValueError('SparseTensor not supported.')
|
||||||
|
|
||||||
# Shape type must be 1D int32.
|
# Shape type must be 1D int32.
|
||||||
@ -376,9 +377,9 @@ def convert_to_tensor_or_sparse_tensor(
|
|||||||
"""
|
"""
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
dtype = dtypes.as_dtype(dtype)
|
dtype = dtypes.as_dtype(dtype)
|
||||||
if isinstance(value, ops.SparseTensorValue):
|
if isinstance(value, sparse_tensor.SparseTensorValue):
|
||||||
value = ops.SparseTensor.from_value(value)
|
value = sparse_tensor.SparseTensor.from_value(value)
|
||||||
if isinstance(value, ops.SparseTensor):
|
if isinstance(value, sparse_tensor.SparseTensor):
|
||||||
if dtype and not dtype.is_compatible_with(value.dtype):
|
if dtype and not dtype.is_compatible_with(value.dtype):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Sparse dtype: requested = %s, actual = %s' % (
|
'Sparse dtype: requested = %s, actual = %s' % (
|
||||||
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import logging_ops
|
from tensorflow.python.ops import logging_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -43,7 +44,7 @@ def _get_tensor_repr(t,
|
|||||||
if print_tensor_type:
|
if print_tensor_type:
|
||||||
if isinstance(t, ops.Tensor):
|
if isinstance(t, ops.Tensor):
|
||||||
t_type_str = "Type: Tensor ({})".format(t.dtype.name)
|
t_type_str = "Type: Tensor ({})".format(t.dtype.name)
|
||||||
elif isinstance(t, ops.SparseTensor):
|
elif isinstance(t, sparse_tensor.SparseTensor):
|
||||||
t_type_str = "Type: SparseTensor ({})".format(t.dtype.name)
|
t_type_str = "Type: SparseTensor ({})".format(t.dtype.name)
|
||||||
elif isinstance(t, tensor_array_ops.TensorArray):
|
elif isinstance(t, tensor_array_ops.TensorArray):
|
||||||
t_type_str = "Type: TensorArray ({})".format(t.dtype.name)
|
t_type_str = "Type: TensorArray ({})".format(t.dtype.name)
|
||||||
@ -51,7 +52,7 @@ def _get_tensor_repr(t,
|
|||||||
tensor_list.append(constant_op.constant(t_type_str))
|
tensor_list.append(constant_op.constant(t_type_str))
|
||||||
|
|
||||||
if print_shape:
|
if print_shape:
|
||||||
if isinstance(t, ops.SparseTensor):
|
if isinstance(t, sparse_tensor.SparseTensor):
|
||||||
tensor_list.append(constant_op.constant("Shape:"))
|
tensor_list.append(constant_op.constant("Shape:"))
|
||||||
tensor_list.append(t.shape)
|
tensor_list.append(t.shape)
|
||||||
elif isinstance(t, ops.Tensor):
|
elif isinstance(t, ops.Tensor):
|
||||||
@ -66,7 +67,7 @@ def _get_tensor_repr(t,
|
|||||||
tensor_list.append(constant_op.constant("First True in Boolean tensor at:"))
|
tensor_list.append(constant_op.constant("First True in Boolean tensor at:"))
|
||||||
tensor_list.append(math_ops.argmax(int_tensor, 0))
|
tensor_list.append(math_ops.argmax(int_tensor, 0))
|
||||||
|
|
||||||
if isinstance(t, ops.SparseTensor):
|
if isinstance(t, sparse_tensor.SparseTensor):
|
||||||
tensor_list.append(constant_op.constant("Sparse indices:"))
|
tensor_list.append(constant_op.constant("Sparse indices:"))
|
||||||
tensor_list.append(t.indices)
|
tensor_list.append(t.indices)
|
||||||
tensor_list.append(constant_op.constant("Sparse values:"))
|
tensor_list.append(constant_op.constant("Sparse values:"))
|
||||||
@ -137,13 +138,13 @@ def print_op(input_,
|
|||||||
if isinstance(input_, ops.Tensor):
|
if isinstance(input_, ops.Tensor):
|
||||||
input_ = logging_ops.Print(input_, tensor_list, message, first_n, summarize,
|
input_ = logging_ops.Print(input_, tensor_list, message, first_n, summarize,
|
||||||
name)
|
name)
|
||||||
elif isinstance(input_, ops.SparseTensor):
|
elif isinstance(input_, sparse_tensor.SparseTensor):
|
||||||
p = logging_ops.Print(
|
p = logging_ops.Print(
|
||||||
constant_op.constant([]), tensor_list, message, first_n, summarize,
|
constant_op.constant([]), tensor_list, message, first_n, summarize,
|
||||||
name)
|
name)
|
||||||
|
|
||||||
with ops.control_dependencies([p]):
|
with ops.control_dependencies([p]):
|
||||||
input_ = ops.SparseTensor(array_ops.identity(input_.indices),
|
input_ = sparse_tensor.SparseTensor(array_ops.identity(input_.indices),
|
||||||
array_ops.identity(input_.values),
|
array_ops.identity(input_.values),
|
||||||
array_ops.identity(input_.shape))
|
array_ops.identity(input_.shape))
|
||||||
elif isinstance(input_, tensor_array_ops.TensorArray):
|
elif isinstance(input_, tensor_array_ops.TensorArray):
|
||||||
|
@ -36,7 +36,7 @@ class LocalVariableTest(tf.test.TestCase):
|
|||||||
variables = tf.local_variables()
|
variables = tf.local_variables()
|
||||||
self.assertEquals(2, len(variables))
|
self.assertEquals(2, len(variables))
|
||||||
self.assertRaises(tf.OpError, sess.run, variables)
|
self.assertRaises(tf.OpError, sess.run, variables)
|
||||||
tf.initialize_variables(variables).run()
|
tf.variables_initializer(variables).run()
|
||||||
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
|
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
|
||||||
|
|
||||||
def testLocalVariableNameAndShape(self):
|
def testLocalVariableNameAndShape(self):
|
||||||
@ -51,7 +51,7 @@ class LocalVariableTest(tf.test.TestCase):
|
|||||||
with self.test_session():
|
with self.test_session():
|
||||||
with tf.variable_scope('A'):
|
with tf.variable_scope('A'):
|
||||||
a = tf.contrib.framework.local_variable(0)
|
a = tf.contrib.framework.local_variable(0)
|
||||||
self.assertFalse(a in tf.all_variables())
|
self.assertFalse(a in tf.global_variables())
|
||||||
self.assertTrue(a in tf.local_variables())
|
self.assertTrue(a in tf.local_variables())
|
||||||
|
|
||||||
def testLocalVariableNotInVariablesToRestore(self):
|
def testLocalVariableNotInVariablesToRestore(self):
|
||||||
@ -82,7 +82,7 @@ class LocalVariableTest(tf.test.TestCase):
|
|||||||
def testInitializedVariableValue(self):
|
def testInitializedVariableValue(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a')
|
a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a')
|
||||||
sess.run(tf.initialize_local_variables())
|
sess.run(tf.local_variables_initializer())
|
||||||
self.assertAllEqual(a.eval(), [0]*5)
|
self.assertAllEqual(a.eval(), [0]*5)
|
||||||
|
|
||||||
|
|
||||||
@ -439,7 +439,7 @@ class ModelVariablesTest(tf.test.TestCase):
|
|||||||
with self.test_session():
|
with self.test_session():
|
||||||
with tf.variable_scope('A'):
|
with tf.variable_scope('A'):
|
||||||
a = tf.contrib.framework.model_variable('a', [5])
|
a = tf.contrib.framework.model_variable('a', [5])
|
||||||
self.assertTrue(a in tf.all_variables())
|
self.assertTrue(a in tf.global_variables())
|
||||||
self.assertTrue(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
|
self.assertTrue(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
|
||||||
self.assertFalse(a in tf.local_variables())
|
self.assertFalse(a in tf.local_variables())
|
||||||
|
|
||||||
@ -474,7 +474,7 @@ class ModelVariablesTest(tf.test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
a = tf.contrib.framework.model_variable(
|
a = tf.contrib.framework.model_variable(
|
||||||
'a', [5], initializer=tf.ones_initializer)
|
'a', [5], initializer=tf.ones_initializer)
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
self.assertAllEqual(a.eval(), [1]*5)
|
self.assertAllEqual(a.eval(), [1]*5)
|
||||||
|
|
||||||
def testDeviceFn(self):
|
def testDeviceFn(self):
|
||||||
@ -667,7 +667,7 @@ class AssignFromValuesTest(tf.test.TestCase):
|
|||||||
var_names_to_values)
|
var_names_to_values)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
sess.run(assign_op, feed_dict)
|
sess.run(assign_op, feed_dict)
|
||||||
@ -697,7 +697,7 @@ class AssignFromValuesTest(tf.test.TestCase):
|
|||||||
var_names_to_values)
|
var_names_to_values)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
sess.run(assign_op, feed_dict)
|
sess.run(assign_op, feed_dict)
|
||||||
@ -725,7 +725,7 @@ class AssignFromValuesFnTest(tf.test.TestCase):
|
|||||||
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
@ -754,7 +754,7 @@ class AssignFromValuesFnTest(tf.test.TestCase):
|
|||||||
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
@ -786,7 +786,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
|||||||
var_value = var_names_to_values[var_name]
|
var_value = var_names_to_values[var_name]
|
||||||
var_list.append(tf.Variable(var_value, name=var_name))
|
var_list.append(tf.Variable(var_value, name=var_name))
|
||||||
saver = tf.train.Saver(var_list)
|
saver = tf.train.Saver(var_list)
|
||||||
init_op = tf.initialize_variables(var_list)
|
init_op = tf.variables_initializer(var_list)
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
# Save the initialized values in the file at 'checkpoint_dir'
|
# Save the initialized values in the file at 'checkpoint_dir'
|
||||||
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
||||||
@ -808,7 +808,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
|||||||
model_path, vars_to_restore)
|
model_path, vars_to_restore)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
sess.run(op, feed_dict)
|
sess.run(op, feed_dict)
|
||||||
@ -859,7 +859,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
|||||||
vars_to_restore)
|
vars_to_restore)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
sess.run(op, feed_dict)
|
sess.run(op, feed_dict)
|
||||||
@ -890,7 +890,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
var_value = var_names_to_values[var_name]
|
var_value = var_names_to_values[var_name]
|
||||||
var_list.append(tf.Variable(var_value, name=var_name))
|
var_list.append(tf.Variable(var_value, name=var_name))
|
||||||
saver = tf.train.Saver(var_list)
|
saver = tf.train.Saver(var_list)
|
||||||
init_op = tf.initialize_variables(var_list)
|
init_op = tf.variables_initializer(var_list)
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
# Save the initialized values in the file at 'checkpoint_dir'
|
# Save the initialized values in the file at 'checkpoint_dir'
|
||||||
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
||||||
@ -912,7 +912,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
model_path, vars_to_restore)
|
model_path, vars_to_restore)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
@ -938,7 +938,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
model_path, vars_to_restore)
|
model_path, vars_to_restore)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
@ -961,7 +961,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
model_path, vars_to_restore, reshape_variables=True)
|
model_path, vars_to_restore, reshape_variables=True)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
@ -989,7 +989,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
vars_to_restore)
|
vars_to_restore)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
with self.assertRaises(tf.errors.NotFoundError):
|
with self.assertRaises(tf.errors.NotFoundError):
|
||||||
@ -1015,7 +1015,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
ignore_missing_vars=True)
|
ignore_missing_vars=True)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
@ -1044,7 +1044,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
|||||||
ignore_missing_vars=True)
|
ignore_missing_vars=True)
|
||||||
|
|
||||||
# Initialize the variables.
|
# Initialize the variables.
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# Perform the assignment.
|
# Perform the assignment.
|
||||||
init_fn(sess)
|
init_fn(sess)
|
||||||
|
38
tensorflow/contrib/integrate/BUILD
Normal file
38
tensorflow/contrib/integrate/BUILD
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# Description:
|
||||||
|
# Integration and ODE solvers for TensorFlow.
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "integrate_py",
|
||||||
|
srcs = [
|
||||||
|
"__init__.py",
|
||||||
|
"python/ops/odes.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "odes_test",
|
||||||
|
srcs = ["python/ops/odes_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":integrate_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
9
tensorflow/contrib/integrate/README.md
Normal file
9
tensorflow/contrib/integrate/README.md
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
# Integration and ODE solvers for TensorFlow
|
||||||
|
|
||||||
|
TensorFlow equivalents to the routines provided by `scipy.integrate`. Currently
|
||||||
|
contains a single function, `odeint`, for integrating ordinary differential
|
||||||
|
equations.
|
||||||
|
|
||||||
|
Maintainers:
|
||||||
|
- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
|
||||||
|
- Marc Coram (mcoram@google.com, github.com/mcoram)
|
64
tensorflow/contrib/integrate/__init__.py
Normal file
64
tensorflow/contrib/integrate/__init__.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Integration and ODE solvers for TensorFlow.
|
||||||
|
|
||||||
|
## Example: Lorenz attractor
|
||||||
|
|
||||||
|
We can use `odeint` to solve the
|
||||||
|
[Lorentz system](https://en.wikipedia.org/wiki/Lorenz_system) of ordinary
|
||||||
|
differential equations, a prototypical example of chaotic dynamics:
|
||||||
|
|
||||||
|
```python
|
||||||
|
rho = 28.0
|
||||||
|
sigma = 10.0
|
||||||
|
beta = 8.0/3.0
|
||||||
|
|
||||||
|
def lorenz_equation(state, t):
|
||||||
|
x, y, z = tf.unpack(state)
|
||||||
|
dx = sigma * (y - x)
|
||||||
|
dy = x * (rho - z) - y
|
||||||
|
dz = x * y - beta * z
|
||||||
|
return tf.pack([dx, dy, dz])
|
||||||
|
|
||||||
|
init_state = tf.constant([0, 2, 20], dtype=tf.float64)
|
||||||
|
t = np.linspace(0, 50, num=5000)
|
||||||
|
tensor_state, tensor_info = tf.contrib.integrate.odeint(
|
||||||
|
lorenz_equation, init_state, t, full_output=True)
|
||||||
|
|
||||||
|
sess = tf.Session()
|
||||||
|
state, info = sess.run([tensor_state, tensor_info])
|
||||||
|
x, y, z = state.T
|
||||||
|
plt.plot(x, z)
|
||||||
|
```
|
||||||
|
|
||||||
|
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||||
|
<img style="width:100%" src="../../images/lorenz_attractor.png" alt>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Ops
|
||||||
|
|
||||||
|
@@odeint
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.contrib.integrate.python.ops.odes import *
|
||||||
|
from tensorflow.python.util.all_util import make_all
|
||||||
|
|
||||||
|
__all__ = make_all(__name__)
|
503
tensorflow/contrib/integrate/python/ops/odes.py
Normal file
503
tensorflow/contrib/integrate/python/ops/odes.py
Normal file
@ -0,0 +1,503 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""ODE solvers for TensorFlow."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
|
|
||||||
|
|
||||||
|
_ButcherTableau = collections.namedtuple(
|
||||||
|
'_ButcherTableau', 'alpha beta c_sol c_mid c_error')
|
||||||
|
|
||||||
|
# Parameters from Shampine (1986), section 4.
|
||||||
|
_DORMAND_PRINCE_TABLEAU = _ButcherTableau(
|
||||||
|
alpha=[1/5, 3/10, 4/5, 8/9, 1., 1.],
|
||||||
|
beta=[[1/5],
|
||||||
|
[3/40, 9/40],
|
||||||
|
[44/45, -56/15, 32/9],
|
||||||
|
[19372/6561, -25360/2187, 64448/6561, -212/729],
|
||||||
|
[9017/3168, -355/33, 46732/5247, 49/176, -5103/18656],
|
||||||
|
[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]],
|
||||||
|
c_sol=[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0],
|
||||||
|
c_mid=[6025192743/30085553152 / 2, 0, 51252292925/65400821598 / 2,
|
||||||
|
-2691868925/45128329728 / 2, 187940372067/1594534317056 / 2,
|
||||||
|
-1776094331/19743644256 / 2, 11237099/235043384 / 2],
|
||||||
|
c_error=[1951/21600 - 35/384,
|
||||||
|
0,
|
||||||
|
22642/50085 - 500/1113,
|
||||||
|
451/720 - 125/192,
|
||||||
|
-12231/42400 - -2187/6784,
|
||||||
|
649/6300 - 11/84,
|
||||||
|
1/60],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _possibly_nonzero(x):
|
||||||
|
return isinstance(x, ops.Tensor) or x != 0
|
||||||
|
|
||||||
|
|
||||||
|
def _scaled_dot_product(scale, xs, ys, name=None):
|
||||||
|
"""Calculate a scaled, vector inner product between lists of Tensors."""
|
||||||
|
with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope:
|
||||||
|
# Some of the parameters in our Butcher tableau include zeros. Using
|
||||||
|
# _possibly_nonzero lets us avoid wasted computation.
|
||||||
|
return math_ops.add_n([(scale * x) * y for x, y in zip(xs, ys)
|
||||||
|
if _possibly_nonzero(x) or _possibly_nonzero(y)],
|
||||||
|
name=scope)
|
||||||
|
|
||||||
|
|
||||||
|
def _dot_product(xs, ys, name=None):
|
||||||
|
"""Calculate the vector inner product between two lists of Tensors."""
|
||||||
|
with ops.name_scope(name, 'dot_product', [xs, ys]) as scope:
|
||||||
|
return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope)
|
||||||
|
|
||||||
|
|
||||||
|
def _runge_kutta_step(func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_TABLEAU,
|
||||||
|
name=None):
|
||||||
|
"""Take an arbitrary Runge-Kutta step and estimate error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to evaluate like `func(y, t)` to compute the time derivative
|
||||||
|
of `y`.
|
||||||
|
y0: Tensor initial value for the state.
|
||||||
|
f0: Tensor initial value for the derivative, computed from `func(y0, t0)`.
|
||||||
|
t0: float64 scalar Tensor giving the initial time.
|
||||||
|
dt: float64 scalar Tensor giving the size of the desired time step.
|
||||||
|
tableau: optional _ButcherTableau describing how to take the Runge-Kutta
|
||||||
|
step.
|
||||||
|
name: optional name for the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
|
||||||
|
the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
|
||||||
|
estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
|
||||||
|
calculating these terms.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, 'runge_kutta_step', [y0, f0, t0, dt]) as scope:
|
||||||
|
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||||
|
f0 = ops.convert_to_tensor(f0, name='f0')
|
||||||
|
t0 = ops.convert_to_tensor(t0, name='t0')
|
||||||
|
dt = ops.convert_to_tensor(dt, name='dt')
|
||||||
|
dt_cast = math_ops.cast(dt, y0.dtype)
|
||||||
|
|
||||||
|
k = [f0]
|
||||||
|
for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
|
||||||
|
ti = t0 + alpha_i * dt
|
||||||
|
yi = y0 + _scaled_dot_product(dt_cast, beta_i, k)
|
||||||
|
k.append(func(yi, ti))
|
||||||
|
|
||||||
|
if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]):
|
||||||
|
# This property (true for Dormand-Prince) lets us save a few FLOPs.
|
||||||
|
yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k)
|
||||||
|
|
||||||
|
y1 = array_ops.identity(yi, name='%s/y1' % scope)
|
||||||
|
f1 = array_ops.identity(k[-1], name='%s/f1' % scope)
|
||||||
|
y1_error = _scaled_dot_product(dt_cast, tableau.c_error, k,
|
||||||
|
name='%s/y1_error' % scope)
|
||||||
|
return (y1, f1, y1_error, k)
|
||||||
|
|
||||||
|
|
||||||
|
def _interp_fit(y0, y1, y_mid, f0, f1, dt):
|
||||||
|
"""Fit coefficients for 4th order polynomial interpolation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y0: function value at the start of the interval.
|
||||||
|
y1: function value at the end of the interval.
|
||||||
|
y_mid: function value at the mid-point of the interval.
|
||||||
|
f0: derivative value at the start of the interval.
|
||||||
|
f1: derivative value at the end of the interval.
|
||||||
|
dt: width of the interval.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
|
||||||
|
`p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
|
||||||
|
between 0 (start of interval) and 1 (end of interval).
|
||||||
|
"""
|
||||||
|
# a, b, c, d, e = sympy.symbols('a b c d e')
|
||||||
|
# x, dt, y0, y1, y_mid, f0, f1 = sympy.symbols('x dt y0 y1 y_mid f0 f1')
|
||||||
|
# p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
|
||||||
|
# sympy.solve([p.subs(x, 0) - y0,
|
||||||
|
# p.subs(x, 1 / 2) - y_mid,
|
||||||
|
# p.subs(x, 1) - y1,
|
||||||
|
# (p.diff(x) / dt).subs(x, 0) - f0,
|
||||||
|
# (p.diff(x) / dt).subs(x, 1) - f1],
|
||||||
|
# [a, b, c, d, e])
|
||||||
|
# {a: -2.0*dt*f0 + 2.0*dt*f1 - 8.0*y0 - 8.0*y1 + 16.0*y_mid,
|
||||||
|
# b: 5.0*dt*f0 - 3.0*dt*f1 + 18.0*y0 + 14.0*y1 - 32.0*y_mid,
|
||||||
|
# c: -4.0*dt*f0 + dt*f1 - 11.0*y0 - 5.0*y1 + 16.0*y_mid,
|
||||||
|
# d: dt*f0,
|
||||||
|
# e: y0}
|
||||||
|
a = _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0, f1, y0, y1, y_mid])
|
||||||
|
b = _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0, f1, y0, y1, y_mid])
|
||||||
|
c = _dot_product([-4 * dt, dt, -11, -5, 16], [f0, f1, y0, y1, y_mid])
|
||||||
|
d = dt * f0
|
||||||
|
e = y0
|
||||||
|
return [a, b, c, d, e]
|
||||||
|
|
||||||
|
|
||||||
|
def _interp_fit_rk(y0, y1, k, dt, tableau=_DORMAND_PRINCE_TABLEAU):
|
||||||
|
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
|
||||||
|
with ops.name_scope('interp_fit_rk'):
|
||||||
|
dt = math_ops.cast(dt, y0.dtype)
|
||||||
|
y_mid = y0 + _scaled_dot_product(dt, tableau.c_mid, k)
|
||||||
|
f0 = k[0]
|
||||||
|
f1 = k[-1]
|
||||||
|
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
|
||||||
|
|
||||||
|
|
||||||
|
def _interp_evaluate(coefficients, t0, t1, t):
|
||||||
|
"""Evaluate polynomial interpolation at the given time point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coefficients: list of Tensor coefficients as created by `interp_fit`.
|
||||||
|
t0: scalar float64 Tensor giving the start of the interval.
|
||||||
|
t1: scalar float64 Tensor giving the end of the interval.
|
||||||
|
t: scalar float64 Tensor giving the desired interpolation point.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Polynomial interpolation of the coefficients at time `t`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope('interp_evaluate'):
|
||||||
|
t0 = ops.convert_to_tensor(t0)
|
||||||
|
t1 = ops.convert_to_tensor(t1)
|
||||||
|
t = ops.convert_to_tensor(t)
|
||||||
|
|
||||||
|
dtype = coefficients[0].dtype
|
||||||
|
|
||||||
|
assert_op = control_flow_ops.Assert(
|
||||||
|
(t0 <= t) & (t <= t1),
|
||||||
|
['invalid interpolation, fails `t0 <= t <= t1`:', t0, t, t1])
|
||||||
|
with ops.control_dependencies([assert_op]):
|
||||||
|
x = math_ops.cast((t - t0) / (t1 - t0), dtype)
|
||||||
|
|
||||||
|
xs = [constant_op.constant(1, dtype), x]
|
||||||
|
for _ in range(2, len(coefficients)):
|
||||||
|
xs.append(xs[-1] * x)
|
||||||
|
|
||||||
|
return _dot_product(coefficients, reversed(xs))
|
||||||
|
|
||||||
|
|
||||||
|
def _optimal_step_size(last_step,
|
||||||
|
error_ratio,
|
||||||
|
safety=0.9,
|
||||||
|
ifactor=10.0,
|
||||||
|
dfactor=0.2,
|
||||||
|
order=5,
|
||||||
|
name=None):
|
||||||
|
"""Calculate the optimal size for the next Runge-Kutta step."""
|
||||||
|
with ops.name_scope(
|
||||||
|
name, 'optimal_step_size', [last_step, error_ratio]) as scope:
|
||||||
|
error_ratio = math_ops.cast(error_ratio, last_step.dtype)
|
||||||
|
exponent = math_ops.cast(1 / order, last_step.dtype)
|
||||||
|
# this looks more complex than necessary, but importantly it keeps
|
||||||
|
# error_ratio in the numerator so we can't divide by zero:
|
||||||
|
factor = math_ops.maximum(
|
||||||
|
1 / ifactor,
|
||||||
|
math_ops.minimum(error_ratio ** exponent / safety, 1 / dfactor))
|
||||||
|
return math_ops.div(last_step, factor, name=scope)
|
||||||
|
|
||||||
|
|
||||||
|
def _abs_square(x):
|
||||||
|
if x.dtype.is_complex:
|
||||||
|
return math_ops.square(math_ops.real(x)) + math_ops.square(math_ops.imag(x))
|
||||||
|
else:
|
||||||
|
return math_ops.square(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _ta_append(tensor_array, value):
|
||||||
|
"""Append a value to the end of a tf.TensorArray."""
|
||||||
|
return tensor_array.write(tensor_array.size(), value)
|
||||||
|
|
||||||
|
|
||||||
|
class _RungeKuttaState(collections.namedtuple(
|
||||||
|
'_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')):
|
||||||
|
"""Saved state of the Runge Kutta solver.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
y1: Tensor giving the function value at the end of the last time step.
|
||||||
|
f1: Tensor giving derivative at the end of the last time step.
|
||||||
|
t0: scalar float64 Tensor giving start of the last time step.
|
||||||
|
t1: scalar float64 Tensor giving end of the last time step.
|
||||||
|
dt: scalar float64 Tensor giving the size for the next time step.
|
||||||
|
interp_coef: list of Tensors giving coefficients for polynomial
|
||||||
|
interpolation between `t0` and `t1`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class _History(collections.namedtuple(
|
||||||
|
'_History', 'integrate_points, error_ratio')):
|
||||||
|
"""Saved integration history for use in `info_dict`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
integrate_points: tf.TensorArray storing integrating time points.
|
||||||
|
error_ratio: tf.TensorArray storing computed error ratios at each
|
||||||
|
integration step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _dopri5(func,
|
||||||
|
y0,
|
||||||
|
t,
|
||||||
|
rtol,
|
||||||
|
atol,
|
||||||
|
full_output=False,
|
||||||
|
first_step=None,
|
||||||
|
safety=0.9,
|
||||||
|
ifactor=10.0,
|
||||||
|
dfactor=0.2,
|
||||||
|
max_num_steps=1000,
|
||||||
|
name=None):
|
||||||
|
"""Solve an ODE for `odeint` using method='dopri5'."""
|
||||||
|
|
||||||
|
if first_step is None:
|
||||||
|
# at some point, we might want to switch to picking the step size
|
||||||
|
# automatically
|
||||||
|
first_step = 1.0
|
||||||
|
|
||||||
|
with ops.name_scope(
|
||||||
|
name, 'dopri5',
|
||||||
|
[y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps]) as scope:
|
||||||
|
|
||||||
|
first_step = ops.convert_to_tensor(first_step, dtype=t.dtype,
|
||||||
|
name='first_step')
|
||||||
|
safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety')
|
||||||
|
ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor')
|
||||||
|
dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor')
|
||||||
|
max_num_steps = ops.convert_to_tensor(max_num_steps, dtype=dtypes.int32,
|
||||||
|
name='max_num_steps')
|
||||||
|
|
||||||
|
def adaptive_runge_kutta_step(rk_state, history, n_steps):
|
||||||
|
"""Take an adaptive Runge-Kutta step to integrate the ODE."""
|
||||||
|
y0, f0, _, t0, dt, interp_coeff = rk_state
|
||||||
|
with ops.name_scope('assertions'):
|
||||||
|
check_underflow = control_flow_ops.Assert(
|
||||||
|
t0 + dt > t0, ['underflow in dt', dt])
|
||||||
|
check_max_num_steps = control_flow_ops.Assert(
|
||||||
|
n_steps < max_num_steps, ['max_num_steps exceeded'])
|
||||||
|
check_numerics = control_flow_ops.Assert(
|
||||||
|
math_ops.reduce_all(math_ops.is_finite(abs(y0))),
|
||||||
|
['non-finite values in state `y`', y0])
|
||||||
|
with ops.control_dependencies(
|
||||||
|
[check_underflow, check_max_num_steps, check_numerics]):
|
||||||
|
y1, f1, y1_error, k = _runge_kutta_step(func, y0, f0, t0, dt)
|
||||||
|
|
||||||
|
with ops.name_scope('error_ratio'):
|
||||||
|
# We use the same approach as the dopri5 fortran code.
|
||||||
|
error_tol = atol + rtol * math_ops.maximum(abs(y0), abs(y1))
|
||||||
|
tensor_error_ratio = _abs_square(y1_error) / _abs_square(error_tol)
|
||||||
|
# Could also use reduce_maximum here.
|
||||||
|
error_ratio = math_ops.sqrt(math_ops.reduce_mean(tensor_error_ratio))
|
||||||
|
accept_step = error_ratio <= 1
|
||||||
|
|
||||||
|
with ops.name_scope('update/rk_state'):
|
||||||
|
# If we don't accept the step, the _RungeKuttaState will be useless
|
||||||
|
# (covering a time-interval of size 0), but that's OK, because in such
|
||||||
|
# cases we always immediately take another Runge-Kutta step.
|
||||||
|
y_next = control_flow_ops.cond(accept_step, lambda: y1, lambda: y0)
|
||||||
|
f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0)
|
||||||
|
t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0)
|
||||||
|
interp_coeff = control_flow_ops.cond(
|
||||||
|
accept_step,
|
||||||
|
lambda: _interp_fit_rk(y0, y1, k, dt),
|
||||||
|
lambda: interp_coeff)
|
||||||
|
dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor)
|
||||||
|
rk_state = _RungeKuttaState(
|
||||||
|
y_next, f_next, t0, t_next, dt_next, interp_coeff)
|
||||||
|
|
||||||
|
with ops.name_scope('update/history'):
|
||||||
|
history = _History(_ta_append(history.integrate_points, t0 + dt),
|
||||||
|
_ta_append(history.error_ratio, error_ratio))
|
||||||
|
return rk_state, history, n_steps + 1
|
||||||
|
|
||||||
|
def interpolate(solution, history, rk_state, i):
|
||||||
|
"""Interpolate through the next time point, integrating as necessary."""
|
||||||
|
with ops.name_scope('interpolate'):
|
||||||
|
rk_state, history, _ = control_flow_ops.while_loop(
|
||||||
|
lambda rk_state, *_: t[i] > rk_state.t1,
|
||||||
|
adaptive_runge_kutta_step,
|
||||||
|
(rk_state, history, 0),
|
||||||
|
name='integrate_loop')
|
||||||
|
y = _interp_evaluate(
|
||||||
|
rk_state.interp_coeff, rk_state.t0, rk_state.t1, t[i])
|
||||||
|
solution = solution.write(i, y)
|
||||||
|
return solution, history, rk_state, i + 1
|
||||||
|
|
||||||
|
assert_increasing = control_flow_ops.Assert(
|
||||||
|
math_ops.reduce_all(t[1:] > t[:-1]),
|
||||||
|
['`t` must be monotonic increasing'])
|
||||||
|
with ops.control_dependencies([assert_increasing]):
|
||||||
|
num_times = array_ops.size(t)
|
||||||
|
|
||||||
|
solution = tensor_array_ops.TensorArray(
|
||||||
|
y0.dtype, size=num_times).write(0, y0)
|
||||||
|
history = _History(
|
||||||
|
integrate_points=tensor_array_ops.TensorArray(
|
||||||
|
t.dtype, size=0, dynamic_size=True),
|
||||||
|
error_ratio=tensor_array_ops.TensorArray(
|
||||||
|
rtol.dtype, size=0, dynamic_size=True))
|
||||||
|
rk_state = _RungeKuttaState(
|
||||||
|
y0, func(y0, t[0]), t[0], t[0], first_step, interp_coeff=[y0] * 5)
|
||||||
|
|
||||||
|
solution, history, _, _ = control_flow_ops.while_loop(
|
||||||
|
lambda _, __, ___, i: i < num_times,
|
||||||
|
interpolate,
|
||||||
|
(solution, history, rk_state, 1),
|
||||||
|
name='interpolate_loop')
|
||||||
|
|
||||||
|
y = solution.pack(name=scope)
|
||||||
|
y.set_shape(t.get_shape().concatenate(y0.get_shape()))
|
||||||
|
if not full_output:
|
||||||
|
return y
|
||||||
|
else:
|
||||||
|
integrate_points = history.integrate_points.pack()
|
||||||
|
info_dict = {'num_func_evals': 6 * array_ops.size(integrate_points) + 1,
|
||||||
|
'integrate_points': integrate_points,
|
||||||
|
'error_ratio': history.error_ratio.pack()}
|
||||||
|
return (y, info_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def odeint(func,
|
||||||
|
y0,
|
||||||
|
t,
|
||||||
|
rtol=1e-6,
|
||||||
|
atol=1e-12,
|
||||||
|
method=None,
|
||||||
|
options=None,
|
||||||
|
full_output=False,
|
||||||
|
name=None):
|
||||||
|
"""Integrate a system of ordinary differential equations.
|
||||||
|
|
||||||
|
Solves the initial value problem for a non-stiff system of first order ode-s:
|
||||||
|
|
||||||
|
```
|
||||||
|
dy/dt = func(y, t), y(t[0]) = y0
|
||||||
|
```
|
||||||
|
|
||||||
|
where y is a Tensor of any shape.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
# solve `dy/dt = -y`, corresponding to exponential decay
|
||||||
|
tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
|
||||||
|
=> [1, exp(-1), exp(-2)]
|
||||||
|
```
|
||||||
|
|
||||||
|
Output dtypes and numerical precision are based on the dtypes of the inputs
|
||||||
|
`y0` and `t`.
|
||||||
|
|
||||||
|
Currently, implements 5th order Runge-Kutta with adaptive step size control
|
||||||
|
and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
|
||||||
|
method of `scipy.integrate.ode` and MATLAB's `ode45`.
|
||||||
|
|
||||||
|
Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
|
||||||
|
Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
|
||||||
|
doi:10.2307/2008219
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function that maps a Tensor holding the state `y` and a scalar Tensor
|
||||||
|
`t` into a Tensor of state derivatives with respect to time.
|
||||||
|
y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
|
||||||
|
have any floating point or complex dtype.
|
||||||
|
t: 1-D Tensor holding a sequence of time points for which to solve for
|
||||||
|
`y`. The initial time point should be the first element of this sequence,
|
||||||
|
and each time must be larger than the previous time. May have any floating
|
||||||
|
point dtype. If not provided as a Tensor, converted to a Tensor with
|
||||||
|
float64 dtype.
|
||||||
|
rtol: optional float64 Tensor specifying an upper bound on relative error,
|
||||||
|
per element of `y`.
|
||||||
|
atol: optional float64 Tensor specifying an upper bound on absolute error,
|
||||||
|
per element of `y`.
|
||||||
|
method: optional string indicating the integration method to use. Currently,
|
||||||
|
the only valid option is `'dopri5'`.
|
||||||
|
options: optional dict of configuring options for the indicated integration
|
||||||
|
method. Can only be provided if a `method` is explicitly set. For
|
||||||
|
`'dopri5'`, valid options include:
|
||||||
|
* first_step: an initial guess for the size of the first integration
|
||||||
|
(current default: 1.0, but may later be changed to use heuristics based
|
||||||
|
on the gradient).
|
||||||
|
* safety: safety factor for adaptive step control, generally a constant
|
||||||
|
in the range 0.8-1 (default: 0.9).
|
||||||
|
* ifactor: maximum factor by which the adaptive step may be increased
|
||||||
|
(default: 10.0).
|
||||||
|
* dfactor: maximum factor by which the adpative step may be decreased
|
||||||
|
(default: 0.2).
|
||||||
|
* max_num_steps: integer maximum number of integrate steps between time
|
||||||
|
points in `t` (default: 1000).
|
||||||
|
full_output: optional boolean. If True, `odeint` returns a tuple
|
||||||
|
`(y, info_dict)` describing the integration process.
|
||||||
|
name: Optional name for this operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y: (N+1)-D tensor, where the first dimension corresponds to different
|
||||||
|
time points. Contains the solved value of y for each desired time point in
|
||||||
|
`t`, with the initial value `y0` being the first element along the first
|
||||||
|
dimension.
|
||||||
|
info_dict: only if `full_output == True`. A dict with the following values:
|
||||||
|
* num_func_evals: integer Tensor counting the number of function
|
||||||
|
evaluations.
|
||||||
|
* integrate_points: 1D float64 Tensor with the upper bound of each
|
||||||
|
integration time step.
|
||||||
|
* error_ratio: 1D float Tensor with the estimated ratio of the integration
|
||||||
|
error to the error tolerance at each integration step. An ratio greater
|
||||||
|
than 1 corresponds to rejected steps.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if an invalid `method` is provided.
|
||||||
|
TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
|
||||||
|
an invalid dtype.
|
||||||
|
"""
|
||||||
|
if method is not None and method != 'dopri5':
|
||||||
|
raise ValueError('invalid method: %r' % method)
|
||||||
|
|
||||||
|
if options is None:
|
||||||
|
options = {}
|
||||||
|
elif method is None:
|
||||||
|
raise ValueError('cannot supply `options` without specifying `method`')
|
||||||
|
|
||||||
|
with ops.name_scope(name, 'odeint', [y0, t, rtol, atol]) as scope:
|
||||||
|
# TODO(shoyer): use nest.flatten (like tf.while_loop) to allow `y0` to be an
|
||||||
|
# arbitrarily nested tuple. This will help performance and usability by
|
||||||
|
# avoiding the need to pack/unpack in user functions.
|
||||||
|
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||||
|
if not (y0.dtype.is_floating or y0.dtype.is_complex):
|
||||||
|
raise TypeError('`y0` must have a floating point or complex floating '
|
||||||
|
'point dtype')
|
||||||
|
|
||||||
|
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
||||||
|
if not t.dtype.is_floating:
|
||||||
|
raise TypeError('`t` must have a floating point dtype')
|
||||||
|
|
||||||
|
error_dtype = abs(y0).dtype
|
||||||
|
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
|
||||||
|
atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol')
|
||||||
|
|
||||||
|
return _dopri5(func, y0, t,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
full_output=full_output,
|
||||||
|
name=scope,
|
||||||
|
**options)
|
232
tensorflow/contrib/integrate/python/ops/odes_test.py
Normal file
232
tensorflow/contrib/integrate/python/ops/odes_test.py
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for ODE solvers."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.integrate.python.ops import odes
|
||||||
|
|
||||||
|
|
||||||
|
class OdeIntTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(OdeIntTest, self).setUp()
|
||||||
|
# simple defaults (solution is a sin-wave)
|
||||||
|
matrix = tf.constant([[0, 1], [-1, 0]], dtype=tf.float64)
|
||||||
|
self.func = lambda y, t: tf.matmul(matrix, y)
|
||||||
|
self.y0 = np.array([[1.0], [0.0]])
|
||||||
|
|
||||||
|
def test_odeint_exp(self):
|
||||||
|
# Test odeint by an exponential function:
|
||||||
|
# dy / dt = y, y(0) = 1.0.
|
||||||
|
# Its analytical solution is y = exp(t).
|
||||||
|
func = lambda y, t: y
|
||||||
|
y0 = tf.constant(1.0, dtype=tf.float64)
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, y0, t)
|
||||||
|
self.assertIn('odeint', y_solved.name)
|
||||||
|
self.assertEqual(y_solved.get_shape(), tf.TensorShape([11]))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
y_true = np.exp(t)
|
||||||
|
self.assertAllClose(y_true, y_solved)
|
||||||
|
|
||||||
|
def test_odeint_complex(self):
|
||||||
|
# Test a complex, linear ODE:
|
||||||
|
# dy / dt = k * y, y(0) = 1.0.
|
||||||
|
# Its analytical solution is y = exp(k * t).
|
||||||
|
k = 1j - 0.1
|
||||||
|
func = lambda y, t: k * y
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, 1.0 + 0.0j, t)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
y_true = np.exp(k * t)
|
||||||
|
self.assertAllClose(y_true, y_solved)
|
||||||
|
|
||||||
|
def test_odeint_riccati(self):
|
||||||
|
# The Ricatti equation is:
|
||||||
|
# dy / dt = (y - t) ** 2 + 1.0, y(0) = 0.5.
|
||||||
|
# Its analytical solution is y = 1.0 / (2.0 - t) + t.
|
||||||
|
func = lambda t, y: (y - t)**2 + 1.0
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, np.float64(0.5), t)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
y_true = 1.0 / (2.0 - t) + t
|
||||||
|
self.assertAllClose(y_true, y_solved)
|
||||||
|
|
||||||
|
def test_odeint_2d_linear(self):
|
||||||
|
# Solve the 2D linear differential equation:
|
||||||
|
# dy1 / dt = 3.0 * y1 + 4.0 * y2,
|
||||||
|
# dy2 / dt = -4.0 * y1 + 3.0 * y2,
|
||||||
|
# y1(0) = 0.0,
|
||||||
|
# y2(0) = 1.0.
|
||||||
|
# Its analytical solution is
|
||||||
|
# y1 = sin(4.0 * t) * exp(3.0 * t),
|
||||||
|
# y2 = cos(4.0 * t) * exp(3.0 * t).
|
||||||
|
matrix = tf.constant([[3.0, 4.0], [-4.0, 3.0]], dtype=tf.float64)
|
||||||
|
func = lambda y, t: tf.matmul(matrix, y)
|
||||||
|
|
||||||
|
y0 = tf.constant([[0.0], [1.0]], dtype=tf.float64)
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, y0, t)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
|
||||||
|
y_true = np.zeros((len(t), 2, 1))
|
||||||
|
y_true[:, 0, 0] = np.sin(4.0 * t) * np.exp(3.0 * t)
|
||||||
|
y_true[:, 1, 0] = np.cos(4.0 * t) * np.exp(3.0 * t)
|
||||||
|
self.assertAllClose(y_true, y_solved, atol=1e-5)
|
||||||
|
|
||||||
|
def test_odeint_higher_rank(self):
|
||||||
|
func = lambda y, t: y
|
||||||
|
y0 = tf.constant(1.0, dtype=tf.float64)
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
for shape in [(), (1,), (1, 1)]:
|
||||||
|
expected_shape = (len(t),) + shape
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, tf.reshape(y0, shape), t)
|
||||||
|
self.assertEqual(y_solved.get_shape(), tf.TensorShape(expected_shape))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
self.assertEquals(y_solved.shape, expected_shape)
|
||||||
|
|
||||||
|
def test_odeint_all_dtypes(self):
|
||||||
|
func = lambda y, t: y
|
||||||
|
t = np.linspace(0.0, 1.0, 11)
|
||||||
|
for y0_dtype in [tf.float32, tf.float64, tf.complex64, tf.complex128]:
|
||||||
|
for t_dtype in [tf.float32, tf.float64]:
|
||||||
|
y0 = tf.cast(1.0, y0_dtype)
|
||||||
|
y_solved = tf.contrib.integrate.odeint(func, y0, tf.cast(t, t_dtype))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved = sess.run(y_solved)
|
||||||
|
expected = np.asarray(np.exp(t))
|
||||||
|
self.assertAllClose(y_solved, expected, rtol=1e-5)
|
||||||
|
self.assertEqual(tf.as_dtype(y_solved.dtype), y0_dtype)
|
||||||
|
|
||||||
|
def test_odeint_required_dtypes(self):
|
||||||
|
with self.assertRaisesRegexp(TypeError, '`y0` must have a floating point'):
|
||||||
|
tf.contrib.integrate.odeint(self.func, tf.cast(self.y0, tf.int32), [0, 1])
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, '`t` must have a floating point'):
|
||||||
|
tf.contrib.integrate.odeint(self.func, self.y0, tf.cast([0, 1], tf.int32))
|
||||||
|
|
||||||
|
def test_odeint_runtime_errors(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'cannot supply `options` without'):
|
||||||
|
tf.contrib.integrate.odeint(self.func, self.y0, [0, 1],
|
||||||
|
options={'first_step': 1.0})
|
||||||
|
|
||||||
|
y = tf.contrib.integrate.odeint(self.func, self.y0, [0, 1], method='dopri5',
|
||||||
|
options={'max_num_steps': 0})
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
tf.errors.InvalidArgumentError, 'max_num_steps'):
|
||||||
|
sess.run(y)
|
||||||
|
|
||||||
|
y = tf.contrib.integrate.odeint(self.func, self.y0, [1, 0])
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
tf.errors.InvalidArgumentError, 'monotonic increasing'):
|
||||||
|
sess.run(y)
|
||||||
|
|
||||||
|
def test_odeint_different_times(self):
|
||||||
|
# integrate steps should be independent of interpolation times
|
||||||
|
times0 = np.linspace(0, 10, num=11, dtype=float)
|
||||||
|
times1 = np.linspace(0, 10, num=101, dtype=float)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_solved_0, info_0 = sess.run(
|
||||||
|
tf.contrib.integrate.odeint(
|
||||||
|
self.func, self.y0, times0, full_output=True))
|
||||||
|
y_solved_1, info_1 = sess.run(
|
||||||
|
tf.contrib.integrate.odeint(
|
||||||
|
self.func, self.y0, times1, full_output=True))
|
||||||
|
|
||||||
|
self.assertAllClose(y_solved_0, y_solved_1[::10])
|
||||||
|
self.assertEqual(info_0['num_func_evals'], info_1['num_func_evals'])
|
||||||
|
self.assertAllEqual(info_0['integrate_points'], info_1['integrate_points'])
|
||||||
|
self.assertAllEqual(info_0['error_ratio'], info_1['error_ratio'])
|
||||||
|
|
||||||
|
def test_odeint_5th_order_accuracy(self):
|
||||||
|
t = [0, 20]
|
||||||
|
kwargs = dict(full_output=True,
|
||||||
|
method='dopri5',
|
||||||
|
options=dict(max_num_steps=2000))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
_, info_0 = sess.run(tf.contrib.integrate.odeint(
|
||||||
|
self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
|
||||||
|
_, info_1 = sess.run(tf.contrib.integrate.odeint(
|
||||||
|
self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs))
|
||||||
|
self.assertAllClose(info_0['integrate_points'].size * 1000 ** 0.2,
|
||||||
|
float(info_1['integrate_points'].size),
|
||||||
|
rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
class StepSizeTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_error_ratio_one(self):
|
||||||
|
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||||
|
error_ratio=tf.constant(1.0))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
new_step = sess.run(new_step)
|
||||||
|
self.assertAllClose(new_step, 0.9)
|
||||||
|
|
||||||
|
def test_ifactor(self):
|
||||||
|
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||||
|
error_ratio=tf.constant(0.0))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
new_step = sess.run(new_step)
|
||||||
|
self.assertAllClose(new_step, 10.0)
|
||||||
|
|
||||||
|
def test_dfactor(self):
|
||||||
|
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||||
|
error_ratio=tf.constant(1e6))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
new_step = sess.run(new_step)
|
||||||
|
self.assertAllClose(new_step, 0.2)
|
||||||
|
|
||||||
|
|
||||||
|
class InterpolationTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_5th_order_polynomial(self):
|
||||||
|
# this should be an exact fit
|
||||||
|
f = lambda x: x ** 4 + x ** 3 - 2 * x ** 2 + 4 * x + 5
|
||||||
|
f_prime = lambda x: 4 * x ** 3 + 3 * x ** 2 - 4 * x + 4
|
||||||
|
coeffs = odes._interp_fit(
|
||||||
|
f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0)
|
||||||
|
times = np.linspace(0, 10, dtype=np.float32)
|
||||||
|
y_fit = tf.pack([odes._interp_evaluate(coeffs, 0.0, 10.0, t)
|
||||||
|
for t in times])
|
||||||
|
y_expected = f(times)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
y_actual = sess.run(y_fit)
|
||||||
|
self.assertAllClose(y_expected, y_actual)
|
||||||
|
|
||||||
|
# attempt interpolation outside bounds
|
||||||
|
y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
|
sess.run(y_invalid)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -22,6 +22,7 @@ from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
@ -114,7 +115,8 @@ def safe_embedding_lookup_sparse(embedding_weights,
|
|||||||
array_ops.slice(original_shape, [0], [original_rank - 1])),
|
array_ops.slice(original_shape, [0], [original_rank - 1])),
|
||||||
array_ops.gather(original_shape, original_rank - 1)])
|
array_ops.gather(original_shape, original_rank - 1)])
|
||||||
if sparse_weights is not None:
|
if sparse_weights is not None:
|
||||||
sparse_weights = ops.SparseTensor(sparse_ids.indices,
|
sparse_weights = sparse_tensor.SparseTensor(
|
||||||
|
sparse_ids.indices,
|
||||||
sparse_weights.values, sparse_ids.shape)
|
sparse_weights.values, sparse_ids.shape)
|
||||||
|
|
||||||
# Prune invalid ids and weights.
|
# Prune invalid ids and weights.
|
||||||
@ -302,7 +304,7 @@ def hashed_embedding_lookup_sparse(params,
|
|||||||
params = list(params)
|
params = list(params)
|
||||||
if not isinstance(params, list):
|
if not isinstance(params, list):
|
||||||
params = [params]
|
params = [params]
|
||||||
if not isinstance(sparse_values, ops.SparseTensor):
|
if not isinstance(sparse_values, sparse_tensor.SparseTensor):
|
||||||
raise TypeError("sparse_values must be SparseTensor")
|
raise TypeError("sparse_values must be SparseTensor")
|
||||||
|
|
||||||
with ops.name_scope(name, "hashed_sparse_embedding_lookup",
|
with ops.name_scope(name, "hashed_sparse_embedding_lookup",
|
||||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.contrib.framework.python.ops import variables
|
from tensorflow.contrib.framework.python.ops import variables
|
||||||
from tensorflow.contrib.layers.python.layers import embedding_ops as contrib_embedding_ops
|
from tensorflow.contrib.layers.python.layers import embedding_ops as contrib_embedding_ops
|
||||||
from tensorflow.contrib.layers.python.ops import sparse_ops
|
from tensorflow.contrib.layers.python.ops import sparse_ops
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
@ -74,14 +74,14 @@ def bow_encoder(ids,
|
|||||||
initializer=initializer, regularizer=regularizer,
|
initializer=initializer, regularizer=regularizer,
|
||||||
trainable=trainable)
|
trainable=trainable)
|
||||||
if sparse_lookup:
|
if sparse_lookup:
|
||||||
if isinstance(ids, ops.SparseTensor):
|
if isinstance(ids, sparse_tensor.SparseTensor):
|
||||||
sparse_ids = ids
|
sparse_ids = ids
|
||||||
else:
|
else:
|
||||||
sparse_ids = sparse_ops.dense_to_sparse_tensor(ids)
|
sparse_ids = sparse_ops.dense_to_sparse_tensor(ids)
|
||||||
return contrib_embedding_ops.safe_embedding_lookup_sparse(
|
return contrib_embedding_ops.safe_embedding_lookup_sparse(
|
||||||
[embeddings], sparse_ids, combiner='mean', default_id=0)
|
[embeddings], sparse_ids, combiner='mean', default_id=0)
|
||||||
else:
|
else:
|
||||||
if isinstance(ids, ops.SparseTensor):
|
if isinstance(ids, sparse_tensor.SparseTensor):
|
||||||
raise TypeError('ids are expected to be dense Tensor, got: %s', ids)
|
raise TypeError('ids are expected to be dense Tensor, got: %s', ids)
|
||||||
return math_ops.reduce_mean(
|
return math_ops.reduce_mean(
|
||||||
embedding_ops.embedding_lookup(embeddings, ids),
|
embedding_ops.embedding_lookup(embeddings, ids),
|
||||||
|
@ -76,13 +76,12 @@ import collections
|
|||||||
import math
|
import math
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.contrib.framework.python.framework import deprecation
|
|
||||||
from tensorflow.contrib.layers.python.layers import layers
|
from tensorflow.contrib.layers.python.layers import layers
|
||||||
from tensorflow.contrib.layers.python.ops import bucketization_op
|
from tensorflow.contrib.layers.python.ops import bucketization_op
|
||||||
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
||||||
from tensorflow.contrib.lookup import lookup_ops as contrib_lookup_ops
|
from tensorflow.contrib.lookup import lookup_ops as contrib_lookup_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -90,6 +89,7 @@ from tensorflow.python.ops import parsing_ops
|
|||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.util import deprecation
|
||||||
|
|
||||||
|
|
||||||
class _LinearEmbeddingLookupArguments(
|
class _LinearEmbeddingLookupArguments(
|
||||||
@ -390,7 +390,7 @@ class _SparseColumnIntegerized(_SparseColumn):
|
|||||||
sparse_id_values = math_ops.mod(columns_to_tensors[self.name].values,
|
sparse_id_values = math_ops.mod(columns_to_tensors[self.name].values,
|
||||||
self.bucket_size,
|
self.bucket_size,
|
||||||
name="mod")
|
name="mod")
|
||||||
columns_to_tensors[self] = ops.SparseTensor(
|
columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
|
||||||
columns_to_tensors[self.name].indices, sparse_id_values,
|
columns_to_tensors[self.name].indices, sparse_id_values,
|
||||||
columns_to_tensors[self.name].shape)
|
columns_to_tensors[self.name].shape)
|
||||||
|
|
||||||
@ -464,7 +464,7 @@ class _SparseColumnHashed(_SparseColumn):
|
|||||||
|
|
||||||
sparse_id_values = string_ops.string_to_hash_bucket_fast(
|
sparse_id_values = string_ops.string_to_hash_bucket_fast(
|
||||||
sparse_values, self.bucket_size, name="lookup")
|
sparse_values, self.bucket_size, name="lookup")
|
||||||
columns_to_tensors[self] = ops.SparseTensor(
|
columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
|
||||||
sparse_tensor.indices, sparse_id_values, sparse_tensor.shape)
|
sparse_tensor.indices, sparse_id_values, sparse_tensor.shape)
|
||||||
|
|
||||||
|
|
||||||
@ -1452,7 +1452,8 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
|
|
||||||
indices = math_ops.to_int64(array_ops.transpose(array_ops.pack((i1, i2))))
|
indices = math_ops.to_int64(array_ops.transpose(array_ops.pack((i1, i2))))
|
||||||
shape = math_ops.to_int64(array_ops.pack([batch_size, dimension]))
|
shape = math_ops.to_int64(array_ops.pack([batch_size, dimension]))
|
||||||
sparse_id_values = ops.SparseTensor(indices, bucket_indices, shape)
|
sparse_id_values = sparse_tensor_py.SparseTensor(
|
||||||
|
indices, bucket_indices, shape)
|
||||||
|
|
||||||
return sparse_id_values
|
return sparse_id_values
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.contrib.layers.python.layers import feature_column as fc
|
|||||||
from tensorflow.contrib.layers.python.layers import layers
|
from tensorflow.contrib.layers.python.layers import layers
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -362,7 +363,7 @@ def _create_joint_embedding_lookup(columns_to_tensors,
|
|||||||
values = t.values + prev_size
|
values = t.values + prev_size
|
||||||
prev_size += a.vocab_size
|
prev_size += a.vocab_size
|
||||||
sparse_tensors.append(
|
sparse_tensors.append(
|
||||||
ops.SparseTensor(t.indices,
|
sparse_tensor_py.SparseTensor(t.indices,
|
||||||
values,
|
values,
|
||||||
t.shape))
|
t.shape))
|
||||||
sparse_tensor = sparse_ops.sparse_concat(1, sparse_tensors)
|
sparse_tensor = sparse_ops.sparse_concat(1, sparse_tensors)
|
||||||
@ -695,7 +696,7 @@ def _log_variable(variable):
|
|||||||
|
|
||||||
def _infer_real_valued_column_for_tensor(name, tensor):
|
def _infer_real_valued_column_for_tensor(name, tensor):
|
||||||
"""Creates a real_valued_column for given tensor and name."""
|
"""Creates a real_valued_column for given tensor and name."""
|
||||||
if isinstance(tensor, ops.SparseTensor):
|
if isinstance(tensor, sparse_tensor_py.SparseTensor):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'SparseTensor is not supported for auto detection. Please define '
|
'SparseTensor is not supported for auto detection. Please define '
|
||||||
'corresponding FeatureColumn for tensor {} {}.', name, tensor)
|
'corresponding FeatureColumn for tensor {} {}.', name, tensor)
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -609,7 +610,10 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
{embedding_col: input_tensor}, [embedding_col])
|
{embedding_col: input_tensor}, [embedding_col])
|
||||||
|
|
||||||
save = tf.train.Saver()
|
save = tf.train.Saver()
|
||||||
checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt")
|
ckpt_dir_prefix = os.path.join(
|
||||||
|
self.get_temp_dir(), "init_embedding_col_w_from_ckpt")
|
||||||
|
ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
|
||||||
|
checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.initialize_all_variables())
|
||||||
@ -670,7 +674,10 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
assign_op = tf.assign(weight[0], weight[0] + 0.5)
|
assign_op = tf.assign(weight[0], weight[0] + 0.5)
|
||||||
|
|
||||||
save = tf.train.Saver()
|
save = tf.train.Saver()
|
||||||
checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt")
|
ckpt_dir_prefix = os.path.join(
|
||||||
|
self.get_temp_dir(), "init_crossed_col_w_from_ckpt")
|
||||||
|
ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
|
||||||
|
checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.initialize_all_variables())
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.contrib.layers.python.layers import utils
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
@ -1217,7 +1218,7 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None):
|
|||||||
TypeError: `inputs` is not a `Tensor` or `SparseTensor`.
|
TypeError: `inputs` is not a `Tensor` or `SparseTensor`.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(scope, 'InnerFlatten', [inputs, new_rank]) as sc:
|
with ops.name_scope(scope, 'InnerFlatten', [inputs, new_rank]) as sc:
|
||||||
if isinstance(inputs, ops.SparseTensor):
|
if isinstance(inputs, sparse_tensor.SparseTensor):
|
||||||
flattened = _sparse_inner_flatten(inputs, new_rank)
|
flattened = _sparse_inner_flatten(inputs, new_rank)
|
||||||
else:
|
else:
|
||||||
inputs = ops.convert_to_tensor(inputs)
|
inputs = ops.convert_to_tensor(inputs)
|
||||||
|
@ -258,10 +258,11 @@ def optimize_loss(loss,
|
|||||||
grad_values = gradient
|
grad_values = gradient
|
||||||
|
|
||||||
if grad_values is not None:
|
if grad_values is not None:
|
||||||
|
var_name = variable.name.replace(":", "_")
|
||||||
if "gradients" in summaries:
|
if "gradients" in summaries:
|
||||||
summary.histogram("gradients/" + variable.name, grad_values)
|
summary.histogram("gradients/%s" % var_name, grad_values)
|
||||||
if "gradient_norm" in summaries:
|
if "gradient_norm" in summaries:
|
||||||
summary.scalar("gradient_norm/" + variable.name,
|
summary.scalar("gradient_norm/%s" % var_name,
|
||||||
clip_ops.global_norm([grad_values]))
|
clip_ops.global_norm([grad_values]))
|
||||||
|
|
||||||
if clip_gradients is not None and "gradient_norm" in summaries:
|
if clip_gradients is not None and "gradient_norm" in summaries:
|
||||||
|
@ -58,7 +58,7 @@ class MultiClassTargetColumnTest(tf.test.TestCase):
|
|||||||
labels = tf.constant([[1.], [0.]])
|
labels = tf.constant([[1.], [0.]])
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||||
self.assertAlmostEqual(.81326163,
|
self.assertAlmostEqual(0.81326175,
|
||||||
sess.run(target_column.loss(logits, labels, {})))
|
sess.run(target_column.loss(logits, labels, {})))
|
||||||
|
|
||||||
def testBinaryClassificationWithWeights(self):
|
def testBinaryClassificationWithWeights(self):
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.contrib.util import loader
|
|||||||
from tensorflow.python.framework import common_shapes
|
from tensorflow.python.framework import common_shapes
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
|
|
||||||
@ -69,12 +70,14 @@ def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0,
|
|||||||
"""
|
"""
|
||||||
if not isinstance(inputs, list):
|
if not isinstance(inputs, list):
|
||||||
raise TypeError("Inputs must be a list")
|
raise TypeError("Inputs must be a list")
|
||||||
if not all(isinstance(i, ops.SparseTensor) or
|
if not all(isinstance(i, sparse_tensor.SparseTensor) or
|
||||||
isinstance(i, ops.Tensor) for i in inputs):
|
isinstance(i, ops.Tensor) for i in inputs):
|
||||||
raise TypeError("All inputs must be SparseTensors")
|
raise TypeError("All inputs must be SparseTensors")
|
||||||
|
|
||||||
sparse_inputs = [i for i in inputs if isinstance(i, ops.SparseTensor)]
|
sparse_inputs = [i for i in inputs
|
||||||
dense_inputs = [i for i in inputs if not isinstance(i, ops.SparseTensor)]
|
if isinstance(i, sparse_tensor.SparseTensor)]
|
||||||
|
dense_inputs = [i for i in inputs
|
||||||
|
if not isinstance(i, sparse_tensor.SparseTensor)]
|
||||||
|
|
||||||
indices = [sp_input.indices for sp_input in sparse_inputs]
|
indices = [sp_input.indices for sp_input in sparse_inputs]
|
||||||
values = [sp_input.values for sp_input in sparse_inputs]
|
values = [sp_input.values for sp_input in sparse_inputs]
|
||||||
@ -117,7 +120,7 @@ def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0,
|
|||||||
internal_type=internal_type,
|
internal_type=internal_type,
|
||||||
name=name))
|
name=name))
|
||||||
|
|
||||||
return ops.SparseTensor(indices_out, values_out, shape_out)
|
return sparse_tensor.SparseTensor(indices_out, values_out, shape_out)
|
||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("SparseFeatureCross")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("SparseFeatureCross")(common_shapes.call_cpp_shape_fn)
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
@ -78,4 +79,4 @@ def dense_to_sparse_tensor(dense_tensor, ignore_value=None):
|
|||||||
math_ops.mul(higher_dims, shape_multipliers), reduction_indices=[1])
|
math_ops.mul(higher_dims, shape_multipliers), reduction_indices=[1])
|
||||||
flat_indices = math_ops.add(flat_indices, offsets)
|
flat_indices = math_ops.add(flat_indices, offsets)
|
||||||
values = array_ops.gather(flat_tensor, flat_indices)
|
values = array_ops.gather(flat_tensor, flat_indices)
|
||||||
return ops.SparseTensor(indices, values, dense_shape)
|
return sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||||
|
@ -291,7 +291,9 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":learn",
|
":learn",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:extra_py_tests_deps",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:test_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,11 +22,12 @@ from __future__ import print_function
|
|||||||
from tensorflow.contrib.layers import feature_column
|
from tensorflow.contrib.layers import feature_column
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import series as ss
|
from tensorflow.contrib.learn.python.learn.dataframe import series as ss
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
|
|
||||||
|
|
||||||
def _to_feature_spec(tensor, default_value=None):
|
def _to_feature_spec(tensor, default_value=None):
|
||||||
if isinstance(tensor, ops.SparseTensor):
|
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||||
return parsing_ops.VarLenFeature(dtype=tensor.dtype)
|
return parsing_ops.VarLenFeature(dtype=tensor.dtype)
|
||||||
else:
|
else:
|
||||||
return parsing_ops.FixedLenFeature(shape=tensor.get_shape(),
|
return parsing_ops.FixedLenFeature(shape=tensor.get_shape(),
|
||||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import series
|
from tensorflow.contrib.learn.python.learn.dataframe import series
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
# Each entry is a mapping from registered_name to operation. Each operation is
|
# Each entry is a mapping from registered_name to operation. Each operation is
|
||||||
@ -55,8 +55,8 @@ class SeriesBinaryTransform(transform.TensorFlowTransform):
|
|||||||
|
|
||||||
def _apply_transform(self, input_tensors, **kwargs):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
# TODO(jamieas): consider supporting sparse inputs.
|
# TODO(jamieas): consider supporting sparse inputs.
|
||||||
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
|
if isinstance(input_tensors[0], sparse_tensor.SparseTensor) or isinstance(
|
||||||
input_tensors[1], ops.SparseTensor):
|
input_tensors[1], sparse_tensor.SparseTensor):
|
||||||
raise TypeError("{} does not support SparseTensors".format(
|
raise TypeError("{} does not support SparseTensors".format(
|
||||||
type(self).__name__))
|
type(self).__name__))
|
||||||
|
|
||||||
@ -89,8 +89,8 @@ class ScalarBinaryTransform(transform.TensorFlowTransform):
|
|||||||
|
|
||||||
def _apply_transform(self, input_tensors, **kwargs):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
input_tensor = input_tensors[0]
|
input_tensor = input_tensors[0]
|
||||||
if isinstance(input_tensor, ops.SparseTensor):
|
if isinstance(input_tensor, sparse_tensor.SparseTensor):
|
||||||
result = ops.SparseTensor(input_tensor.indices,
|
result = sparse_tensor.SparseTensor(input_tensor.indices,
|
||||||
self._apply_op(input_tensor.values),
|
self._apply_op(input_tensor.values),
|
||||||
input_tensor.shape)
|
input_tensor.shape)
|
||||||
else:
|
else:
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.contrib.learn.python.learn.dataframe import series
|
|||||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -93,7 +94,7 @@ class BooleanMask(transform.TensorFlowTransform):
|
|||||||
if mask.get_shape().ndims > 1:
|
if mask.get_shape().ndims > 1:
|
||||||
mask = array_ops.squeeze(mask)
|
mask = array_ops.squeeze(mask)
|
||||||
|
|
||||||
if isinstance(input_tensor, ops.SparseTensor):
|
if isinstance(input_tensor, sparse_tensor_py.SparseTensor):
|
||||||
mask_fn = sparse_boolean_mask
|
mask_fn = sparse_boolean_mask
|
||||||
else:
|
else:
|
||||||
mask_fn = array_ops.boolean_mask
|
mask_fn = array_ops.boolean_mask
|
||||||
|
@ -21,14 +21,14 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import series
|
from tensorflow.contrib.learn.python.learn.dataframe import series
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
|
|
||||||
|
|
||||||
def _negate_sparse(sparse_tensor):
|
def _negate_sparse(st):
|
||||||
return ops.SparseTensor(indices=sparse_tensor.indices,
|
return sparse_tensor.SparseTensor(indices=st.indices,
|
||||||
values=-sparse_tensor.values,
|
values=-st.values,
|
||||||
shape=sparse_tensor.shape)
|
shape=st.shape)
|
||||||
|
|
||||||
|
|
||||||
@series.Series.register_binary_op("__sub__")
|
@series.Series.register_binary_op("__sub__")
|
||||||
@ -51,8 +51,8 @@ class Difference(transform.TensorFlowTransform):
|
|||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors, **kwargs):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
pair_sparsity = (isinstance(input_tensors[0], sparse_tensor.SparseTensor),
|
||||||
isinstance(input_tensors[1], ops.SparseTensor))
|
isinstance(input_tensors[1], sparse_tensor.SparseTensor))
|
||||||
|
|
||||||
if pair_sparsity == (False, False):
|
if pair_sparsity == (False, False):
|
||||||
result = input_tensors[0] - input_tensors[1]
|
result = input_tensors[0] - input_tensors[1]
|
||||||
|
@ -24,7 +24,7 @@ import numpy as np
|
|||||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
@ -82,4 +82,5 @@ class Sparsify(transform.TensorFlowTransform):
|
|||||||
shape = math_ops.cast(array_ops.shape(d), dtypes.int64)
|
shape = math_ops.cast(array_ops.shape(d), dtypes.int64)
|
||||||
|
|
||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
return self.return_type(ops.SparseTensor(sparse_indices, values, shape))
|
return self.return_type(
|
||||||
|
sparse_tensor.SparseTensor(sparse_indices, values, shape))
|
||||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import series
|
from tensorflow.contrib.learn.python.learn.dataframe import series
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
|
|
||||||
|
|
||||||
@ -45,8 +45,8 @@ class Sum(transform.TensorFlowTransform):
|
|||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors, **kwargs):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
pair_sparsity = (isinstance(input_tensors[0], sparse_tensor.SparseTensor),
|
||||||
isinstance(input_tensors[1], ops.SparseTensor))
|
isinstance(input_tensors[1], sparse_tensor.SparseTensor))
|
||||||
|
|
||||||
if pair_sparsity == (False, False):
|
if pair_sparsity == (False, False):
|
||||||
result = input_tensors[0] + input_tensors[1]
|
result = input_tensors[0] + input_tensors[1]
|
||||||
@ -57,6 +57,3 @@ class Sum(transform.TensorFlowTransform):
|
|||||||
|
|
||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
return self.return_type(result)
|
return self.return_type(result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import series
|
from tensorflow.contrib.learn.python.learn.dataframe import series
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
# Each entry is a mapping from registered_name to operation. Each operation is
|
# Each entry is a mapping from registered_name to operation. Each operation is
|
||||||
@ -83,8 +83,8 @@ def register_unary_op(registered_name, operation, ignore_dtype=None):
|
|||||||
|
|
||||||
def _apply_transform(self, input_tensors, **kwargs):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
input_tensor = input_tensors[0]
|
input_tensor = input_tensors[0]
|
||||||
if isinstance(input_tensor, ops.SparseTensor):
|
if isinstance(input_tensor, sparse_tensor.SparseTensor):
|
||||||
result = ops.SparseTensor(input_tensor.indices,
|
result = sparse_tensor.SparseTensor(input_tensor.indices,
|
||||||
operation(input_tensor.values),
|
operation(input_tensor.values),
|
||||||
input_tensor.shape)
|
input_tensor.shape)
|
||||||
else:
|
else:
|
||||||
|
@ -29,11 +29,11 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import MetricKey
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import PredictionKey
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
|
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
|
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
|
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
|
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
|
||||||
|
@ -19,7 +19,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
import math
|
||||||
|
import re
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
@ -27,13 +28,23 @@ from tensorflow.contrib.framework import deprecated
|
|||||||
from tensorflow.contrib.framework import deprecated_arg_values
|
from tensorflow.contrib.framework import deprecated_arg_values
|
||||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||||
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
||||||
|
from tensorflow.contrib.layers.python.layers import optimizers
|
||||||
|
from tensorflow.contrib.learn.python.learn import evaluable
|
||||||
|
from tensorflow.contrib.learn.python.learn import session_run_hook
|
||||||
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import composable_model
|
from tensorflow.contrib.learn.python.learn.estimators import composable_model
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
|
from tensorflow.contrib.learn.python.learn.utils import export
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import logging_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
|
from tensorflow.python.ops import partitioned_variables
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
|
||||||
|
|
||||||
class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||||
@ -307,7 +318,236 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
|
_CENTERED_BIAS_WEIGHT = "centered_bias_weight"
|
||||||
|
|
||||||
|
# The default learning rates are a historical artifact of the initial
|
||||||
|
# implementation, but seem a reasonable choice.
|
||||||
|
_DNN_LEARNING_RATE = 0.05
|
||||||
|
_LINEAR_LEARNING_RATE = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
def _as_iterable(preds, output):
|
||||||
|
for pred in preds:
|
||||||
|
yield pred[output]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_feature_dict(features):
|
||||||
|
if isinstance(features, dict):
|
||||||
|
return features
|
||||||
|
return {"": features}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_optimizer(optimizer):
|
||||||
|
if callable(optimizer):
|
||||||
|
return optimizer()
|
||||||
|
else:
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def _linear_learning_rate(num_linear_feature_columns):
|
||||||
|
"""Returns the default learning rate of the linear model.
|
||||||
|
|
||||||
|
The calculation is a historical artifact of this initial implementation, but
|
||||||
|
has proven a reasonable choice.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_linear_feature_columns: The number of feature columns of the linear
|
||||||
|
model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A float.
|
||||||
|
"""
|
||||||
|
default_learning_rate = 1. / math.sqrt(num_linear_feature_columns)
|
||||||
|
return min(_LINEAR_LEARNING_RATE, default_learning_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_hidden_layer_summary(value, tag):
|
||||||
|
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
|
||||||
|
nn.zero_fraction(value))
|
||||||
|
logging_ops.histogram_summary("%s:activation" % tag, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _dnn_linear_combined_model_fn(features, labels, mode, params):
|
||||||
|
"""Deep Neural Net and Linear combined model_fn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`).
|
||||||
|
labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype
|
||||||
|
`int32` or `int64` in the range `[0, n_classes)`.
|
||||||
|
mode: Defines whether this is training, evaluation or prediction.
|
||||||
|
See `ModeKeys`.
|
||||||
|
params: A dict of hyperparameters.
|
||||||
|
The following hyperparameters are expected:
|
||||||
|
* head: A `Head` instance.
|
||||||
|
* linear_feature_columns: An iterable containing all the feature columns
|
||||||
|
used by the Linear model.
|
||||||
|
* linear_optimizer: string, `Optimizer` object, or callable that defines
|
||||||
|
the optimizer to use for training the Linear model.
|
||||||
|
* joint_linear_weights: If True a single (possibly partitioned) variable
|
||||||
|
will be used to store the linear model weights. It's faster, but
|
||||||
|
requires all columns are sparse and have the 'sum' combiner.
|
||||||
|
* dnn_feature_columns: An iterable containing all the feature columns used
|
||||||
|
by the DNN model.
|
||||||
|
* dnn_optimizer: string, `Optimizer` object, or callable that defines the
|
||||||
|
optimizer to use for training the DNN model.
|
||||||
|
* dnn_hidden_units: List of hidden units per DNN layer.
|
||||||
|
* dnn_activation_fn: Activation function applied to each DNN layer. If
|
||||||
|
`None`, will use `tf.nn.relu`.
|
||||||
|
* dnn_dropout: When not `None`, the probability we will drop out a given
|
||||||
|
DNN coordinate.
|
||||||
|
* gradient_clip_norm: A float > 0. If provided, gradients are
|
||||||
|
clipped to their global norm with this clipping ratio.
|
||||||
|
* num_ps_replicas: The number of parameter server replicas.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`estimator.ModelFnOps`
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
|
||||||
|
are empty at the same time.
|
||||||
|
"""
|
||||||
|
head = params["head"]
|
||||||
|
linear_feature_columns = params.get("linear_feature_columns")
|
||||||
|
linear_optimizer = params.get("linear_optimizer")
|
||||||
|
joint_linear_weights = params.get("joint_linear_weights")
|
||||||
|
dnn_feature_columns = params.get("dnn_feature_columns")
|
||||||
|
dnn_optimizer = params.get("dnn_optimizer")
|
||||||
|
dnn_hidden_units = params.get("dnn_hidden_units")
|
||||||
|
dnn_activation_fn = params.get("dnn_activation_fn")
|
||||||
|
dnn_dropout = params.get("dnn_dropout")
|
||||||
|
gradient_clip_norm = params.get("gradient_clip_norm")
|
||||||
|
num_ps_replicas = params["num_ps_replicas"]
|
||||||
|
|
||||||
|
if not linear_feature_columns and not dnn_feature_columns:
|
||||||
|
raise ValueError(
|
||||||
|
"Either linear_feature_columns or dnn_feature_columns must be defined.")
|
||||||
|
|
||||||
|
features = _get_feature_dict(features)
|
||||||
|
|
||||||
|
# Build DNN Logits.
|
||||||
|
dnn_parent_scope = "dnn"
|
||||||
|
|
||||||
|
if not dnn_feature_columns:
|
||||||
|
dnn_logits = None
|
||||||
|
else:
|
||||||
|
input_layer_partitioner = (
|
||||||
|
partitioned_variables.min_max_variable_partitioner(
|
||||||
|
max_partitions=num_ps_replicas,
|
||||||
|
min_slice_size=64 << 20))
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
dnn_parent_scope + "/input_from_feature_columns",
|
||||||
|
values=features.values(),
|
||||||
|
partitioner=input_layer_partitioner) as scope:
|
||||||
|
net = layers.input_from_feature_columns(
|
||||||
|
columns_to_tensors=features,
|
||||||
|
feature_columns=dnn_feature_columns,
|
||||||
|
weight_collections=[dnn_parent_scope],
|
||||||
|
scope=scope)
|
||||||
|
|
||||||
|
hidden_layer_partitioner = (
|
||||||
|
partitioned_variables.min_max_variable_partitioner(
|
||||||
|
max_partitions=num_ps_replicas))
|
||||||
|
for layer_id, num_hidden_units in enumerate(dnn_hidden_units):
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
dnn_parent_scope + "/hiddenlayer_%d" % layer_id,
|
||||||
|
values=[net],
|
||||||
|
partitioner=hidden_layer_partitioner) as scope:
|
||||||
|
net = layers.fully_connected(
|
||||||
|
net,
|
||||||
|
num_hidden_units,
|
||||||
|
activation_fn=dnn_activation_fn,
|
||||||
|
variables_collections=[dnn_parent_scope],
|
||||||
|
scope=scope)
|
||||||
|
if dnn_dropout is not None and mode == estimator.ModeKeys.TRAIN:
|
||||||
|
net = layers.dropout(
|
||||||
|
net,
|
||||||
|
keep_prob=(1.0 - dnn_dropout))
|
||||||
|
# TODO(b/31209633): Consider adding summary before dropout.
|
||||||
|
_add_hidden_layer_summary(net, scope.name)
|
||||||
|
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
dnn_parent_scope + "/logits",
|
||||||
|
values=[net],
|
||||||
|
partitioner=hidden_layer_partitioner) as scope:
|
||||||
|
dnn_logits = layers.fully_connected(
|
||||||
|
net,
|
||||||
|
head.logits_dimension,
|
||||||
|
activation_fn=None,
|
||||||
|
variables_collections=[dnn_parent_scope],
|
||||||
|
scope=scope)
|
||||||
|
_add_hidden_layer_summary(dnn_logits, scope.name)
|
||||||
|
|
||||||
|
# Build Linear logits.
|
||||||
|
linear_parent_scope = "linear"
|
||||||
|
|
||||||
|
if not linear_feature_columns:
|
||||||
|
linear_logits = None
|
||||||
|
else:
|
||||||
|
linear_partitioner = partitioned_variables.min_max_variable_partitioner(
|
||||||
|
max_partitions=num_ps_replicas,
|
||||||
|
min_slice_size=64 << 20)
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
linear_parent_scope,
|
||||||
|
values=features.values(),
|
||||||
|
partitioner=linear_partitioner) as scope:
|
||||||
|
if joint_linear_weights:
|
||||||
|
linear_logits, _, _ = layers.joint_weighted_sum_from_feature_columns(
|
||||||
|
columns_to_tensors=features,
|
||||||
|
feature_columns=linear_feature_columns,
|
||||||
|
num_outputs=head.logits_dimension,
|
||||||
|
weight_collections=[linear_parent_scope],
|
||||||
|
scope=scope)
|
||||||
|
else:
|
||||||
|
linear_logits, _, _ = layers.weighted_sum_from_feature_columns(
|
||||||
|
columns_to_tensors=features,
|
||||||
|
feature_columns=linear_feature_columns,
|
||||||
|
num_outputs=head.logits_dimension,
|
||||||
|
weight_collections=[linear_parent_scope],
|
||||||
|
scope=scope)
|
||||||
|
|
||||||
|
# Combine logits and build full model.
|
||||||
|
if dnn_logits is not None and linear_logits is not None:
|
||||||
|
logits = dnn_logits + linear_logits
|
||||||
|
elif dnn_logits is not None:
|
||||||
|
logits = dnn_logits
|
||||||
|
else:
|
||||||
|
logits = linear_logits
|
||||||
|
|
||||||
|
def _make_training_op(training_loss):
|
||||||
|
"""Training op for the DNN linear combined model."""
|
||||||
|
train_ops = []
|
||||||
|
if dnn_logits is not None:
|
||||||
|
train_ops.append(
|
||||||
|
optimizers.optimize_loss(
|
||||||
|
loss=training_loss,
|
||||||
|
global_step=contrib_variables.get_global_step(),
|
||||||
|
learning_rate=_DNN_LEARNING_RATE,
|
||||||
|
optimizer=_get_optimizer(dnn_optimizer),
|
||||||
|
clip_gradients=gradient_clip_norm,
|
||||||
|
variables=ops.get_collection(dnn_parent_scope),
|
||||||
|
name=dnn_parent_scope,
|
||||||
|
# Empty summaries, because head already logs "loss" summary.
|
||||||
|
summaries=[]))
|
||||||
|
if linear_logits is not None:
|
||||||
|
train_ops.append(
|
||||||
|
optimizers.optimize_loss(
|
||||||
|
loss=training_loss,
|
||||||
|
global_step=contrib_variables.get_global_step(),
|
||||||
|
learning_rate=_linear_learning_rate(len(linear_feature_columns)),
|
||||||
|
optimizer=_get_optimizer(linear_optimizer),
|
||||||
|
clip_gradients=gradient_clip_norm,
|
||||||
|
variables=ops.get_collection(linear_parent_scope),
|
||||||
|
name=linear_parent_scope,
|
||||||
|
# Empty summaries, because head already logs "loss" summary.
|
||||||
|
summaries=[]))
|
||||||
|
|
||||||
|
return control_flow_ops.group(*train_ops)
|
||||||
|
|
||||||
|
return head.head_ops(
|
||||||
|
features, labels, mode, _make_training_op, logits=logits)
|
||||||
|
|
||||||
|
|
||||||
|
class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||||
"""A classifier for TensorFlow Linear and DNN joined training models.
|
"""A classifier for TensorFlow Linear and DNN joined training models.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -423,30 +663,71 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
|
|||||||
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
|
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
|
||||||
are empty at the same time.
|
are empty at the same time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if n_classes < 2:
|
if n_classes < 2:
|
||||||
raise ValueError("n_classes should be greater than 1. Given: {}".format(
|
raise ValueError("n_classes should be greater than 1. Given: {}".format(
|
||||||
n_classes))
|
n_classes))
|
||||||
|
self._linear_optimizer = linear_optimizer or "Ftrl"
|
||||||
|
linear_feature_columns = linear_feature_columns or []
|
||||||
|
dnn_feature_columns = dnn_feature_columns or []
|
||||||
|
self._feature_columns = linear_feature_columns + dnn_feature_columns
|
||||||
|
if not self._feature_columns:
|
||||||
|
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
|
||||||
|
"must be defined.")
|
||||||
|
self._dnn_hidden_units = dnn_hidden_units
|
||||||
|
self._enable_centered_bias = enable_centered_bias
|
||||||
|
|
||||||
head = head_lib._multi_class_head( # pylint: disable=protected-access
|
head = head_lib._multi_class_head( # pylint: disable=protected-access
|
||||||
n_classes=n_classes,
|
n_classes=n_classes,
|
||||||
weight_column_name=weight_column_name,
|
weight_column_name=weight_column_name,
|
||||||
enable_centered_bias=enable_centered_bias)
|
enable_centered_bias=enable_centered_bias)
|
||||||
super(DNNLinearCombinedClassifier, self).__init__(
|
self._estimator = estimator.Estimator(
|
||||||
|
model_fn=_dnn_linear_combined_model_fn,
|
||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
linear_feature_columns=linear_feature_columns,
|
|
||||||
linear_optimizer=linear_optimizer,
|
|
||||||
_joint_linear_weights=_joint_linear_weights,
|
|
||||||
dnn_feature_columns=dnn_feature_columns,
|
|
||||||
dnn_optimizer=dnn_optimizer,
|
|
||||||
dnn_hidden_units=dnn_hidden_units,
|
|
||||||
dnn_activation_fn=dnn_activation_fn,
|
|
||||||
dnn_dropout=dnn_dropout,
|
|
||||||
gradient_clip_norm=gradient_clip_norm,
|
|
||||||
head=head,
|
|
||||||
config=config,
|
config=config,
|
||||||
feature_engineering_fn=feature_engineering_fn,
|
params={
|
||||||
default_prediction_key=head_lib.PredictionKey.CLASSES,
|
"head": head,
|
||||||
enable_centered_bias=enable_centered_bias)
|
"linear_feature_columns": linear_feature_columns,
|
||||||
|
"linear_optimizer": self._linear_optimizer,
|
||||||
|
"joint_linear_weights": _joint_linear_weights,
|
||||||
|
"dnn_feature_columns": dnn_feature_columns,
|
||||||
|
"dnn_optimizer": dnn_optimizer or "Adagrad",
|
||||||
|
"dnn_hidden_units": dnn_hidden_units,
|
||||||
|
"dnn_activation_fn": dnn_activation_fn,
|
||||||
|
"dnn_dropout": dnn_dropout,
|
||||||
|
"gradient_clip_norm": gradient_clip_norm,
|
||||||
|
"num_ps_replicas": config.num_ps_replicas if config else 0,
|
||||||
|
},
|
||||||
|
feature_engineering_fn=feature_engineering_fn)
|
||||||
|
|
||||||
|
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
|
||||||
|
monitors=None, max_steps=None):
|
||||||
|
"""See trainable.Trainable."""
|
||||||
|
# TODO(roumposg): Remove when deprecated monitors are removed.
|
||||||
|
if monitors is not None:
|
||||||
|
deprecated_monitors = [
|
||||||
|
m for m in monitors
|
||||||
|
if not isinstance(m, session_run_hook.SessionRunHook)
|
||||||
|
]
|
||||||
|
for monitor in deprecated_monitors:
|
||||||
|
monitor.set_estimator(self)
|
||||||
|
monitor._lock_estimator() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
result = self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps,
|
||||||
|
batch_size=batch_size, monitors=monitors,
|
||||||
|
max_steps=max_steps)
|
||||||
|
|
||||||
|
if monitors is not None:
|
||||||
|
for monitor in deprecated_monitors:
|
||||||
|
monitor._unlock_estimator() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
|
||||||
|
batch_size=None, steps=None, metrics=None, name=None):
|
||||||
|
"""See evaluable.Evaluable."""
|
||||||
|
return self._estimator.evaluate(
|
||||||
|
x=x, y=y, input_fn=input_fn, feed_fn=feed_fn, batch_size=batch_size,
|
||||||
|
steps=steps, metrics=metrics, name=name)
|
||||||
|
|
||||||
@deprecated_arg_values(
|
@deprecated_arg_values(
|
||||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||||
@ -467,12 +748,16 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
|
|||||||
Numpy array of predicted classes (or an iterable of predicted classes if
|
Numpy array of predicted classes (or an iterable of predicted classes if
|
||||||
as_iterable is True).
|
as_iterable is True).
|
||||||
"""
|
"""
|
||||||
predictions = self.predict_proba(
|
key = prediction_key.PredictionKey.CLASSES
|
||||||
x=x, input_fn=input_fn, batch_size=batch_size, as_iterable=as_iterable)
|
preds = self._estimator.predict(
|
||||||
|
x=x,
|
||||||
|
input_fn=input_fn,
|
||||||
|
batch_size=batch_size,
|
||||||
|
outputs=[key],
|
||||||
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return (np.argmax(p, axis=0) for p in predictions)
|
return _as_iterable(preds, output=key)
|
||||||
else:
|
return preds[key].reshape(-1)
|
||||||
return np.argmax(predictions, axis=1)
|
|
||||||
|
|
||||||
@deprecated_arg_values(
|
@deprecated_arg_values(
|
||||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||||
@ -494,13 +779,133 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
|
|||||||
Numpy array of predicted probabilities (or an iterable of predicted
|
Numpy array of predicted probabilities (or an iterable of predicted
|
||||||
probabilities if as_iterable is True).
|
probabilities if as_iterable is True).
|
||||||
"""
|
"""
|
||||||
return super(DNNLinearCombinedClassifier, self).predict(
|
key = prediction_key.PredictionKey.PROBABILITIES
|
||||||
x=x, input_fn=input_fn, batch_size=batch_size, as_iterable=as_iterable)
|
preds = self._estimator.predict(
|
||||||
|
x=x,
|
||||||
|
input_fn=input_fn,
|
||||||
|
batch_size=batch_size,
|
||||||
|
outputs=[key],
|
||||||
|
as_iterable=as_iterable)
|
||||||
|
if as_iterable:
|
||||||
|
return _as_iterable(preds, output=key)
|
||||||
|
return preds[key]
|
||||||
|
|
||||||
def _get_predict_ops(self, features):
|
def _get_predict_ops(self, features):
|
||||||
"""See base class."""
|
"""See `Estimator` class."""
|
||||||
return super(DNNLinearCombinedClassifier, self)._get_predict_ops(features)[
|
# pylint: disable=protected-access
|
||||||
head_lib.PredictionKey.PROBABILITIES]
|
return self._estimator._get_predict_ops(features)[
|
||||||
|
prediction_key.PredictionKey.PROBABILITIES]
|
||||||
|
|
||||||
|
def get_variable_names(self):
|
||||||
|
"""Returns list of all variable names in this model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of names.
|
||||||
|
"""
|
||||||
|
return self._estimator.get_variable_names()
|
||||||
|
|
||||||
|
def get_variable_value(self, name):
|
||||||
|
"""Returns value of the variable given by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: string, name of the tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` object.
|
||||||
|
"""
|
||||||
|
return self._estimator.get_variable_value(name)
|
||||||
|
|
||||||
|
def export(self,
|
||||||
|
export_dir,
|
||||||
|
input_fn=None,
|
||||||
|
input_feature_key=None,
|
||||||
|
use_deprecated_input_fn=True,
|
||||||
|
signature_fn=None,
|
||||||
|
default_batch_size=1,
|
||||||
|
exports_to_keep=None):
|
||||||
|
"""See BasEstimator.export."""
|
||||||
|
def default_input_fn(unused_estimator, examples):
|
||||||
|
return layers.parse_feature_columns_from_examples(
|
||||||
|
examples, self._feature_columns)
|
||||||
|
self._estimator.export(
|
||||||
|
export_dir=export_dir,
|
||||||
|
input_fn=input_fn or default_input_fn,
|
||||||
|
input_feature_key=input_feature_key,
|
||||||
|
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||||
|
signature_fn=(signature_fn or
|
||||||
|
export.classification_signature_fn_with_prob),
|
||||||
|
prediction_key=prediction_key.PredictionKey.PROBABILITIES,
|
||||||
|
default_batch_size=default_batch_size,
|
||||||
|
exports_to_keep=exports_to_keep)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_dir(self):
|
||||||
|
return self._estimator.model_dir
|
||||||
|
|
||||||
|
@property
|
||||||
|
@deprecated("2016-10-30",
|
||||||
|
"This method will be removed after the deprecation date. "
|
||||||
|
"To inspect variables, use get_variable_names() and "
|
||||||
|
"get_variable_value().")
|
||||||
|
def dnn_weights_(self):
|
||||||
|
hiddenlayer_weights = [
|
||||||
|
self.get_variable_value("dnn/hiddenlayer_%d/weights" % i)
|
||||||
|
for i, _ in enumerate(self._dnn_hidden_units)
|
||||||
|
]
|
||||||
|
logits_weights = [self.get_variable_value("dnn/logits/weights")]
|
||||||
|
return hiddenlayer_weights + logits_weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
@deprecated("2016-10-30",
|
||||||
|
"This method will be removed after the deprecation date. "
|
||||||
|
"To inspect variables, use get_variable_names() and "
|
||||||
|
"get_variable_value().")
|
||||||
|
def linear_weights_(self):
|
||||||
|
values = {}
|
||||||
|
if isinstance(self._linear_optimizer, str):
|
||||||
|
optimizer_name = self._linear_optimizer
|
||||||
|
else:
|
||||||
|
optimizer_name = self._linear_optimizer.get_name()
|
||||||
|
optimizer_regex = r".*/"+optimizer_name + r"(_\d)?$"
|
||||||
|
for name in self.get_variable_names():
|
||||||
|
if (name.startswith("linear/") and
|
||||||
|
name != "linear/bias_weight" and
|
||||||
|
name != "linear/learning_rate" and
|
||||||
|
not re.match(optimizer_regex, name)):
|
||||||
|
values[name] = self.get_variable_value(name)
|
||||||
|
if len(values) == 1:
|
||||||
|
return values[list(values.keys())[0]]
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
@deprecated("2016-10-30",
|
||||||
|
"This method will be removed after the deprecation date. "
|
||||||
|
"To inspect variables, use get_variable_names() and "
|
||||||
|
"get_variable_value().")
|
||||||
|
def dnn_bias_(self):
|
||||||
|
hiddenlayer_bias = [self.get_variable_value("dnn/hiddenlayer_%d/biases" % i)
|
||||||
|
for i, _ in enumerate(self._dnn_hidden_units)]
|
||||||
|
logits_bias = [self.get_variable_value("dnn/logits/biases")]
|
||||||
|
if not self._enable_centered_bias:
|
||||||
|
return hiddenlayer_bias + logits_bias
|
||||||
|
centered_bias = [self.get_variable_value(_CENTERED_BIAS_WEIGHT)]
|
||||||
|
return hiddenlayer_bias + logits_bias + centered_bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
@deprecated("2016-10-30",
|
||||||
|
"This method will be removed after the deprecation date. "
|
||||||
|
"To inspect variables, use get_variable_names() and "
|
||||||
|
"get_variable_value().")
|
||||||
|
def linear_bias_(self):
|
||||||
|
linear_bias = self.get_variable_value("linear/bias_weight")
|
||||||
|
if not self._enable_centered_bias:
|
||||||
|
return linear_bias
|
||||||
|
centered_bias = [self.get_variable_value(_CENTERED_BIAS_WEIGHT)]
|
||||||
|
return linear_bias + centered_bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self):
|
||||||
|
return self._estimator.config
|
||||||
|
|
||||||
|
|
||||||
class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
|
class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
|
||||||
@ -642,12 +1047,11 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
|
|||||||
head=head,
|
head=head,
|
||||||
config=config,
|
config=config,
|
||||||
feature_engineering_fn=feature_engineering_fn,
|
feature_engineering_fn=feature_engineering_fn,
|
||||||
default_prediction_key=head_lib.PredictionKey.SCORES,
|
default_prediction_key=prediction_key.PredictionKey.SCORES,
|
||||||
enable_centered_bias=enable_centered_bias)
|
enable_centered_bias=enable_centered_bias)
|
||||||
|
|
||||||
def _get_predict_ops(self, features):
|
def _get_predict_ops(self, features):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return super(DNNLinearCombinedRegressor, self)._get_predict_ops(features)[
|
return super(
|
||||||
head_lib.PredictionKey.SCORES]
|
DNNLinearCombinedRegressor,
|
||||||
|
self)._get_predict_ops(features)[prediction_key.PredictionKey.SCORES]
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
|
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
|
||||||
|
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
|
||||||
|
|
||||||
|
|
||||||
def _get_quantile_based_buckets(feature_values, num_buckets):
|
def _get_quantile_based_buckets(feature_values, num_buckets):
|
||||||
@ -65,6 +66,15 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
estimator_test_utils.assert_estimator_contract(
|
estimator_test_utils.assert_estimator_contract(
|
||||||
self, tf.contrib.learn.DNNLinearCombinedClassifier)
|
self, tf.contrib.learn.DNNLinearCombinedClassifier)
|
||||||
|
|
||||||
|
def testNoFeatureColumns(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
'Either linear_feature_columns or dnn_feature_columns must be defined'):
|
||||||
|
tf.contrib.learn.DNNLinearCombinedClassifier(
|
||||||
|
linear_feature_columns=None,
|
||||||
|
dnn_feature_columns=None,
|
||||||
|
dnn_hidden_units=[3, 3])
|
||||||
|
|
||||||
def testLogisticRegression_MatrixData(self):
|
def testLogisticRegression_MatrixData(self):
|
||||||
"""Tests binary classification using matrix data as input."""
|
"""Tests binary classification using matrix data as input."""
|
||||||
iris = _prepare_iris_data_for_logistic_regression()
|
iris = _prepare_iris_data_for_logistic_regression()
|
||||||
@ -80,6 +90,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
|
|
||||||
classifier.fit(input_fn=_iris_input_logistic_fn, steps=100)
|
classifier.fit(input_fn=_iris_input_logistic_fn, steps=100)
|
||||||
scores = classifier.evaluate(input_fn=_iris_input_logistic_fn, steps=100)
|
scores = classifier.evaluate(input_fn=_iris_input_logistic_fn, steps=100)
|
||||||
|
self.assertIn('auc', scores.keys())
|
||||||
self.assertGreater(scores['accuracy'], 0.9)
|
self.assertGreater(scores['accuracy'], 0.9)
|
||||||
|
|
||||||
def testLogisticRegression_TensorData(self):
|
def testLogisticRegression_TensorData(self):
|
||||||
@ -120,6 +131,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
|
|
||||||
classifier.fit(input_fn=_input_fn, steps=100)
|
classifier.fit(input_fn=_input_fn, steps=100)
|
||||||
scores = classifier.evaluate(input_fn=_input_fn, steps=100)
|
scores = classifier.evaluate(input_fn=_input_fn, steps=100)
|
||||||
|
self.assertIn('auc', scores.keys())
|
||||||
self.assertGreater(scores['accuracy'], 0.9)
|
self.assertGreater(scores['accuracy'], 0.9)
|
||||||
|
|
||||||
def testTrainWithPartitionedVariables(self):
|
def testTrainWithPartitionedVariables(self):
|
||||||
@ -397,9 +409,15 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
input_fn=_input_fn,
|
input_fn=_input_fn,
|
||||||
steps=100,
|
steps=100,
|
||||||
metrics={
|
metrics={
|
||||||
'my_accuracy': tf.contrib.metrics.streaming_accuracy,
|
'my_accuracy': MetricSpec(
|
||||||
('my_precision', 'classes'): tf.contrib.metrics.streaming_precision,
|
metric_fn=tf.contrib.metrics.streaming_accuracy,
|
||||||
('my_metric', 'probabilities'): _my_metric_op
|
prediction_key='classes'),
|
||||||
|
'my_precision': MetricSpec(
|
||||||
|
metric_fn=tf.contrib.metrics.streaming_precision,
|
||||||
|
prediction_key='classes'),
|
||||||
|
'my_metric': MetricSpec(
|
||||||
|
metric_fn=_my_metric_op,
|
||||||
|
prediction_key='probabilities')
|
||||||
})
|
})
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
set(['loss', 'my_accuracy', 'my_precision', 'my_metric'
|
set(['loss', 'my_accuracy', 'my_precision', 'my_metric'
|
||||||
@ -412,7 +430,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# Test the case where the 2nd element of the key is neither "classes" nor
|
# Test the case where the 2nd element of the key is neither "classes" nor
|
||||||
# "probabilities".
|
# "probabilities".
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaisesRegexp(KeyError, 'bad_type'):
|
||||||
classifier.evaluate(
|
classifier.evaluate(
|
||||||
input_fn=_input_fn,
|
input_fn=_input_fn,
|
||||||
steps=100,
|
steps=100,
|
||||||
@ -428,6 +446,17 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
tf.contrib.metrics.streaming_accuracy
|
tf.contrib.metrics.streaming_accuracy
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Test the case where the prediction_key is neither "classes" nor
|
||||||
|
# "probabilities".
|
||||||
|
with self.assertRaisesRegexp(KeyError, 'bad_type'):
|
||||||
|
classifier.evaluate(
|
||||||
|
input_fn=_input_fn,
|
||||||
|
steps=100,
|
||||||
|
metrics={
|
||||||
|
'bad_name': MetricSpec(
|
||||||
|
metric_fn=tf.contrib.metrics.streaming_auc,
|
||||||
|
prediction_key='bad_type')})
|
||||||
|
|
||||||
def testVariableQuery(self):
|
def testVariableQuery(self):
|
||||||
"""Tests bias is centered or not."""
|
"""Tests bias is centered or not."""
|
||||||
def _input_fn_train():
|
def _input_fn_train():
|
||||||
@ -447,6 +476,39 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
for name in var_names:
|
for name in var_names:
|
||||||
classifier.get_variable_value(name)
|
classifier.get_variable_value(name)
|
||||||
|
|
||||||
|
def testExport(self):
|
||||||
|
"""Tests export model for servo."""
|
||||||
|
|
||||||
|
def input_fn():
|
||||||
|
return {
|
||||||
|
'age': tf.constant([1]),
|
||||||
|
'language': tf.SparseTensor(values=['english'],
|
||||||
|
indices=[[0, 0]],
|
||||||
|
shape=[1, 1])
|
||||||
|
}, tf.constant([[1]])
|
||||||
|
|
||||||
|
language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100)
|
||||||
|
|
||||||
|
classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
|
||||||
|
linear_feature_columns=[
|
||||||
|
tf.contrib.layers.real_valued_column('age'),
|
||||||
|
language,
|
||||||
|
],
|
||||||
|
dnn_feature_columns=[
|
||||||
|
tf.contrib.layers.embedding_column(language, dimension=1),
|
||||||
|
],
|
||||||
|
dnn_hidden_units=[3, 3])
|
||||||
|
classifier.fit(input_fn=input_fn, steps=100)
|
||||||
|
|
||||||
|
export_dir = tempfile.mkdtemp()
|
||||||
|
input_feature_key = 'examples'
|
||||||
|
def serving_input_fn():
|
||||||
|
features, targets = input_fn()
|
||||||
|
features[input_feature_key] = tf.placeholder(tf.string)
|
||||||
|
return features, targets
|
||||||
|
classifier.export(export_dir, serving_input_fn, input_feature_key,
|
||||||
|
use_deprecated_input_fn=False)
|
||||||
|
|
||||||
def testCenteredBias(self):
|
def testCenteredBias(self):
|
||||||
"""Tests bias is centered or not."""
|
"""Tests bias is centered or not."""
|
||||||
def _input_fn_train():
|
def _input_fn_train():
|
||||||
@ -461,7 +523,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
dnn_hidden_units=[3, 3],
|
dnn_hidden_units=[3, 3],
|
||||||
enable_centered_bias=True)
|
enable_centered_bias=True)
|
||||||
|
|
||||||
classifier.fit(input_fn=_input_fn_train, steps=500)
|
classifier.fit(input_fn=_input_fn_train, steps=1000)
|
||||||
# logodds(0.75) = 1.09861228867
|
# logodds(0.75) = 1.09861228867
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
1.0986,
|
1.0986,
|
||||||
@ -483,7 +545,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
enable_centered_bias=False)
|
enable_centered_bias=False)
|
||||||
|
|
||||||
classifier.fit(input_fn=_input_fn_train, steps=500)
|
classifier.fit(input_fn=_input_fn_train, steps=500)
|
||||||
self.assertFalse('centered_bias_weight' in classifier.get_variable_names())
|
self.assertNotIn('centered_bias_weight', classifier.get_variable_names())
|
||||||
|
|
||||||
def testLinearOnly(self):
|
def testLinearOnly(self):
|
||||||
"""Tests that linear-only instantiation works."""
|
"""Tests that linear-only instantiation works."""
|
||||||
@ -822,6 +884,44 @@ class DNNLinearCombinedRegressorTest(tf.test.TestCase):
|
|||||||
metrics={('my_error', 'predictions'
|
metrics={('my_error', 'predictions'
|
||||||
): tf.contrib.metrics.streaming_mean_squared_error})
|
): tf.contrib.metrics.streaming_mean_squared_error})
|
||||||
|
|
||||||
|
def testExport(self):
|
||||||
|
"""Tests export model for servo."""
|
||||||
|
labels = [1., 0., 0.2]
|
||||||
|
def _input_fn(num_epochs=None):
|
||||||
|
features = {
|
||||||
|
'age': tf.train.limit_epochs(tf.constant([[0.8], [0.15], [0.]]),
|
||||||
|
num_epochs=num_epochs),
|
||||||
|
'language': tf.SparseTensor(values=['en', 'fr', 'zh'],
|
||||||
|
indices=[[0, 0], [0, 1], [2, 0]],
|
||||||
|
shape=[3, 2])
|
||||||
|
}
|
||||||
|
return features, tf.constant(labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
language_column = tf.contrib.layers.sparse_column_with_hash_bucket(
|
||||||
|
'language', hash_bucket_size=20)
|
||||||
|
|
||||||
|
regressor = tf.contrib.learn.DNNLinearCombinedRegressor(
|
||||||
|
linear_feature_columns=[
|
||||||
|
language_column,
|
||||||
|
tf.contrib.layers.real_valued_column('age')
|
||||||
|
],
|
||||||
|
dnn_feature_columns=[
|
||||||
|
tf.contrib.layers.embedding_column(language_column, dimension=1),
|
||||||
|
],
|
||||||
|
dnn_hidden_units=[3, 3],
|
||||||
|
config=tf.contrib.learn.RunConfig(tf_random_seed=1))
|
||||||
|
|
||||||
|
regressor.fit(input_fn=_input_fn, steps=100)
|
||||||
|
|
||||||
|
export_dir = tempfile.mkdtemp()
|
||||||
|
input_feature_key = 'examples'
|
||||||
|
def serving_input_fn():
|
||||||
|
features, targets = _input_fn()
|
||||||
|
features[input_feature_key] = tf.placeholder(tf.string)
|
||||||
|
return features, targets
|
||||||
|
regressor.export(export_dir, serving_input_fn, input_feature_key,
|
||||||
|
use_deprecated_input_fn=False)
|
||||||
|
|
||||||
def testTrainSaveLoad(self):
|
def testTrainSaveLoad(self):
|
||||||
"""Tests regression with restarting training / evaluate."""
|
"""Tests regression with restarting training / evaluate."""
|
||||||
def _input_fn(num_epochs=None):
|
def _input_fn(num_epochs=None):
|
||||||
|
@ -45,6 +45,7 @@ from tensorflow.contrib.learn.python.learn import metric_spec
|
|||||||
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||||
from tensorflow.contrib.learn.python.learn import trainable
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
|
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
|
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
|
||||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
||||||
@ -1108,8 +1109,9 @@ class Estimator(BaseEstimator):
|
|||||||
|
|
||||||
result = _make_metrics_ops(all_metrics, features, labels,
|
result = _make_metrics_ops(all_metrics, features, labels,
|
||||||
model_fn_ops.predictions)
|
model_fn_ops.predictions)
|
||||||
if 'loss' not in result:
|
if metric_key.MetricKey.LOSS not in result:
|
||||||
result['loss'] = metrics_lib.streaming_mean(model_fn_ops.loss)
|
result[metric_key.MetricKey.LOSS] = metrics_lib.streaming_mean(
|
||||||
|
model_fn_ops.loss)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_predict_ops(self, features):
|
def _get_predict_ops(self, features):
|
||||||
|
@ -24,9 +24,12 @@ from tensorflow.contrib import losses
|
|||||||
from tensorflow.contrib import metrics as metrics_lib
|
from tensorflow.contrib import metrics as metrics_lib
|
||||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
from tensorflow.contrib.session_bundle import exporter
|
from tensorflow.contrib.session_bundle import exporter
|
||||||
from tensorflow.python import summary
|
from tensorflow.python import summary
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -387,17 +390,17 @@ class _RegressionHead(_Head):
|
|||||||
def _logits_to_prediction(self, logits=None):
|
def _logits_to_prediction(self, logits=None):
|
||||||
predictions = {}
|
predictions = {}
|
||||||
if self.logits_dimension == 1:
|
if self.logits_dimension == 1:
|
||||||
predictions[PredictionKey.SCORES] = array_ops.squeeze(
|
predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze(
|
||||||
logits, squeeze_dims=[1])
|
logits, squeeze_dims=[1])
|
||||||
else:
|
else:
|
||||||
predictions[PredictionKey.SCORES] = logits
|
predictions[prediction_key.PredictionKey.SCORES] = logits
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
# pylint: disable=undefined-variable
|
# pylint: disable=undefined-variable
|
||||||
def _create_signature_fn(self):
|
def _create_signature_fn(self):
|
||||||
def _regression_signature_fn(examples, unused_features, predictions):
|
def _regression_signature_fn(examples, unused_features, predictions):
|
||||||
if isinstance(predictions, dict):
|
if isinstance(predictions, dict):
|
||||||
score = predictions[PredictionKey.SCORES]
|
score = predictions[prediction_key.PredictionKey.SCORES]
|
||||||
else:
|
else:
|
||||||
score = predictions
|
score = predictions
|
||||||
|
|
||||||
@ -408,9 +411,10 @@ class _RegressionHead(_Head):
|
|||||||
return _regression_signature_fn
|
return _regression_signature_fn
|
||||||
|
|
||||||
def _default_metric(self):
|
def _default_metric(self):
|
||||||
return {_head_prefixed(self._head_name, MetricKey.LOSS):
|
return {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
_weighted_average_loss_metric_spec(
|
||||||
PredictionKey.SCORES,
|
self._eval_loss_fn,
|
||||||
|
prediction_key.PredictionKey.SCORES,
|
||||||
self._label_name,
|
self._label_name,
|
||||||
self._weight_column_name)}
|
self._weight_column_name)}
|
||||||
|
|
||||||
@ -529,12 +533,16 @@ class _MultiClassHead(_Head):
|
|||||||
return self._logits_to_prediction(logits)
|
return self._logits_to_prediction(logits)
|
||||||
|
|
||||||
def _logits_to_prediction(self, logits=None):
|
def _logits_to_prediction(self, logits=None):
|
||||||
predictions = {PredictionKey.LOGITS: logits}
|
# pylint: disable=missing-docstring
|
||||||
|
predictions = {prediction_key.PredictionKey.LOGITS: logits}
|
||||||
if self.logits_dimension == 1:
|
if self.logits_dimension == 1:
|
||||||
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
|
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
|
||||||
|
logits)
|
||||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||||
predictions[PredictionKey.PROBABILITIES] = nn.softmax(logits)
|
predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
|
||||||
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
|
logits)
|
||||||
|
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
|
||||||
|
logits, 1)
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
@ -545,8 +553,9 @@ class _MultiClassHead(_Head):
|
|||||||
if isinstance(predictions, dict):
|
if isinstance(predictions, dict):
|
||||||
default_signature = exporter.classification_signature(
|
default_signature = exporter.classification_signature(
|
||||||
input_tensor=examples,
|
input_tensor=examples,
|
||||||
classes_tensor=predictions[PredictionKey.CLASSES],
|
classes_tensor=predictions[prediction_key.PredictionKey.CLASSES],
|
||||||
scores_tensor=predictions[PredictionKey.PROBABILITIES])
|
scores_tensor=predictions[
|
||||||
|
prediction_key.PredictionKey.PROBABILITIES])
|
||||||
else:
|
else:
|
||||||
default_signature = exporter.classification_signature(
|
default_signature = exporter.classification_signature(
|
||||||
input_tensor=examples,
|
input_tensor=examples,
|
||||||
@ -557,44 +566,49 @@ class _MultiClassHead(_Head):
|
|||||||
return _classification_signature_fn
|
return _classification_signature_fn
|
||||||
|
|
||||||
def _default_metric(self):
|
def _default_metric(self):
|
||||||
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
|
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
_weighted_average_loss_metric_spec(
|
||||||
PredictionKey.LOGITS,
|
self._eval_loss_fn,
|
||||||
|
prediction_key.PredictionKey.LOGITS,
|
||||||
self._label_name,
|
self._label_name,
|
||||||
self._weight_column_name)}
|
self._weight_column_name)}
|
||||||
|
|
||||||
# TODO(b/29366811): This currently results in both an "accuracy" and an
|
# TODO(b/29366811): This currently results in both an "accuracy" and an
|
||||||
# "accuracy/threshold_0.500000_mean" metric for binary classification.
|
# "accuracy/threshold_0.500000_mean" metric for binary classification.
|
||||||
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
|
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
|
||||||
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
||||||
PredictionKey.CLASSES, self._label_name,
|
prediction_key.PredictionKey.CLASSES,
|
||||||
self._weight_column_name))
|
|
||||||
if self.logits_dimension == 1:
|
|
||||||
def _add_binary_metric(metric_key, metric_fn):
|
|
||||||
metrics[_head_prefixed(self._head_name, metric_key)] = (
|
|
||||||
metric_spec.MetricSpec(metric_fn,
|
|
||||||
PredictionKey.LOGISTIC,
|
|
||||||
self._label_name,
|
self._label_name,
|
||||||
self._weight_column_name))
|
self._weight_column_name))
|
||||||
_add_binary_metric(MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
|
if self.logits_dimension == 1:
|
||||||
_add_binary_metric(MetricKey.LABEL_MEAN, _labels_streaming_mean)
|
def _add_binary_metric(key, metric_fn):
|
||||||
|
metrics[_head_prefixed(self._head_name, key)] = (
|
||||||
|
metric_spec.MetricSpec(metric_fn,
|
||||||
|
prediction_key.PredictionKey.LOGISTIC,
|
||||||
|
self._label_name,
|
||||||
|
self._weight_column_name))
|
||||||
|
_add_binary_metric(
|
||||||
|
metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
|
||||||
|
_add_binary_metric(
|
||||||
|
metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean)
|
||||||
|
|
||||||
# Also include the streaming mean of the label as an accuracy baseline, as
|
# Also include the streaming mean of the label as an accuracy baseline, as
|
||||||
# a reminder to users.
|
# a reminder to users.
|
||||||
_add_binary_metric(MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
|
_add_binary_metric(
|
||||||
|
metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
|
||||||
|
|
||||||
_add_binary_metric(MetricKey.AUC, _streaming_auc)
|
_add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc)
|
||||||
|
|
||||||
for threshold in self._thresholds:
|
for threshold in self._thresholds:
|
||||||
_add_binary_metric(MetricKey.ACCURACY_MEAN % threshold,
|
_add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold,
|
||||||
_accuracy_at_threshold(threshold))
|
_accuracy_at_threshold(threshold))
|
||||||
# Precision for positive examples.
|
# Precision for positive examples.
|
||||||
_add_binary_metric(MetricKey.PRECISION_MEAN % threshold,
|
_add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold,
|
||||||
_streaming_at_threshold(
|
_streaming_at_threshold(
|
||||||
metrics_lib.streaming_precision_at_thresholds,
|
metrics_lib.streaming_precision_at_thresholds,
|
||||||
threshold),)
|
threshold),)
|
||||||
# Recall for positive examples.
|
# Recall for positive examples.
|
||||||
_add_binary_metric(MetricKey.RECALL_MEAN % threshold,
|
_add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold,
|
||||||
_streaming_at_threshold(
|
_streaming_at_threshold(
|
||||||
metrics_lib.streaming_recall_at_thresholds,
|
metrics_lib.streaming_recall_at_thresholds,
|
||||||
threshold))
|
threshold))
|
||||||
@ -603,7 +617,7 @@ class _MultiClassHead(_Head):
|
|||||||
|
|
||||||
def _check_labels(labels, label_name):
|
def _check_labels(labels, label_name):
|
||||||
labels = labels[label_name] if isinstance(labels, dict) else labels
|
labels = labels[label_name] if isinstance(labels, dict) else labels
|
||||||
if isinstance(labels, ops.SparseTensor):
|
if isinstance(labels, sparse_tensor.SparseTensor):
|
||||||
raise ValueError("SparseTensor is not supported as labels.")
|
raise ValueError("SparseTensor is not supported as labels.")
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
@ -634,21 +648,24 @@ class _BinarySvmHead(_MultiClassHead):
|
|||||||
|
|
||||||
def _logits_to_prediction(self, logits=None):
|
def _logits_to_prediction(self, logits=None):
|
||||||
predictions = {}
|
predictions = {}
|
||||||
predictions[PredictionKey.LOGITS] = logits
|
predictions[prediction_key.PredictionKey.LOGITS] = logits
|
||||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||||
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
|
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
|
||||||
|
logits, 1)
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
def _default_metric(self):
|
def _default_metric(self):
|
||||||
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
|
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
_weighted_average_loss_metric_spec(
|
||||||
PredictionKey.LOGITS,
|
self._eval_loss_fn,
|
||||||
|
prediction_key.PredictionKey.LOGITS,
|
||||||
self._label_name,
|
self._label_name,
|
||||||
self._weight_column_name)}
|
self._weight_column_name)}
|
||||||
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
|
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
|
||||||
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
||||||
PredictionKey.CLASSES, self._label_name,
|
prediction_key.PredictionKey.CLASSES,
|
||||||
|
self._label_name,
|
||||||
self._weight_column_name))
|
self._weight_column_name))
|
||||||
# TODO(sibyl-vie3Poto): add more metrics relevant for svms.
|
# TODO(sibyl-vie3Poto): add more metrics relevant for svms.
|
||||||
return metrics
|
return metrics
|
||||||
@ -673,12 +690,14 @@ class _MultiLabelHead(_MultiClassHead):
|
|||||||
thresholds=thresholds)
|
thresholds=thresholds)
|
||||||
|
|
||||||
def _logits_to_prediction(self, logits=None):
|
def _logits_to_prediction(self, logits=None):
|
||||||
predictions = {PredictionKey.LOGITS: logits}
|
predictions = {prediction_key.PredictionKey.LOGITS: logits}
|
||||||
if self.logits_dimension == 1:
|
if self.logits_dimension == 1:
|
||||||
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
|
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
|
||||||
|
logits)
|
||||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||||
predictions[PredictionKey.PROBABILITIES] = math_ops.sigmoid(logits)
|
predictions[prediction_key.PredictionKey.PROBABILITIES] = math_ops.sigmoid(
|
||||||
predictions[PredictionKey.CLASSES] = math_ops.to_int64(
|
logits)
|
||||||
|
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.to_int64(
|
||||||
math_ops.greater(logits, 0))
|
math_ops.greater(logits, 0))
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
@ -848,23 +867,3 @@ def _streaming_at_threshold(streaming_metrics_fn, threshold):
|
|||||||
return array_ops.squeeze(precision_tensor), update_op
|
return array_ops.squeeze(precision_tensor), update_op
|
||||||
|
|
||||||
return _streaming_metrics
|
return _streaming_metrics
|
||||||
|
|
||||||
|
|
||||||
class PredictionKey(object):
|
|
||||||
CLASSES = "classes"
|
|
||||||
PROBABILITIES = "probabilities"
|
|
||||||
LOGITS = "logits"
|
|
||||||
LOGISTIC = "logistic"
|
|
||||||
SCORES = "scores"
|
|
||||||
|
|
||||||
|
|
||||||
class MetricKey(object):
|
|
||||||
LOSS = "loss"
|
|
||||||
AUC = "auc"
|
|
||||||
PREDICTION_MEAN = "labels/prediction_mean"
|
|
||||||
LABEL_MEAN = "labels/actual_label_mean"
|
|
||||||
ACCURACY = "accuracy"
|
|
||||||
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
|
|
||||||
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
|
|
||||||
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
|
|
||||||
RECALL_MEAN = "recall/positive_threshold_%f_mean"
|
|
||||||
|
@ -74,7 +74,7 @@ class MultiClassModelHeadTest(tf.test.TestCase):
|
|||||||
model_fn_ops = head.head_ops({}, labels,
|
model_fn_ops = head.head_ops({}, labels,
|
||||||
tf.contrib.learn.ModeKeys.TRAIN,
|
tf.contrib.learn.ModeKeys.TRAIN,
|
||||||
_noop_train_op, logits=logits)
|
_noop_train_op, logits=logits)
|
||||||
self.assertAlmostEqual(.81326163, sess.run(model_fn_ops.loss))
|
self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss))
|
||||||
|
|
||||||
def testErrorInSparseTensorLabels(self):
|
def testErrorInSparseTensorLabels(self):
|
||||||
head = head_lib._multi_class_head(n_classes=2)
|
head = head_lib._multi_class_head(n_classes=2)
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.contrib.learn.python.learn import evaluable
|
|||||||
from tensorflow.contrib.learn.python.learn import trainable
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
from tensorflow.contrib.learn.python.learn.utils import export
|
from tensorflow.contrib.learn.python.learn.utils import export
|
||||||
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -267,21 +268,18 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
education = sparse_column_with_hash_bucket(column_name="education",
|
sparse_column_a = sparse_column_with_hash_bucket(...)
|
||||||
hash_bucket_size=1000)
|
sparse_column_b = sparse_column_with_hash_bucket(...)
|
||||||
occupation = sparse_column_with_hash_bucket(column_name="occupation",
|
|
||||||
hash_bucket_size=1000)
|
|
||||||
|
|
||||||
education_x_occupation = crossed_column(columns=[education, occupation],
|
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||||
hash_bucket_size=10000)
|
|
||||||
|
|
||||||
# Estimator using the default optimizer.
|
# Estimator using the default optimizer.
|
||||||
estimator = LinearClassifier(
|
estimator = LinearClassifier(
|
||||||
feature_columns=[occupation, education_x_occupation])
|
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b])
|
||||||
|
|
||||||
# Or estimator using the FTRL optimizer with regularization.
|
# Or estimator using the FTRL optimizer with regularization.
|
||||||
estimator = LinearClassifier(
|
estimator = LinearClassifier(
|
||||||
feature_columns=[occupation, education_x_occupation],
|
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
||||||
optimizer=tf.train.FtrlOptimizer(
|
optimizer=tf.train.FtrlOptimizer(
|
||||||
learning_rate=0.1,
|
learning_rate=0.1,
|
||||||
l1_regularization_strength=0.001
|
l1_regularization_strength=0.001
|
||||||
@ -289,7 +287,7 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
|
|
||||||
# Or estimator using the SDCAOptimizer.
|
# Or estimator using the SDCAOptimizer.
|
||||||
estimator = LinearClassifier(
|
estimator = LinearClassifier(
|
||||||
feature_columns=[occupation, education_x_occupation],
|
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
||||||
optimizer=tf.contrib.linear_optimizer.SDCAOptimizer(
|
optimizer=tf.contrib.linear_optimizer.SDCAOptimizer(
|
||||||
example_id_column='example_id',
|
example_id_column='example_id',
|
||||||
num_loss_partitions=...,
|
num_loss_partitions=...,
|
||||||
@ -465,13 +463,16 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
as_iterable=False)
|
as_iterable=False)
|
||||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||||
"""Runs inference to determine the predicted class."""
|
"""Runs inference to determine the predicted class."""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.CLASSES
|
||||||
|
preds = self._estimator.predict(
|
||||||
|
x=x,
|
||||||
|
input_fn=input_fn,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
outputs=[head_lib.PredictionKey.CLASSES],
|
outputs=[key],
|
||||||
as_iterable=as_iterable)
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.CLASSES]
|
return preds[key]
|
||||||
|
|
||||||
@deprecated_arg_values(
|
@deprecated_arg_values(
|
||||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||||
@ -479,14 +480,16 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
||||||
as_iterable=True):
|
as_iterable=True):
|
||||||
"""Runs inference to determine the class probability predictions."""
|
"""Runs inference to determine the class probability predictions."""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.PROBABILITIES
|
||||||
|
preds = self._estimator.predict(
|
||||||
|
x=x,
|
||||||
|
input_fn=input_fn,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
outputs=[
|
outputs=[key],
|
||||||
head_lib.PredictionKey.PROBABILITIES],
|
|
||||||
as_iterable=as_iterable)
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.PROBABILITIES]
|
return preds[key]
|
||||||
|
|
||||||
def get_variable_names(self):
|
def get_variable_names(self):
|
||||||
return self._estimator.get_variable_names()
|
return self._estimator.get_variable_names()
|
||||||
@ -512,9 +515,9 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
|||||||
input_fn=input_fn or default_input_fn,
|
input_fn=input_fn or default_input_fn,
|
||||||
input_feature_key=input_feature_key,
|
input_feature_key=input_feature_key,
|
||||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||||
signature_fn=(
|
signature_fn=(signature_fn or
|
||||||
signature_fn or export.classification_signature_fn_with_prob),
|
export.classification_signature_fn_with_prob),
|
||||||
prediction_key=head_lib.PredictionKey.PROBABILITIES,
|
prediction_key=prediction_key.PredictionKey.PROBABILITIES,
|
||||||
default_batch_size=default_batch_size,
|
default_batch_size=default_batch_size,
|
||||||
exports_to_keep=exports_to_keep)
|
exports_to_keep=exports_to_keep)
|
||||||
|
|
||||||
@ -561,16 +564,13 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
education = sparse_column_with_hash_bucket(column_name="education",
|
sparse_column_a = sparse_column_with_hash_bucket(...)
|
||||||
hash_bucket_size=1000)
|
sparse_column_b = sparse_column_with_hash_bucket(...)
|
||||||
occupation = sparse_column_with_hash_bucket(column_name="occupation",
|
|
||||||
hash_bucket_size=1000)
|
|
||||||
|
|
||||||
education_x_occupation = crossed_column(columns=[education, occupation],
|
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||||
hash_bucket_size=10000)
|
|
||||||
|
|
||||||
estimator = LinearRegressor(
|
estimator = LinearRegressor(
|
||||||
feature_columns=[occupation, education_x_occupation])
|
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b])
|
||||||
|
|
||||||
# Input builders
|
# Input builders
|
||||||
def input_fn_train: # returns x, y
|
def input_fn_train: # returns x, y
|
||||||
@ -731,13 +731,16 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
|||||||
as_iterable=False)
|
as_iterable=False)
|
||||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||||
"""Runs inference to determine the predicted class."""
|
"""Runs inference to determine the predicted class."""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.SCORES
|
||||||
|
preds = self._estimator.predict(
|
||||||
|
x=x,
|
||||||
|
input_fn=input_fn,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
outputs=[head_lib.PredictionKey.SCORES],
|
outputs=[key],
|
||||||
as_iterable=as_iterable)
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.SCORES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.SCORES]
|
return preds[key]
|
||||||
|
|
||||||
def get_variable_names(self):
|
def get_variable_names(self):
|
||||||
return self._estimator.get_variable_names()
|
return self._estimator.get_variable_names()
|
||||||
@ -764,7 +767,7 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
|||||||
input_feature_key=input_feature_key,
|
input_feature_key=input_feature_key,
|
||||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||||
signature_fn=(signature_fn or export.regression_signature_fn),
|
signature_fn=(signature_fn or export.regression_signature_fn),
|
||||||
prediction_key=head_lib.PredictionKey.SCORES,
|
prediction_key=prediction_key.PredictionKey.SCORES,
|
||||||
default_batch_size=default_batch_size,
|
default_batch_size=default_batch_size,
|
||||||
exports_to_keep=exports_to_keep)
|
exports_to_keep=exports_to_keep)
|
||||||
|
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Enum for metric keys."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
class MetricKey(object):
|
||||||
|
LOSS = "loss"
|
||||||
|
AUC = "auc"
|
||||||
|
PREDICTION_MEAN = "labels/prediction_mean"
|
||||||
|
LABEL_MEAN = "labels/actual_label_mean"
|
||||||
|
ACCURACY = "accuracy"
|
||||||
|
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
|
||||||
|
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
|
||||||
|
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
|
||||||
|
RECALL_MEAN = "recall/positive_threshold_%f_mean"
|
@ -0,0 +1,26 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Enum for model prediction keys."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionKey(object):
|
||||||
|
CLASSES = "classes"
|
||||||
|
PROBABILITIES = "probabilities"
|
||||||
|
LOGITS = "logits"
|
||||||
|
LOGISTIC = "logistic"
|
||||||
|
SCORES = "scores"
|
@ -30,6 +30,7 @@ from tensorflow.contrib.learn.python.learn import trainable
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import linear
|
from tensorflow.contrib.learn.python.learn.estimators import linear
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -188,13 +189,16 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
|||||||
as_iterable=False)
|
as_iterable=False)
|
||||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||||
"""Runs inference to determine the predicted class."""
|
"""Runs inference to determine the predicted class."""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.CLASSES
|
||||||
|
preds = self._estimator.predict(
|
||||||
|
x=x,
|
||||||
|
input_fn=input_fn,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
outputs=[head_lib.PredictionKey.CLASSES],
|
outputs=[key],
|
||||||
as_iterable=as_iterable)
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.CLASSES]
|
return preds[key]
|
||||||
|
|
||||||
@deprecated_arg_values(
|
@deprecated_arg_values(
|
||||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||||
@ -202,14 +206,16 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
|||||||
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
||||||
as_iterable=True):
|
as_iterable=True):
|
||||||
"""Runs inference to determine the class probability predictions."""
|
"""Runs inference to determine the class probability predictions."""
|
||||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
key = prediction_key.PredictionKey.PROBABILITIES
|
||||||
|
preds = self._estimator.predict(
|
||||||
|
x=x,
|
||||||
|
input_fn=input_fn,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
outputs=[
|
outputs=[key],
|
||||||
head_lib.PredictionKey.PROBABILITIES],
|
|
||||||
as_iterable=as_iterable)
|
as_iterable=as_iterable)
|
||||||
if as_iterable:
|
if as_iterable:
|
||||||
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
|
return _as_iterable(preds, output=key)
|
||||||
return preds[head_lib.PredictionKey.PROBABILITIES]
|
return preds[key]
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
def get_variable_names(self):
|
def get_variable_names(self):
|
||||||
|
@ -22,7 +22,7 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -41,7 +41,7 @@ class TensorSignature(collections.namedtuple(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, tensor):
|
def __new__(cls, tensor):
|
||||||
if isinstance(tensor, ops.SparseTensor):
|
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||||
return super(TensorSignature, cls).__new__(
|
return super(TensorSignature, cls).__new__(
|
||||||
cls, dtype=tensor.values.dtype, shape=None, is_sparse=True)
|
cls, dtype=tensor.values.dtype, shape=None, is_sparse=True)
|
||||||
return super(TensorSignature, cls).__new__(
|
return super(TensorSignature, cls).__new__(
|
||||||
|
@ -40,6 +40,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import logging_ops
|
from tensorflow.python.ops import logging_ops
|
||||||
|
from tensorflow.python.ops import resources
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import basic_session_run_hooks
|
from tensorflow.python.training import basic_session_run_hooks
|
||||||
@ -77,7 +78,8 @@ def get_summary_writer(logdir):
|
|||||||
|
|
||||||
|
|
||||||
def _make_saver(graph, keep_checkpoint_max=5):
|
def _make_saver(graph, keep_checkpoint_max=5):
|
||||||
vars_to_save = graph.get_collection(ops.GraphKeys.VARIABLES)
|
vars_to_save = (graph.get_collection(ops.GraphKeys.VARIABLES) +
|
||||||
|
graph.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
|
||||||
if vars_to_save:
|
if vars_to_save:
|
||||||
return tf_saver.Saver(vars_to_save,
|
return tf_saver.Saver(vars_to_save,
|
||||||
sharded=True,
|
sharded=True,
|
||||||
@ -846,9 +848,11 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
|
|||||||
raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)
|
raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)
|
||||||
|
|
||||||
graph = contrib_ops.get_graph_from_inputs(output_dict.values())
|
graph = contrib_ops.get_graph_from_inputs(output_dict.values())
|
||||||
|
|
||||||
with graph.as_default() as g:
|
with graph.as_default() as g:
|
||||||
with tf_session.Session('') as session:
|
with tf_session.Session('') as session:
|
||||||
|
session.run(
|
||||||
|
resources.initialize_resources(resources.shared_resources() +
|
||||||
|
resources.local_resources()))
|
||||||
if restore_checkpoint_path:
|
if restore_checkpoint_path:
|
||||||
_restore_from_checkpoint(session, g, restore_checkpoint_path)
|
_restore_from_checkpoint(session, g, restore_checkpoint_path)
|
||||||
else:
|
else:
|
||||||
|
@ -28,6 +28,8 @@ from tensorflow.contrib.learn.python import learn
|
|||||||
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
|
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
|
||||||
from tensorflow.python.framework import meta_graph
|
from tensorflow.python.framework import meta_graph
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_ops
|
||||||
|
from tensorflow.python.ops import resources
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
|
|
||||||
@ -194,6 +196,19 @@ class GraphActionsTest(tf.test.TestCase):
|
|||||||
pass
|
pass
|
||||||
self.assertTrue(request_stop.called)
|
self.assertTrue(request_stop.called)
|
||||||
|
|
||||||
|
def test_run_feeds_iter_calls_resources_init(self):
|
||||||
|
with tf.Graph().as_default() as g:
|
||||||
|
in0, _, _ = self._build_inference_graph()
|
||||||
|
handle = test_ops.stub_resource_handle_op(container='a', shared_name='b')
|
||||||
|
resources.register_resource(
|
||||||
|
handle=handle,
|
||||||
|
create_op=test_ops.resource_create_op(handle),
|
||||||
|
is_initialized_op=test_ops.resource_initialized_op(handle))
|
||||||
|
|
||||||
|
for _ in learn.graph_actions.run_feeds_iter({'in0': in0},
|
||||||
|
feed_dicts=[{}]):
|
||||||
|
self.assertTrue(test_ops.resource_initialized_op(handle).eval())
|
||||||
|
|
||||||
def test_infer_different_default_graph(self):
|
def test_infer_different_default_graph(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
self._assert_ckpt(self._output_dir, False)
|
self._assert_ckpt(self._output_dir, False)
|
||||||
|
@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -645,7 +646,7 @@ def queue_parsed_features(parsed_features,
|
|||||||
# directly.
|
# directly.
|
||||||
for key in sorted(parsed_features.keys()):
|
for key in sorted(parsed_features.keys()):
|
||||||
tensor = parsed_features[key]
|
tensor = parsed_features[key]
|
||||||
if isinstance(tensor, ops.SparseTensor):
|
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||||
tensors_mapping.append((key, True))
|
tensors_mapping.append((key, True))
|
||||||
tensors_to_enqueue.extend([tensor.indices, tensor.values, tensor.shape])
|
tensors_to_enqueue.extend([tensor.indices, tensor.values, tensor.shape])
|
||||||
else:
|
else:
|
||||||
@ -704,7 +705,7 @@ def queue_parsed_features(parsed_features,
|
|||||||
for key, is_sparse_tensor in tensors_mapping:
|
for key, is_sparse_tensor in tensors_mapping:
|
||||||
if is_sparse_tensor:
|
if is_sparse_tensor:
|
||||||
# Three tensors are (indices, values, shape).
|
# Three tensors are (indices, values, shape).
|
||||||
dequeued_parsed_features[key] = ops.SparseTensor(
|
dequeued_parsed_features[key] = sparse_tensor.SparseTensor(
|
||||||
dequeued_tensors[index], dequeued_tensors[index + 1],
|
dequeued_tensors[index], dequeued_tensors[index + 1],
|
||||||
dequeued_tensors[index + 2])
|
dequeued_tensors[index + 2])
|
||||||
index += 3
|
index += 3
|
||||||
|
@ -542,7 +542,8 @@ class CheckpointSaverTest(tf.test.TestCase):
|
|||||||
self.assertEqual(1, tf.contrib.framework.load_variable(
|
self.assertEqual(1, tf.contrib.framework.load_variable(
|
||||||
self.model_dir, self.global_step.name))
|
self.model_dir, self.global_step.name))
|
||||||
|
|
||||||
def test_save_secs_saves_periodically(self):
|
# TODO(gunan): Reenable this test after b/32446874 is fixed.
|
||||||
|
def disabled_test_save_secs_saves_periodically(self):
|
||||||
with self.graph.as_default():
|
with self.graph.as_default():
|
||||||
monitor = learn.monitors.CheckpointSaver(
|
monitor = learn.monitors.CheckpointSaver(
|
||||||
self.model_dir, save_secs=2, scaffold=self.scaffold)
|
self.model_dir, save_secs=2, scaffold=self.scaffold)
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_data_flow_ops
|
from tensorflow.python.ops import gen_data_flow_ops
|
||||||
@ -166,7 +167,7 @@ class InitializableLookupTableBase(LookupInterface):
|
|||||||
name = "%s_lookup_table_find" % self._name
|
name = "%s_lookup_table_find" % self._name
|
||||||
|
|
||||||
key_tensor = keys
|
key_tensor = keys
|
||||||
if isinstance(keys, ops.SparseTensor):
|
if isinstance(keys, sparse_tensor.SparseTensor):
|
||||||
key_tensor = keys.values
|
key_tensor = keys.values
|
||||||
|
|
||||||
if keys.dtype != self._key_dtype:
|
if keys.dtype != self._key_dtype:
|
||||||
@ -181,8 +182,8 @@ class InitializableLookupTableBase(LookupInterface):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
values.set_shape(key_tensor.get_shape())
|
values.set_shape(key_tensor.get_shape())
|
||||||
if isinstance(keys, ops.SparseTensor):
|
if isinstance(keys, sparse_tensor.SparseTensor):
|
||||||
return ops.SparseTensor(keys.indices, values, keys.shape)
|
return sparse_tensor.SparseTensor(keys.indices, values, keys.shape)
|
||||||
else:
|
else:
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -296,7 +297,8 @@ class MutableHashTableOpTest(tf.test.TestCase):
|
|||||||
self.assertAllEqual([0, 1, 2], sorted_values)
|
self.assertAllEqual([0, 1, 2], sorted_values)
|
||||||
|
|
||||||
def testSaveRestore(self):
|
def testSaveRestore(self):
|
||||||
save_path = os.path.join(self.get_temp_dir(), "hash")
|
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
||||||
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||||
|
|
||||||
with self.test_session(graph=tf.Graph()) as sess:
|
with self.test_session(graph=tf.Graph()) as sess:
|
||||||
v0 = tf.Variable(10.0, name="v0")
|
v0 = tf.Variable(10.0, name="v0")
|
||||||
@ -867,7 +869,8 @@ class MutableDenseHashTableOpTest(tf.test.TestCase):
|
|||||||
[100, 0], [100, 0], [100, 0]], pairs)
|
[100, 0], [100, 0], [100, 0]], pairs)
|
||||||
|
|
||||||
def testSaveRestore(self):
|
def testSaveRestore(self):
|
||||||
save_path = os.path.join(self.get_temp_dir(), "hash")
|
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
||||||
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||||
|
|
||||||
with self.test_session(graph=tf.Graph()) as sess:
|
with self.test_session(graph=tf.Graph()) as sess:
|
||||||
default_value = -1
|
default_value = -1
|
||||||
@ -922,7 +925,8 @@ class MutableDenseHashTableOpTest(tf.test.TestCase):
|
|||||||
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
|
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
|
||||||
|
|
||||||
def testVectorSaveRestore(self):
|
def testVectorSaveRestore(self):
|
||||||
save_path = os.path.join(self.get_temp_dir(), "hash")
|
save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")
|
||||||
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||||
|
|
||||||
with self.test_session(graph=tf.Graph()) as sess:
|
with self.test_session(graph=tf.Graph()) as sess:
|
||||||
empty_key = tf.constant([11, 13], tf.int64)
|
empty_key = tf.constant([11, 13], tf.int64)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
### TensorFlow Makefile
|
### TensorFlow Makefile
|
||||||
|
|
||||||
The recommended way to build TensorFlow from source is using the Bazel
|
The recommended way to build TensorFlow from source is using the Bazel
|
||||||
open-source build system. Sometimes this isn't possible.
|
open-source build system. Sometimes this isn't possible. For example,
|
||||||
|
if you are building for iOS, you currently need to use the Makefile.
|
||||||
|
|
||||||
- The build system may not have the RAM or processing power to support Bazel.
|
- The build system may not have the RAM or processing power to support Bazel.
|
||||||
- Bazel or its dependencies may not be available.
|
- Bazel or its dependencies may not be available.
|
||||||
|
@ -43,6 +43,13 @@ tensorflow/core/kernels/sequence_ops.cc
|
|||||||
tensorflow/core/kernels/sendrecv_ops.cc
|
tensorflow/core/kernels/sendrecv_ops.cc
|
||||||
tensorflow/core/kernels/scatter_op.cc
|
tensorflow/core/kernels/scatter_op.cc
|
||||||
tensorflow/core/kernels/scatter_functor.cc
|
tensorflow/core/kernels/scatter_functor.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
|
||||||
|
tensorflow/core/kernels/scatter_nd_op.cc
|
||||||
tensorflow/core/kernels/save_restore_tensor.cc
|
tensorflow/core/kernels/save_restore_tensor.cc
|
||||||
tensorflow/core/kernels/save_restore_v2_ops.cc
|
tensorflow/core/kernels/save_restore_v2_ops.cc
|
||||||
tensorflow/core/kernels/save_op.cc
|
tensorflow/core/kernels/save_op.cc
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.contrib.framework import tensor_util
|
from tensorflow.contrib.framework import tensor_util
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
@ -102,7 +103,7 @@ def confusion_matrix(predictions, labels, num_classes=None, dtype=dtypes.int32,
|
|||||||
indices = array_ops.transpose(array_ops.pack([predictions, labels]))
|
indices = array_ops.transpose(array_ops.pack([predictions, labels]))
|
||||||
values = (array_ops.ones_like(predictions, dtype)
|
values = (array_ops.ones_like(predictions, dtype)
|
||||||
if weights is None else weights)
|
if weights is None else weights)
|
||||||
cm_sparse = ops.SparseTensor(
|
cm_sparse = sparse_tensor.SparseTensor(
|
||||||
indices=indices, values=values, shape=math_ops.to_int64(shape))
|
indices=indices, values=values, shape=math_ops.to_int64(shape))
|
||||||
zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtype)
|
zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtype)
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
|
|||||||
from tensorflow.contrib.metrics.python.ops import set_ops
|
from tensorflow.contrib.metrics.python.ops import set_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -762,7 +763,12 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
|
|||||||
computes the area under a discretized curve of precision versus recall values
|
computes the area under a discretized curve of precision versus recall values
|
||||||
(computed using the aforementioned variables). The `num_thresholds` variable
|
(computed using the aforementioned variables). The `num_thresholds` variable
|
||||||
controls the degree of discretization with larger numbers of thresholds more
|
controls the degree of discretization with larger numbers of thresholds more
|
||||||
closely approximating the true AUC.
|
closely approximating the true AUC. The quality of the approximation may vary
|
||||||
|
dramatically depending on `num_thresholds`.
|
||||||
|
|
||||||
|
For best results, `predictions` should be distributed approximately uniformly
|
||||||
|
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
|
||||||
|
approximation may be poor if this is not the case.
|
||||||
|
|
||||||
For estimation of the metric over a stream of data, the function creates an
|
For estimation of the metric over a stream of data, the function creates an
|
||||||
`update_op` operation that updates these variables and returns the `auc`.
|
`update_op` operation that updates these variables and returns the `auc`.
|
||||||
@ -1601,7 +1607,8 @@ def num_relevant(labels, k):
|
|||||||
raise ValueError('Invalid k=%s.' % k)
|
raise ValueError('Invalid k=%s.' % k)
|
||||||
with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
|
with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
|
||||||
# For SparseTensor, calculate separate count for each row.
|
# For SparseTensor, calculate separate count for each row.
|
||||||
if isinstance(labels, (ops.SparseTensor, ops.SparseTensorValue)):
|
if isinstance(
|
||||||
|
labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
|
||||||
labels_sizes = set_ops.set_size(labels)
|
labels_sizes = set_ops.set_size(labels)
|
||||||
return math_ops.minimum(labels_sizes, k, name=scope)
|
return math_ops.minimum(labels_sizes, k, name=scope)
|
||||||
|
|
||||||
@ -1637,9 +1644,9 @@ def expand_and_tile(tensor, multiple, dim=0, name=None):
|
|||||||
with ops.name_scope(
|
with ops.name_scope(
|
||||||
name, 'expand_and_tile', (tensor, multiple, dim)) as scope:
|
name, 'expand_and_tile', (tensor, multiple, dim)) as scope:
|
||||||
# Sparse.
|
# Sparse.
|
||||||
if isinstance(tensor, ops.SparseTensorValue):
|
if isinstance(tensor, sparse_tensor.SparseTensorValue):
|
||||||
tensor = ops.SparseTensor.from_value(tensor)
|
tensor = sparse_tensor.SparseTensor.from_value(tensor)
|
||||||
if isinstance(tensor, ops.SparseTensor):
|
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||||
if dim < 0:
|
if dim < 0:
|
||||||
expand_dims = array_ops.reshape(
|
expand_dims = array_ops.reshape(
|
||||||
array_ops.size(tensor.shape) + dim, [1])
|
array_ops.size(tensor.shape) + dim, [1])
|
||||||
@ -1871,7 +1878,8 @@ def _select_class_id(ids, selected_id):
|
|||||||
`SparseTensor` of same dimensions as `ids`. This contains only the entries
|
`SparseTensor` of same dimensions as `ids`. This contains only the entries
|
||||||
equal to `selected_id`.
|
equal to `selected_id`.
|
||||||
"""
|
"""
|
||||||
if isinstance(ids, (ops.SparseTensor, ops.SparseTensorValue)):
|
if isinstance(
|
||||||
|
ids, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
|
||||||
return sparse_ops.sparse_retain(
|
return sparse_ops.sparse_retain(
|
||||||
ids, math_ops.equal(ids.values, selected_id))
|
ids, math_ops.equal(ids.values, selected_id))
|
||||||
|
|
||||||
@ -1888,7 +1896,7 @@ def _select_class_id(ids, selected_id):
|
|||||||
filled_selected_id = array_ops.fill(
|
filled_selected_id = array_ops.fill(
|
||||||
filled_selected_id_shape, math_ops.to_int64(selected_id))
|
filled_selected_id_shape, math_ops.to_int64(selected_id))
|
||||||
result = set_ops.set_intersection(filled_selected_id, ids)
|
result = set_ops.set_intersection(filled_selected_id, ids)
|
||||||
return ops.SparseTensor(
|
return sparse_tensor.SparseTensor(
|
||||||
indices=result.indices, values=result.values, shape=ids_shape)
|
indices=result.indices, values=result.values, shape=ids_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.contrib.util import loader
|
|||||||
from tensorflow.python.framework import common_shapes
|
from tensorflow.python.framework import common_shapes
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ def set_size(a, validate_indices=True):
|
|||||||
TypeError: If `a` is an invalid types.
|
TypeError: If `a` is an invalid types.
|
||||||
"""
|
"""
|
||||||
a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
|
a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
|
||||||
if not isinstance(a, ops.SparseTensor):
|
if not isinstance(a, sparse_tensor.SparseTensor):
|
||||||
raise TypeError("Expected `SparseTensor`, got %s." % a)
|
raise TypeError("Expected `SparseTensor`, got %s." % a)
|
||||||
if a.values.dtype.base_dtype not in _VALID_DTYPES:
|
if a.values.dtype.base_dtype not in _VALID_DTYPES:
|
||||||
raise TypeError("Invalid dtype %s." % a.values.dtype)
|
raise TypeError("Invalid dtype %s." % a.values.dtype)
|
||||||
@ -106,22 +107,22 @@ def _set_operation(a, b, set_operation, validate_indices=True):
|
|||||||
if b.dtype.base_dtype != a.dtype.base_dtype:
|
if b.dtype.base_dtype != a.dtype.base_dtype:
|
||||||
raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
|
raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if isinstance(a, ops.SparseTensor):
|
if isinstance(a, sparse_tensor.SparseTensor):
|
||||||
if isinstance(b, ops.SparseTensor):
|
if isinstance(b, sparse_tensor.SparseTensor):
|
||||||
indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
|
indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
|
||||||
a.indices, a.values, a.shape, b.indices, b.values, b.shape,
|
a.indices, a.values, a.shape, b.indices, b.values, b.shape,
|
||||||
set_operation, validate_indices)
|
set_operation, validate_indices)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
|
raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
|
||||||
"Please flip the order of your inputs.")
|
"Please flip the order of your inputs.")
|
||||||
elif isinstance(b, ops.SparseTensor):
|
elif isinstance(b, sparse_tensor.SparseTensor):
|
||||||
indices, values, shape = _set_ops.dense_to_sparse_set_operation(
|
indices, values, shape = _set_ops.dense_to_sparse_set_operation(
|
||||||
a, b.indices, b.values, b.shape, set_operation, validate_indices)
|
a, b.indices, b.values, b.shape, set_operation, validate_indices)
|
||||||
else:
|
else:
|
||||||
indices, values, shape = _set_ops.dense_to_dense_set_operation(
|
indices, values, shape = _set_ops.dense_to_dense_set_operation(
|
||||||
a, b, set_operation, validate_indices)
|
a, b, set_operation, validate_indices)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
return ops.SparseTensor(indices, values, shape)
|
return sparse_tensor.SparseTensor(indices, values, shape)
|
||||||
|
|
||||||
|
|
||||||
def set_intersection(a, b, validate_indices=True):
|
def set_intersection(a, b, validate_indices=True):
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os.path
|
import os.path
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import six
|
import six
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -40,7 +41,9 @@ class MovingAverageOptimizerTest(tf.test.TestCase):
|
|||||||
tf.train.GradientDescentOptimizer(learning_rate=2.0),
|
tf.train.GradientDescentOptimizer(learning_rate=2.0),
|
||||||
average_decay=0.5,
|
average_decay=0.5,
|
||||||
sequential_update=sequential_update)
|
sequential_update=sequential_update)
|
||||||
save_path = os.path.join(self.get_temp_dir(), 'model')
|
save_dir = tempfile.mkdtemp(
|
||||||
|
prefix=os.path.join(self.get_temp_dir(), 'run_1'))
|
||||||
|
save_path = os.path.join(save_dir, 'model')
|
||||||
update = opt.apply_gradients(
|
update = opt.apply_gradients(
|
||||||
list(six.moves.zip([grads0, grads1], [var0, var1])))
|
list(six.moves.zip([grads0, grads1], [var0, var1])))
|
||||||
train_saver = opt.swapping_saver()
|
train_saver = opt.swapping_saver()
|
||||||
|
@ -39,7 +39,7 @@ INCLUDES := \
|
|||||||
-I/usr/local/include \
|
-I/usr/local/include \
|
||||||
-I. \
|
-I. \
|
||||||
-I$(DOWNLOADSDIR) \
|
-I$(DOWNLOADSDIR) \
|
||||||
-I$(DOWNLOADSDIR)/eigen-latest/ \
|
-I$(DOWNLOADSDIR)/eigen/ \
|
||||||
-I$(PROTOGENDIR) \
|
-I$(PROTOGENDIR) \
|
||||||
-I$(PBTGENDIR)
|
-I$(PBTGENDIR)
|
||||||
LIBS := \
|
LIBS := \
|
||||||
|
@ -39,7 +39,7 @@ INCLUDES := \
|
|||||||
-I/usr/local/include \
|
-I/usr/local/include \
|
||||||
-I. \
|
-I. \
|
||||||
-I$(DOWNLOADSDIR) \
|
-I$(DOWNLOADSDIR) \
|
||||||
-I$(DOWNLOADSDIR)/eigen-latest/ \
|
-I$(DOWNLOADSDIR)/eigen/ \
|
||||||
-I$(PROTOGENDIR) \
|
-I$(PROTOGENDIR) \
|
||||||
-I$(PBTGENDIR)
|
-I$(PBTGENDIR)
|
||||||
LIBS := \
|
LIBS := \
|
||||||
|
@ -46,6 +46,7 @@ py_test(
|
|||||||
name = "learning_test",
|
name = "learning_test",
|
||||||
srcs = ["python/slim/learning_test.py"],
|
srcs = ["python/slim/learning_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["manual"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/contrib/slim",
|
"//tensorflow/contrib/slim",
|
||||||
|
@ -27,7 +27,7 @@ import abc
|
|||||||
|
|
||||||
from tensorflow.contrib.slim.python.slim.data import data_decoder
|
from tensorflow.contrib.slim.python.slim.data import data_decoder
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import image_ops
|
from tensorflow.python.ops import image_ops
|
||||||
@ -189,11 +189,11 @@ class Tensor(ItemHandler):
|
|||||||
shape_dims = []
|
shape_dims = []
|
||||||
for k in self._shape_keys:
|
for k in self._shape_keys:
|
||||||
shape_dim = keys_to_tensors[k]
|
shape_dim = keys_to_tensors[k]
|
||||||
if isinstance(shape_dim, ops.SparseTensor):
|
if isinstance(shape_dim, sparse_tensor.SparseTensor):
|
||||||
shape_dim = sparse_ops.sparse_tensor_to_dense(shape_dim)
|
shape_dim = sparse_ops.sparse_tensor_to_dense(shape_dim)
|
||||||
shape_dims.append(shape_dim)
|
shape_dims.append(shape_dim)
|
||||||
shape = array_ops.reshape(array_ops.pack(shape_dims), [-1])
|
shape = array_ops.reshape(array_ops.pack(shape_dims), [-1])
|
||||||
if isinstance(tensor, ops.SparseTensor):
|
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||||
if shape is not None:
|
if shape is not None:
|
||||||
tensor = sparse_ops.sparse_reshape(tensor, shape)
|
tensor = sparse_ops.sparse_reshape(tensor, shape)
|
||||||
tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
|
tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
|
||||||
@ -241,7 +241,7 @@ class SparseTensor(ItemHandler):
|
|||||||
values = keys_to_tensors[self._values_key]
|
values = keys_to_tensors[self._values_key]
|
||||||
if self._shape_key:
|
if self._shape_key:
|
||||||
shape = keys_to_tensors[self._shape_key]
|
shape = keys_to_tensors[self._shape_key]
|
||||||
if isinstance(shape, ops.SparseTensor):
|
if isinstance(shape, sparse_tensor.SparseTensor):
|
||||||
shape = sparse_ops.sparse_tensor_to_dense(shape)
|
shape = sparse_ops.sparse_tensor_to_dense(shape)
|
||||||
elif self._shape:
|
elif self._shape:
|
||||||
shape = self._shape
|
shape = self._shape
|
||||||
@ -255,7 +255,7 @@ class SparseTensor(ItemHandler):
|
|||||||
new_indices = array_ops.concat(1, [indices_columns_to_preserve,
|
new_indices = array_ops.concat(1, [indices_columns_to_preserve,
|
||||||
array_ops.reshape(ids, [-1, 1])])
|
array_ops.reshape(ids, [-1, 1])])
|
||||||
|
|
||||||
tensor = ops.SparseTensor(new_indices, values.values, shape)
|
tensor = sparse_tensor.SparseTensor(new_indices, values.values, shape)
|
||||||
if self._densify:
|
if self._densify:
|
||||||
tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
|
tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
|
||||||
return tensor
|
return tensor
|
||||||
|
@ -132,11 +132,9 @@ REGISTER_OP("FinishedNodes")
|
|||||||
.Attr("num_split_after_samples: int")
|
.Attr("num_split_after_samples: int")
|
||||||
.Attr("min_split_samples: int")
|
.Attr("min_split_samples: int")
|
||||||
.Attr("dominate_fraction: float = 0.99")
|
.Attr("dominate_fraction: float = 0.99")
|
||||||
// TODO(thomaswc): Test out bootstrap on several datasets, confirm it
|
|
||||||
// works well, make it the default.
|
|
||||||
.Attr(
|
.Attr(
|
||||||
"dominate_method:"
|
"dominate_method:"
|
||||||
" {'none', 'hoeffding', 'bootstrap', 'chebyshev'} = 'hoeffding'")
|
" {'none', 'hoeffding', 'bootstrap', 'chebyshev'} = 'bootstrap'")
|
||||||
.Attr("random_seed: int = 0")
|
.Attr("random_seed: int = 0")
|
||||||
.Input("leaves: int32")
|
.Input("leaves: int32")
|
||||||
.Input("node_to_accumulator: int32")
|
.Input("node_to_accumulator: int32")
|
||||||
|
@ -26,7 +26,7 @@ namespace tensorflow {
|
|||||||
|
|
||||||
TEST(TrainingOpsTest, UpdateFertileSlots_ShapeFn) {
|
TEST(TrainingOpsTest, UpdateFertileSlots_ShapeFn) {
|
||||||
ShapeInferenceTestOp op("UpdateFertileSlots");
|
ShapeInferenceTestOp op("UpdateFertileSlots");
|
||||||
INFER_OK(op, "?;?;?;?;?;?;?", "[2,?];[2,?];[?];[?]");
|
INFER_OK(op, "?;?;?;?;?;?;?;?", "[2,?];[2,?];[?];[?]");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TrainingOpsTest, ScatterAddNdim_ShapeFn) {
|
TEST(TrainingOpsTest, ScatterAddNdim_ShapeFn) {
|
||||||
|
@ -55,23 +55,29 @@ T Sum(Tensor counts) {
|
|||||||
// is stored in index 0, individual feature types start at index 1.
|
// is stored in index 0, individual feature types start at index 1.
|
||||||
DataColumnTypes FeatureSpec(int32 input_feature, const Tensor& spec);
|
DataColumnTypes FeatureSpec(int32 input_feature, const Tensor& spec);
|
||||||
|
|
||||||
// Given an Eigen::Tensor type, calculate the Gini impurity, which we use
|
// Given an Eigen::Tensor type, calculate the Gini impurity.
|
||||||
// to determine the best split (lowest) and which nodes to allocate first
|
template <typename T>
|
||||||
// (highest).
|
float RawWeightedGiniImpurity(const T& counts) {
|
||||||
template<typename T>
|
|
||||||
float WeightedGiniImpurity(const T& counts) {
|
|
||||||
// Our split score is the Gini impurity times the number of examples
|
// Our split score is the Gini impurity times the number of examples
|
||||||
// seen by the leaf. If c(i) denotes the i-th class count and c = sum_i c(i)
|
// seen by the leaf. If c(i) denotes the i-th class count and c = sum_i c(i)
|
||||||
// then
|
// then
|
||||||
// score = c * (1 - sum_i ( c(i) / c )^2 )
|
// score = c * (1 - sum_i ( c(i) / c )^2 )
|
||||||
// = c - sum_i c(i)^2 / c
|
// = c - sum_i c(i)^2 / c
|
||||||
const auto smoothed = counts + counts.constant(1.0f);
|
const auto sum = counts.sum();
|
||||||
const auto sum = smoothed.sum();
|
const auto sum2 = counts.square().sum();
|
||||||
const auto sum2 = smoothed.square().sum();
|
|
||||||
Eigen::Tensor<float, 0, Eigen::RowMajor> ret = sum - (sum2 / sum);
|
Eigen::Tensor<float, 0, Eigen::RowMajor> ret = sum - (sum2 / sum);
|
||||||
return ret(0);
|
return ret(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Given an Eigen::Tensor type, calculate the smoothed Gini impurity, which we
|
||||||
|
// use to determine the best split (lowest) and which nodes to allocate first
|
||||||
|
// (highest).
|
||||||
|
template <typename T>
|
||||||
|
float WeightedGiniImpurity(const T& counts) {
|
||||||
|
const auto smoothed = counts + counts.constant(1.0f);
|
||||||
|
return RawWeightedGiniImpurity(smoothed);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T1, typename T2>
|
template<typename T1, typename T2>
|
||||||
float WeightedVariance(const T1& sums, const T2& squares, float count) {
|
float WeightedVariance(const T1& sums, const T2& squares, float count) {
|
||||||
const auto e_x = sums / count;
|
const auto e_x = sums / count;
|
||||||
|
@ -48,6 +48,7 @@ REGISTER_OP("UpdateFertileSlots")
|
|||||||
.Input("accumulator_sums: float")
|
.Input("accumulator_sums: float")
|
||||||
.Input("node_to_accumulator: int32")
|
.Input("node_to_accumulator: int32")
|
||||||
.Input("stale_leaves: int32")
|
.Input("stale_leaves: int32")
|
||||||
|
.Input("node_sums: float")
|
||||||
.Output("node_to_accumulator_map_updates: int32")
|
.Output("node_to_accumulator_map_updates: int32")
|
||||||
.Output("accumulator_to_node_map_updates: int32")
|
.Output("accumulator_to_node_map_updates: int32")
|
||||||
.Output("accumulators_cleared: int32")
|
.Output("accumulators_cleared: int32")
|
||||||
@ -84,6 +85,8 @@ node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
|
|||||||
fertile node i, or -1 if node i isn't fertile.
|
fertile node i, or -1 if node i isn't fertile.
|
||||||
stale_leaves:= A 1-d int32 tensor containing the indices of all leaves that
|
stale_leaves:= A 1-d int32 tensor containing the indices of all leaves that
|
||||||
have stopped accumulating statistics because they are too old.
|
have stopped accumulating statistics because they are too old.
|
||||||
|
node_sums: `node_sums[n][c]` records how many
|
||||||
|
training examples have class c and have ended up in node n.
|
||||||
node_to_accumulator_map_updates:= A 2-d int32 tensor describing the changes
|
node_to_accumulator_map_updates:= A 2-d int32 tensor describing the changes
|
||||||
that need to be applied to the node_to_accumulator map. Intended to be used
|
that need to be applied to the node_to_accumulator map. Intended to be used
|
||||||
with
|
with
|
||||||
@ -121,6 +124,7 @@ class UpdateFertileSlots : public OpKernel {
|
|||||||
const Tensor& accumulator_sums = context->input(4);
|
const Tensor& accumulator_sums = context->input(4);
|
||||||
const Tensor& node_to_accumulator = context->input(5);
|
const Tensor& node_to_accumulator = context->input(5);
|
||||||
const Tensor& stale_leaves = context->input(6);
|
const Tensor& stale_leaves = context->input(6);
|
||||||
|
const Tensor& node_sums = context->input(7);
|
||||||
|
|
||||||
OP_REQUIRES(context, finished.shape().dims() == 1,
|
OP_REQUIRES(context, finished.shape().dims() == 1,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -204,6 +208,8 @@ class UpdateFertileSlots : public OpKernel {
|
|||||||
non_fertile_leaves, non_fertile_leaf_scores, eot, num_new_leaves,
|
non_fertile_leaves, non_fertile_leaf_scores, eot, num_new_leaves,
|
||||||
static_cast<int32>(accumulator_sums.shape().dim_size(1)), &leaf_heap);
|
static_cast<int32>(accumulator_sums.shape().dim_size(1)), &leaf_heap);
|
||||||
|
|
||||||
|
const auto sums = node_sums.unaligned_flat<float>();
|
||||||
|
const int32 num_columns = node_sums.shape().dim_size(1);
|
||||||
// Allocate leaves.
|
// Allocate leaves.
|
||||||
std::unique_ptr<HeapValuesType> values(
|
std::unique_ptr<HeapValuesType> values(
|
||||||
leaf_heap.Extract());
|
leaf_heap.Extract());
|
||||||
@ -218,6 +224,18 @@ class UpdateFertileSlots : public OpKernel {
|
|||||||
VLOG(1) << "No allocators left.";
|
VLOG(1) << "No allocators left.";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
// For classification, don't make a node fertile until it is unpure.
|
||||||
|
if (!regression_) {
|
||||||
|
// Add 1 here because index 0 contains the sum of the weights across
|
||||||
|
// classes.
|
||||||
|
Eigen::array<int, 1> offsets = {node.first * num_columns + 1};
|
||||||
|
Eigen::array<int, 1> extents = {num_columns - 1};
|
||||||
|
const auto node_counts = sums.slice(offsets, extents);
|
||||||
|
// TODO(thomaswc): Implement a faster check for pure nodes.
|
||||||
|
if (tensorforest::RawWeightedGiniImpurity(node_counts) == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
VLOG(1) << "setting node " << node.first << " to accumulator "
|
VLOG(1) << "setting node " << node.first << " to accumulator "
|
||||||
<< accumulator;
|
<< accumulator;
|
||||||
++num_accumulators_allocated;
|
++num_accumulators_allocated;
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import common_shapes
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import load_library
|
from tensorflow.python.framework import load_library
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
@ -77,7 +78,7 @@ def _ParseSparse(data):
|
|||||||
ValueError: If data contains non-string Tensors.
|
ValueError: If data contains non-string Tensors.
|
||||||
"""
|
"""
|
||||||
for k in sorted(data.keys()):
|
for k in sorted(data.keys()):
|
||||||
if not isinstance(data[k], ops.SparseTensor):
|
if not isinstance(data[k], sparse_tensor.SparseTensor):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Features should be either all sparse or all dense. Use a '
|
'Features should be either all sparse or all dense. Use a '
|
||||||
'feature engineering function to convert some of them.')
|
'feature engineering function to convert some of them.')
|
||||||
@ -133,7 +134,7 @@ def ParseDataTensorOrDict(data):
|
|||||||
# If there's at least one sparse tensor, everything has to be sparse.
|
# If there's at least one sparse tensor, everything has to be sparse.
|
||||||
is_sparse = False
|
is_sparse = False
|
||||||
for v in data.values():
|
for v in data.values():
|
||||||
if isinstance(v, ops.SparseTensor):
|
if isinstance(v, sparse_tensor.SparseTensor):
|
||||||
is_sparse = True
|
is_sparse = True
|
||||||
break
|
break
|
||||||
if is_sparse:
|
if is_sparse:
|
||||||
@ -161,11 +162,11 @@ def ParseLabelTensorOrDict(labels):
|
|||||||
"""
|
"""
|
||||||
if isinstance(labels, dict):
|
if isinstance(labels, dict):
|
||||||
return math_ops.to_float(array_ops.concat(
|
return math_ops.to_float(array_ops.concat(
|
||||||
1, [sparse_ops.sparse_tensor_to_dense(labels[
|
1, [sparse_ops.sparse_tensor_to_dense(labels[k], default_value=-1)
|
||||||
k], default_value=-1) if isinstance(labels, ops.SparseTensor) else
|
if isinstance(labels, sparse_tensor.SparseTensor)
|
||||||
labels[k] for k in sorted(labels.keys())]))
|
else labels[k] for k in sorted(labels.keys())]))
|
||||||
else:
|
else:
|
||||||
if isinstance(labels, ops.SparseTensor):
|
if isinstance(labels, sparse_tensor.SparseTensor):
|
||||||
return math_ops.to_float(sparse_ops.sparse_tensor_to_dense(
|
return math_ops.to_float(sparse_ops.sparse_tensor_to_dense(
|
||||||
labels, default_value=-1))
|
labels, default_value=-1))
|
||||||
else:
|
else:
|
||||||
|
@ -40,6 +40,8 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
|
|||||||
self.total_counts = [[80., 40., 40.]]
|
self.total_counts = [[80., 40., 40.]]
|
||||||
self.ops = training_ops.Load()
|
self.ops = training_ops.Load()
|
||||||
self.stale_leaves = []
|
self.stale_leaves = []
|
||||||
|
self.node_sums = [[3, 1, 2], [4, 2, 2], [5, 2, 3], [6, 1, 5], [7, 5, 2],
|
||||||
|
[8, 4, 4], [9, 7, 2]]
|
||||||
|
|
||||||
def testSimple(self):
|
def testSimple(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -47,7 +49,7 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
|
|||||||
accumulators_allocated) = self.ops.update_fertile_slots(
|
accumulators_allocated) = self.ops.update_fertile_slots(
|
||||||
self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
|
self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
|
||||||
self.end_of_tree, self.total_counts, self.node_map,
|
self.end_of_tree, self.total_counts, self.node_map,
|
||||||
self.stale_leaves)
|
self.stale_leaves, self.node_sums)
|
||||||
|
|
||||||
self.assertAllEqual([[2, 4], [-1, 0]], n2a_map_updates.eval())
|
self.assertAllEqual([[2, 4], [-1, 0]], n2a_map_updates.eval())
|
||||||
self.assertAllEqual([[0], [4]], a2n_map_updates.eval())
|
self.assertAllEqual([[0], [4]], a2n_map_updates.eval())
|
||||||
@ -60,13 +62,27 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
|
|||||||
accumulators_allocated) = self.ops.update_fertile_slots(
|
accumulators_allocated) = self.ops.update_fertile_slots(
|
||||||
[], self.non_fertile_leaves, self.non_fertile_leaf_scores,
|
[], self.non_fertile_leaves, self.non_fertile_leaf_scores,
|
||||||
self.end_of_tree, self.total_counts, self.node_map,
|
self.end_of_tree, self.total_counts, self.node_map,
|
||||||
self.stale_leaves)
|
self.stale_leaves, self.node_sums)
|
||||||
|
|
||||||
self.assertAllEqual((2, 0), n2a_map_updates.eval().shape)
|
self.assertAllEqual((2, 0), n2a_map_updates.eval().shape)
|
||||||
self.assertAllEqual((2, 0), a2n_map_updates.eval().shape)
|
self.assertAllEqual((2, 0), a2n_map_updates.eval().shape)
|
||||||
self.assertAllEqual([], accumulators_cleared.eval())
|
self.assertAllEqual([], accumulators_cleared.eval())
|
||||||
self.assertAllEqual([], accumulators_allocated.eval())
|
self.assertAllEqual([], accumulators_allocated.eval())
|
||||||
|
|
||||||
|
def testPureCounts(self):
|
||||||
|
with self.test_session():
|
||||||
|
self.node_sums[4] = [10, 0, 10]
|
||||||
|
(n2a_map_updates, a2n_map_updates, accumulators_cleared,
|
||||||
|
accumulators_allocated) = self.ops.update_fertile_slots(
|
||||||
|
self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
|
||||||
|
self.end_of_tree, self.total_counts, self.node_map,
|
||||||
|
self.stale_leaves, self.node_sums)
|
||||||
|
|
||||||
|
self.assertAllEqual([[2, 3], [-1, 0]], n2a_map_updates.eval())
|
||||||
|
self.assertAllEqual([[0], [3]], a2n_map_updates.eval())
|
||||||
|
self.assertAllEqual([], accumulators_cleared.eval())
|
||||||
|
self.assertAllEqual([0], accumulators_allocated.eval())
|
||||||
|
|
||||||
def testBadInput(self):
|
def testBadInput(self):
|
||||||
del self.non_fertile_leaf_scores[-1]
|
del self.non_fertile_leaf_scores[-1]
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -76,7 +92,7 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
|
|||||||
(n2a_map_updates, _, _, _) = self.ops.update_fertile_slots(
|
(n2a_map_updates, _, _, _) = self.ops.update_fertile_slots(
|
||||||
self.finished, self.non_fertile_leaves,
|
self.finished, self.non_fertile_leaves,
|
||||||
self.non_fertile_leaf_scores, self.end_of_tree, self.total_counts,
|
self.non_fertile_leaf_scores, self.end_of_tree, self.total_counts,
|
||||||
self.node_map, self.stale_leaves)
|
self.node_map, self.stale_leaves, self.node_sums)
|
||||||
self.assertAllEqual((2, 0), n2a_map_updates.eval().shape)
|
self.assertAllEqual((2, 0), n2a_map_updates.eval().shape)
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.contrib.tensor_forest.python.ops import training_ops
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
@ -629,7 +630,7 @@ class RandomTreeGraphs(object):
|
|||||||
sparse_indices = []
|
sparse_indices = []
|
||||||
sparse_values = []
|
sparse_values = []
|
||||||
sparse_shape = []
|
sparse_shape = []
|
||||||
if isinstance(input_data, ops.SparseTensor):
|
if isinstance(input_data, sparse_tensor.SparseTensor):
|
||||||
sparse_indices = input_data.indices
|
sparse_indices = input_data.indices
|
||||||
sparse_values = input_data.values
|
sparse_values = input_data.values
|
||||||
sparse_shape = input_data.shape
|
sparse_shape = input_data.shape
|
||||||
@ -780,6 +781,7 @@ class RandomTreeGraphs(object):
|
|||||||
self.variables.accumulator_sums,
|
self.variables.accumulator_sums,
|
||||||
self.variables.node_to_accumulator_map,
|
self.variables.node_to_accumulator_map,
|
||||||
stale,
|
stale,
|
||||||
|
self.variables.node_sums,
|
||||||
regression=self.params.regression))
|
regression=self.params.regression))
|
||||||
|
|
||||||
# Ensure end_of_tree doesn't get updated until UpdateFertileSlots has
|
# Ensure end_of_tree doesn't get updated until UpdateFertileSlots has
|
||||||
@ -881,7 +883,7 @@ class RandomTreeGraphs(object):
|
|||||||
sparse_indices = []
|
sparse_indices = []
|
||||||
sparse_values = []
|
sparse_values = []
|
||||||
sparse_shape = []
|
sparse_shape = []
|
||||||
if isinstance(input_data, ops.SparseTensor):
|
if isinstance(input_data, sparse_tensor.SparseTensor):
|
||||||
sparse_indices = input_data.indices
|
sparse_indices = input_data.indices
|
||||||
sparse_values = input_data.values
|
sparse_values = input_data.values
|
||||||
sparse_shape = input_data.shape
|
sparse_shape = input_data.shape
|
||||||
|
@ -15,9 +15,16 @@ py_library(
|
|||||||
"python/training/resample.py",
|
"python/training/resample.py",
|
||||||
"python/training/sampling_ops.py",
|
"python/training/sampling_ops.py",
|
||||||
"python/training/sequence_queueing_state_saver.py",
|
"python/training/sequence_queueing_state_saver.py",
|
||||||
|
"python/training/training.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:ops",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
@ -37,6 +44,7 @@ py_test(
|
|||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["python/training/batch_sequences_with_states_test.py"],
|
srcs = ["python/training/batch_sequences_with_states_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["manual"],
|
||||||
deps = [
|
deps = [
|
||||||
":training_py",
|
":training_py",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
@ -73,7 +81,10 @@ py_test(
|
|||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["python/training/sampling_ops_threading_test.py"],
|
srcs = ["python/training/sampling_ops_threading_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = ["notsan"],
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"notsan",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":training_py",
|
":training_py",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
@ -86,6 +97,20 @@ py_test(
|
|||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["python/training/bucket_ops_test.py"],
|
srcs = ["python/training/bucket_ops_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["manual"],
|
||||||
|
deps = [
|
||||||
|
":training_py",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "training_test",
|
||||||
|
size = "large",
|
||||||
|
srcs = ["python/training/training_test.py"],
|
||||||
|
shard_count = 3,
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":training_py",
|
":training_py",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
|
@ -70,6 +70,11 @@ from tensorflow.contrib.training.python.training.bucket_ops import *
|
|||||||
from tensorflow.contrib.training.python.training.resample import *
|
from tensorflow.contrib.training.python.training.resample import *
|
||||||
from tensorflow.contrib.training.python.training.sampling_ops import *
|
from tensorflow.contrib.training.python.training.sampling_ops import *
|
||||||
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
||||||
|
from tensorflow.contrib.training.python.training.training import add_gradients_summaries
|
||||||
|
from tensorflow.contrib.training.python.training.training import clip_gradient_norms
|
||||||
|
from tensorflow.contrib.training.python.training.training import create_train_op
|
||||||
|
from tensorflow.contrib.training.python.training.training import multiply_gradients
|
||||||
|
from tensorflow.contrib.training.python.training.training import train
|
||||||
from tensorflow.python.util.all_util import make_all
|
from tensorflow.python.util.all_util import make_all
|
||||||
|
|
||||||
__all__ = make_all(__name__)
|
__all__ = make_all(__name__)
|
||||||
|
316
tensorflow/contrib/training/python/training/training.py
Normal file
316
tensorflow/contrib/training/python/training/training.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Contains various routines and helper functions for training models.
|
||||||
|
|
||||||
|
TODO(nsilberman): Port documentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.framework.python.ops import variables
|
||||||
|
from tensorflow.python import summary
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import variables as tf_variables
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import basic_session_run_hooks
|
||||||
|
from tensorflow.python.training import monitored_session
|
||||||
|
from tensorflow.python.training import optimizer as tf_optimizer
|
||||||
|
|
||||||
|
# TODO(nsilberman): move add_gradients_summaries, clip_gradient_norms and
|
||||||
|
# multiply_gradients into contrib/summaries and contrib/optimizers.py
|
||||||
|
__all__ = [
|
||||||
|
'add_gradients_summaries',
|
||||||
|
'clip_gradient_norms',
|
||||||
|
'create_train_op',
|
||||||
|
'multiply_gradients',
|
||||||
|
'train',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_gradients_summaries(grads_and_vars):
|
||||||
|
"""Add summaries to gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of created summaries.
|
||||||
|
"""
|
||||||
|
summaries = []
|
||||||
|
for grad, var in grads_and_vars:
|
||||||
|
if grad is not None:
|
||||||
|
if isinstance(grad, ops.IndexedSlices):
|
||||||
|
grad_values = grad.values
|
||||||
|
else:
|
||||||
|
grad_values = grad
|
||||||
|
summaries.append(summary.histogram_summary(
|
||||||
|
var.op.name + ':gradient', grad_values))
|
||||||
|
summaries.append(summary.histogram_summary(
|
||||||
|
var.op.name + ':gradient_norm', clip_ops.global_norm([grad_values])))
|
||||||
|
else:
|
||||||
|
logging.info('Var %s has no gradient', var.op.name)
|
||||||
|
|
||||||
|
return summaries
|
||||||
|
|
||||||
|
|
||||||
|
def clip_gradient_norms(gradients_to_variables, max_norm):
|
||||||
|
"""Clips the gradients by the given value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gradients_to_variables: A list of gradient to variable pairs (tuples).
|
||||||
|
max_norm: the maximum norm value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of clipped gradient to variable pairs.
|
||||||
|
"""
|
||||||
|
clipped_grads_and_vars = []
|
||||||
|
for grad, var in gradients_to_variables:
|
||||||
|
if grad is not None:
|
||||||
|
if isinstance(grad, ops.IndexedSlices):
|
||||||
|
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
|
||||||
|
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
|
||||||
|
else:
|
||||||
|
grad = clip_ops.clip_by_norm(grad, max_norm)
|
||||||
|
clipped_grads_and_vars.append((grad, var))
|
||||||
|
return clipped_grads_and_vars
|
||||||
|
|
||||||
|
|
||||||
|
def multiply_gradients(grads_and_vars, gradient_multipliers):
|
||||||
|
"""Multiply specified gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||||
|
gradient_multipliers: A map from either `Variables` or `Variable` op names
|
||||||
|
to the coefficient by which the associated gradient should be scaled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated list of gradient to variable pairs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
|
||||||
|
is empty or None or if `gradient_multipliers` is not a dictionary.
|
||||||
|
"""
|
||||||
|
if not isinstance(grads_and_vars, list):
|
||||||
|
raise ValueError('`grads_and_vars` must be a list.')
|
||||||
|
if not gradient_multipliers:
|
||||||
|
raise ValueError('`gradient_multipliers` is empty.')
|
||||||
|
if not isinstance(gradient_multipliers, dict):
|
||||||
|
raise ValueError('`gradient_multipliers` must be a dict.')
|
||||||
|
|
||||||
|
multiplied_grads_and_vars = []
|
||||||
|
for grad, var in grads_and_vars:
|
||||||
|
if var in gradient_multipliers or var.op.name in gradient_multipliers:
|
||||||
|
key = var if var in gradient_multipliers else var.op.name
|
||||||
|
if grad is None:
|
||||||
|
raise ValueError('Requested multiple of `None` gradient.')
|
||||||
|
|
||||||
|
if isinstance(grad, ops.IndexedSlices):
|
||||||
|
tmp = grad.values * constant_op.constant(
|
||||||
|
gradient_multipliers[key], dtype=grad.dtype)
|
||||||
|
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
|
||||||
|
else:
|
||||||
|
grad *= constant_op.constant(
|
||||||
|
gradient_multipliers[key], dtype=grad.dtype)
|
||||||
|
multiplied_grads_and_vars.append((grad, var))
|
||||||
|
return multiplied_grads_and_vars
|
||||||
|
|
||||||
|
|
||||||
|
def create_train_op(total_loss,
|
||||||
|
optimizer,
|
||||||
|
global_step=None,
|
||||||
|
update_ops=None,
|
||||||
|
variables_to_train=None,
|
||||||
|
transform_grads_fn=None,
|
||||||
|
summarize_gradients=False,
|
||||||
|
gate_gradients=tf_optimizer.Optimizer.GATE_OP,
|
||||||
|
aggregation_method=None,
|
||||||
|
colocate_gradients_with_ops=False):
|
||||||
|
"""Creates an `Operation` that evaluates the gradients and returns the loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_loss: A `Tensor` representing the total loss.
|
||||||
|
optimizer: A tf.Optimizer to use for computing the gradients.
|
||||||
|
global_step: A `Tensor` representing the global step variable. If left as
|
||||||
|
`None`, then slim.variables.global_step() is used.
|
||||||
|
update_ops: An optional list of updates to execute. If `update_ops` is
|
||||||
|
`None`, then the update ops are set to the contents of the
|
||||||
|
`tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
|
||||||
|
it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
|
||||||
|
a warning will be displayed.
|
||||||
|
variables_to_train: an optional list of variables to train. If None, it will
|
||||||
|
default to all tf.trainable_variables().
|
||||||
|
transform_grads_fn: A function which takes a single argument, a list of
|
||||||
|
gradient to variable pairs (tuples), performs any requested gradient
|
||||||
|
updates, such as gradient clipping or multipliers, and returns the updated
|
||||||
|
list.
|
||||||
|
summarize_gradients: Whether or not add summaries for each gradient.
|
||||||
|
gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
|
||||||
|
aggregation_method: Specifies the method used to combine gradient terms.
|
||||||
|
Valid values are defined in the class `AggregationMethod`.
|
||||||
|
colocate_gradients_with_ops: Whether or not to try colocating the gradients
|
||||||
|
with the ops that generated them.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that when evaluated, computes the gradients and returns the total
|
||||||
|
loss value.
|
||||||
|
"""
|
||||||
|
if global_step is None:
|
||||||
|
global_step = variables.get_or_create_global_step()
|
||||||
|
|
||||||
|
# Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None.
|
||||||
|
global_update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
|
||||||
|
if update_ops is None:
|
||||||
|
update_ops = global_update_ops
|
||||||
|
else:
|
||||||
|
update_ops = set(update_ops)
|
||||||
|
if not global_update_ops.issubset(update_ops):
|
||||||
|
logging.warning('update_ops in create_train_op does not contain all the '
|
||||||
|
' update_ops in GraphKeys.UPDATE_OPS')
|
||||||
|
|
||||||
|
# Make sure update_ops are computed before total_loss.
|
||||||
|
if update_ops:
|
||||||
|
with ops.control_dependencies(update_ops):
|
||||||
|
barrier = control_flow_ops.no_op(name='update_barrier')
|
||||||
|
total_loss = control_flow_ops.with_dependencies([barrier], total_loss)
|
||||||
|
|
||||||
|
if variables_to_train is None:
|
||||||
|
# Default to tf.trainable_variables()
|
||||||
|
variables_to_train = tf_variables.trainable_variables()
|
||||||
|
else:
|
||||||
|
# Make sure that variables_to_train are in tf.trainable_variables()
|
||||||
|
for v in variables_to_train:
|
||||||
|
assert v in tf_variables.trainable_variables()
|
||||||
|
|
||||||
|
assert variables_to_train
|
||||||
|
|
||||||
|
# Create the gradients. Note that apply_gradients adds the gradient
|
||||||
|
# computation to the current graph.
|
||||||
|
grads = optimizer.compute_gradients(
|
||||||
|
total_loss,
|
||||||
|
variables_to_train,
|
||||||
|
gate_gradients=gate_gradients,
|
||||||
|
aggregation_method=aggregation_method,
|
||||||
|
colocate_gradients_with_ops=colocate_gradients_with_ops)
|
||||||
|
|
||||||
|
if transform_grads_fn:
|
||||||
|
grads = transform_grads_fn(grads)
|
||||||
|
|
||||||
|
# Summarize gradients.
|
||||||
|
if summarize_gradients:
|
||||||
|
with ops.name_scope('summarize_grads'):
|
||||||
|
add_gradients_summaries(grads)
|
||||||
|
|
||||||
|
# Create gradient updates.
|
||||||
|
grad_updates = optimizer.apply_gradients(grads, global_step=global_step)
|
||||||
|
|
||||||
|
with ops.name_scope('train_op'):
|
||||||
|
# Make sure total_loss is valid.
|
||||||
|
total_loss = array_ops.check_numerics(total_loss,
|
||||||
|
'LossTensor is inf or nan')
|
||||||
|
|
||||||
|
# Ensure the train_tensor computes grad_updates.
|
||||||
|
return control_flow_ops.with_dependencies([grad_updates], total_loss)
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
train_op,
|
||||||
|
logdir,
|
||||||
|
master='',
|
||||||
|
is_chief=True,
|
||||||
|
scaffold=None,
|
||||||
|
hooks=None,
|
||||||
|
chief_only_hooks=None,
|
||||||
|
save_checkpoint_secs=600,
|
||||||
|
save_summaries_steps=100,
|
||||||
|
config=None):
|
||||||
|
"""Runs the training loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_op: A `Tensor` that, when executed, will apply the gradients and
|
||||||
|
return the loss value.
|
||||||
|
logdir: The directory where the graph and checkpoints are saved.
|
||||||
|
master: The URL of the master.
|
||||||
|
is_chief: Specifies whether or not the training is being run by the primary
|
||||||
|
replica during replica training.
|
||||||
|
scaffold: An tf.train.Scaffold instance.
|
||||||
|
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
|
||||||
|
training loop.
|
||||||
|
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
|
||||||
|
inside the training loop for the chief trainer only.
|
||||||
|
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
|
||||||
|
using a default checkpoint saver. If `save_checkpoint_secs` is set to
|
||||||
|
`None`, then the default checkpoint saver isn't used.
|
||||||
|
save_summaries_steps: The frequency, in number of global steps, that the
|
||||||
|
summaries are written to disk using a default summary saver. If
|
||||||
|
`save_summaries_steps` is set to `None`, then the default summary saver
|
||||||
|
isn't used.
|
||||||
|
config: An instance of `tf.ConfigProto`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the value of the loss function after training.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
|
||||||
|
`save_summaries_steps` are `None.
|
||||||
|
"""
|
||||||
|
# TODO(nsilberman): move this logic into monitored_session.py
|
||||||
|
scaffold = scaffold or monitored_session.Scaffold()
|
||||||
|
|
||||||
|
hooks = hooks or []
|
||||||
|
|
||||||
|
if is_chief:
|
||||||
|
session_creator = monitored_session.ChiefSessionCreator(
|
||||||
|
scaffold=scaffold,
|
||||||
|
checkpoint_dir=logdir,
|
||||||
|
master=master,
|
||||||
|
config=config)
|
||||||
|
|
||||||
|
if chief_only_hooks:
|
||||||
|
hooks.extend(chief_only_hooks)
|
||||||
|
|
||||||
|
hooks.append(basic_session_run_hooks.StepCounterHook(
|
||||||
|
output_dir=logdir))
|
||||||
|
|
||||||
|
if save_summaries_steps:
|
||||||
|
if logdir is None:
|
||||||
|
raise ValueError(
|
||||||
|
'logdir cannot be None when save_summaries_steps is None')
|
||||||
|
hooks.append(basic_session_run_hooks.SummarySaverHook(
|
||||||
|
scaffold=scaffold,
|
||||||
|
save_steps=save_summaries_steps,
|
||||||
|
output_dir=logdir))
|
||||||
|
|
||||||
|
if save_checkpoint_secs:
|
||||||
|
if logdir is None:
|
||||||
|
raise ValueError(
|
||||||
|
'logdir cannot be None when save_checkpoint_secs is None')
|
||||||
|
hooks.append(basic_session_run_hooks.CheckpointSaverHook(
|
||||||
|
logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
|
||||||
|
else:
|
||||||
|
session_creator = monitored_session.WorkerSessionCreator(
|
||||||
|
scaffold=scaffold, master=master, config=config)
|
||||||
|
|
||||||
|
with monitored_session.MonitoredSession(
|
||||||
|
session_creator=session_creator, hooks=hooks) as session:
|
||||||
|
loss = None
|
||||||
|
while not session.should_stop():
|
||||||
|
loss = session.run(train_op)
|
||||||
|
return loss
|
514
tensorflow/contrib/training/python/training/training_test.py
Normal file
514
tensorflow/contrib/training/python/training/training_test.py
Normal file
@ -0,0 +1,514 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tf.contrib.training.training."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def logistic_classifier(inputs):
|
||||||
|
return tf.contrib.layers.fully_connected(
|
||||||
|
inputs, 1, activation_fn=tf.sigmoid)
|
||||||
|
|
||||||
|
|
||||||
|
def batchnorm_classifier(inputs):
|
||||||
|
inputs = tf.contrib.layers.batch_norm(inputs, decay=0.1)
|
||||||
|
return tf.contrib.layers.fully_connected(inputs, 1, activation_fn=tf.sigmoid)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateTrainOpTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
# Create an easy training set:
|
||||||
|
self._inputs = np.random.rand(16, 4).astype(np.float32)
|
||||||
|
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||||
|
|
||||||
|
def testUseUpdateOps(self):
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
expected_mean = np.mean(self._inputs, axis=(0))
|
||||||
|
expected_var = np.var(self._inputs, axis=(0))
|
||||||
|
|
||||||
|
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||||
|
|
||||||
|
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
|
||||||
|
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||||
|
'moving_variance')[0]
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
# Initialize all variables
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
mean, variance = sess.run([moving_mean, moving_variance])
|
||||||
|
# After initialization moving_mean == 0 and moving_variance == 1.
|
||||||
|
self.assertAllClose(mean, [0] * 4)
|
||||||
|
self.assertAllClose(variance, [1] * 4)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
sess.run([train_op])
|
||||||
|
mean = moving_mean.eval()
|
||||||
|
variance = moving_variance.eval()
|
||||||
|
# After 10 updates with decay 0.1 moving_mean == expected_mean and
|
||||||
|
# moving_variance == expected_var.
|
||||||
|
self.assertAllClose(mean, expected_mean)
|
||||||
|
self.assertAllClose(variance, expected_var)
|
||||||
|
|
||||||
|
def testEmptyUpdateOps(self):
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer, update_ops=[])
|
||||||
|
|
||||||
|
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
|
||||||
|
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||||
|
'moving_variance')[0]
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
# Initialize all variables
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
mean, variance = sess.run([moving_mean, moving_variance])
|
||||||
|
# After initialization moving_mean == 0 and moving_variance == 1.
|
||||||
|
self.assertAllClose(mean, [0] * 4)
|
||||||
|
self.assertAllClose(variance, [1] * 4)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
sess.run([train_op])
|
||||||
|
mean = moving_mean.eval()
|
||||||
|
variance = moving_variance.eval()
|
||||||
|
|
||||||
|
# Since we skip update_ops the moving_vars are not updated.
|
||||||
|
self.assertAllClose(mean, [0] * 4)
|
||||||
|
self.assertAllClose(variance, [1] * 4)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainBNClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Create an easy training set:
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
self._inputs = np.zeros((16, 4))
|
||||||
|
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||||
|
self._logdir = os.path.join(self.get_temp_dir(), 'tmp_bnlogs/')
|
||||||
|
|
||||||
|
for i in range(16):
|
||||||
|
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||||
|
self._inputs[i, j] = 1
|
||||||
|
|
||||||
|
def testTrainWithNoInitAssignCanAchieveZeroLoss(self):
|
||||||
|
g = tf.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer)
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, self._logdir, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=300)
|
||||||
|
])
|
||||||
|
self.assertLess(loss, .1)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Create an easy training set:
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
self._inputs = np.zeros((16, 4))
|
||||||
|
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||||
|
|
||||||
|
for i in range(16):
|
||||||
|
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||||
|
self._inputs[i, j] = 1
|
||||||
|
|
||||||
|
def testCanAchieveZeroLoss(self):
|
||||||
|
logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss')
|
||||||
|
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = logistic_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=300)
|
||||||
|
])
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .015)
|
||||||
|
|
||||||
|
def testTrainWithLocalVariable(self):
|
||||||
|
logdir = os.path.join(self.get_temp_dir(), 'train_with_local_variable')
|
||||||
|
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
local_multiplier = tf.contrib.framework.local_variable(1.0)
|
||||||
|
|
||||||
|
tf_predictions = logistic_classifier(tf_inputs) * local_multiplier
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer)
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=300)
|
||||||
|
])
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .015)
|
||||||
|
|
||||||
|
def testResumeTrainAchievesRoughlyTheSameLoss(self):
|
||||||
|
number_of_steps = [300, 1, 5]
|
||||||
|
logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss')
|
||||||
|
|
||||||
|
for i in range(len(number_of_steps)):
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(i)
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = logistic_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer)
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=number_of_steps[i]),
|
||||||
|
tf.train.CheckpointSaverHook(
|
||||||
|
logdir, save_steps=50, saver=saver),
|
||||||
|
])
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .015)
|
||||||
|
|
||||||
|
def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0):
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = logistic_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
total_loss = tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(
|
||||||
|
learning_rate=learning_rate)
|
||||||
|
|
||||||
|
def transform_grads_fn(grads):
|
||||||
|
if gradient_multiplier != 1.0:
|
||||||
|
variables = tf.trainable_variables()
|
||||||
|
gradient_multipliers = {var: gradient_multiplier for var in variables}
|
||||||
|
|
||||||
|
with tf.name_scope('multiply_grads'):
|
||||||
|
return tf.contrib.training.multiply_gradients(
|
||||||
|
grads, gradient_multipliers)
|
||||||
|
else:
|
||||||
|
return grads
|
||||||
|
|
||||||
|
return tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer, transform_grads_fn=transform_grads_fn)
|
||||||
|
|
||||||
|
def testTrainWithInitFromCheckpoint(self):
|
||||||
|
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
|
||||||
|
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
|
||||||
|
|
||||||
|
if tf.gfile.Exists(logdir1): # For running on jenkins.
|
||||||
|
tf.gfile.DeleteRecursively(logdir1)
|
||||||
|
if tf.gfile.Exists(logdir2): # For running on jenkins.
|
||||||
|
tf.gfile.DeleteRecursively(logdir2)
|
||||||
|
|
||||||
|
# First, train the model one step (make sure the error is high).
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
train_op = self.create_train_op()
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir1, hooks=[
|
||||||
|
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
|
||||||
|
tf.train.StopAtStepHook(num_steps=1),
|
||||||
|
], save_checkpoint_secs=None)
|
||||||
|
self.assertGreater(loss, .5)
|
||||||
|
|
||||||
|
# Next, train the model to convergence.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(1)
|
||||||
|
train_op = self.create_train_op()
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir1, hooks=[
|
||||||
|
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
|
||||||
|
tf.train.StopAtStepHook(num_steps=300),
|
||||||
|
], save_checkpoint_secs=None)
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .02)
|
||||||
|
|
||||||
|
# Finally, advance the model a single step and validate that the loss is
|
||||||
|
# still low.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(2)
|
||||||
|
train_op = self.create_train_op()
|
||||||
|
|
||||||
|
model_variables = tf.all_variables()
|
||||||
|
model_path = os.path.join(logdir1, 'model.ckpt-300')
|
||||||
|
|
||||||
|
assign_fn = tf.contrib.framework.assign_from_checkpoint_fn(
|
||||||
|
model_path, model_variables)
|
||||||
|
def init_fn(_, session):
|
||||||
|
assign_fn(session)
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op,
|
||||||
|
logdir2,
|
||||||
|
scaffold=tf.train.Scaffold(init_fn=init_fn),
|
||||||
|
hooks=[tf.train.StopAtStepHook(num_steps=1)])
|
||||||
|
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .02)
|
||||||
|
|
||||||
|
def ModelLoss(self):
|
||||||
|
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||||
|
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
tf_predictions = logistic_classifier(tf_inputs)
|
||||||
|
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||||
|
return tf.contrib.losses.get_total_loss()
|
||||||
|
|
||||||
|
def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
|
||||||
|
logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
|
||||||
|
if tf.gfile.Exists(logdir): # For running on jenkins.
|
||||||
|
tf.gfile.DeleteRecursively(logdir)
|
||||||
|
|
||||||
|
# First, train only the weights of the model.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
total_loss = self.ModelLoss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
weights = tf.contrib.framework.get_variables_by_name('weights')
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss,
|
||||||
|
optimizer,
|
||||||
|
variables_to_train=weights)
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||||
|
tf.train.StopAtStepHook(num_steps=200),
|
||||||
|
])
|
||||||
|
self.assertGreater(loss, .015)
|
||||||
|
self.assertLess(loss, .05)
|
||||||
|
|
||||||
|
# Next, train the biases of the model.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(1)
|
||||||
|
total_loss = self.ModelLoss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
biases = tf.contrib.framework.get_variables_by_name('biases')
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(
|
||||||
|
total_loss,
|
||||||
|
optimizer,
|
||||||
|
variables_to_train=biases)
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||||
|
tf.train.StopAtStepHook(num_steps=300),
|
||||||
|
])
|
||||||
|
self.assertGreater(loss, .015)
|
||||||
|
self.assertLess(loss, .05)
|
||||||
|
|
||||||
|
# Finally, train both weights and bias to get lower loss.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(2)
|
||||||
|
total_loss = self.ModelLoss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir, hooks=[
|
||||||
|
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||||
|
tf.train.StopAtStepHook(num_steps=400),
|
||||||
|
])
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .015)
|
||||||
|
|
||||||
|
def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
|
||||||
|
# First, train only the weights of the model.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
total_loss = self.ModelLoss()
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||||
|
weights, biases = tf.contrib.framework.get_variables()
|
||||||
|
|
||||||
|
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||||
|
train_weights = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer, variables_to_train=[weights])
|
||||||
|
train_biases = tf.contrib.training.create_train_op(
|
||||||
|
total_loss, optimizer, variables_to_train=[biases])
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
# Initialize the variables.
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
|
||||||
|
# Get the intial weights and biases values.
|
||||||
|
weights_values, biases_values = sess.run([weights, biases])
|
||||||
|
self.assertGreater(np.linalg.norm(weights_values), 0)
|
||||||
|
self.assertAlmostEqual(np.linalg.norm(biases_values), 0)
|
||||||
|
|
||||||
|
# Update weights and biases.
|
||||||
|
loss = sess.run(train_op)
|
||||||
|
self.assertGreater(loss, .5)
|
||||||
|
new_weights, new_biases = sess.run([weights, biases])
|
||||||
|
|
||||||
|
# Check that the weights and biases have been updated.
|
||||||
|
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
|
||||||
|
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
|
||||||
|
|
||||||
|
weights_values, biases_values = new_weights, new_biases
|
||||||
|
|
||||||
|
# Update only weights.
|
||||||
|
loss = sess.run(train_weights)
|
||||||
|
self.assertGreater(loss, .5)
|
||||||
|
new_weights, new_biases = sess.run([weights, biases])
|
||||||
|
|
||||||
|
# Check that the weights have been updated, but biases have not.
|
||||||
|
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
|
||||||
|
self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0)
|
||||||
|
weights_values = new_weights
|
||||||
|
|
||||||
|
# Update only biases.
|
||||||
|
loss = sess.run(train_biases)
|
||||||
|
self.assertGreater(loss, .5)
|
||||||
|
new_weights, new_biases = sess.run([weights, biases])
|
||||||
|
|
||||||
|
# Check that the biases have been updated, but weights have not.
|
||||||
|
self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0)
|
||||||
|
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
|
||||||
|
|
||||||
|
def testTrainWithAlteredGradients(self):
|
||||||
|
# Use the same learning rate but different gradient multipliers
|
||||||
|
# to train two models. Model with equivalently larger learning
|
||||||
|
# rate (i.e., learning_rate * gradient_multiplier) has smaller
|
||||||
|
# training loss.
|
||||||
|
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs6/')
|
||||||
|
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs7/')
|
||||||
|
|
||||||
|
if tf.gfile.Exists(logdir1):
|
||||||
|
tf.gfile.DeleteRecursively(logdir1)
|
||||||
|
if tf.gfile.Exists(logdir2):
|
||||||
|
tf.gfile.DeleteRecursively(logdir2)
|
||||||
|
|
||||||
|
multipliers = [1., 1000.]
|
||||||
|
number_of_steps = 10
|
||||||
|
losses = []
|
||||||
|
learning_rate = 0.001
|
||||||
|
|
||||||
|
# First, train the model with equivalently smaller learning rate.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
train_op = self.create_train_op(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
gradient_multiplier=multipliers[0])
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir1, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=number_of_steps),
|
||||||
|
tf.train.CheckpointSaverHook(logdir1, save_steps=50, saver=saver),
|
||||||
|
])
|
||||||
|
|
||||||
|
losses.append(loss)
|
||||||
|
self.assertGreater(loss, .5)
|
||||||
|
|
||||||
|
# Second, train the model with equivalently larger learning rate.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
train_op = self.create_train_op(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
gradient_multiplier=multipliers[1])
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
|
||||||
|
loss = tf.contrib.training.train(
|
||||||
|
train_op, logdir2, hooks=[
|
||||||
|
tf.train.StopAtStepHook(num_steps=number_of_steps),
|
||||||
|
tf.train.CheckpointSaverHook(logdir2, save_steps=50, saver=saver),
|
||||||
|
])
|
||||||
|
|
||||||
|
losses.append(loss)
|
||||||
|
self.assertIsNotNone(loss)
|
||||||
|
self.assertLess(loss, .5)
|
||||||
|
|
||||||
|
# The loss of the model trained with larger learning rate should
|
||||||
|
# be smaller.
|
||||||
|
self.assertGreater(losses[0], losses[1])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user