Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
1f094bccdb
11
.mention-bot
Normal file
11
.mention-bot
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"maxReviewers": 2,
|
||||
"numFilesToCheck": 10, // Number of files to check against, default is 5
|
||||
"userBlacklist": ["tensorflower-gardener"], // users in this list will never be mentioned by mention-bot
|
||||
"requiredOrgs": ["tensorflow"], // mention-bot will only mention user who are a member of one of these organizations
|
||||
"skipAlreadyAssignedPR": true, // mention-bot will ignore already assigned PR's
|
||||
"skipAlreadyMentionedPR": true, // mention-bot will ignore if there is already existing an exact mention
|
||||
"skipTitle": "Branch", // mention-bot will ignore PR that includes text in the title,
|
||||
"delayed": true, // mention-bot will wait to comment until specified time in `delayedUntil` value
|
||||
"delayedUntil": "10m",
|
||||
}
|
@ -33,10 +33,10 @@ and discussion.**
|
||||
|
||||
People who are a little more adventurous can also try our nightly binaries:
|
||||
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/))
|
||||
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/))
|
||||
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||
* [Android](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
|
||||
|
||||
#### *Try your first TensorFlow program*
|
||||
|
@ -10,6 +10,7 @@ BUS_ANY was used.
|
||||
|
||||
## Major Features and Improvements
|
||||
|
||||
* CUDA 8 support.
|
||||
* cuDNN 5 support.
|
||||
* HDFS Support.
|
||||
* Adds Fused LSTM support via cuDNN 5 in `tensorflow/contrib/cudnn_rnn`.
|
||||
|
19
WORKSPACE
19
WORKSPACE
@ -153,8 +153,8 @@ new_http_archive(
|
||||
new_http_archive(
|
||||
name = "iron_iconset_svg",
|
||||
build_file = "bower.BUILD",
|
||||
url = "https://github.com/polymerelements/iron-iconset-svg/archive/v1.0.10.tar.gz",
|
||||
strip_prefix = "iron-iconset-svg-1.0.10",
|
||||
url = "https://github.com/polymerelements/iron-iconset-svg/archive/v1.1.0.tar.gz",
|
||||
strip_prefix = "iron-iconset-svg-1.1.0",
|
||||
)
|
||||
|
||||
new_http_archive(
|
||||
@ -188,8 +188,8 @@ new_http_archive(
|
||||
new_http_archive(
|
||||
name = "iron_overlay_behavior",
|
||||
build_file = "bower.BUILD",
|
||||
url = "https://github.com/polymerelements/iron-overlay-behavior/archive/v1.9.0.tar.gz",
|
||||
strip_prefix = "iron-overlay-behavior-1.9.0",
|
||||
url = "https://github.com/polymerelements/iron-overlay-behavior/archive/v1.10.1.tar.gz",
|
||||
strip_prefix = "iron-overlay-behavior-1.10.1",
|
||||
)
|
||||
|
||||
new_http_archive(
|
||||
@ -206,6 +206,13 @@ new_http_archive(
|
||||
strip_prefix = "iron-resizable-behavior-1.0.3",
|
||||
)
|
||||
|
||||
new_http_archive(
|
||||
name = "iron_scroll_target_behavior",
|
||||
build_file = "bower.BUILD",
|
||||
url = "https://github.com/polymerelements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz",
|
||||
strip_prefix = "iron-scroll-target-behavior-1.0.3",
|
||||
)
|
||||
|
||||
new_http_archive(
|
||||
name = "iron_selector",
|
||||
build_file = "bower.BUILD",
|
||||
@ -291,8 +298,8 @@ new_http_archive(
|
||||
new_http_archive(
|
||||
name = "paper_icon_button",
|
||||
build_file = "bower.BUILD",
|
||||
url = "https://github.com/polymerelements/paper-icon-button/archive/v1.1.2.tar.gz",
|
||||
strip_prefix = "paper-icon-button-1.1.2",
|
||||
url = "https://github.com/polymerelements/paper-icon-button/archive/v1.1.3.tar.gz",
|
||||
strip_prefix = "paper-icon-button-1.1.3",
|
||||
)
|
||||
|
||||
new_http_archive(
|
||||
|
@ -209,6 +209,7 @@ filegroup(
|
||||
name = "iron_overlay_behavior",
|
||||
srcs = [
|
||||
"index.html",
|
||||
"iron-focusables-helper.html",
|
||||
"iron-overlay-backdrop.html",
|
||||
"iron-overlay-behavior.html",
|
||||
"iron-overlay-manager.html",
|
||||
@ -232,6 +233,14 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "iron_scroll_target_behavior",
|
||||
srcs = [
|
||||
"index.html",
|
||||
"iron-scroll-target-behavior.html",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "iron_selector",
|
||||
srcs = [
|
||||
|
@ -62,8 +62,6 @@ cc_library(
|
||||
# This define (mostly) guarantees we don't link any problematic
|
||||
# code. We use it, but we do not rely on it, as evidenced above.
|
||||
"EIGEN_MPL2_ONLY",
|
||||
# TODO(jart): Use EIGEN_USE_NONBLOCKING_THREAD_POOL but first add an
|
||||
# eigen_initialize.cc file and alwayslink=1.
|
||||
],
|
||||
includes = ["."],
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -105,6 +105,7 @@ filegroup(
|
||||
"//tensorflow/contrib/framework:all_files",
|
||||
"//tensorflow/contrib/graph_editor:all_files",
|
||||
"//tensorflow/contrib/grid_rnn:all_files",
|
||||
"//tensorflow/contrib/integrate:all_files",
|
||||
"//tensorflow/contrib/layers:all_files",
|
||||
"//tensorflow/contrib/layers/kernels:all_files",
|
||||
"//tensorflow/contrib/learn:all_files",
|
||||
@ -148,7 +149,6 @@ filegroup(
|
||||
"//tensorflow/examples/image_retraining:all_files",
|
||||
"//tensorflow/examples/label_image:all_files",
|
||||
"//tensorflow/examples/learn:all_files",
|
||||
"//tensorflow/examples/skflow:all_files",
|
||||
"//tensorflow/examples/tutorials/estimators:all_files",
|
||||
"//tensorflow/examples/tutorials/mnist:all_files",
|
||||
"//tensorflow/examples/tutorials/word2vec:all_files",
|
||||
|
@ -264,6 +264,36 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nn_grad",
|
||||
srcs = ["gradients/nn_grad.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gradients_nn_grad_test",
|
||||
srcs = ["gradients/nn_grad_test.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":grad_testutil",
|
||||
":gradient_checker",
|
||||
":nn_grad",
|
||||
":testutil",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "cc_ops",
|
||||
op_lib_names = [
|
||||
@ -411,6 +441,7 @@ cc_library(
|
||||
srcs = ["training/queue_runner.cc"],
|
||||
hdrs = ["training/queue_runner.h"],
|
||||
deps = [
|
||||
":coordinator",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -425,6 +456,7 @@ tf_cc_test(
|
||||
name = "queue_runner_test",
|
||||
srcs = ["training/queue_runner_test.cc"],
|
||||
deps = [
|
||||
"coordinator",
|
||||
":cc_ops",
|
||||
":queue_runner",
|
||||
":scope",
|
||||
@ -439,3 +471,37 @@ tf_cc_test(
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "coordinator",
|
||||
srcs = ["training/coordinator.cc"],
|
||||
hdrs = ["training/coordinator.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "coordinator_test",
|
||||
srcs = ["training/coordinator_test.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":coordinator",
|
||||
":queue_runner",
|
||||
":scope",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
@ -110,20 +110,15 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const ops::Output& x,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error) {
|
||||
Status ComputeGradientErrorInternal(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape,
|
||||
const ops::Output& y,
|
||||
const TensorShape& y_shape, Tensor* x_data,
|
||||
T* max_error) {
|
||||
const int64 x_size = x_shape.num_elements();
|
||||
const int64 y_size = y_shape.num_elements();
|
||||
|
||||
// Initialize 'x_data' to random values.
|
||||
Tensor x_data(x.type(), x_shape);
|
||||
auto x_data_flat = x_data.flat<T>();
|
||||
x_data_flat.setRandom();
|
||||
|
||||
// Initialize theoretical Jacobian to zeros.
|
||||
Tensor jacobian_t(x.type(), {x_size, y_size});
|
||||
auto jacobian_t_flat = jacobian_t.flat<T>();
|
||||
@ -131,7 +126,7 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
|
||||
// Compute theoretical Jacobian.
|
||||
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
|
||||
scope, x, x_shape, x_data, y, y_shape, &jacobian_t));
|
||||
scope, x, x_shape, *x_data, y, y_shape, &jacobian_t));
|
||||
|
||||
// Initialize numeric Jacobian to zeros.
|
||||
Tensor jacobian_n(x.type(), {x_size, y_size});
|
||||
@ -140,7 +135,7 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
|
||||
// Compute numeric Jacobian.
|
||||
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
|
||||
scope, x, x_shape, y, y_shape, 1e-3, &x_data, &jacobian_n));
|
||||
scope, x, x_shape, y, y_shape, 1e-3, x_data, &jacobian_n));
|
||||
|
||||
// Compute the maximum error between theoretical and numeric Jacobians.
|
||||
*max_error = 0.0;
|
||||
@ -154,10 +149,39 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error) {
|
||||
// Initialize 'x_data' to random values.
|
||||
Tensor x_data(x.type(), x_shape);
|
||||
auto x_data_flat = x_data.flat<T>();
|
||||
x_data_flat.setRandom();
|
||||
// Compute gradient error.
|
||||
return ComputeGradientErrorInternal(scope, x, x_shape, y, y_shape, &x_data,
|
||||
max_error);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const Tensor& x_init_value, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error) {
|
||||
// Initialize 'x_data' from 'x_init_value'.
|
||||
Tensor x_data(x_init_value);
|
||||
// Compute gradient error.
|
||||
return ComputeGradientErrorInternal(scope, x, x_data.shape(), y, y_shape,
|
||||
&x_data, max_error);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
|
||||
template Status ComputeGradientError<T>( \
|
||||
const Scope& scope, const ops::Output& x, const TensorShape& x_shape, \
|
||||
const ops::Output& y, const TensorShape& y_shape, T* max_error)
|
||||
const ops::Output& y, const TensorShape& y_shape, T* max_error); \
|
||||
template Status ComputeGradientError<T>( \
|
||||
const Scope& scope, const ops::Output& x, const Tensor& x_init_value, \
|
||||
const ops::Output& y, const TensorShape& y_shape, T* max_error);
|
||||
|
||||
INSTANTIATE_GRAD_ERR_TYPE(float);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(double);
|
||||
|
@ -30,6 +30,12 @@ Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error);
|
||||
|
||||
// Overload of ComputeGradientError which takes an initial value for 'x'.
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const Tensor& x_init_value, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
||||
|
77
tensorflow/cc/gradients/nn_grad.cc
Normal file
77
tensorflow/cc/gradients/nn_grad.cc
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
namespace {
|
||||
|
||||
Status SoftmaxGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
// Softmax gradient function.
|
||||
// p = softmax(x) maps from [batch, n] to [batch, m]
|
||||
// dp/dx = [dp0/dx0 ... dp0/dxn-1 ]
|
||||
// [ ... ... ]
|
||||
// [dpm-1/dx0 ... dpm-1/dxn-1]
|
||||
// dL/dx = dp/dx * dL/dy
|
||||
//
|
||||
// Using alternative formula:
|
||||
// dL/dx = dL/dy * y - sum(dL/dy * y) * y
|
||||
// = (dL/dy - sum(dL/dy * y)) * y
|
||||
auto y = op.output(0);
|
||||
auto dyy = Mul(scope, grad_inputs[0], y);
|
||||
auto sum = Reshape(scope, Sum(scope, dyy, {1}), {-1, 1});
|
||||
auto sub = Sub(scope, grad_inputs[0], sum);
|
||||
auto dx = Mul(scope, sub, y);
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
|
||||
|
||||
Status ReluGradHelper(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto dx = ReluGrad(scope, grad_inputs[0], op.input(0));
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Relu", ReluGradHelper);
|
||||
|
||||
Status Relu6GradHelper(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto dx = Relu6Grad(scope, grad_inputs[0], op.input(0));
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
|
||||
|
||||
Status EluGradHelper(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto dx = EluGrad(scope, grad_inputs[0], op.output(0));
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Elu", EluGradHelper);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
91
tensorflow/cc/gradients/nn_grad_test.cc
Normal file
91
tensorflow/cc/gradients/nn_grad_test.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||
#include "tensorflow/cc/framework/testutil.h"
|
||||
#include "tensorflow/cc/gradients/grad_testutil.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
namespace {
|
||||
|
||||
class NNGradTest : public ::testing::Test {
|
||||
protected:
|
||||
NNGradTest() : scope_(Scope::NewRootScope()) {}
|
||||
|
||||
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
|
||||
const TensorShape& y_shape) {
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, x, x_shape, y, y_shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
|
||||
const TensorShape& y_shape) {
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
Scope scope_;
|
||||
};
|
||||
|
||||
TEST_F(NNGradTest, SoftmaxGrad) {
|
||||
TensorShape shape({32, 10});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Softmax(scope_, x);
|
||||
RunTest(x, shape, y, shape);
|
||||
}
|
||||
|
||||
TEST_F(NNGradTest, ReluGrad) {
|
||||
TensorShape shape({5, 2});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Relu(scope_, x);
|
||||
// Avoid input values where ReLU gradient is not well defined (around zero).
|
||||
Tensor x_init_value = test::AsTensor<float>(
|
||||
{-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9}, {5, 2});
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
TEST_F(NNGradTest, Relu6Grad) {
|
||||
TensorShape shape({5, 2});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Relu6(scope_, x);
|
||||
// Avoid input values where ReLU gradient is not well defined (around zero
|
||||
// and six).
|
||||
Tensor x_init_value = test::AsTensor<float>(
|
||||
{-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9}, {5, 2});
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
TEST_F(NNGradTest, EluGrad) {
|
||||
TensorShape shape({5, 2});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Elu(scope_, x);
|
||||
Tensor x_init_value = test::AsTensor<float>(
|
||||
{-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9}, {5, 2});
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
90
tensorflow/cc/training/coordinator.cc
Normal file
90
tensorflow/cc/training/coordinator.cc
Normal file
@ -0,0 +1,90 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/training/coordinator.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Coordinator::Coordinator() : Coordinator(std::vector<error::Code>()) {}
|
||||
|
||||
Coordinator::Coordinator(const std::vector<error::Code>& clean_stop_errors)
|
||||
: should_stop_(false) {
|
||||
if (clean_stop_errors.empty()) {
|
||||
clean_stop_errors_.insert(error::OUT_OF_RANGE);
|
||||
} else {
|
||||
for (const auto& code : clean_stop_errors) {
|
||||
clean_stop_errors_.insert(static_cast<int>(code));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Coordinator::~Coordinator() {
|
||||
RequestStop();
|
||||
Join();
|
||||
}
|
||||
|
||||
Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) {
|
||||
runners_.push_back(std::move(runner));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Coordinator::RequestStop() {
|
||||
mutex_lock l(mu_);
|
||||
if (should_stop_) {
|
||||
return Status(error::FAILED_PRECONDITION,
|
||||
"The Coordinator is not running.");
|
||||
}
|
||||
should_stop_ = true;
|
||||
wait_for_stop_.notify_all();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool Coordinator::ShouldStop() {
|
||||
mutex_lock l(mu_);
|
||||
return should_stop_;
|
||||
}
|
||||
|
||||
Status Coordinator::Join() {
|
||||
// TODO(yuefengz): deal with unexpected calls to Join().
|
||||
// TODO(yuefengz): deal with stragglers.
|
||||
for (const auto& t : runners_) {
|
||||
ReportStatus(t->Join());
|
||||
}
|
||||
runners_.clear();
|
||||
return status_;
|
||||
}
|
||||
|
||||
void Coordinator::ReportStatus(const Status& status) {
|
||||
mutex_lock l(status_lock_);
|
||||
if (status.ok() || !status_.ok() ||
|
||||
clean_stop_errors_.count(static_cast<int>(status.code())) > 0) {
|
||||
return;
|
||||
}
|
||||
status_ = status;
|
||||
}
|
||||
|
||||
Status Coordinator::GetStatus() {
|
||||
mutex_lock l(status_lock_);
|
||||
return status_;
|
||||
}
|
||||
|
||||
void Coordinator::WaitForStop() {
|
||||
mutex_lock l(mu_);
|
||||
while (!should_stop_) {
|
||||
wait_for_stop_.wait(l);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
109
tensorflow/cc/training/coordinator.h
Normal file
109
tensorflow/cc/training/coordinator.h
Normal file
@ -0,0 +1,109 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The abstract interface for runners which must implement the Join function.
|
||||
class RunnerInterface {
|
||||
public:
|
||||
virtual ~RunnerInterface() {}
|
||||
virtual Status Join() = 0;
|
||||
};
|
||||
|
||||
// Coordinator class manages the termination of a collection of QueueRunners.
|
||||
// Without a coordinator, QueueRunners have to be joined in a specific order;
|
||||
// otherwise the QueueRunner::Join() could sometimes hang. The
|
||||
// Coordinator::RequestStop() plays the key role which notifies all running
|
||||
// threads under a coordinator to stop. This function could be called by any
|
||||
// thread or any client.
|
||||
// Usage, in the client:
|
||||
// Coordinator coord;
|
||||
// std::unique_ptr<QueueRunner> qr(&coord, ...);
|
||||
// qr.Start(session);
|
||||
// coord.RegisterRunner(std::move(qr));
|
||||
// // do some work
|
||||
// TF_CHECK_OK(coord.Join());
|
||||
// In each thread of QueueRunner, the coordinator needs to be used as:
|
||||
// void Run() {
|
||||
// while (!coord->ShouldStop()) {
|
||||
// // do some work
|
||||
// if (error) {
|
||||
// coord->RequestStop();
|
||||
// coord->ReportStatus(error_status);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class Coordinator {
|
||||
public:
|
||||
Coordinator();
|
||||
|
||||
// Constructor with a list of error codes which would not be taken as errors
|
||||
// in status reporting.
|
||||
Coordinator(const std::vector<error::Code>& clean_stop_errors);
|
||||
|
||||
// In the destructor, RequestStop() and Join() would be called.
|
||||
~Coordinator();
|
||||
|
||||
// Registers a runner, i.e. a unit of running threads which is usually a
|
||||
// QueueRunner. It takes the ownership of runner to avoid lifecycle-related
|
||||
// problems. Note, the coordinator would not start these threads; they are
|
||||
// supposed to be in running state when they are registered here.
|
||||
Status RegisterRunner(std::unique_ptr<RunnerInterface> runner);
|
||||
|
||||
// Requests all running threads to stop.
|
||||
Status RequestStop();
|
||||
|
||||
// Returns true if its RequestStop() has been called.
|
||||
bool ShouldStop();
|
||||
|
||||
// Joins all threads, returns OK or the first reported and unexpected status.
|
||||
Status Join();
|
||||
|
||||
// Reports status to the coordinator. This is usually called by threads.
|
||||
void ReportStatus(const Status& status);
|
||||
|
||||
// Returns the latest status.
|
||||
Status GetStatus();
|
||||
|
||||
// Returns immediately if the coordinator is stopped or blocks until
|
||||
// RequestStop() is called.
|
||||
void WaitForStop();
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<RunnerInterface>> runners_;
|
||||
std::unordered_set<int> clean_stop_errors_;
|
||||
mutex mu_;
|
||||
bool should_stop_ GUARDED_BY(mu_);
|
||||
mutex status_lock_;
|
||||
Status status_;
|
||||
condition_variable wait_for_stop_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
|
183
tensorflow/cc/training/coordinator_test.cc
Normal file
183
tensorflow/cc/training/coordinator_test.cc
Normal file
@ -0,0 +1,183 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/training/coordinator.h"
|
||||
|
||||
#include "tensorflow/cc/training/queue_runner.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using error::Code;
|
||||
|
||||
void WaitForStopThread(Coordinator* coord, bool* stopped, Notification* done) {
|
||||
coord->WaitForStop();
|
||||
*stopped = true;
|
||||
done->Notify();
|
||||
}
|
||||
|
||||
TEST(CoordinatorTest, TestStopAndWaitOnStop) {
|
||||
Coordinator coord;
|
||||
EXPECT_EQ(coord.ShouldStop(), false);
|
||||
|
||||
bool stopped = false;
|
||||
Notification done;
|
||||
Env::Default()->SchedClosure(
|
||||
std::bind(&WaitForStopThread, &coord, &stopped, &done));
|
||||
Env::Default()->SleepForMicroseconds(10000000);
|
||||
EXPECT_EQ(stopped, false);
|
||||
|
||||
coord.RequestStop();
|
||||
done.WaitForNotification();
|
||||
EXPECT_EQ(stopped, true);
|
||||
EXPECT_EQ(coord.ShouldStop(), true);
|
||||
}
|
||||
|
||||
class MockQueueRunner : public RunnerInterface {
|
||||
public:
|
||||
MockQueueRunner(Coordinator* coord) {
|
||||
coord_ = coord;
|
||||
join_counter_ = nullptr;
|
||||
thread_pool_.reset(new thread::ThreadPool(Env::Default(), "test-pool", 10));
|
||||
}
|
||||
|
||||
MockQueueRunner(Coordinator* coord, int* join_counter)
|
||||
: MockQueueRunner(coord) {
|
||||
join_counter_ = join_counter;
|
||||
}
|
||||
|
||||
void StartCounting(std::atomic<int>* counter, int until) {
|
||||
thread_pool_->Schedule(
|
||||
std::bind(&MockQueueRunner::CountThread, this, counter, until));
|
||||
}
|
||||
|
||||
void StartSettingStatus(const Status& status, BlockingCounter* counter) {
|
||||
thread_pool_->Schedule(
|
||||
std::bind(&MockQueueRunner::SetStatusThread, this, status, counter));
|
||||
}
|
||||
|
||||
Status Join() {
|
||||
if (join_counter_ != nullptr) {
|
||||
(*join_counter_)++;
|
||||
}
|
||||
thread_pool_.reset();
|
||||
return status_;
|
||||
}
|
||||
|
||||
Status GetStatus() { return status_; }
|
||||
|
||||
void SetStatus(const Status& status) { status_ = status; }
|
||||
|
||||
private:
|
||||
void CountThread(std::atomic<int>* counter, int until) {
|
||||
while (!coord_->ShouldStop() && counter->load() < until) {
|
||||
(*counter)++;
|
||||
Env::Default()->SleepForMicroseconds(100000);
|
||||
}
|
||||
coord_->RequestStop();
|
||||
}
|
||||
void SetStatusThread(const Status& status, BlockingCounter* counter) {
|
||||
Env::Default()->SleepForMicroseconds(100000);
|
||||
SetStatus(status);
|
||||
counter->DecrementCount();
|
||||
}
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
Status status_;
|
||||
Coordinator* coord_;
|
||||
int* join_counter_;
|
||||
};
|
||||
|
||||
TEST(CoordinatorTest, TestRealStop) {
|
||||
std::atomic<int> counter(0);
|
||||
Coordinator coord;
|
||||
|
||||
std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
|
||||
qr1->StartCounting(&counter, 100);
|
||||
coord.RegisterRunner(std::move(qr1));
|
||||
|
||||
std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
|
||||
qr2->StartCounting(&counter, 100);
|
||||
coord.RegisterRunner(std::move(qr2));
|
||||
|
||||
// Wait until the counting has started
|
||||
while (counter.load() == 0)
|
||||
;
|
||||
coord.RequestStop();
|
||||
|
||||
int temp_counter = counter.load();
|
||||
Env::Default()->SleepForMicroseconds(10000000);
|
||||
EXPECT_EQ(temp_counter, counter.load());
|
||||
TF_EXPECT_OK(coord.Join());
|
||||
}
|
||||
|
||||
TEST(CoordinatorTest, TestRequestStop) {
|
||||
Coordinator coord;
|
||||
std::atomic<int> counter(0);
|
||||
std::unique_ptr<MockQueueRunner> qr;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
qr.reset(new MockQueueRunner(&coord));
|
||||
qr->StartCounting(&counter, 10);
|
||||
coord.RegisterRunner(std::move(qr));
|
||||
}
|
||||
|
||||
coord.WaitForStop();
|
||||
EXPECT_EQ(coord.ShouldStop(), true);
|
||||
EXPECT_EQ(counter.load(), 10);
|
||||
TF_EXPECT_OK(coord.Join());
|
||||
}
|
||||
|
||||
TEST(CoordinatorTest, TestJoin) {
|
||||
Coordinator coord;
|
||||
int join_counter = 0;
|
||||
std::unique_ptr<MockQueueRunner> qr1(
|
||||
new MockQueueRunner(&coord, &join_counter));
|
||||
coord.RegisterRunner(std::move(qr1));
|
||||
std::unique_ptr<MockQueueRunner> qr2(
|
||||
new MockQueueRunner(&coord, &join_counter));
|
||||
coord.RegisterRunner(std::move(qr2));
|
||||
|
||||
TF_EXPECT_OK(coord.Join());
|
||||
EXPECT_EQ(join_counter, 2);
|
||||
}
|
||||
|
||||
TEST(CoordinatorTest, StatusReporting) {
|
||||
Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE});
|
||||
BlockingCounter counter(3);
|
||||
|
||||
std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
|
||||
qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter);
|
||||
coord.RegisterRunner(std::move(qr1));
|
||||
|
||||
std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
|
||||
qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter);
|
||||
coord.RegisterRunner(std::move(qr2));
|
||||
|
||||
std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord));
|
||||
qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter);
|
||||
coord.RegisterRunner(std::move(qr3));
|
||||
|
||||
counter.Wait();
|
||||
EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -25,6 +25,14 @@ Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
|
||||
return (*result)->Init(queue_runner_def);
|
||||
}
|
||||
|
||||
Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
|
||||
Coordinator* coord,
|
||||
std::unique_ptr<QueueRunner>* result) {
|
||||
result->reset(new QueueRunner());
|
||||
(*result)->coord_ = coord;
|
||||
return (*result)->Init(queue_runner_def);
|
||||
}
|
||||
|
||||
Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
|
||||
queue_name_ = queue_runner_def.queue_name();
|
||||
enqueue_op_names_.clear();
|
||||
@ -46,8 +54,8 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
|
||||
}
|
||||
|
||||
thread_pool_.reset(new thread::ThreadPool(
|
||||
Env::Default(), SanitizeThreadSuffix(queue_name_), runs_));
|
||||
should_stop_ = false;
|
||||
Env::Default(), SanitizeThreadSuffix(queue_name_), runs_ + 1));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -57,63 +65,108 @@ QueueRunner::~QueueRunner() {
|
||||
Join();
|
||||
}
|
||||
|
||||
Status QueueRunner::Start(Session* sess) {
|
||||
Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
|
||||
|
||||
Status QueueRunner::Start(Session* sess, int wait_for) {
|
||||
counter_.reset(new BlockingCounter(runs_));
|
||||
for (const string& enqueue_op : enqueue_op_names_) {
|
||||
thread_pool_->Schedule(
|
||||
std::bind(&QueueRunner::Run, this, sess, enqueue_op));
|
||||
}
|
||||
if (coord_) {
|
||||
thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess));
|
||||
}
|
||||
// Wait for up to 'wait_for' milliseconds.
|
||||
if (wait_for > 0) {
|
||||
if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) {
|
||||
return Status(error::DEADLINE_EXCEEDED,
|
||||
"Queues not fed before the timeout");
|
||||
}
|
||||
// Check the status of the queue runner as well as the result of the enqueue
|
||||
// operations.
|
||||
mutex_lock l(mu_);
|
||||
if (!enqueue_status_.ok()) {
|
||||
return enqueue_status_;
|
||||
} else {
|
||||
return status_;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QueueRunner::Stop(Session* sess) {
|
||||
should_stop_ = true;
|
||||
void QueueRunner::Stop(Session* sess) {
|
||||
if (cancel_op_name_.empty()) {
|
||||
return Status::OK();
|
||||
return;
|
||||
} else {
|
||||
return sess->Run({}, {}, {cancel_op_name_}, nullptr);
|
||||
CHECK(coord_ != nullptr);
|
||||
coord_->WaitForStop();
|
||||
UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr));
|
||||
}
|
||||
}
|
||||
|
||||
Status QueueRunner::Join() {
|
||||
thread_pool_.reset();
|
||||
mutex_lock l(mu_);
|
||||
return status_;
|
||||
}
|
||||
|
||||
void QueueRunner::UpdateStatus(const Status& status) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok() || status.ok() ||
|
||||
queue_closed_exception_types_.count(static_cast<int>(status.code())) >
|
||||
0) {
|
||||
return;
|
||||
}
|
||||
status_ = status;
|
||||
}
|
||||
if (coord_) {
|
||||
coord_->ReportStatus(status);
|
||||
}
|
||||
}
|
||||
|
||||
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
|
||||
bool decremented = false;
|
||||
while (!should_stop_.load()) {
|
||||
bool first_iteration = true;
|
||||
while (true) {
|
||||
if (coord_ && coord_->ShouldStop()) {
|
||||
break;
|
||||
}
|
||||
auto status = sess->Run({}, {}, {enqueue_op}, nullptr);
|
||||
if (first_iteration) {
|
||||
if (!status.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
enqueue_status_ = status;
|
||||
}
|
||||
counter_->DecrementCount();
|
||||
first_iteration = false;
|
||||
}
|
||||
if (status.ok()) {
|
||||
continue;
|
||||
} else if (queue_closed_exception_types_.count(
|
||||
static_cast<int>(status.code())) > 0) {
|
||||
mutex_lock l(mu_);
|
||||
runs_--;
|
||||
decremented = true;
|
||||
should_stop_ = true;
|
||||
|
||||
// If all enqueue ops have finished, run the close op.
|
||||
if (runs_ == 0 && !close_op_name_.empty()) {
|
||||
auto s = sess->Run({}, {}, {close_op_name_}, nullptr);
|
||||
if (!s.ok() && status_.ok() &&
|
||||
queue_closed_exception_types_.count(static_cast<int>(s.code())) ==
|
||||
0) {
|
||||
status_ = s;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
should_stop_ = true;
|
||||
// Only record the first failure status.
|
||||
if (status_.ok()) {
|
||||
status_ = status;
|
||||
}
|
||||
runs_--;
|
||||
decremented = true;
|
||||
}
|
||||
// Stop the queue runner immediately to propagate the error to
|
||||
// subsequent queues.
|
||||
Stop(sess);
|
||||
|
||||
// If all enqueue ops have finished, run the close op.
|
||||
if (runs_ == 0) {
|
||||
if (!close_op_name_.empty()) {
|
||||
auto s = sess->Run({}, {}, {close_op_name_}, nullptr);
|
||||
UpdateStatus(status);
|
||||
}
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
UpdateStatus(status);
|
||||
if (coord_) {
|
||||
coord_->RequestStop();
|
||||
}
|
||||
break;
|
||||
}
|
||||
first_iteration = false;
|
||||
}
|
||||
|
||||
if (!decremented) {
|
||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/training/coordinator.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
@ -32,7 +34,7 @@ namespace tensorflow {
|
||||
|
||||
// QueueRunner class imitates the behavior of the python version of QueueRunner
|
||||
// which creates a thread for each enqueue op, runs close op on completion.
|
||||
class QueueRunner {
|
||||
class QueueRunner : public RunnerInterface {
|
||||
public:
|
||||
// Creates a new QueueRunner from proto.
|
||||
// TODO(yuefengz): we may want to initialize from queues and ops in the
|
||||
@ -40,24 +42,29 @@ class QueueRunner {
|
||||
static Status New(const QueueRunnerDef& queue_runner_def,
|
||||
std::unique_ptr<QueueRunner>* result);
|
||||
|
||||
// Creates a new QueueRunner with a coordinator, see coordinator.h for usage.
|
||||
static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord,
|
||||
std::unique_ptr<QueueRunner>* result);
|
||||
|
||||
// The destructor would join all the threads.
|
||||
~QueueRunner();
|
||||
|
||||
// Starts the queue runner with the given session.
|
||||
Status Start(Session* sess);
|
||||
|
||||
// Requests to stop and runs the cancel op.
|
||||
Status Stop(Session* sess);
|
||||
// Starts the queue runner with the given session, and wait for up to the
|
||||
// specified time (in milliseconds) for the queues to start to fill up.
|
||||
Status Start(Session* sess, int wait_for);
|
||||
|
||||
// Joins all the threads. Returns okay if all threads run successfully;
|
||||
// otherwise returns the first captured failure status.
|
||||
Status Join();
|
||||
Status Join() final;
|
||||
|
||||
// Returns the lastest status.
|
||||
Status GetStatus();
|
||||
|
||||
private:
|
||||
QueueRunner() {}
|
||||
QueueRunner() : coord_(nullptr) {}
|
||||
|
||||
// Initializes the instance with the QueueRunnerDef proto.
|
||||
Status Init(const QueueRunnerDef& queue_runner_def);
|
||||
@ -65,6 +72,14 @@ class QueueRunner {
|
||||
// The Run function for each thread.
|
||||
void Run(Session* sess, const string& enqueue_op);
|
||||
|
||||
// Requests to stop and runs the cancel op. It would be called in a separate
|
||||
// thread when coordinator is set.
|
||||
void Stop(Session* sess);
|
||||
|
||||
// Updates the internal status; it only keeps OK or the first unexpected error
|
||||
// status.
|
||||
void UpdateStatus(const Status& status);
|
||||
|
||||
string queue_name_;
|
||||
std::vector<string> enqueue_op_names_;
|
||||
string close_op_name_;
|
||||
@ -73,12 +88,15 @@ class QueueRunner {
|
||||
std::unordered_set<int> queue_closed_exception_types_;
|
||||
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
std::atomic<bool> should_stop_;
|
||||
condition_variable wait_to_close_;
|
||||
mutex mu_;
|
||||
// TODO(yuefengz): implement c++ coordinator.
|
||||
int runs_ = 0;
|
||||
Status status_;
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
Status enqueue_status_ GUARDED_BY(mu_);
|
||||
std::unique_ptr<BlockingCounter> counter_;
|
||||
|
||||
Coordinator* coord_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/cc/training/coordinator.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -111,7 +112,7 @@ TEST(QueueRunnerTest, BasicTest) {
|
||||
auto session = BuildSessionAndInitVariable(graph_def);
|
||||
|
||||
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
|
||||
kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {});
|
||||
kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
|
||||
|
||||
std::unique_ptr<QueueRunner> qr;
|
||||
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
|
||||
@ -164,7 +165,8 @@ GraphDef BuildDoubleQueueGraph() {
|
||||
auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
|
||||
auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
|
||||
QueueClose::CancelPendingEnqueues(true));
|
||||
auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32});
|
||||
auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32},
|
||||
FIFOQueue::Capacity(3));
|
||||
auto dequeue0 =
|
||||
QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32});
|
||||
auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]});
|
||||
@ -252,34 +254,34 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
|
||||
EXPECT_EQ(join_succeeded, true);
|
||||
}
|
||||
|
||||
TEST(QueueRunnerTest, Stop) {
|
||||
auto graph_def = BuildDoubleQueueGraph();
|
||||
TEST(QueueRunnerTest, EmptyEnqueueOps) {
|
||||
QueueRunnerDef queue_runner_def =
|
||||
BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
|
||||
|
||||
std::unique_ptr<QueueRunner> qr;
|
||||
EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
|
||||
Code::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
TEST(QueueRunnerTest, StartTimeout) {
|
||||
GraphDef graph_def = BuildDoubleQueueGraph();
|
||||
SessionOptions options;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
TF_CHECK_OK(session->Create(graph_def));
|
||||
|
||||
QueueRunnerDef queue_runner_def =
|
||||
BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
|
||||
{Code::OUT_OF_RANGE, Code::CANCELLED});
|
||||
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
|
||||
kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
|
||||
|
||||
std::unique_ptr<QueueRunner> qr;
|
||||
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
|
||||
TF_CHECK_OK(qr->Start(session.get()));
|
||||
|
||||
TF_EXPECT_OK(qr->Stop(session.get()));
|
||||
|
||||
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
|
||||
|
||||
EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
|
||||
Code::OUT_OF_RANGE);
|
||||
|
||||
// qr is already stopped
|
||||
TF_EXPECT_OK(qr->Join());
|
||||
// This will timeout since queue0 is not fed and queue1 is fetching data from
|
||||
// queue0.
|
||||
EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED);
|
||||
session->Close();
|
||||
}
|
||||
|
||||
TEST(QueueRunnerTest, StopTwoQueues) {
|
||||
TEST(QueueRunnerTest, TestCoordinatorStop) {
|
||||
auto graph_def = BuildDoubleQueueGraph();
|
||||
|
||||
SessionOptions options;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
TF_CHECK_OK(session->Create(graph_def));
|
||||
@ -290,31 +292,24 @@ TEST(QueueRunnerTest, StopTwoQueues) {
|
||||
QueueRunnerDef queue_runner1 =
|
||||
BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
|
||||
{Code::OUT_OF_RANGE, Code::CANCELLED});
|
||||
|
||||
Coordinator coord;
|
||||
std::unique_ptr<QueueRunner> qr0;
|
||||
TF_EXPECT_OK(QueueRunner::New(queue_runner0, &qr0));
|
||||
TF_EXPECT_OK(QueueRunner::New(queue_runner0, &coord, &qr0));
|
||||
TF_CHECK_OK(qr0->Start(session.get()));
|
||||
std::unique_ptr<QueueRunner> qr1;
|
||||
TF_EXPECT_OK(QueueRunner::New(queue_runner1, &qr1));
|
||||
TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1));
|
||||
TF_CHECK_OK(qr1->Start(session.get()));
|
||||
|
||||
coord.RegisterRunner(std::move(qr0));
|
||||
coord.RegisterRunner(std::move(qr1));
|
||||
|
||||
std::vector<Tensor> dq;
|
||||
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
|
||||
EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
|
||||
|
||||
TF_EXPECT_OK(qr0->Stop(session.get()));
|
||||
TF_EXPECT_OK(qr1->Stop(session.get()));
|
||||
|
||||
TF_EXPECT_OK(qr0->Join());
|
||||
TF_EXPECT_OK(qr1->Join());
|
||||
}
|
||||
|
||||
TEST(QueueRunnerTest, EmptyEnqueueOps) {
|
||||
QueueRunnerDef queue_runner_def =
|
||||
BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
|
||||
|
||||
std::unique_ptr<QueueRunner> qr;
|
||||
EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
|
||||
Code::INVALID_ARGUMENT);
|
||||
TF_EXPECT_OK(coord.RequestStop());
|
||||
TF_EXPECT_OK(coord.Join());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -23,6 +23,7 @@ py_library(
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
||||
"//tensorflow/contrib/integrate:integrate_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib import factorization
|
||||
from tensorflow.contrib import framework
|
||||
from tensorflow.contrib import graph_editor
|
||||
from tensorflow.contrib import grid_rnn
|
||||
from tensorflow.contrib import integrate
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.contrib import linear_optimizer
|
||||
|
@ -76,7 +76,7 @@ def build_split_apply_merge_model():
|
||||
|
||||
# REINFORCE forward step
|
||||
route_selection = st.StochasticTensor(
|
||||
distributions.Categorical, logits=logits)
|
||||
distributions.Categorical(logits=logits))
|
||||
|
||||
# Accessing route_selection as a Tensor below forces a sample of
|
||||
# the Categorical distribution based on its logits.
|
||||
|
@ -22,6 +22,7 @@ import tensorflow as tf
|
||||
|
||||
st = tf.contrib.bayesflow.stochastic_tensor
|
||||
sge = tf.contrib.bayesflow.stochastic_gradient_estimators
|
||||
dists = tf.contrib.distributions
|
||||
|
||||
|
||||
class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
||||
@ -31,7 +32,7 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
||||
self._final_loss = tf.constant(3.2)
|
||||
|
||||
def _testScoreFunction(self, loss_fn, expected):
|
||||
x = st.BernoulliTensor(p=self._p, loss_fn=loss_fn)
|
||||
x = st.StochasticTensor(dists.Bernoulli(p=self._p), loss_fn=loss_fn)
|
||||
sf = x.loss(self._final_loss)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
@ -62,8 +63,8 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
||||
def testScoreFunctionWithMeanBaseline(self):
|
||||
ema_decay = 0.8
|
||||
num_steps = 6
|
||||
x = st.BernoulliTensor(
|
||||
p=self._p,
|
||||
x = st.StochasticTensor(
|
||||
dists.Bernoulli(p=self._p),
|
||||
loss_fn=sge.get_score_function_with_baseline(
|
||||
sge.get_mean_baseline(ema_decay)))
|
||||
sf = x.loss(self._final_loss)
|
||||
@ -98,12 +99,12 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
|
||||
|
||||
def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self):
|
||||
ema_decay = 0.8
|
||||
x = st.BernoulliTensor(
|
||||
p=self._p,
|
||||
x = st.StochasticTensor(
|
||||
dists.Bernoulli(p=self._p),
|
||||
loss_fn=sge.get_score_function_with_baseline(
|
||||
sge.get_mean_baseline(ema_decay)))
|
||||
y = st.BernoulliTensor(
|
||||
p=self._p,
|
||||
y = st.StochasticTensor(
|
||||
dists.Bernoulli(p=self._p),
|
||||
loss_fn=sge.get_score_function_with_baseline(
|
||||
sge.get_mean_baseline(ema_decay)))
|
||||
sf_x = x.loss(self._final_loss)
|
||||
|
@ -39,9 +39,9 @@ class TestSurrogateLosses(tf.test.TestCase):
|
||||
mu = [0.0, 0.1, 0.2]
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
||||
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||
likelihood = st.StochasticTensor(
|
||||
distributions.Normal, mu=prior, sigma=sigma)
|
||||
distributions.Normal(mu=prior, sigma=sigma))
|
||||
self.assertTrue(prior.distribution.is_reparameterized)
|
||||
self.assertTrue(likelihood.distribution.is_reparameterized)
|
||||
|
||||
@ -77,10 +77,9 @@ class TestSurrogateLosses(tf.test.TestCase):
|
||||
mu = tf.constant([0.0, 0.1, 0.2])
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
prior = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
||||
likelihood = st.StochasticTensor(
|
||||
NormalNotParam, mu=prior, sigma=sigma)
|
||||
prior_2 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
||||
prior = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||
likelihood = st.StochasticTensor(NormalNotParam(mu=prior, sigma=sigma))
|
||||
prior_2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||
|
||||
loss = tf.square(tf.identity(likelihood) - mu)
|
||||
part_loss = tf.square(tf.identity(prior) - mu)
|
||||
@ -155,9 +154,7 @@ class TestSurrogateLosses(tf.test.TestCase):
|
||||
mu = tf.constant([0.0, 0.1, 0.2])
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
dt = st.StochasticTensor(NormalNotParam,
|
||||
mu=mu,
|
||||
sigma=sigma,
|
||||
dt = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma),
|
||||
loss_fn=None)
|
||||
self.assertEqual(None, dt.loss(tf.constant([2.0])))
|
||||
|
||||
@ -166,8 +163,8 @@ class TestSurrogateLosses(tf.test.TestCase):
|
||||
mu = tf.constant([0.0, 0.1, 0.2])
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
dt1 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
||||
dt2 = st.StochasticTensor(NormalNotParam, mu=mu, sigma=sigma)
|
||||
dt1 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||
dt2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
|
||||
loss = tf.square(tf.identity(dt1)) + 10. + dt2
|
||||
|
||||
sl_all = sg.surrogate_loss([loss])
|
||||
@ -186,8 +183,8 @@ class TestSurrogateLosses(tf.test.TestCase):
|
||||
class StochasticDependenciesMapTest(tf.test.TestCase):
|
||||
|
||||
def testBuildsMapOfUpstreamNodes(self):
|
||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt2 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
dt2 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
out1 = dt1.value() + 1.
|
||||
out2 = dt2.value() + 2.
|
||||
x = out1 + out2
|
||||
@ -197,11 +194,11 @@ class StochasticDependenciesMapTest(tf.test.TestCase):
|
||||
self.assertEqual(dep_map[dt2], set([x, y]))
|
||||
|
||||
def testHandlesStackedStochasticNodes(self):
|
||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
out1 = dt1.value() + 1.
|
||||
dt2 = st.StochasticTensor(distributions.Normal, mu=out1, sigma=1.)
|
||||
dt2 = st.StochasticTensor(distributions.Normal(mu=out1, sigma=1.))
|
||||
x = dt2.value() + 2.
|
||||
dt3 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
y = dt3.value() * 3.
|
||||
dep_map = sg._stochastic_dependencies_map([x, y])
|
||||
self.assertEqual(dep_map[dt1], set([x]))
|
||||
@ -209,10 +206,10 @@ class StochasticDependenciesMapTest(tf.test.TestCase):
|
||||
self.assertEqual(dep_map[dt3], set([y]))
|
||||
|
||||
def testTraversesControlInputs(self):
|
||||
dt1 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
logits = dt1.value() * 3.
|
||||
dt2 = st.StochasticTensor(distributions.Bernoulli, logits=logits)
|
||||
dt3 = st.StochasticTensor(distributions.Normal, mu=0., sigma=1.)
|
||||
dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits))
|
||||
dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
|
||||
x = dt3.value()
|
||||
y = tf.ones((2, 2)) * 4.
|
||||
z = tf.ones((2, 2)) * 3.
|
||||
|
@ -35,19 +35,19 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
sigma2 = tf.constant([0.1, 0.2, 0.3])
|
||||
|
||||
prior_default = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma)
|
||||
distributions.Normal(mu=mu, sigma=sigma))
|
||||
self.assertTrue(
|
||||
isinstance(prior_default.value_type, st.SampleAndReshapeValue))
|
||||
prior_0 = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma,
|
||||
distributions.Normal(mu=mu, sigma=sigma),
|
||||
dist_value_type=st.SampleAndReshapeValue())
|
||||
self.assertTrue(isinstance(prior_0.value_type, st.SampleAndReshapeValue))
|
||||
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
||||
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||
self.assertTrue(isinstance(prior.value_type, st.SampleAndReshapeValue))
|
||||
likelihood = st.StochasticTensor(
|
||||
distributions.Normal, mu=prior, sigma=sigma2)
|
||||
distributions.Normal(mu=prior, sigma=sigma2))
|
||||
self.assertTrue(
|
||||
isinstance(likelihood.value_type, st.SampleAndReshapeValue))
|
||||
|
||||
@ -77,7 +77,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
|
||||
with st.value_type(st.MeanValue()):
|
||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
||||
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||
self.assertTrue(isinstance(prior.value_type, st.MeanValue))
|
||||
|
||||
prior_mean = prior.mean()
|
||||
@ -94,7 +94,8 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
|
||||
with st.value_type(st.SampleAndReshapeValue()):
|
||||
prior_single = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma)
|
||||
distributions.Normal(
|
||||
mu=mu, sigma=sigma))
|
||||
|
||||
prior_single_value = prior_single.value()
|
||||
self.assertEqual(prior_single_value.get_shape(), (2, 3))
|
||||
@ -104,7 +105,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
|
||||
with st.value_type(st.SampleAndReshapeValue(n=2)):
|
||||
prior_double = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma)
|
||||
distributions.Normal(mu=mu, sigma=sigma))
|
||||
|
||||
prior_double_value = prior_double.value()
|
||||
self.assertEqual(prior_double_value.get_shape(), (4, 3))
|
||||
@ -119,7 +120,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
|
||||
with st.value_type(st.SampleValue()):
|
||||
prior_single = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma)
|
||||
distributions.Normal(mu=mu, sigma=sigma))
|
||||
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
|
||||
|
||||
prior_single_value = prior_single.value()
|
||||
@ -130,7 +131,7 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
|
||||
with st.value_type(st.SampleValue(n=2)):
|
||||
prior_double = st.StochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma)
|
||||
distributions.Normal(mu=mu, sigma=sigma))
|
||||
|
||||
prior_double_value = prior_double.value()
|
||||
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
|
||||
@ -143,9 +144,9 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
mu = [0.0, -1.0, 1.0]
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
with st.value_type(st.MeanValue()):
|
||||
prior = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
||||
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||
entropy = prior.entropy()
|
||||
deep_entropy = prior.entropy()
|
||||
deep_entropy = prior.distribution.entropy()
|
||||
expected_deep_entropy = distributions.Normal(
|
||||
mu=mu, sigma=sigma).entropy()
|
||||
entropies = sess.run([entropy, deep_entropy, expected_deep_entropy])
|
||||
@ -159,17 +160,15 @@ class StochasticTensorTest(tf.test.TestCase):
|
||||
|
||||
# With default
|
||||
with st.value_type(st.MeanValue(stop_gradient=True)):
|
||||
dt = st.StochasticTensor(distributions.Normal, mu=mu, sigma=sigma)
|
||||
dt = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
|
||||
loss = dt.loss([tf.constant(2.0)])
|
||||
self.assertTrue(loss is not None)
|
||||
self.assertAllClose(dt.distribution.log_prob(mu).eval() * 2.0,
|
||||
loss.eval())
|
||||
self.assertAllClose(
|
||||
dt.distribution.log_prob(mu).eval() * 2.0, loss.eval())
|
||||
|
||||
# With passed-in loss_fn.
|
||||
dt = st.StochasticTensor(
|
||||
distributions.Normal,
|
||||
mu=mu,
|
||||
sigma=sigma,
|
||||
distributions.Normal(mu=mu, sigma=sigma),
|
||||
dist_value_type=st.MeanValue(stop_gradient=True),
|
||||
loss_fn=sge.get_score_function_with_constant_baseline(
|
||||
baseline=tf.constant(8.0)))
|
||||
@ -204,7 +203,7 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
||||
sigma = tf.constant([1.1, 1.2, 1.3])
|
||||
obs = tf.zeros((2, 3))
|
||||
z = st.ObservedStochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma, value=obs)
|
||||
distributions.Normal(mu=mu, sigma=sigma), value=obs)
|
||||
[obs_val, z_val] = sess.run([obs, z.value()])
|
||||
self.assertAllEqual(obs_val, z_val)
|
||||
|
||||
@ -216,13 +215,13 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
||||
sigma = tf.placeholder(tf.float32)
|
||||
obs = tf.placeholder(tf.float32)
|
||||
z = st.ObservedStochasticTensor(
|
||||
distributions.Normal, mu=mu, sigma=sigma, value=obs)
|
||||
distributions.Normal(mu=mu, sigma=sigma), value=obs)
|
||||
|
||||
mu2 = tf.placeholder(tf.float32, shape=[None])
|
||||
sigma2 = tf.placeholder(tf.float32, shape=[None])
|
||||
obs2 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
z2 = st.ObservedStochasticTensor(
|
||||
distributions.Normal, mu=mu2, sigma=sigma2, value=obs2)
|
||||
distributions.Normal(mu=mu2, sigma=sigma2), value=obs2)
|
||||
|
||||
coll = tf.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
|
||||
self.assertEqual(coll, [z, z2])
|
||||
@ -230,27 +229,19 @@ class ObservedStochasticTensorTest(tf.test.TestCase):
|
||||
def testConstructionErrors(self):
|
||||
mu = [0., 0.]
|
||||
sigma = [1., 1.]
|
||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
||||
distributions.Normal, mu=mu, sigma=sigma,
|
||||
value=tf.zeros((3,)))
|
||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
||||
distributions.Normal, mu=mu, sigma=sigma,
|
||||
value=tf.zeros((3, 1)))
|
||||
self.assertRaises(ValueError, st.ObservedStochasticTensor,
|
||||
distributions.Normal, mu=mu, sigma=sigma,
|
||||
value=tf.zeros((1, 2), dtype=tf.int32))
|
||||
|
||||
|
||||
class AutomaticDistributionImportTest(tf.test.TestCase):
|
||||
|
||||
def testImportNormal(self):
|
||||
self.assertTrue(hasattr(st, "NormalTensor"))
|
||||
self.assertTrue(callable(st.NormalTensor))
|
||||
norm = st.NormalTensor(mu=0.0, sigma=1.0)
|
||||
self.assertEqual(type(norm).__name__, "NormalTensor")
|
||||
self.assertTrue(isinstance(norm, st.NormalTensor))
|
||||
self.assertTrue(isinstance(norm, st.StochasticTensor))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
st.ObservedStochasticTensor,
|
||||
distributions.Normal(mu=mu, sigma=sigma),
|
||||
value=tf.zeros((3,)))
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
st.ObservedStochasticTensor,
|
||||
distributions.Normal(mu=mu, sigma=sigma),
|
||||
value=tf.zeros((3, 1)))
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
st.ObservedStochasticTensor,
|
||||
distributions.Normal(mu=mu, sigma=sigma),
|
||||
value=tf.zeros(
|
||||
(1, 2), dtype=tf.int32))
|
||||
|
@ -44,7 +44,7 @@ def mini_vae():
|
||||
x = [[-6., 3., 6.], [-8., 4., 8.]]
|
||||
prior = distributions.Normal(mu=0., sigma=1.)
|
||||
variational = st.StochasticTensor(
|
||||
distributions.Normal, mu=inference_net(x, 1), sigma=1.)
|
||||
distributions.Normal(mu=inference_net(x, 1), sigma=1.))
|
||||
vi.register_prior(variational, prior)
|
||||
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
||||
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
||||
@ -101,7 +101,7 @@ class VariationalInferenceTest(tf.test.TestCase):
|
||||
|
||||
prior = distributions.Bernoulli(0.5)
|
||||
variational = st.StochasticTensor(
|
||||
NormalNoEntropy, mu=inference_net(x, 1), sigma=1.)
|
||||
NormalNoEntropy(mu=inference_net(x, 1), sigma=1.))
|
||||
vi.register_prior(variational, prior)
|
||||
px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
|
||||
log_likelihood = tf.reduce_sum(px.log_prob(x), 1)
|
||||
|
@ -44,7 +44,6 @@ from __future__ import print_function
|
||||
import abc
|
||||
import collections
|
||||
import contextlib
|
||||
import inspect
|
||||
import threading
|
||||
|
||||
import six
|
||||
@ -79,10 +78,6 @@ class BaseStochasticTensor(object):
|
||||
def graph(self):
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def input_dict(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def value(self, name=None):
|
||||
pass
|
||||
@ -120,6 +115,7 @@ class BaseStochasticTensor(object):
|
||||
# pylint: disable=protected-access
|
||||
ops.register_tensor_conversion_function(
|
||||
BaseStochasticTensor, BaseStochasticTensor._tensor_conversion_function)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@ -223,8 +219,8 @@ class SampleAndReshapeValue(_StochasticValueType):
|
||||
st_value = st.value()
|
||||
assertEqual(st_value.get_shape(), (4, 3))
|
||||
|
||||
dt_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
|
||||
assertEqual(dt_value_val.shape, (4, 3))
|
||||
st_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
|
||||
assertEqual(st_value_val.shape, (4, 3))
|
||||
```
|
||||
"""
|
||||
|
||||
@ -312,17 +308,16 @@ class StochasticTensor(BaseStochasticTensor):
|
||||
"""StochasticTensor is a BaseStochasticTensor backed by a distribution."""
|
||||
|
||||
def __init__(self,
|
||||
dist_cls,
|
||||
name=None,
|
||||
dist,
|
||||
name="StochasticTensor",
|
||||
dist_value_type=None,
|
||||
loss_fn=sge.score_function,
|
||||
**dist_args):
|
||||
loss_fn=sge.score_function):
|
||||
"""Construct a `StochasticTensor`.
|
||||
|
||||
`StochasticTensor` will instantiate a distribution from `dist_cls` and
|
||||
`dist_args` and its `value` method will return the same value each time
|
||||
it is called. What `value` is returned is controlled by the
|
||||
`dist_value_type` (defaults to `SampleAndReshapeValue`).
|
||||
`StochasticTensor` is backed by the `dist` distribution and its `value`
|
||||
method will return the same value each time it is called. What `value` is
|
||||
returned is controlled by the `dist_value_type` (defaults to
|
||||
`SampleAndReshapeValue`).
|
||||
|
||||
Some distributions' sample functions are not differentiable (e.g. a sample
|
||||
from a discrete distribution like a Bernoulli) and so to differentiate
|
||||
@ -338,28 +333,25 @@ class StochasticTensor(BaseStochasticTensor):
|
||||
`MeanValueType` or if `loss_fn=None`.
|
||||
|
||||
Args:
|
||||
dist_cls: a `Distribution` class.
|
||||
dist: an instance of `Distribution`.
|
||||
name: a name for this `StochasticTensor` and its ops.
|
||||
dist_value_type: a `_StochasticValueType`, which will determine what the
|
||||
`value` of this `StochasticTensor` will be. If not provided, the
|
||||
value type set with the `value_type` context manager will be used.
|
||||
loss_fn: callable that takes `(st, st.value(), influenced_loss)`, where
|
||||
loss_fn: callable that takes
|
||||
`(st, st.value(), influenced_loss)`, where
|
||||
`st` is this `StochasticTensor`, and returns a `Tensor` loss. By
|
||||
default, `loss_fn` is the `score_function`, or more precisely, the
|
||||
integral of the score function, such that when the gradient is taken,
|
||||
the score function results. See the `stochastic_gradient_estimators`
|
||||
module for additional loss functions and baselines.
|
||||
**dist_args: keyword arguments to be passed through to `dist_cls` on
|
||||
construction.
|
||||
|
||||
Raises:
|
||||
TypeError: if `dist_cls` is not a `Distribution`.
|
||||
TypeError: if `dist` is not an instance of `Distribution`.
|
||||
TypeError: if `loss_fn` is not `callable`.
|
||||
"""
|
||||
if not issubclass(dist_cls, distributions.Distribution):
|
||||
raise TypeError("dist_cls must be a subclass of Distribution")
|
||||
self._dist_cls = dist_cls
|
||||
self._dist_args = dist_args
|
||||
if not isinstance(dist, distributions.Distribution):
|
||||
raise TypeError("dist must be an instance of Distribution")
|
||||
if dist_value_type is None:
|
||||
try:
|
||||
self._value_type = get_current_value_type()
|
||||
@ -371,24 +363,17 @@ class StochasticTensor(BaseStochasticTensor):
|
||||
with value_type(dist_value_type):
|
||||
self._value_type = get_current_value_type()
|
||||
|
||||
self._value_type.declare_inputs(self, dist_args)
|
||||
|
||||
if loss_fn is not None and not callable(loss_fn):
|
||||
raise TypeError("loss_fn must be callable")
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
with ops.name_scope(name, "StochasticTensor",
|
||||
dist_args.values()) as scope:
|
||||
with ops.name_scope(name) as scope:
|
||||
self._name = scope
|
||||
self._dist = dist_cls(**dist_args)
|
||||
self._dist = dist
|
||||
self._value = self._create_value()
|
||||
|
||||
super(StochasticTensor, self).__init__()
|
||||
|
||||
@property
|
||||
def input_dict(self):
|
||||
return self._dist_args
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return self._value_type
|
||||
@ -397,9 +382,6 @@ class StochasticTensor(BaseStochasticTensor):
|
||||
def distribution(self):
|
||||
return self._dist
|
||||
|
||||
def clone(self, name=None, **dist_args):
|
||||
return StochasticTensor(self._dist_cls, name=name, **dist_args)
|
||||
|
||||
def _create_value(self):
|
||||
"""Create the value Tensor based on the value type, store as self._value."""
|
||||
|
||||
@ -494,33 +476,28 @@ class ObservedStochasticTensor(StochasticTensor):
|
||||
"""A StochasticTensor with an observed value."""
|
||||
|
||||
# pylint: disable=super-init-not-called
|
||||
def __init__(self, dist_cls, value, name=None, **dist_args):
|
||||
def __init__(self, dist, value, name=None):
|
||||
"""Construct an `ObservedStochasticTensor`.
|
||||
|
||||
`ObservedStochasticTensor` will instantiate a distribution from `dist_cls`
|
||||
and `dist_args` but use the provided value instead of sampling from the
|
||||
distribution. The provided value argument must be appropriately shaped
|
||||
to have come from the constructed distribution.
|
||||
`ObservedStochasticTensor` is backed by distribution `dist` and uses the
|
||||
provided value instead of using the current value type to draw a value from
|
||||
the distribution. The provided value argument must be appropriately shaped
|
||||
to have come from the distribution.
|
||||
|
||||
Args:
|
||||
dist_cls: a `Distribution` class.
|
||||
dist: an instance of `Distribution`.
|
||||
value: a Tensor containing the observed value
|
||||
name: a name for this `ObservedStochasticTensor` and its ops.
|
||||
**dist_args: keyword arguments to be passed through to `dist_cls` on
|
||||
construction.
|
||||
|
||||
Raises:
|
||||
TypeError: if `dist_cls` is not a `Distribution`.
|
||||
TypeError: if `dist` is not an instance of `Distribution`.
|
||||
ValueError: if `value` is not compatible with the distribution.
|
||||
"""
|
||||
if not issubclass(dist_cls, distributions.Distribution):
|
||||
raise TypeError("dist_cls must be a subclass of Distribution")
|
||||
self._dist_cls = dist_cls
|
||||
self._dist_args = dist_args
|
||||
with ops.name_scope(name, "ObservedStochasticTensor",
|
||||
list(dist_args.values()) + [value]) as scope:
|
||||
if not isinstance(dist, distributions.Distribution):
|
||||
raise TypeError("dist must be an instance of Distribution")
|
||||
with ops.name_scope(name, "ObservedStochasticTensor", [value]) as scope:
|
||||
self._name = scope
|
||||
self._dist = dist_cls(**dist_args)
|
||||
self._dist = dist
|
||||
dist_shape = self._dist.get_batch_shape().concatenate(
|
||||
self._dist.get_event_shape())
|
||||
value = ops.convert_to_tensor(value)
|
||||
@ -538,7 +515,7 @@ class ObservedStochasticTensor(StochasticTensor):
|
||||
"sample from the distribution %s." % (value_shape, dist_shape))
|
||||
if value.dtype != self._dist.dtype:
|
||||
raise ValueError("Type of observed value (%s) does not match type of "
|
||||
"distribuiton (%s)." % (value.dtype, self._dist.dtype))
|
||||
"distribution (%s)." % (value.dtype, self._dist.dtype))
|
||||
self._value = array_ops.identity(value)
|
||||
# pylint: disable=non-parent-init-called
|
||||
BaseStochasticTensor.__init__(self)
|
||||
@ -557,39 +534,3 @@ __all__ = [
|
||||
"value_type",
|
||||
"get_current_value_type",
|
||||
]
|
||||
|
||||
_globals = globals()
|
||||
# pylint: disable=redefined-builtin
|
||||
__doc__ += "\n\n## Automatically Generated StochasticTensors\n\n"
|
||||
# pylint: enable=redefined-builtin
|
||||
for _name in sorted(dir(distributions)):
|
||||
_candidate = getattr(distributions, _name)
|
||||
if (inspect.isclass(_candidate)
|
||||
and _candidate != distributions.Distribution
|
||||
and issubclass(_candidate, distributions.Distribution)):
|
||||
_local_name = "%sTensor" % _name
|
||||
|
||||
class _WrapperTensor(StochasticTensor):
|
||||
_my_candidate = _candidate
|
||||
|
||||
def __init__(self, name=None, dist_value_type=None,
|
||||
loss_fn=sge.score_function, **dist_args):
|
||||
StochasticTensor.__init__(
|
||||
self,
|
||||
dist_cls=self._my_candidate,
|
||||
name=name,
|
||||
dist_value_type=dist_value_type,
|
||||
loss_fn=loss_fn, **dist_args)
|
||||
|
||||
_WrapperTensor.__name__ = _local_name
|
||||
_WrapperTensor.__doc__ = (
|
||||
"`%s` is a `StochasticTensor` backed by the distribution `%s`."""
|
||||
% (_local_name, _name))
|
||||
_globals[_local_name] = _WrapperTensor
|
||||
del _WrapperTensor
|
||||
del _candidate
|
||||
|
||||
__all__.append(_local_name)
|
||||
__doc__ += "@@%s\n" % _local_name
|
||||
|
||||
del _local_name
|
||||
|
@ -126,7 +126,7 @@ def get_stochastic_variable(getter,
|
||||
|
||||
dist_kwargs = dist_kwargs or {}
|
||||
dist_kwargs.update(params)
|
||||
sample = st.StochasticTensor(dist_cls, **dist_kwargs)
|
||||
sample = st.StochasticTensor(dist_cls(**dist_kwargs))
|
||||
|
||||
if prior is not None:
|
||||
if callable(prior):
|
||||
|
2
tensorflow/contrib/cmake/external/gif.cmake
vendored
2
tensorflow/contrib/cmake/external/gif.cmake
vendored
@ -1,7 +1,7 @@
|
||||
include (ExternalProject)
|
||||
|
||||
set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/)
|
||||
set(gif_URL http://ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
|
||||
set(gif_URL http://cdimage.debian.org/mirror/xbmc.org/build-deps/sources/giflib-5.1.4.tar.gz)
|
||||
set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1)
|
||||
set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install)
|
||||
set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif)
|
||||
|
@ -26,7 +26,7 @@ from setuptools import find_packages, setup, Command
|
||||
from setuptools.command.install import install as InstallCommandBase
|
||||
from setuptools.dist import Distribution
|
||||
|
||||
_VERSION = '0.11.0rc1-cmake-experimental'
|
||||
_VERSION = '0.11.0rc2-cmake-experimental'
|
||||
|
||||
REQUIRED_PACKAGES = [
|
||||
'numpy >= 1.11.0',
|
||||
|
@ -57,7 +57,6 @@ initialized with parameters that define the distributions.
|
||||
@@MultivariateNormalCholesky
|
||||
@@MultivariateNormalDiagPlusVDVT
|
||||
@@MultivariateNormalDiagWithSoftplusStDev
|
||||
@@matrix_diag_transform
|
||||
|
||||
### Other multivariate distributions
|
||||
|
||||
@ -67,6 +66,10 @@ initialized with parameters that define the distributions.
|
||||
@@WishartCholesky
|
||||
@@WishartFull
|
||||
|
||||
### Multivariate Utilities
|
||||
|
||||
@@matrix_diag_transform
|
||||
|
||||
## Transformed distributions
|
||||
|
||||
@@TransformedDistribution
|
||||
@ -86,7 +89,7 @@ representing the posterior or posterior predictive.
|
||||
@@normal_conjugates_known_sigma_posterior
|
||||
@@normal_conjugates_known_sigma_predictive
|
||||
|
||||
## Kullback Leibler Divergence
|
||||
## Kullback-Leibler Divergence
|
||||
|
||||
@@kl
|
||||
@@RegisterKL
|
||||
|
@ -25,7 +25,7 @@ import tensorflow as tf
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
|
||||
|
||||
class DistributionUtilTest(tf.test.TestCase):
|
||||
class AssertCloseTest(tf.test.TestCase):
|
||||
|
||||
def testAssertCloseIntegerDtype(self):
|
||||
x = [1, 5, 10, 15, 20]
|
||||
@ -110,6 +110,9 @@ class DistributionUtilTest(tf.test.TestCase):
|
||||
distribution_util.assert_integer_form(w)]):
|
||||
tf.identity(w).eval()
|
||||
|
||||
|
||||
class GetLogitsAndProbTest(tf.test.TestCase):
|
||||
|
||||
def testGetLogitsAndProbImproperArguments(self):
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
@ -229,6 +232,9 @@ class DistributionUtilTest(tf.test.TestCase):
|
||||
p=p4, multidimensional=True, validate_args=False)
|
||||
prob.eval()
|
||||
|
||||
|
||||
class LogCombinationsTest(tf.test.TestCase):
|
||||
|
||||
def testLogCombinationsBinomial(self):
|
||||
n = [2, 5, 12, 15]
|
||||
k = [1, 2, 4, 11]
|
||||
@ -252,6 +258,9 @@ class DistributionUtilTest(tf.test.TestCase):
|
||||
log_binom = distribution_util.log_combinations(n, counts)
|
||||
self.assertEqual([2, 2], log_binom.get_shape())
|
||||
|
||||
|
||||
class RotateTransposeTest(tf.test.TestCase):
|
||||
|
||||
def _np_rotate_transpose(self, x, shift):
|
||||
if not isinstance(x, np.ndarray):
|
||||
x = np.array(x)
|
||||
@ -283,7 +292,10 @@ class DistributionUtilTest(tf.test.TestCase):
|
||||
sess.run(distribution_util.rotate_transpose(x, shift),
|
||||
feed_dict={x: x_value, shift: shift_value}))
|
||||
|
||||
def testChooseVector(self):
|
||||
|
||||
class PickVectorTest(tf.test.TestCase):
|
||||
|
||||
def testCorrectlyPicksVector(self):
|
||||
with self.test_session():
|
||||
x = np.arange(10, 12)
|
||||
y = np.arange(15, 18)
|
||||
@ -301,5 +313,51 @@ class DistributionUtilTest(tf.test.TestCase):
|
||||
tf.constant(False), x, y)) # No eval.
|
||||
|
||||
|
||||
class FillLowerTriangularTest(tf.test.TestCase):
|
||||
|
||||
def testCorrectlyMakes1x1LowerTril(self):
|
||||
with self.test_session():
|
||||
x = np.array([[1.], [2], [3]])
|
||||
expected = np.array([[[1.]], [[2]], [[3]]])
|
||||
actual = distribution_util.fill_lower_triangular(x)
|
||||
self.assertAllEqual(expected.shape, actual.get_shape())
|
||||
self.assertAllEqual(expected, actual.eval())
|
||||
|
||||
def testCorrectlyMakesNoBatchLowerTril(self):
|
||||
with self.test_session():
|
||||
x = tf.convert_to_tensor(np.arange(9, dtype=np.float32))
|
||||
expected = np.array(
|
||||
[[0., 0., 0.],
|
||||
[1., 2., 0.],
|
||||
[3., 4., 5.]])
|
||||
actual = distribution_util.fill_lower_triangular(x)
|
||||
self.assertAllEqual(expected.shape, actual.get_shape())
|
||||
self.assertAllEqual(expected, actual.eval())
|
||||
self.assertAllEqual(
|
||||
np.concatenate([np.ones(6, dtype=np.float32),
|
||||
np.zeros(3, dtype=np.float32)]),
|
||||
tf.gradients(distribution_util.fill_lower_triangular(x), x)[0].eval())
|
||||
|
||||
def testCorrectlyMakesBatchLowerTril(self):
|
||||
with self.test_session():
|
||||
x = np.reshape(np.arange(24), (2, 2, 6))
|
||||
expected = np.array(
|
||||
[[[[0., 0., 0.],
|
||||
[1., 2., 0.],
|
||||
[3., 4., 5.]],
|
||||
[[6., 0., 0.],
|
||||
[7., 8., 0.],
|
||||
[9., 10., 11.]]],
|
||||
[[[12., 0., 0.],
|
||||
[13., 14., 0.],
|
||||
[15., 16., 17.]],
|
||||
[[18., 0., 0.],
|
||||
[19., 20., 0.],
|
||||
[21., 22., 23.]]]])
|
||||
actual = distribution_util.fill_lower_triangular(x)
|
||||
self.assertAllEqual(expected.shape, actual.get_shape())
|
||||
self.assertAllEqual(expected, actual.eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -20,11 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
@ -376,7 +378,7 @@ def pick_vector(cond,
|
||||
TypeError: if `cond` is not a constant and
|
||||
`true_vector.dtype != false_vector.dtype`
|
||||
"""
|
||||
with ops.op_scope((cond, true_vector, false_vector), name):
|
||||
with ops.name_scope(name, values=(cond, true_vector, false_vector)):
|
||||
cond = ops.convert_to_tensor(cond, name="cond")
|
||||
if cond.dtype != dtypes.bool:
|
||||
raise TypeError("%s.dtype=%s which is not %s" %
|
||||
@ -405,6 +407,105 @@ def gen_new_seed(seed, salt):
|
||||
return None
|
||||
|
||||
|
||||
def fill_lower_triangular(x, name="fill_lower_triangular"):
|
||||
"""Creates a (batch of) lower triangular matrix from a vector of inputs.
|
||||
|
||||
If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
|
||||
b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
|
||||
`n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.
|
||||
|
||||
Note: This function is very slow; possibly 10x slower than zero-ing out the
|
||||
upper-triangular portion of a full matrix.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
fill_lower_triangular([1, 2, 3, 4, 5, 6])
|
||||
# Returns: [[1, 0, 0],
|
||||
# [2, 3, 0],
|
||||
# [4, 5, 6]]
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor` representing lower triangular elements.
|
||||
name: `String`. The name to give this op.
|
||||
|
||||
Returns:
|
||||
tril: `Tensor` with lower triangular elements filled from `x`.
|
||||
"""
|
||||
with ops.name_scope(name, values=(x,)):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
if (x.get_shape().ndims is not None and
|
||||
x.get_shape()[-1].value is not None):
|
||||
d = x.get_shape()[-1].value
|
||||
# d = n^2/2 + n/2 implies n is:
|
||||
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
|
||||
final_shape = x.get_shape()[:-1].concatenate(
|
||||
tensor_shape.TensorShape([n, n]))
|
||||
else:
|
||||
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
|
||||
# d = n^2/2 + n/2 implies n is:
|
||||
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
|
||||
dtype=dtypes.int32)
|
||||
final_shape = x.get_shape()[:-1].concatenate(
|
||||
tensor_shape.TensorShape([None, None]))
|
||||
|
||||
# Make ids for each batch dim.
|
||||
if (x.get_shape().ndims is not None and
|
||||
x.get_shape()[:-1].is_fully_defined()):
|
||||
batch_shape = np.asarray(x.get_shape()[:-1].as_list(), dtype=np.int32)
|
||||
m = np.prod(batch_shape)
|
||||
else:
|
||||
batch_shape = array_ops.shape(x)[:-1]
|
||||
m = array_ops.reduce_prod(batch_shape)
|
||||
|
||||
# Flatten batch dims.
|
||||
y = array_ops.reshape(x, [-1, d])
|
||||
|
||||
# Prepend a zero to each row.
|
||||
y = array_ops.pad(y, paddings=[[0, 0], [1, 0]])
|
||||
|
||||
# Make ids for each batch dim.
|
||||
if x.get_shape()[:-1].is_fully_defined():
|
||||
m = np.asarray(np.prod(x.get_shape()[:-1].as_list()), dtype=np.int32)
|
||||
else:
|
||||
m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
|
||||
batch_ids = math_ops.range(m)
|
||||
|
||||
def make_tril_ids(n):
|
||||
"""Internal helper to create vector of linear indices into y."""
|
||||
cols = array_ops.reshape(array_ops.tile(math_ops.range(n), [n]), [n, n])
|
||||
rows = array_ops.tile(
|
||||
array_ops.expand_dims(math_ops.range(n), -1), [1, n])
|
||||
pred = math_ops.greater(cols, rows)
|
||||
tril_ids = array_ops.tile(array_ops.reshape(
|
||||
math_ops.cumsum(math_ops.range(n)), [n, 1]), [1, n]) + cols
|
||||
tril_ids = math_ops.select(pred,
|
||||
array_ops.zeros([n, n], dtype=dtypes.int32),
|
||||
tril_ids + 1)
|
||||
tril_ids = array_ops.reshape(tril_ids, [-1])
|
||||
return tril_ids
|
||||
tril_ids = make_tril_ids(n)
|
||||
|
||||
# Assemble the ids into pairs.
|
||||
idx = array_ops.pack([
|
||||
array_ops.tile(array_ops.expand_dims(batch_ids, -1), [1, n*n]),
|
||||
array_ops.tile([tril_ids], [m, 1])])
|
||||
idx = array_ops.transpose(idx, [1, 2, 0])
|
||||
|
||||
if x.get_shape().ndims == 1:
|
||||
# Prefer using gather because it has a gradient.
|
||||
# We wrap the result in a list so downstream logic "just works."
|
||||
y = [array_ops.gather(y[0, :], tril_ids)]
|
||||
else:
|
||||
y = array_ops.gather_nd(y, idx)
|
||||
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
|
||||
|
||||
y.set_shape(y.get_shape().merge_with(final_shape))
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class AppendDocstring(object):
|
||||
"""Helper class to promote private subclass docstring to public counterpart.
|
||||
|
||||
|
@ -571,9 +571,8 @@ class WALSModel(object):
|
||||
extras = size % num_shards
|
||||
assignments = tf.maximum(ids // (ids_per_shard + 1),
|
||||
(ids - extras) // ids_per_shard)
|
||||
new_ids = tf.select(assignments < extras,
|
||||
ids % (ids_per_shard + 1),
|
||||
(ids - extras) % ids_per_shard)
|
||||
new_ids = tf.where(assignments < extras, ids % (ids_per_shard + 1),
|
||||
(ids - extras) % ids_per_shard)
|
||||
return assignments, new_ids
|
||||
return func
|
||||
|
||||
@ -655,7 +654,7 @@ class WALSModel(object):
|
||||
update_op: An op that assigns the newly computed values to the row/column
|
||||
factors.
|
||||
"""
|
||||
assert isinstance(sp_input, ops.SparseTensor)
|
||||
assert isinstance(sp_input, tf.SparseTensor)
|
||||
|
||||
if update_row_factors:
|
||||
left = self._row_factors
|
||||
|
@ -18,8 +18,6 @@ py_library(
|
||||
"__init__.py",
|
||||
"python/framework/__init__.py",
|
||||
"python/framework/checkpoint_utils.py",
|
||||
"python/framework/decorator_utils.py",
|
||||
"python/framework/deprecation.py",
|
||||
"python/framework/experimental.py",
|
||||
"python/framework/tensor_util.py",
|
||||
"python/ops/__init__.py",
|
||||
@ -102,20 +100,6 @@ py_test(
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "deprecation_test",
|
||||
srcs = ["python/framework/deprecation_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "decorator_utils_test",
|
||||
srcs = ["python/framework/decorator_utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "experimental_test",
|
||||
srcs = ["python/framework/experimental_test.py"],
|
||||
@ -135,6 +119,7 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/ops/variables_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["manual"],
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
|
@ -19,10 +19,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.framework.python.framework import decorator_utils
|
||||
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
|
||||
from tensorflow.contrib.framework.python.framework.deprecation import deprecated
|
||||
from tensorflow.contrib.framework.python.framework.deprecation import deprecated_arg_values
|
||||
from tensorflow.contrib.framework.python.framework.deprecation import deprecated_args
|
||||
from tensorflow.contrib.framework.python.framework.experimental import experimental
|
||||
from tensorflow.contrib.framework.python.framework.tensor_util import *
|
||||
from tensorflow.python.util import decorator_utils
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.deprecation import deprecated_arg_values
|
||||
from tensorflow.python.util.deprecation import deprecated_args
|
||||
|
@ -20,8 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import decorator_utils
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import decorator_utils
|
||||
|
||||
|
||||
def _add_experimental_function_notice_to_docstring(doc):
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -283,7 +284,7 @@ def is_tensor(x):
|
||||
Returns:
|
||||
`True` if `x` is a tensor, `False` if not.
|
||||
"""
|
||||
tensor_types = (ops.Tensor, ops.SparseTensor, variables.Variable)
|
||||
tensor_types = (ops.Tensor, sparse_tensor.SparseTensor, variables.Variable)
|
||||
return isinstance(x, tensor_types)
|
||||
|
||||
|
||||
@ -303,7 +304,7 @@ def with_shape(expected_shape, tensor):
|
||||
Raises:
|
||||
ValueError: if tensor has an invalid shape.
|
||||
"""
|
||||
if isinstance(tensor, ops.SparseTensor):
|
||||
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||
raise ValueError('SparseTensor not supported.')
|
||||
|
||||
# Shape type must be 1D int32.
|
||||
@ -376,9 +377,9 @@ def convert_to_tensor_or_sparse_tensor(
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if isinstance(value, ops.SparseTensorValue):
|
||||
value = ops.SparseTensor.from_value(value)
|
||||
if isinstance(value, ops.SparseTensor):
|
||||
if isinstance(value, sparse_tensor.SparseTensorValue):
|
||||
value = sparse_tensor.SparseTensor.from_value(value)
|
||||
if isinstance(value, sparse_tensor.SparseTensor):
|
||||
if dtype and not dtype.is_compatible_with(value.dtype):
|
||||
raise RuntimeError(
|
||||
'Sparse dtype: requested = %s, actual = %s' % (
|
||||
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -43,7 +44,7 @@ def _get_tensor_repr(t,
|
||||
if print_tensor_type:
|
||||
if isinstance(t, ops.Tensor):
|
||||
t_type_str = "Type: Tensor ({})".format(t.dtype.name)
|
||||
elif isinstance(t, ops.SparseTensor):
|
||||
elif isinstance(t, sparse_tensor.SparseTensor):
|
||||
t_type_str = "Type: SparseTensor ({})".format(t.dtype.name)
|
||||
elif isinstance(t, tensor_array_ops.TensorArray):
|
||||
t_type_str = "Type: TensorArray ({})".format(t.dtype.name)
|
||||
@ -51,7 +52,7 @@ def _get_tensor_repr(t,
|
||||
tensor_list.append(constant_op.constant(t_type_str))
|
||||
|
||||
if print_shape:
|
||||
if isinstance(t, ops.SparseTensor):
|
||||
if isinstance(t, sparse_tensor.SparseTensor):
|
||||
tensor_list.append(constant_op.constant("Shape:"))
|
||||
tensor_list.append(t.shape)
|
||||
elif isinstance(t, ops.Tensor):
|
||||
@ -66,7 +67,7 @@ def _get_tensor_repr(t,
|
||||
tensor_list.append(constant_op.constant("First True in Boolean tensor at:"))
|
||||
tensor_list.append(math_ops.argmax(int_tensor, 0))
|
||||
|
||||
if isinstance(t, ops.SparseTensor):
|
||||
if isinstance(t, sparse_tensor.SparseTensor):
|
||||
tensor_list.append(constant_op.constant("Sparse indices:"))
|
||||
tensor_list.append(t.indices)
|
||||
tensor_list.append(constant_op.constant("Sparse values:"))
|
||||
@ -137,15 +138,15 @@ def print_op(input_,
|
||||
if isinstance(input_, ops.Tensor):
|
||||
input_ = logging_ops.Print(input_, tensor_list, message, first_n, summarize,
|
||||
name)
|
||||
elif isinstance(input_, ops.SparseTensor):
|
||||
elif isinstance(input_, sparse_tensor.SparseTensor):
|
||||
p = logging_ops.Print(
|
||||
constant_op.constant([]), tensor_list, message, first_n, summarize,
|
||||
name)
|
||||
|
||||
with ops.control_dependencies([p]):
|
||||
input_ = ops.SparseTensor(array_ops.identity(input_.indices),
|
||||
array_ops.identity(input_.values),
|
||||
array_ops.identity(input_.shape))
|
||||
input_ = sparse_tensor.SparseTensor(array_ops.identity(input_.indices),
|
||||
array_ops.identity(input_.values),
|
||||
array_ops.identity(input_.shape))
|
||||
elif isinstance(input_, tensor_array_ops.TensorArray):
|
||||
p = logging_ops.Print(
|
||||
constant_op.constant([]), tensor_list, message, first_n, summarize,
|
||||
|
@ -36,7 +36,7 @@ class LocalVariableTest(tf.test.TestCase):
|
||||
variables = tf.local_variables()
|
||||
self.assertEquals(2, len(variables))
|
||||
self.assertRaises(tf.OpError, sess.run, variables)
|
||||
tf.initialize_variables(variables).run()
|
||||
tf.variables_initializer(variables).run()
|
||||
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
|
||||
|
||||
def testLocalVariableNameAndShape(self):
|
||||
@ -51,7 +51,7 @@ class LocalVariableTest(tf.test.TestCase):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.local_variable(0)
|
||||
self.assertFalse(a in tf.all_variables())
|
||||
self.assertFalse(a in tf.global_variables())
|
||||
self.assertTrue(a in tf.local_variables())
|
||||
|
||||
def testLocalVariableNotInVariablesToRestore(self):
|
||||
@ -82,7 +82,7 @@ class LocalVariableTest(tf.test.TestCase):
|
||||
def testInitializedVariableValue(self):
|
||||
with self.test_session() as sess:
|
||||
a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a')
|
||||
sess.run(tf.initialize_local_variables())
|
||||
sess.run(tf.local_variables_initializer())
|
||||
self.assertAllEqual(a.eval(), [0]*5)
|
||||
|
||||
|
||||
@ -439,7 +439,7 @@ class ModelVariablesTest(tf.test.TestCase):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.model_variable('a', [5])
|
||||
self.assertTrue(a in tf.all_variables())
|
||||
self.assertTrue(a in tf.global_variables())
|
||||
self.assertTrue(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
|
||||
self.assertFalse(a in tf.local_variables())
|
||||
|
||||
@ -474,7 +474,7 @@ class ModelVariablesTest(tf.test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
a = tf.contrib.framework.model_variable(
|
||||
'a', [5], initializer=tf.ones_initializer)
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
self.assertAllEqual(a.eval(), [1]*5)
|
||||
|
||||
def testDeviceFn(self):
|
||||
@ -667,7 +667,7 @@ class AssignFromValuesTest(tf.test.TestCase):
|
||||
var_names_to_values)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
sess.run(assign_op, feed_dict)
|
||||
@ -697,7 +697,7 @@ class AssignFromValuesTest(tf.test.TestCase):
|
||||
var_names_to_values)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
sess.run(assign_op, feed_dict)
|
||||
@ -725,7 +725,7 @@ class AssignFromValuesFnTest(tf.test.TestCase):
|
||||
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
@ -754,7 +754,7 @@ class AssignFromValuesFnTest(tf.test.TestCase):
|
||||
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
@ -786,7 +786,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
||||
var_value = var_names_to_values[var_name]
|
||||
var_list.append(tf.Variable(var_value, name=var_name))
|
||||
saver = tf.train.Saver(var_list)
|
||||
init_op = tf.initialize_variables(var_list)
|
||||
init_op = tf.variables_initializer(var_list)
|
||||
sess.run(init_op)
|
||||
# Save the initialized values in the file at 'checkpoint_dir'
|
||||
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
||||
@ -808,7 +808,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
||||
model_path, vars_to_restore)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
sess.run(op, feed_dict)
|
||||
@ -859,7 +859,7 @@ class AssignFromCheckpointTest(tf.test.TestCase):
|
||||
vars_to_restore)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
sess.run(op, feed_dict)
|
||||
@ -890,7 +890,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
var_value = var_names_to_values[var_name]
|
||||
var_list.append(tf.Variable(var_value, name=var_name))
|
||||
saver = tf.train.Saver(var_list)
|
||||
init_op = tf.initialize_variables(var_list)
|
||||
init_op = tf.variables_initializer(var_list)
|
||||
sess.run(init_op)
|
||||
# Save the initialized values in the file at 'checkpoint_dir'
|
||||
return saver.save(sess, checkpoint_dir, global_step=global_step)
|
||||
@ -912,7 +912,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
model_path, vars_to_restore)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
@ -938,7 +938,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
model_path, vars_to_restore)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
@ -961,7 +961,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
model_path, vars_to_restore, reshape_variables=True)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
@ -989,7 +989,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
vars_to_restore)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
with self.assertRaises(tf.errors.NotFoundError):
|
||||
@ -1015,7 +1015,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
ignore_missing_vars=True)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
@ -1044,7 +1044,7 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
|
||||
ignore_missing_vars=True)
|
||||
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Perform the assignment.
|
||||
init_fn(sess)
|
||||
|
38
tensorflow/contrib/integrate/BUILD
Normal file
38
tensorflow/contrib/integrate/BUILD
Normal file
@ -0,0 +1,38 @@
|
||||
# Description:
|
||||
# Integration and ODE solvers for TensorFlow.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
py_library(
|
||||
name = "integrate_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/ops/odes.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "odes_test",
|
||||
srcs = ["python/ops/odes_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":integrate_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
)
|
9
tensorflow/contrib/integrate/README.md
Normal file
9
tensorflow/contrib/integrate/README.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Integration and ODE solvers for TensorFlow
|
||||
|
||||
TensorFlow equivalents to the routines provided by `scipy.integrate`. Currently
|
||||
contains a single function, `odeint`, for integrating ordinary differential
|
||||
equations.
|
||||
|
||||
Maintainers:
|
||||
- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
|
||||
- Marc Coram (mcoram@google.com, github.com/mcoram)
|
64
tensorflow/contrib/integrate/__init__.py
Normal file
64
tensorflow/contrib/integrate/__init__.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Integration and ODE solvers for TensorFlow.
|
||||
|
||||
## Example: Lorenz attractor
|
||||
|
||||
We can use `odeint` to solve the
|
||||
[Lorentz system](https://en.wikipedia.org/wiki/Lorenz_system) of ordinary
|
||||
differential equations, a prototypical example of chaotic dynamics:
|
||||
|
||||
```python
|
||||
rho = 28.0
|
||||
sigma = 10.0
|
||||
beta = 8.0/3.0
|
||||
|
||||
def lorenz_equation(state, t):
|
||||
x, y, z = tf.unpack(state)
|
||||
dx = sigma * (y - x)
|
||||
dy = x * (rho - z) - y
|
||||
dz = x * y - beta * z
|
||||
return tf.pack([dx, dy, dz])
|
||||
|
||||
init_state = tf.constant([0, 2, 20], dtype=tf.float64)
|
||||
t = np.linspace(0, 50, num=5000)
|
||||
tensor_state, tensor_info = tf.contrib.integrate.odeint(
|
||||
lorenz_equation, init_state, t, full_output=True)
|
||||
|
||||
sess = tf.Session()
|
||||
state, info = sess.run([tensor_state, tensor_info])
|
||||
x, y, z = state.T
|
||||
plt.plot(x, z)
|
||||
```
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/lorenz_attractor.png" alt>
|
||||
</div>
|
||||
|
||||
## Ops
|
||||
|
||||
@@odeint
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.integrate.python.ops.odes import *
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
__all__ = make_all(__name__)
|
503
tensorflow/contrib/integrate/python/ops/odes.py
Normal file
503
tensorflow/contrib/integrate/python/ops/odes.py
Normal file
@ -0,0 +1,503 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""ODE solvers for TensorFlow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
|
||||
|
||||
_ButcherTableau = collections.namedtuple(
|
||||
'_ButcherTableau', 'alpha beta c_sol c_mid c_error')
|
||||
|
||||
# Parameters from Shampine (1986), section 4.
|
||||
_DORMAND_PRINCE_TABLEAU = _ButcherTableau(
|
||||
alpha=[1/5, 3/10, 4/5, 8/9, 1., 1.],
|
||||
beta=[[1/5],
|
||||
[3/40, 9/40],
|
||||
[44/45, -56/15, 32/9],
|
||||
[19372/6561, -25360/2187, 64448/6561, -212/729],
|
||||
[9017/3168, -355/33, 46732/5247, 49/176, -5103/18656],
|
||||
[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]],
|
||||
c_sol=[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0],
|
||||
c_mid=[6025192743/30085553152 / 2, 0, 51252292925/65400821598 / 2,
|
||||
-2691868925/45128329728 / 2, 187940372067/1594534317056 / 2,
|
||||
-1776094331/19743644256 / 2, 11237099/235043384 / 2],
|
||||
c_error=[1951/21600 - 35/384,
|
||||
0,
|
||||
22642/50085 - 500/1113,
|
||||
451/720 - 125/192,
|
||||
-12231/42400 - -2187/6784,
|
||||
649/6300 - 11/84,
|
||||
1/60],
|
||||
)
|
||||
|
||||
|
||||
def _possibly_nonzero(x):
|
||||
return isinstance(x, ops.Tensor) or x != 0
|
||||
|
||||
|
||||
def _scaled_dot_product(scale, xs, ys, name=None):
|
||||
"""Calculate a scaled, vector inner product between lists of Tensors."""
|
||||
with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope:
|
||||
# Some of the parameters in our Butcher tableau include zeros. Using
|
||||
# _possibly_nonzero lets us avoid wasted computation.
|
||||
return math_ops.add_n([(scale * x) * y for x, y in zip(xs, ys)
|
||||
if _possibly_nonzero(x) or _possibly_nonzero(y)],
|
||||
name=scope)
|
||||
|
||||
|
||||
def _dot_product(xs, ys, name=None):
|
||||
"""Calculate the vector inner product between two lists of Tensors."""
|
||||
with ops.name_scope(name, 'dot_product', [xs, ys]) as scope:
|
||||
return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope)
|
||||
|
||||
|
||||
def _runge_kutta_step(func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_TABLEAU,
|
||||
name=None):
|
||||
"""Take an arbitrary Runge-Kutta step and estimate error.
|
||||
|
||||
Args:
|
||||
func: Function to evaluate like `func(y, t)` to compute the time derivative
|
||||
of `y`.
|
||||
y0: Tensor initial value for the state.
|
||||
f0: Tensor initial value for the derivative, computed from `func(y0, t0)`.
|
||||
t0: float64 scalar Tensor giving the initial time.
|
||||
dt: float64 scalar Tensor giving the size of the desired time step.
|
||||
tableau: optional _ButcherTableau describing how to take the Runge-Kutta
|
||||
step.
|
||||
name: optional name for the operation.
|
||||
|
||||
Returns:
|
||||
Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
|
||||
the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
|
||||
estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
|
||||
calculating these terms.
|
||||
"""
|
||||
with ops.name_scope(name, 'runge_kutta_step', [y0, f0, t0, dt]) as scope:
|
||||
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||
f0 = ops.convert_to_tensor(f0, name='f0')
|
||||
t0 = ops.convert_to_tensor(t0, name='t0')
|
||||
dt = ops.convert_to_tensor(dt, name='dt')
|
||||
dt_cast = math_ops.cast(dt, y0.dtype)
|
||||
|
||||
k = [f0]
|
||||
for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
|
||||
ti = t0 + alpha_i * dt
|
||||
yi = y0 + _scaled_dot_product(dt_cast, beta_i, k)
|
||||
k.append(func(yi, ti))
|
||||
|
||||
if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]):
|
||||
# This property (true for Dormand-Prince) lets us save a few FLOPs.
|
||||
yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k)
|
||||
|
||||
y1 = array_ops.identity(yi, name='%s/y1' % scope)
|
||||
f1 = array_ops.identity(k[-1], name='%s/f1' % scope)
|
||||
y1_error = _scaled_dot_product(dt_cast, tableau.c_error, k,
|
||||
name='%s/y1_error' % scope)
|
||||
return (y1, f1, y1_error, k)
|
||||
|
||||
|
||||
def _interp_fit(y0, y1, y_mid, f0, f1, dt):
|
||||
"""Fit coefficients for 4th order polynomial interpolation.
|
||||
|
||||
Args:
|
||||
y0: function value at the start of the interval.
|
||||
y1: function value at the end of the interval.
|
||||
y_mid: function value at the mid-point of the interval.
|
||||
f0: derivative value at the start of the interval.
|
||||
f1: derivative value at the end of the interval.
|
||||
dt: width of the interval.
|
||||
|
||||
Returns:
|
||||
List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
|
||||
`p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
|
||||
between 0 (start of interval) and 1 (end of interval).
|
||||
"""
|
||||
# a, b, c, d, e = sympy.symbols('a b c d e')
|
||||
# x, dt, y0, y1, y_mid, f0, f1 = sympy.symbols('x dt y0 y1 y_mid f0 f1')
|
||||
# p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
|
||||
# sympy.solve([p.subs(x, 0) - y0,
|
||||
# p.subs(x, 1 / 2) - y_mid,
|
||||
# p.subs(x, 1) - y1,
|
||||
# (p.diff(x) / dt).subs(x, 0) - f0,
|
||||
# (p.diff(x) / dt).subs(x, 1) - f1],
|
||||
# [a, b, c, d, e])
|
||||
# {a: -2.0*dt*f0 + 2.0*dt*f1 - 8.0*y0 - 8.0*y1 + 16.0*y_mid,
|
||||
# b: 5.0*dt*f0 - 3.0*dt*f1 + 18.0*y0 + 14.0*y1 - 32.0*y_mid,
|
||||
# c: -4.0*dt*f0 + dt*f1 - 11.0*y0 - 5.0*y1 + 16.0*y_mid,
|
||||
# d: dt*f0,
|
||||
# e: y0}
|
||||
a = _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0, f1, y0, y1, y_mid])
|
||||
b = _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0, f1, y0, y1, y_mid])
|
||||
c = _dot_product([-4 * dt, dt, -11, -5, 16], [f0, f1, y0, y1, y_mid])
|
||||
d = dt * f0
|
||||
e = y0
|
||||
return [a, b, c, d, e]
|
||||
|
||||
|
||||
def _interp_fit_rk(y0, y1, k, dt, tableau=_DORMAND_PRINCE_TABLEAU):
|
||||
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
|
||||
with ops.name_scope('interp_fit_rk'):
|
||||
dt = math_ops.cast(dt, y0.dtype)
|
||||
y_mid = y0 + _scaled_dot_product(dt, tableau.c_mid, k)
|
||||
f0 = k[0]
|
||||
f1 = k[-1]
|
||||
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
|
||||
|
||||
|
||||
def _interp_evaluate(coefficients, t0, t1, t):
|
||||
"""Evaluate polynomial interpolation at the given time point.
|
||||
|
||||
Args:
|
||||
coefficients: list of Tensor coefficients as created by `interp_fit`.
|
||||
t0: scalar float64 Tensor giving the start of the interval.
|
||||
t1: scalar float64 Tensor giving the end of the interval.
|
||||
t: scalar float64 Tensor giving the desired interpolation point.
|
||||
|
||||
Returns:
|
||||
Polynomial interpolation of the coefficients at time `t`.
|
||||
"""
|
||||
with ops.name_scope('interp_evaluate'):
|
||||
t0 = ops.convert_to_tensor(t0)
|
||||
t1 = ops.convert_to_tensor(t1)
|
||||
t = ops.convert_to_tensor(t)
|
||||
|
||||
dtype = coefficients[0].dtype
|
||||
|
||||
assert_op = control_flow_ops.Assert(
|
||||
(t0 <= t) & (t <= t1),
|
||||
['invalid interpolation, fails `t0 <= t <= t1`:', t0, t, t1])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
x = math_ops.cast((t - t0) / (t1 - t0), dtype)
|
||||
|
||||
xs = [constant_op.constant(1, dtype), x]
|
||||
for _ in range(2, len(coefficients)):
|
||||
xs.append(xs[-1] * x)
|
||||
|
||||
return _dot_product(coefficients, reversed(xs))
|
||||
|
||||
|
||||
def _optimal_step_size(last_step,
|
||||
error_ratio,
|
||||
safety=0.9,
|
||||
ifactor=10.0,
|
||||
dfactor=0.2,
|
||||
order=5,
|
||||
name=None):
|
||||
"""Calculate the optimal size for the next Runge-Kutta step."""
|
||||
with ops.name_scope(
|
||||
name, 'optimal_step_size', [last_step, error_ratio]) as scope:
|
||||
error_ratio = math_ops.cast(error_ratio, last_step.dtype)
|
||||
exponent = math_ops.cast(1 / order, last_step.dtype)
|
||||
# this looks more complex than necessary, but importantly it keeps
|
||||
# error_ratio in the numerator so we can't divide by zero:
|
||||
factor = math_ops.maximum(
|
||||
1 / ifactor,
|
||||
math_ops.minimum(error_ratio ** exponent / safety, 1 / dfactor))
|
||||
return math_ops.div(last_step, factor, name=scope)
|
||||
|
||||
|
||||
def _abs_square(x):
|
||||
if x.dtype.is_complex:
|
||||
return math_ops.square(math_ops.real(x)) + math_ops.square(math_ops.imag(x))
|
||||
else:
|
||||
return math_ops.square(x)
|
||||
|
||||
|
||||
def _ta_append(tensor_array, value):
|
||||
"""Append a value to the end of a tf.TensorArray."""
|
||||
return tensor_array.write(tensor_array.size(), value)
|
||||
|
||||
|
||||
class _RungeKuttaState(collections.namedtuple(
|
||||
'_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')):
|
||||
"""Saved state of the Runge Kutta solver.
|
||||
|
||||
Attributes:
|
||||
y1: Tensor giving the function value at the end of the last time step.
|
||||
f1: Tensor giving derivative at the end of the last time step.
|
||||
t0: scalar float64 Tensor giving start of the last time step.
|
||||
t1: scalar float64 Tensor giving end of the last time step.
|
||||
dt: scalar float64 Tensor giving the size for the next time step.
|
||||
interp_coef: list of Tensors giving coefficients for polynomial
|
||||
interpolation between `t0` and `t1`.
|
||||
"""
|
||||
|
||||
|
||||
class _History(collections.namedtuple(
|
||||
'_History', 'integrate_points, error_ratio')):
|
||||
"""Saved integration history for use in `info_dict`.
|
||||
|
||||
Attributes:
|
||||
integrate_points: tf.TensorArray storing integrating time points.
|
||||
error_ratio: tf.TensorArray storing computed error ratios at each
|
||||
integration step.
|
||||
"""
|
||||
|
||||
|
||||
def _dopri5(func,
|
||||
y0,
|
||||
t,
|
||||
rtol,
|
||||
atol,
|
||||
full_output=False,
|
||||
first_step=None,
|
||||
safety=0.9,
|
||||
ifactor=10.0,
|
||||
dfactor=0.2,
|
||||
max_num_steps=1000,
|
||||
name=None):
|
||||
"""Solve an ODE for `odeint` using method='dopri5'."""
|
||||
|
||||
if first_step is None:
|
||||
# at some point, we might want to switch to picking the step size
|
||||
# automatically
|
||||
first_step = 1.0
|
||||
|
||||
with ops.name_scope(
|
||||
name, 'dopri5',
|
||||
[y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps]) as scope:
|
||||
|
||||
first_step = ops.convert_to_tensor(first_step, dtype=t.dtype,
|
||||
name='first_step')
|
||||
safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety')
|
||||
ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor')
|
||||
dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor')
|
||||
max_num_steps = ops.convert_to_tensor(max_num_steps, dtype=dtypes.int32,
|
||||
name='max_num_steps')
|
||||
|
||||
def adaptive_runge_kutta_step(rk_state, history, n_steps):
|
||||
"""Take an adaptive Runge-Kutta step to integrate the ODE."""
|
||||
y0, f0, _, t0, dt, interp_coeff = rk_state
|
||||
with ops.name_scope('assertions'):
|
||||
check_underflow = control_flow_ops.Assert(
|
||||
t0 + dt > t0, ['underflow in dt', dt])
|
||||
check_max_num_steps = control_flow_ops.Assert(
|
||||
n_steps < max_num_steps, ['max_num_steps exceeded'])
|
||||
check_numerics = control_flow_ops.Assert(
|
||||
math_ops.reduce_all(math_ops.is_finite(abs(y0))),
|
||||
['non-finite values in state `y`', y0])
|
||||
with ops.control_dependencies(
|
||||
[check_underflow, check_max_num_steps, check_numerics]):
|
||||
y1, f1, y1_error, k = _runge_kutta_step(func, y0, f0, t0, dt)
|
||||
|
||||
with ops.name_scope('error_ratio'):
|
||||
# We use the same approach as the dopri5 fortran code.
|
||||
error_tol = atol + rtol * math_ops.maximum(abs(y0), abs(y1))
|
||||
tensor_error_ratio = _abs_square(y1_error) / _abs_square(error_tol)
|
||||
# Could also use reduce_maximum here.
|
||||
error_ratio = math_ops.sqrt(math_ops.reduce_mean(tensor_error_ratio))
|
||||
accept_step = error_ratio <= 1
|
||||
|
||||
with ops.name_scope('update/rk_state'):
|
||||
# If we don't accept the step, the _RungeKuttaState will be useless
|
||||
# (covering a time-interval of size 0), but that's OK, because in such
|
||||
# cases we always immediately take another Runge-Kutta step.
|
||||
y_next = control_flow_ops.cond(accept_step, lambda: y1, lambda: y0)
|
||||
f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0)
|
||||
t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0)
|
||||
interp_coeff = control_flow_ops.cond(
|
||||
accept_step,
|
||||
lambda: _interp_fit_rk(y0, y1, k, dt),
|
||||
lambda: interp_coeff)
|
||||
dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor)
|
||||
rk_state = _RungeKuttaState(
|
||||
y_next, f_next, t0, t_next, dt_next, interp_coeff)
|
||||
|
||||
with ops.name_scope('update/history'):
|
||||
history = _History(_ta_append(history.integrate_points, t0 + dt),
|
||||
_ta_append(history.error_ratio, error_ratio))
|
||||
return rk_state, history, n_steps + 1
|
||||
|
||||
def interpolate(solution, history, rk_state, i):
|
||||
"""Interpolate through the next time point, integrating as necessary."""
|
||||
with ops.name_scope('interpolate'):
|
||||
rk_state, history, _ = control_flow_ops.while_loop(
|
||||
lambda rk_state, *_: t[i] > rk_state.t1,
|
||||
adaptive_runge_kutta_step,
|
||||
(rk_state, history, 0),
|
||||
name='integrate_loop')
|
||||
y = _interp_evaluate(
|
||||
rk_state.interp_coeff, rk_state.t0, rk_state.t1, t[i])
|
||||
solution = solution.write(i, y)
|
||||
return solution, history, rk_state, i + 1
|
||||
|
||||
assert_increasing = control_flow_ops.Assert(
|
||||
math_ops.reduce_all(t[1:] > t[:-1]),
|
||||
['`t` must be monotonic increasing'])
|
||||
with ops.control_dependencies([assert_increasing]):
|
||||
num_times = array_ops.size(t)
|
||||
|
||||
solution = tensor_array_ops.TensorArray(
|
||||
y0.dtype, size=num_times).write(0, y0)
|
||||
history = _History(
|
||||
integrate_points=tensor_array_ops.TensorArray(
|
||||
t.dtype, size=0, dynamic_size=True),
|
||||
error_ratio=tensor_array_ops.TensorArray(
|
||||
rtol.dtype, size=0, dynamic_size=True))
|
||||
rk_state = _RungeKuttaState(
|
||||
y0, func(y0, t[0]), t[0], t[0], first_step, interp_coeff=[y0] * 5)
|
||||
|
||||
solution, history, _, _ = control_flow_ops.while_loop(
|
||||
lambda _, __, ___, i: i < num_times,
|
||||
interpolate,
|
||||
(solution, history, rk_state, 1),
|
||||
name='interpolate_loop')
|
||||
|
||||
y = solution.pack(name=scope)
|
||||
y.set_shape(t.get_shape().concatenate(y0.get_shape()))
|
||||
if not full_output:
|
||||
return y
|
||||
else:
|
||||
integrate_points = history.integrate_points.pack()
|
||||
info_dict = {'num_func_evals': 6 * array_ops.size(integrate_points) + 1,
|
||||
'integrate_points': integrate_points,
|
||||
'error_ratio': history.error_ratio.pack()}
|
||||
return (y, info_dict)
|
||||
|
||||
|
||||
def odeint(func,
|
||||
y0,
|
||||
t,
|
||||
rtol=1e-6,
|
||||
atol=1e-12,
|
||||
method=None,
|
||||
options=None,
|
||||
full_output=False,
|
||||
name=None):
|
||||
"""Integrate a system of ordinary differential equations.
|
||||
|
||||
Solves the initial value problem for a non-stiff system of first order ode-s:
|
||||
|
||||
```
|
||||
dy/dt = func(y, t), y(t[0]) = y0
|
||||
```
|
||||
|
||||
where y is a Tensor of any shape.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# solve `dy/dt = -y`, corresponding to exponential decay
|
||||
tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
|
||||
=> [1, exp(-1), exp(-2)]
|
||||
```
|
||||
|
||||
Output dtypes and numerical precision are based on the dtypes of the inputs
|
||||
`y0` and `t`.
|
||||
|
||||
Currently, implements 5th order Runge-Kutta with adaptive step size control
|
||||
and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
|
||||
method of `scipy.integrate.ode` and MATLAB's `ode45`.
|
||||
|
||||
Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
|
||||
Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
|
||||
doi:10.2307/2008219
|
||||
|
||||
Args:
|
||||
func: Function that maps a Tensor holding the state `y` and a scalar Tensor
|
||||
`t` into a Tensor of state derivatives with respect to time.
|
||||
y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
|
||||
have any floating point or complex dtype.
|
||||
t: 1-D Tensor holding a sequence of time points for which to solve for
|
||||
`y`. The initial time point should be the first element of this sequence,
|
||||
and each time must be larger than the previous time. May have any floating
|
||||
point dtype. If not provided as a Tensor, converted to a Tensor with
|
||||
float64 dtype.
|
||||
rtol: optional float64 Tensor specifying an upper bound on relative error,
|
||||
per element of `y`.
|
||||
atol: optional float64 Tensor specifying an upper bound on absolute error,
|
||||
per element of `y`.
|
||||
method: optional string indicating the integration method to use. Currently,
|
||||
the only valid option is `'dopri5'`.
|
||||
options: optional dict of configuring options for the indicated integration
|
||||
method. Can only be provided if a `method` is explicitly set. For
|
||||
`'dopri5'`, valid options include:
|
||||
* first_step: an initial guess for the size of the first integration
|
||||
(current default: 1.0, but may later be changed to use heuristics based
|
||||
on the gradient).
|
||||
* safety: safety factor for adaptive step control, generally a constant
|
||||
in the range 0.8-1 (default: 0.9).
|
||||
* ifactor: maximum factor by which the adaptive step may be increased
|
||||
(default: 10.0).
|
||||
* dfactor: maximum factor by which the adpative step may be decreased
|
||||
(default: 0.2).
|
||||
* max_num_steps: integer maximum number of integrate steps between time
|
||||
points in `t` (default: 1000).
|
||||
full_output: optional boolean. If True, `odeint` returns a tuple
|
||||
`(y, info_dict)` describing the integration process.
|
||||
name: Optional name for this operation.
|
||||
|
||||
Returns:
|
||||
y: (N+1)-D tensor, where the first dimension corresponds to different
|
||||
time points. Contains the solved value of y for each desired time point in
|
||||
`t`, with the initial value `y0` being the first element along the first
|
||||
dimension.
|
||||
info_dict: only if `full_output == True`. A dict with the following values:
|
||||
* num_func_evals: integer Tensor counting the number of function
|
||||
evaluations.
|
||||
* integrate_points: 1D float64 Tensor with the upper bound of each
|
||||
integration time step.
|
||||
* error_ratio: 1D float Tensor with the estimated ratio of the integration
|
||||
error to the error tolerance at each integration step. An ratio greater
|
||||
than 1 corresponds to rejected steps.
|
||||
|
||||
Raises:
|
||||
ValueError: if an invalid `method` is provided.
|
||||
TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
|
||||
an invalid dtype.
|
||||
"""
|
||||
if method is not None and method != 'dopri5':
|
||||
raise ValueError('invalid method: %r' % method)
|
||||
|
||||
if options is None:
|
||||
options = {}
|
||||
elif method is None:
|
||||
raise ValueError('cannot supply `options` without specifying `method`')
|
||||
|
||||
with ops.name_scope(name, 'odeint', [y0, t, rtol, atol]) as scope:
|
||||
# TODO(shoyer): use nest.flatten (like tf.while_loop) to allow `y0` to be an
|
||||
# arbitrarily nested tuple. This will help performance and usability by
|
||||
# avoiding the need to pack/unpack in user functions.
|
||||
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||
if not (y0.dtype.is_floating or y0.dtype.is_complex):
|
||||
raise TypeError('`y0` must have a floating point or complex floating '
|
||||
'point dtype')
|
||||
|
||||
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
||||
if not t.dtype.is_floating:
|
||||
raise TypeError('`t` must have a floating point dtype')
|
||||
|
||||
error_dtype = abs(y0).dtype
|
||||
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
|
||||
atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol')
|
||||
|
||||
return _dopri5(func, y0, t,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
full_output=full_output,
|
||||
name=scope,
|
||||
**options)
|
232
tensorflow/contrib/integrate/python/ops/odes_test.py
Normal file
232
tensorflow/contrib/integrate/python/ops/odes_test.py
Normal file
@ -0,0 +1,232 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for ODE solvers."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.integrate.python.ops import odes
|
||||
|
||||
|
||||
class OdeIntTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(OdeIntTest, self).setUp()
|
||||
# simple defaults (solution is a sin-wave)
|
||||
matrix = tf.constant([[0, 1], [-1, 0]], dtype=tf.float64)
|
||||
self.func = lambda y, t: tf.matmul(matrix, y)
|
||||
self.y0 = np.array([[1.0], [0.0]])
|
||||
|
||||
def test_odeint_exp(self):
|
||||
# Test odeint by an exponential function:
|
||||
# dy / dt = y, y(0) = 1.0.
|
||||
# Its analytical solution is y = exp(t).
|
||||
func = lambda y, t: y
|
||||
y0 = tf.constant(1.0, dtype=tf.float64)
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
y_solved = tf.contrib.integrate.odeint(func, y0, t)
|
||||
self.assertIn('odeint', y_solved.name)
|
||||
self.assertEqual(y_solved.get_shape(), tf.TensorShape([11]))
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
y_true = np.exp(t)
|
||||
self.assertAllClose(y_true, y_solved)
|
||||
|
||||
def test_odeint_complex(self):
|
||||
# Test a complex, linear ODE:
|
||||
# dy / dt = k * y, y(0) = 1.0.
|
||||
# Its analytical solution is y = exp(k * t).
|
||||
k = 1j - 0.1
|
||||
func = lambda y, t: k * y
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
y_solved = tf.contrib.integrate.odeint(func, 1.0 + 0.0j, t)
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
y_true = np.exp(k * t)
|
||||
self.assertAllClose(y_true, y_solved)
|
||||
|
||||
def test_odeint_riccati(self):
|
||||
# The Ricatti equation is:
|
||||
# dy / dt = (y - t) ** 2 + 1.0, y(0) = 0.5.
|
||||
# Its analytical solution is y = 1.0 / (2.0 - t) + t.
|
||||
func = lambda t, y: (y - t)**2 + 1.0
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
y_solved = tf.contrib.integrate.odeint(func, np.float64(0.5), t)
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
y_true = 1.0 / (2.0 - t) + t
|
||||
self.assertAllClose(y_true, y_solved)
|
||||
|
||||
def test_odeint_2d_linear(self):
|
||||
# Solve the 2D linear differential equation:
|
||||
# dy1 / dt = 3.0 * y1 + 4.0 * y2,
|
||||
# dy2 / dt = -4.0 * y1 + 3.0 * y2,
|
||||
# y1(0) = 0.0,
|
||||
# y2(0) = 1.0.
|
||||
# Its analytical solution is
|
||||
# y1 = sin(4.0 * t) * exp(3.0 * t),
|
||||
# y2 = cos(4.0 * t) * exp(3.0 * t).
|
||||
matrix = tf.constant([[3.0, 4.0], [-4.0, 3.0]], dtype=tf.float64)
|
||||
func = lambda y, t: tf.matmul(matrix, y)
|
||||
|
||||
y0 = tf.constant([[0.0], [1.0]], dtype=tf.float64)
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
|
||||
y_solved = tf.contrib.integrate.odeint(func, y0, t)
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
|
||||
y_true = np.zeros((len(t), 2, 1))
|
||||
y_true[:, 0, 0] = np.sin(4.0 * t) * np.exp(3.0 * t)
|
||||
y_true[:, 1, 0] = np.cos(4.0 * t) * np.exp(3.0 * t)
|
||||
self.assertAllClose(y_true, y_solved, atol=1e-5)
|
||||
|
||||
def test_odeint_higher_rank(self):
|
||||
func = lambda y, t: y
|
||||
y0 = tf.constant(1.0, dtype=tf.float64)
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
for shape in [(), (1,), (1, 1)]:
|
||||
expected_shape = (len(t),) + shape
|
||||
y_solved = tf.contrib.integrate.odeint(func, tf.reshape(y0, shape), t)
|
||||
self.assertEqual(y_solved.get_shape(), tf.TensorShape(expected_shape))
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
self.assertEquals(y_solved.shape, expected_shape)
|
||||
|
||||
def test_odeint_all_dtypes(self):
|
||||
func = lambda y, t: y
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
for y0_dtype in [tf.float32, tf.float64, tf.complex64, tf.complex128]:
|
||||
for t_dtype in [tf.float32, tf.float64]:
|
||||
y0 = tf.cast(1.0, y0_dtype)
|
||||
y_solved = tf.contrib.integrate.odeint(func, y0, tf.cast(t, t_dtype))
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
expected = np.asarray(np.exp(t))
|
||||
self.assertAllClose(y_solved, expected, rtol=1e-5)
|
||||
self.assertEqual(tf.as_dtype(y_solved.dtype), y0_dtype)
|
||||
|
||||
def test_odeint_required_dtypes(self):
|
||||
with self.assertRaisesRegexp(TypeError, '`y0` must have a floating point'):
|
||||
tf.contrib.integrate.odeint(self.func, tf.cast(self.y0, tf.int32), [0, 1])
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, '`t` must have a floating point'):
|
||||
tf.contrib.integrate.odeint(self.func, self.y0, tf.cast([0, 1], tf.int32))
|
||||
|
||||
def test_odeint_runtime_errors(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'cannot supply `options` without'):
|
||||
tf.contrib.integrate.odeint(self.func, self.y0, [0, 1],
|
||||
options={'first_step': 1.0})
|
||||
|
||||
y = tf.contrib.integrate.odeint(self.func, self.y0, [0, 1], method='dopri5',
|
||||
options={'max_num_steps': 0})
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.InvalidArgumentError, 'max_num_steps'):
|
||||
sess.run(y)
|
||||
|
||||
y = tf.contrib.integrate.odeint(self.func, self.y0, [1, 0])
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.InvalidArgumentError, 'monotonic increasing'):
|
||||
sess.run(y)
|
||||
|
||||
def test_odeint_different_times(self):
|
||||
# integrate steps should be independent of interpolation times
|
||||
times0 = np.linspace(0, 10, num=11, dtype=float)
|
||||
times1 = np.linspace(0, 10, num=101, dtype=float)
|
||||
|
||||
with self.test_session() as sess:
|
||||
y_solved_0, info_0 = sess.run(
|
||||
tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, times0, full_output=True))
|
||||
y_solved_1, info_1 = sess.run(
|
||||
tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, times1, full_output=True))
|
||||
|
||||
self.assertAllClose(y_solved_0, y_solved_1[::10])
|
||||
self.assertEqual(info_0['num_func_evals'], info_1['num_func_evals'])
|
||||
self.assertAllEqual(info_0['integrate_points'], info_1['integrate_points'])
|
||||
self.assertAllEqual(info_0['error_ratio'], info_1['error_ratio'])
|
||||
|
||||
def test_odeint_5th_order_accuracy(self):
|
||||
t = [0, 20]
|
||||
kwargs = dict(full_output=True,
|
||||
method='dopri5',
|
||||
options=dict(max_num_steps=2000))
|
||||
with self.test_session() as sess:
|
||||
_, info_0 = sess.run(tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
|
||||
_, info_1 = sess.run(tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs))
|
||||
self.assertAllClose(info_0['integrate_points'].size * 1000 ** 0.2,
|
||||
float(info_1['integrate_points'].size),
|
||||
rtol=0.01)
|
||||
|
||||
|
||||
class StepSizeTest(tf.test.TestCase):
|
||||
|
||||
def test_error_ratio_one(self):
|
||||
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||
error_ratio=tf.constant(1.0))
|
||||
with self.test_session() as sess:
|
||||
new_step = sess.run(new_step)
|
||||
self.assertAllClose(new_step, 0.9)
|
||||
|
||||
def test_ifactor(self):
|
||||
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||
error_ratio=tf.constant(0.0))
|
||||
with self.test_session() as sess:
|
||||
new_step = sess.run(new_step)
|
||||
self.assertAllClose(new_step, 10.0)
|
||||
|
||||
def test_dfactor(self):
|
||||
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||
error_ratio=tf.constant(1e6))
|
||||
with self.test_session() as sess:
|
||||
new_step = sess.run(new_step)
|
||||
self.assertAllClose(new_step, 0.2)
|
||||
|
||||
|
||||
class InterpolationTest(tf.test.TestCase):
|
||||
|
||||
def test_5th_order_polynomial(self):
|
||||
# this should be an exact fit
|
||||
f = lambda x: x ** 4 + x ** 3 - 2 * x ** 2 + 4 * x + 5
|
||||
f_prime = lambda x: 4 * x ** 3 + 3 * x ** 2 - 4 * x + 4
|
||||
coeffs = odes._interp_fit(
|
||||
f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0)
|
||||
times = np.linspace(0, 10, dtype=np.float32)
|
||||
y_fit = tf.pack([odes._interp_evaluate(coeffs, 0.0, 10.0, t)
|
||||
for t in times])
|
||||
y_expected = f(times)
|
||||
with self.test_session() as sess:
|
||||
y_actual = sess.run(y_fit)
|
||||
self.assertAllClose(y_expected, y_actual)
|
||||
|
||||
# attempt interpolation outside bounds
|
||||
y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
sess.run(y_invalid)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -22,6 +22,7 @@ from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
@ -114,8 +115,9 @@ def safe_embedding_lookup_sparse(embedding_weights,
|
||||
array_ops.slice(original_shape, [0], [original_rank - 1])),
|
||||
array_ops.gather(original_shape, original_rank - 1)])
|
||||
if sparse_weights is not None:
|
||||
sparse_weights = ops.SparseTensor(sparse_ids.indices,
|
||||
sparse_weights.values, sparse_ids.shape)
|
||||
sparse_weights = sparse_tensor.SparseTensor(
|
||||
sparse_ids.indices,
|
||||
sparse_weights.values, sparse_ids.shape)
|
||||
|
||||
# Prune invalid ids and weights.
|
||||
sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
|
||||
@ -302,7 +304,7 @@ def hashed_embedding_lookup_sparse(params,
|
||||
params = list(params)
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
if not isinstance(sparse_values, ops.SparseTensor):
|
||||
if not isinstance(sparse_values, sparse_tensor.SparseTensor):
|
||||
raise TypeError("sparse_values must be SparseTensor")
|
||||
|
||||
with ops.name_scope(name, "hashed_sparse_embedding_lookup",
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.contrib.layers.python.layers import embedding_ops as contrib_embedding_ops
|
||||
from tensorflow.contrib.layers.python.ops import sparse_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -74,14 +74,14 @@ def bow_encoder(ids,
|
||||
initializer=initializer, regularizer=regularizer,
|
||||
trainable=trainable)
|
||||
if sparse_lookup:
|
||||
if isinstance(ids, ops.SparseTensor):
|
||||
if isinstance(ids, sparse_tensor.SparseTensor):
|
||||
sparse_ids = ids
|
||||
else:
|
||||
sparse_ids = sparse_ops.dense_to_sparse_tensor(ids)
|
||||
return contrib_embedding_ops.safe_embedding_lookup_sparse(
|
||||
[embeddings], sparse_ids, combiner='mean', default_id=0)
|
||||
else:
|
||||
if isinstance(ids, ops.SparseTensor):
|
||||
if isinstance(ids, sparse_tensor.SparseTensor):
|
||||
raise TypeError('ids are expected to be dense Tensor, got: %s', ids)
|
||||
return math_ops.reduce_mean(
|
||||
embedding_ops.embedding_lookup(embeddings, ids),
|
||||
|
@ -76,13 +76,12 @@ import collections
|
||||
import math
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import deprecation
|
||||
from tensorflow.contrib.layers.python.layers import layers
|
||||
from tensorflow.contrib.layers.python.ops import bucketization_op
|
||||
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
||||
from tensorflow.contrib.lookup import lookup_ops as contrib_lookup_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -90,6 +89,7 @@ from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import deprecation
|
||||
|
||||
|
||||
class _LinearEmbeddingLookupArguments(
|
||||
@ -390,7 +390,7 @@ class _SparseColumnIntegerized(_SparseColumn):
|
||||
sparse_id_values = math_ops.mod(columns_to_tensors[self.name].values,
|
||||
self.bucket_size,
|
||||
name="mod")
|
||||
columns_to_tensors[self] = ops.SparseTensor(
|
||||
columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
|
||||
columns_to_tensors[self.name].indices, sparse_id_values,
|
||||
columns_to_tensors[self.name].shape)
|
||||
|
||||
@ -464,7 +464,7 @@ class _SparseColumnHashed(_SparseColumn):
|
||||
|
||||
sparse_id_values = string_ops.string_to_hash_bucket_fast(
|
||||
sparse_values, self.bucket_size, name="lookup")
|
||||
columns_to_tensors[self] = ops.SparseTensor(
|
||||
columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
|
||||
sparse_tensor.indices, sparse_id_values, sparse_tensor.shape)
|
||||
|
||||
|
||||
@ -1452,7 +1452,8 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
||||
|
||||
indices = math_ops.to_int64(array_ops.transpose(array_ops.pack((i1, i2))))
|
||||
shape = math_ops.to_int64(array_ops.pack([batch_size, dimension]))
|
||||
sparse_id_values = ops.SparseTensor(indices, bucket_indices, shape)
|
||||
sparse_id_values = sparse_tensor_py.SparseTensor(
|
||||
indices, bucket_indices, shape)
|
||||
|
||||
return sparse_id_values
|
||||
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.contrib.layers.python.layers import feature_column as fc
|
||||
from tensorflow.contrib.layers.python.layers import layers
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -362,9 +363,9 @@ def _create_joint_embedding_lookup(columns_to_tensors,
|
||||
values = t.values + prev_size
|
||||
prev_size += a.vocab_size
|
||||
sparse_tensors.append(
|
||||
ops.SparseTensor(t.indices,
|
||||
values,
|
||||
t.shape))
|
||||
sparse_tensor_py.SparseTensor(t.indices,
|
||||
values,
|
||||
t.shape))
|
||||
sparse_tensor = sparse_ops.sparse_concat(1, sparse_tensors)
|
||||
with variable_scope.variable_scope(
|
||||
None, default_name='linear_weights', values=columns_to_tensors.values()):
|
||||
@ -695,7 +696,7 @@ def _log_variable(variable):
|
||||
|
||||
def _infer_real_valued_column_for_tensor(name, tensor):
|
||||
"""Creates a real_valued_column for given tensor and name."""
|
||||
if isinstance(tensor, ops.SparseTensor):
|
||||
if isinstance(tensor, sparse_tensor_py.SparseTensor):
|
||||
raise ValueError(
|
||||
'SparseTensor is not supported for auto detection. Please define '
|
||||
'corresponding FeatureColumn for tensor {} {}.', name, tensor)
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -609,7 +610,10 @@ class FeatureColumnTest(tf.test.TestCase):
|
||||
{embedding_col: input_tensor}, [embedding_col])
|
||||
|
||||
save = tf.train.Saver()
|
||||
checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt")
|
||||
ckpt_dir_prefix = os.path.join(
|
||||
self.get_temp_dir(), "init_embedding_col_w_from_ckpt")
|
||||
ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
|
||||
checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
@ -670,7 +674,10 @@ class FeatureColumnTest(tf.test.TestCase):
|
||||
assign_op = tf.assign(weight[0], weight[0] + 0.5)
|
||||
|
||||
save = tf.train.Saver()
|
||||
checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt")
|
||||
ckpt_dir_prefix = os.path.join(
|
||||
self.get_temp_dir(), "init_crossed_col_w_from_ckpt")
|
||||
ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
|
||||
checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.contrib.layers.python.layers import utils
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -1217,7 +1218,7 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None):
|
||||
TypeError: `inputs` is not a `Tensor` or `SparseTensor`.
|
||||
"""
|
||||
with ops.name_scope(scope, 'InnerFlatten', [inputs, new_rank]) as sc:
|
||||
if isinstance(inputs, ops.SparseTensor):
|
||||
if isinstance(inputs, sparse_tensor.SparseTensor):
|
||||
flattened = _sparse_inner_flatten(inputs, new_rank)
|
||||
else:
|
||||
inputs = ops.convert_to_tensor(inputs)
|
||||
|
@ -258,10 +258,11 @@ def optimize_loss(loss,
|
||||
grad_values = gradient
|
||||
|
||||
if grad_values is not None:
|
||||
var_name = variable.name.replace(":", "_")
|
||||
if "gradients" in summaries:
|
||||
summary.histogram("gradients/" + variable.name, grad_values)
|
||||
summary.histogram("gradients/%s" % var_name, grad_values)
|
||||
if "gradient_norm" in summaries:
|
||||
summary.scalar("gradient_norm/" + variable.name,
|
||||
summary.scalar("gradient_norm/%s" % var_name,
|
||||
clip_ops.global_norm([grad_values]))
|
||||
|
||||
if clip_gradients is not None and "gradient_norm" in summaries:
|
||||
|
@ -58,7 +58,7 @@ class MultiClassTargetColumnTest(tf.test.TestCase):
|
||||
labels = tf.constant([[1.], [0.]])
|
||||
# logloss: z:label, x:logit
|
||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
self.assertAlmostEqual(.81326163,
|
||||
self.assertAlmostEqual(0.81326175,
|
||||
sess.run(target_column.loss(logits, labels, {})))
|
||||
|
||||
def testBinaryClassificationWithWeights(self):
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
@ -69,12 +70,14 @@ def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0,
|
||||
"""
|
||||
if not isinstance(inputs, list):
|
||||
raise TypeError("Inputs must be a list")
|
||||
if not all(isinstance(i, ops.SparseTensor) or
|
||||
if not all(isinstance(i, sparse_tensor.SparseTensor) or
|
||||
isinstance(i, ops.Tensor) for i in inputs):
|
||||
raise TypeError("All inputs must be SparseTensors")
|
||||
|
||||
sparse_inputs = [i for i in inputs if isinstance(i, ops.SparseTensor)]
|
||||
dense_inputs = [i for i in inputs if not isinstance(i, ops.SparseTensor)]
|
||||
sparse_inputs = [i for i in inputs
|
||||
if isinstance(i, sparse_tensor.SparseTensor)]
|
||||
dense_inputs = [i for i in inputs
|
||||
if not isinstance(i, sparse_tensor.SparseTensor)]
|
||||
|
||||
indices = [sp_input.indices for sp_input in sparse_inputs]
|
||||
values = [sp_input.values for sp_input in sparse_inputs]
|
||||
@ -117,7 +120,7 @@ def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0,
|
||||
internal_type=internal_type,
|
||||
name=name))
|
||||
|
||||
return ops.SparseTensor(indices_out, values_out, shape_out)
|
||||
return sparse_tensor.SparseTensor(indices_out, values_out, shape_out)
|
||||
|
||||
|
||||
ops.RegisterShape("SparseFeatureCross")(common_shapes.call_cpp_shape_fn)
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
@ -78,4 +79,4 @@ def dense_to_sparse_tensor(dense_tensor, ignore_value=None):
|
||||
math_ops.mul(higher_dims, shape_multipliers), reduction_indices=[1])
|
||||
flat_indices = math_ops.add(flat_indices, offsets)
|
||||
values = array_ops.gather(flat_tensor, flat_indices)
|
||||
return ops.SparseTensor(indices, values, dense_shape)
|
||||
return sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||
|
@ -291,7 +291,9 @@ py_test(
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:test_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,11 +22,12 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.layers import feature_column
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import series as ss
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
|
||||
|
||||
def _to_feature_spec(tensor, default_value=None):
|
||||
if isinstance(tensor, ops.SparseTensor):
|
||||
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||
return parsing_ops.VarLenFeature(dtype=tensor.dtype)
|
||||
else:
|
||||
return parsing_ops.FixedLenFeature(shape=tensor.get_shape(),
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import series
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
# Each entry is a mapping from registered_name to operation. Each operation is
|
||||
@ -55,8 +55,8 @@ class SeriesBinaryTransform(transform.TensorFlowTransform):
|
||||
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
# TODO(jamieas): consider supporting sparse inputs.
|
||||
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
|
||||
input_tensors[1], ops.SparseTensor):
|
||||
if isinstance(input_tensors[0], sparse_tensor.SparseTensor) or isinstance(
|
||||
input_tensors[1], sparse_tensor.SparseTensor):
|
||||
raise TypeError("{} does not support SparseTensors".format(
|
||||
type(self).__name__))
|
||||
|
||||
@ -89,10 +89,10 @@ class ScalarBinaryTransform(transform.TensorFlowTransform):
|
||||
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
input_tensor = input_tensors[0]
|
||||
if isinstance(input_tensor, ops.SparseTensor):
|
||||
result = ops.SparseTensor(input_tensor.indices,
|
||||
self._apply_op(input_tensor.values),
|
||||
input_tensor.shape)
|
||||
if isinstance(input_tensor, sparse_tensor.SparseTensor):
|
||||
result = sparse_tensor.SparseTensor(input_tensor.indices,
|
||||
self._apply_op(input_tensor.values),
|
||||
input_tensor.shape)
|
||||
else:
|
||||
result = self._apply_op(input_tensor)
|
||||
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.contrib.learn.python.learn.dataframe import series
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -93,7 +94,7 @@ class BooleanMask(transform.TensorFlowTransform):
|
||||
if mask.get_shape().ndims > 1:
|
||||
mask = array_ops.squeeze(mask)
|
||||
|
||||
if isinstance(input_tensor, ops.SparseTensor):
|
||||
if isinstance(input_tensor, sparse_tensor_py.SparseTensor):
|
||||
mask_fn = sparse_boolean_mask
|
||||
else:
|
||||
mask_fn = array_ops.boolean_mask
|
||||
|
@ -21,14 +21,14 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import series
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
|
||||
|
||||
def _negate_sparse(sparse_tensor):
|
||||
return ops.SparseTensor(indices=sparse_tensor.indices,
|
||||
values=-sparse_tensor.values,
|
||||
shape=sparse_tensor.shape)
|
||||
def _negate_sparse(st):
|
||||
return sparse_tensor.SparseTensor(indices=st.indices,
|
||||
values=-st.values,
|
||||
shape=st.shape)
|
||||
|
||||
|
||||
@series.Series.register_binary_op("__sub__")
|
||||
@ -51,8 +51,8 @@ class Difference(transform.TensorFlowTransform):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
||||
isinstance(input_tensors[1], ops.SparseTensor))
|
||||
pair_sparsity = (isinstance(input_tensors[0], sparse_tensor.SparseTensor),
|
||||
isinstance(input_tensors[1], sparse_tensor.SparseTensor))
|
||||
|
||||
if pair_sparsity == (False, False):
|
||||
result = input_tensors[0] - input_tensors[1]
|
||||
|
@ -24,7 +24,7 @@ import numpy as np
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
@ -82,4 +82,5 @@ class Sparsify(transform.TensorFlowTransform):
|
||||
shape = math_ops.cast(array_ops.shape(d), dtypes.int64)
|
||||
|
||||
# pylint: disable=not-callable
|
||||
return self.return_type(ops.SparseTensor(sparse_indices, values, shape))
|
||||
return self.return_type(
|
||||
sparse_tensor.SparseTensor(sparse_indices, values, shape))
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import series
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
|
||||
|
||||
@ -45,8 +45,8 @@ class Sum(transform.TensorFlowTransform):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
||||
isinstance(input_tensors[1], ops.SparseTensor))
|
||||
pair_sparsity = (isinstance(input_tensors[0], sparse_tensor.SparseTensor),
|
||||
isinstance(input_tensors[1], sparse_tensor.SparseTensor))
|
||||
|
||||
if pair_sparsity == (False, False):
|
||||
result = input_tensors[0] + input_tensors[1]
|
||||
@ -57,6 +57,3 @@ class Sum(transform.TensorFlowTransform):
|
||||
|
||||
# pylint: disable=not-callable
|
||||
return self.return_type(result)
|
||||
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import series
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
# Each entry is a mapping from registered_name to operation. Each operation is
|
||||
@ -83,10 +83,10 @@ def register_unary_op(registered_name, operation, ignore_dtype=None):
|
||||
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
input_tensor = input_tensors[0]
|
||||
if isinstance(input_tensor, ops.SparseTensor):
|
||||
result = ops.SparseTensor(input_tensor.indices,
|
||||
operation(input_tensor.values),
|
||||
input_tensor.shape)
|
||||
if isinstance(input_tensor, sparse_tensor.SparseTensor):
|
||||
result = sparse_tensor.SparseTensor(input_tensor.indices,
|
||||
operation(input_tensor.values),
|
||||
input_tensor.shape)
|
||||
else:
|
||||
result = operation(input_tensor)
|
||||
# pylint: disable=not-callable
|
||||
|
@ -29,11 +29,11 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import MetricKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import PredictionKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
|
||||
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
|
||||
|
@ -19,7 +19,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import re
|
||||
import six
|
||||
|
||||
from tensorflow.contrib import layers
|
||||
@ -27,13 +28,23 @@ from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.contrib.framework import deprecated_arg_values
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
||||
from tensorflow.contrib.layers.python.layers import optimizers
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import session_run_hook
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import composable_model
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.learn.python.learn.utils import export
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
|
||||
class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
@ -307,7 +318,236 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
return logits
|
||||
|
||||
|
||||
class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
|
||||
_CENTERED_BIAS_WEIGHT = "centered_bias_weight"
|
||||
|
||||
# The default learning rates are a historical artifact of the initial
|
||||
# implementation, but seem a reasonable choice.
|
||||
_DNN_LEARNING_RATE = 0.05
|
||||
_LINEAR_LEARNING_RATE = 0.2
|
||||
|
||||
|
||||
def _as_iterable(preds, output):
|
||||
for pred in preds:
|
||||
yield pred[output]
|
||||
|
||||
|
||||
def _get_feature_dict(features):
|
||||
if isinstance(features, dict):
|
||||
return features
|
||||
return {"": features}
|
||||
|
||||
|
||||
def _get_optimizer(optimizer):
|
||||
if callable(optimizer):
|
||||
return optimizer()
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
|
||||
def _linear_learning_rate(num_linear_feature_columns):
|
||||
"""Returns the default learning rate of the linear model.
|
||||
|
||||
The calculation is a historical artifact of this initial implementation, but
|
||||
has proven a reasonable choice.
|
||||
|
||||
Args:
|
||||
num_linear_feature_columns: The number of feature columns of the linear
|
||||
model.
|
||||
|
||||
Returns:
|
||||
A float.
|
||||
"""
|
||||
default_learning_rate = 1. / math.sqrt(num_linear_feature_columns)
|
||||
return min(_LINEAR_LEARNING_RATE, default_learning_rate)
|
||||
|
||||
|
||||
def _add_hidden_layer_summary(value, tag):
|
||||
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
|
||||
nn.zero_fraction(value))
|
||||
logging_ops.histogram_summary("%s:activation" % tag, value)
|
||||
|
||||
|
||||
def _dnn_linear_combined_model_fn(features, labels, mode, params):
|
||||
"""Deep Neural Net and Linear combined model_fn.
|
||||
|
||||
Args:
|
||||
features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`).
|
||||
labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype
|
||||
`int32` or `int64` in the range `[0, n_classes)`.
|
||||
mode: Defines whether this is training, evaluation or prediction.
|
||||
See `ModeKeys`.
|
||||
params: A dict of hyperparameters.
|
||||
The following hyperparameters are expected:
|
||||
* head: A `Head` instance.
|
||||
* linear_feature_columns: An iterable containing all the feature columns
|
||||
used by the Linear model.
|
||||
* linear_optimizer: string, `Optimizer` object, or callable that defines
|
||||
the optimizer to use for training the Linear model.
|
||||
* joint_linear_weights: If True a single (possibly partitioned) variable
|
||||
will be used to store the linear model weights. It's faster, but
|
||||
requires all columns are sparse and have the 'sum' combiner.
|
||||
* dnn_feature_columns: An iterable containing all the feature columns used
|
||||
by the DNN model.
|
||||
* dnn_optimizer: string, `Optimizer` object, or callable that defines the
|
||||
optimizer to use for training the DNN model.
|
||||
* dnn_hidden_units: List of hidden units per DNN layer.
|
||||
* dnn_activation_fn: Activation function applied to each DNN layer. If
|
||||
`None`, will use `tf.nn.relu`.
|
||||
* dnn_dropout: When not `None`, the probability we will drop out a given
|
||||
DNN coordinate.
|
||||
* gradient_clip_norm: A float > 0. If provided, gradients are
|
||||
clipped to their global norm with this clipping ratio.
|
||||
* num_ps_replicas: The number of parameter server replicas.
|
||||
|
||||
Returns:
|
||||
`estimator.ModelFnOps`
|
||||
|
||||
Raises:
|
||||
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
|
||||
are empty at the same time.
|
||||
"""
|
||||
head = params["head"]
|
||||
linear_feature_columns = params.get("linear_feature_columns")
|
||||
linear_optimizer = params.get("linear_optimizer")
|
||||
joint_linear_weights = params.get("joint_linear_weights")
|
||||
dnn_feature_columns = params.get("dnn_feature_columns")
|
||||
dnn_optimizer = params.get("dnn_optimizer")
|
||||
dnn_hidden_units = params.get("dnn_hidden_units")
|
||||
dnn_activation_fn = params.get("dnn_activation_fn")
|
||||
dnn_dropout = params.get("dnn_dropout")
|
||||
gradient_clip_norm = params.get("gradient_clip_norm")
|
||||
num_ps_replicas = params["num_ps_replicas"]
|
||||
|
||||
if not linear_feature_columns and not dnn_feature_columns:
|
||||
raise ValueError(
|
||||
"Either linear_feature_columns or dnn_feature_columns must be defined.")
|
||||
|
||||
features = _get_feature_dict(features)
|
||||
|
||||
# Build DNN Logits.
|
||||
dnn_parent_scope = "dnn"
|
||||
|
||||
if not dnn_feature_columns:
|
||||
dnn_logits = None
|
||||
else:
|
||||
input_layer_partitioner = (
|
||||
partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=num_ps_replicas,
|
||||
min_slice_size=64 << 20))
|
||||
with variable_scope.variable_scope(
|
||||
dnn_parent_scope + "/input_from_feature_columns",
|
||||
values=features.values(),
|
||||
partitioner=input_layer_partitioner) as scope:
|
||||
net = layers.input_from_feature_columns(
|
||||
columns_to_tensors=features,
|
||||
feature_columns=dnn_feature_columns,
|
||||
weight_collections=[dnn_parent_scope],
|
||||
scope=scope)
|
||||
|
||||
hidden_layer_partitioner = (
|
||||
partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=num_ps_replicas))
|
||||
for layer_id, num_hidden_units in enumerate(dnn_hidden_units):
|
||||
with variable_scope.variable_scope(
|
||||
dnn_parent_scope + "/hiddenlayer_%d" % layer_id,
|
||||
values=[net],
|
||||
partitioner=hidden_layer_partitioner) as scope:
|
||||
net = layers.fully_connected(
|
||||
net,
|
||||
num_hidden_units,
|
||||
activation_fn=dnn_activation_fn,
|
||||
variables_collections=[dnn_parent_scope],
|
||||
scope=scope)
|
||||
if dnn_dropout is not None and mode == estimator.ModeKeys.TRAIN:
|
||||
net = layers.dropout(
|
||||
net,
|
||||
keep_prob=(1.0 - dnn_dropout))
|
||||
# TODO(b/31209633): Consider adding summary before dropout.
|
||||
_add_hidden_layer_summary(net, scope.name)
|
||||
|
||||
with variable_scope.variable_scope(
|
||||
dnn_parent_scope + "/logits",
|
||||
values=[net],
|
||||
partitioner=hidden_layer_partitioner) as scope:
|
||||
dnn_logits = layers.fully_connected(
|
||||
net,
|
||||
head.logits_dimension,
|
||||
activation_fn=None,
|
||||
variables_collections=[dnn_parent_scope],
|
||||
scope=scope)
|
||||
_add_hidden_layer_summary(dnn_logits, scope.name)
|
||||
|
||||
# Build Linear logits.
|
||||
linear_parent_scope = "linear"
|
||||
|
||||
if not linear_feature_columns:
|
||||
linear_logits = None
|
||||
else:
|
||||
linear_partitioner = partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=num_ps_replicas,
|
||||
min_slice_size=64 << 20)
|
||||
with variable_scope.variable_scope(
|
||||
linear_parent_scope,
|
||||
values=features.values(),
|
||||
partitioner=linear_partitioner) as scope:
|
||||
if joint_linear_weights:
|
||||
linear_logits, _, _ = layers.joint_weighted_sum_from_feature_columns(
|
||||
columns_to_tensors=features,
|
||||
feature_columns=linear_feature_columns,
|
||||
num_outputs=head.logits_dimension,
|
||||
weight_collections=[linear_parent_scope],
|
||||
scope=scope)
|
||||
else:
|
||||
linear_logits, _, _ = layers.weighted_sum_from_feature_columns(
|
||||
columns_to_tensors=features,
|
||||
feature_columns=linear_feature_columns,
|
||||
num_outputs=head.logits_dimension,
|
||||
weight_collections=[linear_parent_scope],
|
||||
scope=scope)
|
||||
|
||||
# Combine logits and build full model.
|
||||
if dnn_logits is not None and linear_logits is not None:
|
||||
logits = dnn_logits + linear_logits
|
||||
elif dnn_logits is not None:
|
||||
logits = dnn_logits
|
||||
else:
|
||||
logits = linear_logits
|
||||
|
||||
def _make_training_op(training_loss):
|
||||
"""Training op for the DNN linear combined model."""
|
||||
train_ops = []
|
||||
if dnn_logits is not None:
|
||||
train_ops.append(
|
||||
optimizers.optimize_loss(
|
||||
loss=training_loss,
|
||||
global_step=contrib_variables.get_global_step(),
|
||||
learning_rate=_DNN_LEARNING_RATE,
|
||||
optimizer=_get_optimizer(dnn_optimizer),
|
||||
clip_gradients=gradient_clip_norm,
|
||||
variables=ops.get_collection(dnn_parent_scope),
|
||||
name=dnn_parent_scope,
|
||||
# Empty summaries, because head already logs "loss" summary.
|
||||
summaries=[]))
|
||||
if linear_logits is not None:
|
||||
train_ops.append(
|
||||
optimizers.optimize_loss(
|
||||
loss=training_loss,
|
||||
global_step=contrib_variables.get_global_step(),
|
||||
learning_rate=_linear_learning_rate(len(linear_feature_columns)),
|
||||
optimizer=_get_optimizer(linear_optimizer),
|
||||
clip_gradients=gradient_clip_norm,
|
||||
variables=ops.get_collection(linear_parent_scope),
|
||||
name=linear_parent_scope,
|
||||
# Empty summaries, because head already logs "loss" summary.
|
||||
summaries=[]))
|
||||
|
||||
return control_flow_ops.group(*train_ops)
|
||||
|
||||
return head.head_ops(
|
||||
features, labels, mode, _make_training_op, logits=logits)
|
||||
|
||||
|
||||
class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
"""A classifier for TensorFlow Linear and DNN joined training models.
|
||||
|
||||
Example:
|
||||
@ -423,30 +663,71 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
|
||||
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
|
||||
are empty at the same time.
|
||||
"""
|
||||
|
||||
if n_classes < 2:
|
||||
raise ValueError("n_classes should be greater than 1. Given: {}".format(
|
||||
n_classes))
|
||||
self._linear_optimizer = linear_optimizer or "Ftrl"
|
||||
linear_feature_columns = linear_feature_columns or []
|
||||
dnn_feature_columns = dnn_feature_columns or []
|
||||
self._feature_columns = linear_feature_columns + dnn_feature_columns
|
||||
if not self._feature_columns:
|
||||
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
|
||||
"must be defined.")
|
||||
self._dnn_hidden_units = dnn_hidden_units
|
||||
self._enable_centered_bias = enable_centered_bias
|
||||
|
||||
head = head_lib._multi_class_head( # pylint: disable=protected-access
|
||||
n_classes=n_classes,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=enable_centered_bias)
|
||||
super(DNNLinearCombinedClassifier, self).__init__(
|
||||
self._estimator = estimator.Estimator(
|
||||
model_fn=_dnn_linear_combined_model_fn,
|
||||
model_dir=model_dir,
|
||||
linear_feature_columns=linear_feature_columns,
|
||||
linear_optimizer=linear_optimizer,
|
||||
_joint_linear_weights=_joint_linear_weights,
|
||||
dnn_feature_columns=dnn_feature_columns,
|
||||
dnn_optimizer=dnn_optimizer,
|
||||
dnn_hidden_units=dnn_hidden_units,
|
||||
dnn_activation_fn=dnn_activation_fn,
|
||||
dnn_dropout=dnn_dropout,
|
||||
gradient_clip_norm=gradient_clip_norm,
|
||||
head=head,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn,
|
||||
default_prediction_key=head_lib.PredictionKey.CLASSES,
|
||||
enable_centered_bias=enable_centered_bias)
|
||||
params={
|
||||
"head": head,
|
||||
"linear_feature_columns": linear_feature_columns,
|
||||
"linear_optimizer": self._linear_optimizer,
|
||||
"joint_linear_weights": _joint_linear_weights,
|
||||
"dnn_feature_columns": dnn_feature_columns,
|
||||
"dnn_optimizer": dnn_optimizer or "Adagrad",
|
||||
"dnn_hidden_units": dnn_hidden_units,
|
||||
"dnn_activation_fn": dnn_activation_fn,
|
||||
"dnn_dropout": dnn_dropout,
|
||||
"gradient_clip_norm": gradient_clip_norm,
|
||||
"num_ps_replicas": config.num_ps_replicas if config else 0,
|
||||
},
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
|
||||
monitors=None, max_steps=None):
|
||||
"""See trainable.Trainable."""
|
||||
# TODO(roumposg): Remove when deprecated monitors are removed.
|
||||
if monitors is not None:
|
||||
deprecated_monitors = [
|
||||
m for m in monitors
|
||||
if not isinstance(m, session_run_hook.SessionRunHook)
|
||||
]
|
||||
for monitor in deprecated_monitors:
|
||||
monitor.set_estimator(self)
|
||||
monitor._lock_estimator() # pylint: disable=protected-access
|
||||
|
||||
result = self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps,
|
||||
batch_size=batch_size, monitors=monitors,
|
||||
max_steps=max_steps)
|
||||
|
||||
if monitors is not None:
|
||||
for monitor in deprecated_monitors:
|
||||
monitor._unlock_estimator() # pylint: disable=protected-access
|
||||
|
||||
return result
|
||||
|
||||
def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
|
||||
batch_size=None, steps=None, metrics=None, name=None):
|
||||
"""See evaluable.Evaluable."""
|
||||
return self._estimator.evaluate(
|
||||
x=x, y=y, input_fn=input_fn, feed_fn=feed_fn, batch_size=batch_size,
|
||||
steps=steps, metrics=metrics, name=name)
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
@ -467,12 +748,16 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
|
||||
Numpy array of predicted classes (or an iterable of predicted classes if
|
||||
as_iterable is True).
|
||||
"""
|
||||
predictions = self.predict_proba(
|
||||
x=x, input_fn=input_fn, batch_size=batch_size, as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.CLASSES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return (np.argmax(p, axis=0) for p in predictions)
|
||||
else:
|
||||
return np.argmax(predictions, axis=1)
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key].reshape(-1)
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
@ -494,13 +779,133 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
|
||||
Numpy array of predicted probabilities (or an iterable of predicted
|
||||
probabilities if as_iterable is True).
|
||||
"""
|
||||
return super(DNNLinearCombinedClassifier, self).predict(
|
||||
x=x, input_fn=input_fn, batch_size=batch_size, as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.PROBABILITIES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
|
||||
def _get_predict_ops(self, features):
|
||||
"""See base class."""
|
||||
return super(DNNLinearCombinedClassifier, self)._get_predict_ops(features)[
|
||||
head_lib.PredictionKey.PROBABILITIES]
|
||||
"""See `Estimator` class."""
|
||||
# pylint: disable=protected-access
|
||||
return self._estimator._get_predict_ops(features)[
|
||||
prediction_key.PredictionKey.PROBABILITIES]
|
||||
|
||||
def get_variable_names(self):
|
||||
"""Returns list of all variable names in this model.
|
||||
|
||||
Returns:
|
||||
List of names.
|
||||
"""
|
||||
return self._estimator.get_variable_names()
|
||||
|
||||
def get_variable_value(self, name):
|
||||
"""Returns value of the variable given by name.
|
||||
|
||||
Args:
|
||||
name: string, name of the tensor.
|
||||
|
||||
Returns:
|
||||
`Tensor` object.
|
||||
"""
|
||||
return self._estimator.get_variable_value(name)
|
||||
|
||||
def export(self,
|
||||
export_dir,
|
||||
input_fn=None,
|
||||
input_feature_key=None,
|
||||
use_deprecated_input_fn=True,
|
||||
signature_fn=None,
|
||||
default_batch_size=1,
|
||||
exports_to_keep=None):
|
||||
"""See BasEstimator.export."""
|
||||
def default_input_fn(unused_estimator, examples):
|
||||
return layers.parse_feature_columns_from_examples(
|
||||
examples, self._feature_columns)
|
||||
self._estimator.export(
|
||||
export_dir=export_dir,
|
||||
input_fn=input_fn or default_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||
signature_fn=(signature_fn or
|
||||
export.classification_signature_fn_with_prob),
|
||||
prediction_key=prediction_key.PredictionKey.PROBABILITIES,
|
||||
default_batch_size=default_batch_size,
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
@property
|
||||
def model_dir(self):
|
||||
return self._estimator.model_dir
|
||||
|
||||
@property
|
||||
@deprecated("2016-10-30",
|
||||
"This method will be removed after the deprecation date. "
|
||||
"To inspect variables, use get_variable_names() and "
|
||||
"get_variable_value().")
|
||||
def dnn_weights_(self):
|
||||
hiddenlayer_weights = [
|
||||
self.get_variable_value("dnn/hiddenlayer_%d/weights" % i)
|
||||
for i, _ in enumerate(self._dnn_hidden_units)
|
||||
]
|
||||
logits_weights = [self.get_variable_value("dnn/logits/weights")]
|
||||
return hiddenlayer_weights + logits_weights
|
||||
|
||||
@property
|
||||
@deprecated("2016-10-30",
|
||||
"This method will be removed after the deprecation date. "
|
||||
"To inspect variables, use get_variable_names() and "
|
||||
"get_variable_value().")
|
||||
def linear_weights_(self):
|
||||
values = {}
|
||||
if isinstance(self._linear_optimizer, str):
|
||||
optimizer_name = self._linear_optimizer
|
||||
else:
|
||||
optimizer_name = self._linear_optimizer.get_name()
|
||||
optimizer_regex = r".*/"+optimizer_name + r"(_\d)?$"
|
||||
for name in self.get_variable_names():
|
||||
if (name.startswith("linear/") and
|
||||
name != "linear/bias_weight" and
|
||||
name != "linear/learning_rate" and
|
||||
not re.match(optimizer_regex, name)):
|
||||
values[name] = self.get_variable_value(name)
|
||||
if len(values) == 1:
|
||||
return values[list(values.keys())[0]]
|
||||
return values
|
||||
|
||||
@property
|
||||
@deprecated("2016-10-30",
|
||||
"This method will be removed after the deprecation date. "
|
||||
"To inspect variables, use get_variable_names() and "
|
||||
"get_variable_value().")
|
||||
def dnn_bias_(self):
|
||||
hiddenlayer_bias = [self.get_variable_value("dnn/hiddenlayer_%d/biases" % i)
|
||||
for i, _ in enumerate(self._dnn_hidden_units)]
|
||||
logits_bias = [self.get_variable_value("dnn/logits/biases")]
|
||||
if not self._enable_centered_bias:
|
||||
return hiddenlayer_bias + logits_bias
|
||||
centered_bias = [self.get_variable_value(_CENTERED_BIAS_WEIGHT)]
|
||||
return hiddenlayer_bias + logits_bias + centered_bias
|
||||
|
||||
@property
|
||||
@deprecated("2016-10-30",
|
||||
"This method will be removed after the deprecation date. "
|
||||
"To inspect variables, use get_variable_names() and "
|
||||
"get_variable_value().")
|
||||
def linear_bias_(self):
|
||||
linear_bias = self.get_variable_value("linear/bias_weight")
|
||||
if not self._enable_centered_bias:
|
||||
return linear_bias
|
||||
centered_bias = [self.get_variable_value(_CENTERED_BIAS_WEIGHT)]
|
||||
return linear_bias + centered_bias
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._estimator.config
|
||||
|
||||
|
||||
class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
|
||||
@ -642,12 +1047,11 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
|
||||
head=head,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn,
|
||||
default_prediction_key=head_lib.PredictionKey.SCORES,
|
||||
default_prediction_key=prediction_key.PredictionKey.SCORES,
|
||||
enable_centered_bias=enable_centered_bias)
|
||||
|
||||
def _get_predict_ops(self, features):
|
||||
"""See base class."""
|
||||
return super(DNNLinearCombinedRegressor, self)._get_predict_ops(features)[
|
||||
head_lib.PredictionKey.SCORES]
|
||||
|
||||
|
||||
return super(
|
||||
DNNLinearCombinedRegressor,
|
||||
self)._get_predict_ops(features)[prediction_key.PredictionKey.SCORES]
|
||||
|
@ -27,6 +27,7 @@ import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
|
||||
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
|
||||
|
||||
|
||||
def _get_quantile_based_buckets(feature_values, num_buckets):
|
||||
@ -65,6 +66,15 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
estimator_test_utils.assert_estimator_contract(
|
||||
self, tf.contrib.learn.DNNLinearCombinedClassifier)
|
||||
|
||||
def testNoFeatureColumns(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'Either linear_feature_columns or dnn_feature_columns must be defined'):
|
||||
tf.contrib.learn.DNNLinearCombinedClassifier(
|
||||
linear_feature_columns=None,
|
||||
dnn_feature_columns=None,
|
||||
dnn_hidden_units=[3, 3])
|
||||
|
||||
def testLogisticRegression_MatrixData(self):
|
||||
"""Tests binary classification using matrix data as input."""
|
||||
iris = _prepare_iris_data_for_logistic_regression()
|
||||
@ -80,6 +90,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
|
||||
classifier.fit(input_fn=_iris_input_logistic_fn, steps=100)
|
||||
scores = classifier.evaluate(input_fn=_iris_input_logistic_fn, steps=100)
|
||||
self.assertIn('auc', scores.keys())
|
||||
self.assertGreater(scores['accuracy'], 0.9)
|
||||
|
||||
def testLogisticRegression_TensorData(self):
|
||||
@ -120,6 +131,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
|
||||
classifier.fit(input_fn=_input_fn, steps=100)
|
||||
scores = classifier.evaluate(input_fn=_input_fn, steps=100)
|
||||
self.assertIn('auc', scores.keys())
|
||||
self.assertGreater(scores['accuracy'], 0.9)
|
||||
|
||||
def testTrainWithPartitionedVariables(self):
|
||||
@ -397,9 +409,15 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
input_fn=_input_fn,
|
||||
steps=100,
|
||||
metrics={
|
||||
'my_accuracy': tf.contrib.metrics.streaming_accuracy,
|
||||
('my_precision', 'classes'): tf.contrib.metrics.streaming_precision,
|
||||
('my_metric', 'probabilities'): _my_metric_op
|
||||
'my_accuracy': MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_accuracy,
|
||||
prediction_key='classes'),
|
||||
'my_precision': MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_precision,
|
||||
prediction_key='classes'),
|
||||
'my_metric': MetricSpec(
|
||||
metric_fn=_my_metric_op,
|
||||
prediction_key='probabilities')
|
||||
})
|
||||
self.assertTrue(
|
||||
set(['loss', 'my_accuracy', 'my_precision', 'my_metric'
|
||||
@ -412,7 +430,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
|
||||
# Test the case where the 2nd element of the key is neither "classes" nor
|
||||
# "probabilities".
|
||||
with self.assertRaises(KeyError):
|
||||
with self.assertRaisesRegexp(KeyError, 'bad_type'):
|
||||
classifier.evaluate(
|
||||
input_fn=_input_fn,
|
||||
steps=100,
|
||||
@ -428,6 +446,17 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
tf.contrib.metrics.streaming_accuracy
|
||||
})
|
||||
|
||||
# Test the case where the prediction_key is neither "classes" nor
|
||||
# "probabilities".
|
||||
with self.assertRaisesRegexp(KeyError, 'bad_type'):
|
||||
classifier.evaluate(
|
||||
input_fn=_input_fn,
|
||||
steps=100,
|
||||
metrics={
|
||||
'bad_name': MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_auc,
|
||||
prediction_key='bad_type')})
|
||||
|
||||
def testVariableQuery(self):
|
||||
"""Tests bias is centered or not."""
|
||||
def _input_fn_train():
|
||||
@ -447,6 +476,39 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
for name in var_names:
|
||||
classifier.get_variable_value(name)
|
||||
|
||||
def testExport(self):
|
||||
"""Tests export model for servo."""
|
||||
|
||||
def input_fn():
|
||||
return {
|
||||
'age': tf.constant([1]),
|
||||
'language': tf.SparseTensor(values=['english'],
|
||||
indices=[[0, 0]],
|
||||
shape=[1, 1])
|
||||
}, tf.constant([[1]])
|
||||
|
||||
language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100)
|
||||
|
||||
classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
|
||||
linear_feature_columns=[
|
||||
tf.contrib.layers.real_valued_column('age'),
|
||||
language,
|
||||
],
|
||||
dnn_feature_columns=[
|
||||
tf.contrib.layers.embedding_column(language, dimension=1),
|
||||
],
|
||||
dnn_hidden_units=[3, 3])
|
||||
classifier.fit(input_fn=input_fn, steps=100)
|
||||
|
||||
export_dir = tempfile.mkdtemp()
|
||||
input_feature_key = 'examples'
|
||||
def serving_input_fn():
|
||||
features, targets = input_fn()
|
||||
features[input_feature_key] = tf.placeholder(tf.string)
|
||||
return features, targets
|
||||
classifier.export(export_dir, serving_input_fn, input_feature_key,
|
||||
use_deprecated_input_fn=False)
|
||||
|
||||
def testCenteredBias(self):
|
||||
"""Tests bias is centered or not."""
|
||||
def _input_fn_train():
|
||||
@ -461,7 +523,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
dnn_hidden_units=[3, 3],
|
||||
enable_centered_bias=True)
|
||||
|
||||
classifier.fit(input_fn=_input_fn_train, steps=500)
|
||||
classifier.fit(input_fn=_input_fn_train, steps=1000)
|
||||
# logodds(0.75) = 1.09861228867
|
||||
self.assertAlmostEqual(
|
||||
1.0986,
|
||||
@ -483,7 +545,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
enable_centered_bias=False)
|
||||
|
||||
classifier.fit(input_fn=_input_fn_train, steps=500)
|
||||
self.assertFalse('centered_bias_weight' in classifier.get_variable_names())
|
||||
self.assertNotIn('centered_bias_weight', classifier.get_variable_names())
|
||||
|
||||
def testLinearOnly(self):
|
||||
"""Tests that linear-only instantiation works."""
|
||||
@ -822,6 +884,44 @@ class DNNLinearCombinedRegressorTest(tf.test.TestCase):
|
||||
metrics={('my_error', 'predictions'
|
||||
): tf.contrib.metrics.streaming_mean_squared_error})
|
||||
|
||||
def testExport(self):
|
||||
"""Tests export model for servo."""
|
||||
labels = [1., 0., 0.2]
|
||||
def _input_fn(num_epochs=None):
|
||||
features = {
|
||||
'age': tf.train.limit_epochs(tf.constant([[0.8], [0.15], [0.]]),
|
||||
num_epochs=num_epochs),
|
||||
'language': tf.SparseTensor(values=['en', 'fr', 'zh'],
|
||||
indices=[[0, 0], [0, 1], [2, 0]],
|
||||
shape=[3, 2])
|
||||
}
|
||||
return features, tf.constant(labels, dtype=tf.float32)
|
||||
|
||||
language_column = tf.contrib.layers.sparse_column_with_hash_bucket(
|
||||
'language', hash_bucket_size=20)
|
||||
|
||||
regressor = tf.contrib.learn.DNNLinearCombinedRegressor(
|
||||
linear_feature_columns=[
|
||||
language_column,
|
||||
tf.contrib.layers.real_valued_column('age')
|
||||
],
|
||||
dnn_feature_columns=[
|
||||
tf.contrib.layers.embedding_column(language_column, dimension=1),
|
||||
],
|
||||
dnn_hidden_units=[3, 3],
|
||||
config=tf.contrib.learn.RunConfig(tf_random_seed=1))
|
||||
|
||||
regressor.fit(input_fn=_input_fn, steps=100)
|
||||
|
||||
export_dir = tempfile.mkdtemp()
|
||||
input_feature_key = 'examples'
|
||||
def serving_input_fn():
|
||||
features, targets = _input_fn()
|
||||
features[input_feature_key] = tf.placeholder(tf.string)
|
||||
return features, targets
|
||||
regressor.export(export_dir, serving_input_fn, input_feature_key,
|
||||
use_deprecated_input_fn=False)
|
||||
|
||||
def testTrainSaveLoad(self):
|
||||
"""Tests regression with restarting training / evaluate."""
|
||||
def _input_fn(num_epochs=None):
|
||||
@ -1009,7 +1109,7 @@ class FeatureEngineeringFunctionTest(tf.test.TestCase):
|
||||
config=tf.contrib.learn.RunConfig(tf_random_seed=1))
|
||||
estimator_without_fe_fn.fit(input_fn=input_fn, steps=100)
|
||||
|
||||
# predictions = y
|
||||
# predictions = y
|
||||
prediction_with_fe_fn = next(
|
||||
estimator_with_fe_fn.predict(input_fn=input_fn, as_iterable=True))
|
||||
self.assertAlmostEqual(1000., prediction_with_fe_fn, delta=1.0)
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
|
||||
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
|
||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
||||
@ -1108,8 +1109,9 @@ class Estimator(BaseEstimator):
|
||||
|
||||
result = _make_metrics_ops(all_metrics, features, labels,
|
||||
model_fn_ops.predictions)
|
||||
if 'loss' not in result:
|
||||
result['loss'] = metrics_lib.streaming_mean(model_fn_ops.loss)
|
||||
if metric_key.MetricKey.LOSS not in result:
|
||||
result[metric_key.MetricKey.LOSS] = metrics_lib.streaming_mean(
|
||||
model_fn_ops.loss)
|
||||
return result
|
||||
|
||||
def _get_predict_ops(self, features):
|
||||
|
@ -24,9 +24,12 @@ from tensorflow.contrib import losses
|
||||
from tensorflow.contrib import metrics as metrics_lib
|
||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.session_bundle import exporter
|
||||
from tensorflow.python import summary
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -387,17 +390,17 @@ class _RegressionHead(_Head):
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {}
|
||||
if self.logits_dimension == 1:
|
||||
predictions[PredictionKey.SCORES] = array_ops.squeeze(
|
||||
predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze(
|
||||
logits, squeeze_dims=[1])
|
||||
else:
|
||||
predictions[PredictionKey.SCORES] = logits
|
||||
predictions[prediction_key.PredictionKey.SCORES] = logits
|
||||
return predictions
|
||||
|
||||
# pylint: disable=undefined-variable
|
||||
def _create_signature_fn(self):
|
||||
def _regression_signature_fn(examples, unused_features, predictions):
|
||||
if isinstance(predictions, dict):
|
||||
score = predictions[PredictionKey.SCORES]
|
||||
score = predictions[prediction_key.PredictionKey.SCORES]
|
||||
else:
|
||||
score = predictions
|
||||
|
||||
@ -408,11 +411,12 @@ class _RegressionHead(_Head):
|
||||
return _regression_signature_fn
|
||||
|
||||
def _default_metric(self):
|
||||
return {_head_prefixed(self._head_name, MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
||||
PredictionKey.SCORES,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
return {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(
|
||||
self._eval_loss_fn,
|
||||
prediction_key.PredictionKey.SCORES,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
|
||||
|
||||
class _MultiClassHead(_Head):
|
||||
@ -529,12 +533,16 @@ class _MultiClassHead(_Head):
|
||||
return self._logits_to_prediction(logits)
|
||||
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {PredictionKey.LOGITS: logits}
|
||||
# pylint: disable=missing-docstring
|
||||
predictions = {prediction_key.PredictionKey.LOGITS: logits}
|
||||
if self.logits_dimension == 1:
|
||||
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
|
||||
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
|
||||
logits)
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
predictions[PredictionKey.PROBABILITIES] = nn.softmax(logits)
|
||||
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
|
||||
predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
|
||||
logits)
|
||||
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
|
||||
logits, 1)
|
||||
|
||||
return predictions
|
||||
|
||||
@ -545,8 +553,9 @@ class _MultiClassHead(_Head):
|
||||
if isinstance(predictions, dict):
|
||||
default_signature = exporter.classification_signature(
|
||||
input_tensor=examples,
|
||||
classes_tensor=predictions[PredictionKey.CLASSES],
|
||||
scores_tensor=predictions[PredictionKey.PROBABILITIES])
|
||||
classes_tensor=predictions[prediction_key.PredictionKey.CLASSES],
|
||||
scores_tensor=predictions[
|
||||
prediction_key.PredictionKey.PROBABILITIES])
|
||||
else:
|
||||
default_signature = exporter.classification_signature(
|
||||
input_tensor=examples,
|
||||
@ -557,44 +566,49 @@ class _MultiClassHead(_Head):
|
||||
return _classification_signature_fn
|
||||
|
||||
def _default_metric(self):
|
||||
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
||||
PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(
|
||||
self._eval_loss_fn,
|
||||
prediction_key.PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
|
||||
# TODO(b/29366811): This currently results in both an "accuracy" and an
|
||||
# "accuracy/threshold_0.500000_mean" metric for binary classification.
|
||||
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
|
||||
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
|
||||
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
||||
PredictionKey.CLASSES, self._label_name,
|
||||
prediction_key.PredictionKey.CLASSES,
|
||||
self._label_name,
|
||||
self._weight_column_name))
|
||||
if self.logits_dimension == 1:
|
||||
def _add_binary_metric(metric_key, metric_fn):
|
||||
metrics[_head_prefixed(self._head_name, metric_key)] = (
|
||||
def _add_binary_metric(key, metric_fn):
|
||||
metrics[_head_prefixed(self._head_name, key)] = (
|
||||
metric_spec.MetricSpec(metric_fn,
|
||||
PredictionKey.LOGISTIC,
|
||||
prediction_key.PredictionKey.LOGISTIC,
|
||||
self._label_name,
|
||||
self._weight_column_name))
|
||||
_add_binary_metric(MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
|
||||
_add_binary_metric(MetricKey.LABEL_MEAN, _labels_streaming_mean)
|
||||
_add_binary_metric(
|
||||
metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
|
||||
_add_binary_metric(
|
||||
metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean)
|
||||
|
||||
# Also include the streaming mean of the label as an accuracy baseline, as
|
||||
# a reminder to users.
|
||||
_add_binary_metric(MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
|
||||
_add_binary_metric(
|
||||
metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
|
||||
|
||||
_add_binary_metric(MetricKey.AUC, _streaming_auc)
|
||||
_add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc)
|
||||
|
||||
for threshold in self._thresholds:
|
||||
_add_binary_metric(MetricKey.ACCURACY_MEAN % threshold,
|
||||
_add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold,
|
||||
_accuracy_at_threshold(threshold))
|
||||
# Precision for positive examples.
|
||||
_add_binary_metric(MetricKey.PRECISION_MEAN % threshold,
|
||||
_add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold,
|
||||
_streaming_at_threshold(
|
||||
metrics_lib.streaming_precision_at_thresholds,
|
||||
threshold),)
|
||||
# Recall for positive examples.
|
||||
_add_binary_metric(MetricKey.RECALL_MEAN % threshold,
|
||||
_add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold,
|
||||
_streaming_at_threshold(
|
||||
metrics_lib.streaming_recall_at_thresholds,
|
||||
threshold))
|
||||
@ -603,7 +617,7 @@ class _MultiClassHead(_Head):
|
||||
|
||||
def _check_labels(labels, label_name):
|
||||
labels = labels[label_name] if isinstance(labels, dict) else labels
|
||||
if isinstance(labels, ops.SparseTensor):
|
||||
if isinstance(labels, sparse_tensor.SparseTensor):
|
||||
raise ValueError("SparseTensor is not supported as labels.")
|
||||
return labels
|
||||
|
||||
@ -634,21 +648,24 @@ class _BinarySvmHead(_MultiClassHead):
|
||||
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {}
|
||||
predictions[PredictionKey.LOGITS] = logits
|
||||
predictions[prediction_key.PredictionKey.LOGITS] = logits
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
|
||||
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
|
||||
logits, 1)
|
||||
|
||||
return predictions
|
||||
|
||||
def _default_metric(self):
|
||||
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
||||
PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
|
||||
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(
|
||||
self._eval_loss_fn,
|
||||
prediction_key.PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
|
||||
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
||||
PredictionKey.CLASSES, self._label_name,
|
||||
prediction_key.PredictionKey.CLASSES,
|
||||
self._label_name,
|
||||
self._weight_column_name))
|
||||
# TODO(sibyl-vie3Poto): add more metrics relevant for svms.
|
||||
return metrics
|
||||
@ -673,12 +690,14 @@ class _MultiLabelHead(_MultiClassHead):
|
||||
thresholds=thresholds)
|
||||
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {PredictionKey.LOGITS: logits}
|
||||
predictions = {prediction_key.PredictionKey.LOGITS: logits}
|
||||
if self.logits_dimension == 1:
|
||||
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
|
||||
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
|
||||
logits)
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
predictions[PredictionKey.PROBABILITIES] = math_ops.sigmoid(logits)
|
||||
predictions[PredictionKey.CLASSES] = math_ops.to_int64(
|
||||
predictions[prediction_key.PredictionKey.PROBABILITIES] = math_ops.sigmoid(
|
||||
logits)
|
||||
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.to_int64(
|
||||
math_ops.greater(logits, 0))
|
||||
return predictions
|
||||
|
||||
@ -848,23 +867,3 @@ def _streaming_at_threshold(streaming_metrics_fn, threshold):
|
||||
return array_ops.squeeze(precision_tensor), update_op
|
||||
|
||||
return _streaming_metrics
|
||||
|
||||
|
||||
class PredictionKey(object):
|
||||
CLASSES = "classes"
|
||||
PROBABILITIES = "probabilities"
|
||||
LOGITS = "logits"
|
||||
LOGISTIC = "logistic"
|
||||
SCORES = "scores"
|
||||
|
||||
|
||||
class MetricKey(object):
|
||||
LOSS = "loss"
|
||||
AUC = "auc"
|
||||
PREDICTION_MEAN = "labels/prediction_mean"
|
||||
LABEL_MEAN = "labels/actual_label_mean"
|
||||
ACCURACY = "accuracy"
|
||||
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
|
||||
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
|
||||
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
|
||||
RECALL_MEAN = "recall/positive_threshold_%f_mean"
|
||||
|
@ -74,7 +74,7 @@ class MultiClassModelHeadTest(tf.test.TestCase):
|
||||
model_fn_ops = head.head_ops({}, labels,
|
||||
tf.contrib.learn.ModeKeys.TRAIN,
|
||||
_noop_train_op, logits=logits)
|
||||
self.assertAlmostEqual(.81326163, sess.run(model_fn_ops.loss))
|
||||
self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss))
|
||||
|
||||
def testErrorInSparseTensorLabels(self):
|
||||
head = head_lib._multi_class_head(n_classes=2)
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.learn.python.learn.utils import export
|
||||
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -267,21 +268,18 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
Example:
|
||||
|
||||
```python
|
||||
education = sparse_column_with_hash_bucket(column_name="education",
|
||||
hash_bucket_size=1000)
|
||||
occupation = sparse_column_with_hash_bucket(column_name="occupation",
|
||||
hash_bucket_size=1000)
|
||||
sparse_column_a = sparse_column_with_hash_bucket(...)
|
||||
sparse_column_b = sparse_column_with_hash_bucket(...)
|
||||
|
||||
education_x_occupation = crossed_column(columns=[education, occupation],
|
||||
hash_bucket_size=10000)
|
||||
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||
|
||||
# Estimator using the default optimizer.
|
||||
estimator = LinearClassifier(
|
||||
feature_columns=[occupation, education_x_occupation])
|
||||
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b])
|
||||
|
||||
# Or estimator using the FTRL optimizer with regularization.
|
||||
estimator = LinearClassifier(
|
||||
feature_columns=[occupation, education_x_occupation],
|
||||
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
||||
optimizer=tf.train.FtrlOptimizer(
|
||||
learning_rate=0.1,
|
||||
l1_regularization_strength=0.001
|
||||
@ -289,7 +287,7 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
|
||||
# Or estimator using the SDCAOptimizer.
|
||||
estimator = LinearClassifier(
|
||||
feature_columns=[occupation, education_x_occupation],
|
||||
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
||||
optimizer=tf.contrib.linear_optimizer.SDCAOptimizer(
|
||||
example_id_column='example_id',
|
||||
num_loss_partitions=...,
|
||||
@ -465,13 +463,16 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
as_iterable=False)
|
||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||
"""Runs inference to determine the predicted class."""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[head_lib.PredictionKey.CLASSES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.CLASSES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
|
||||
return preds[head_lib.PredictionKey.CLASSES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
@ -479,14 +480,16 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
||||
as_iterable=True):
|
||||
"""Runs inference to determine the class probability predictions."""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[
|
||||
head_lib.PredictionKey.PROBABILITIES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.PROBABILITIES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
|
||||
return preds[head_lib.PredictionKey.PROBABILITIES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
|
||||
def get_variable_names(self):
|
||||
return self._estimator.get_variable_names()
|
||||
@ -512,9 +515,9 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
input_fn=input_fn or default_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||
signature_fn=(
|
||||
signature_fn or export.classification_signature_fn_with_prob),
|
||||
prediction_key=head_lib.PredictionKey.PROBABILITIES,
|
||||
signature_fn=(signature_fn or
|
||||
export.classification_signature_fn_with_prob),
|
||||
prediction_key=prediction_key.PredictionKey.PROBABILITIES,
|
||||
default_batch_size=default_batch_size,
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
@ -561,16 +564,13 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
||||
Example:
|
||||
|
||||
```python
|
||||
education = sparse_column_with_hash_bucket(column_name="education",
|
||||
hash_bucket_size=1000)
|
||||
occupation = sparse_column_with_hash_bucket(column_name="occupation",
|
||||
hash_bucket_size=1000)
|
||||
sparse_column_a = sparse_column_with_hash_bucket(...)
|
||||
sparse_column_b = sparse_column_with_hash_bucket(...)
|
||||
|
||||
education_x_occupation = crossed_column(columns=[education, occupation],
|
||||
hash_bucket_size=10000)
|
||||
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||
|
||||
estimator = LinearRegressor(
|
||||
feature_columns=[occupation, education_x_occupation])
|
||||
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b])
|
||||
|
||||
# Input builders
|
||||
def input_fn_train: # returns x, y
|
||||
@ -731,13 +731,16 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
||||
as_iterable=False)
|
||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||
"""Runs inference to determine the predicted class."""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[head_lib.PredictionKey.SCORES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.SCORES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.SCORES)
|
||||
return preds[head_lib.PredictionKey.SCORES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
|
||||
def get_variable_names(self):
|
||||
return self._estimator.get_variable_names()
|
||||
@ -764,7 +767,7 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
|
||||
input_feature_key=input_feature_key,
|
||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||
signature_fn=(signature_fn or export.regression_signature_fn),
|
||||
prediction_key=head_lib.PredictionKey.SCORES,
|
||||
prediction_key=prediction_key.PredictionKey.SCORES,
|
||||
default_batch_size=default_batch_size,
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
|
@ -0,0 +1,30 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Enum for metric keys."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class MetricKey(object):
|
||||
LOSS = "loss"
|
||||
AUC = "auc"
|
||||
PREDICTION_MEAN = "labels/prediction_mean"
|
||||
LABEL_MEAN = "labels/actual_label_mean"
|
||||
ACCURACY = "accuracy"
|
||||
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
|
||||
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
|
||||
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
|
||||
RECALL_MEAN = "recall/positive_threshold_%f_mean"
|
@ -0,0 +1,26 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Enum for model prediction keys."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class PredictionKey(object):
|
||||
CLASSES = "classes"
|
||||
PROBABILITIES = "probabilities"
|
||||
LOGITS = "logits"
|
||||
LOGISTIC = "logistic"
|
||||
SCORES = "scores"
|
@ -30,6 +30,7 @@ from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import linear
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
|
||||
|
||||
|
||||
@ -188,13 +189,16 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
||||
as_iterable=False)
|
||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
|
||||
"""Runs inference to determine the predicted class."""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[head_lib.PredictionKey.CLASSES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.CLASSES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
|
||||
return preds[head_lib.PredictionKey.CLASSES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
@ -202,14 +206,16 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
|
||||
def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
|
||||
as_iterable=True):
|
||||
"""Runs inference to determine the class probability predictions."""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[
|
||||
head_lib.PredictionKey.PROBABILITIES],
|
||||
as_iterable=as_iterable)
|
||||
key = prediction_key.PredictionKey.PROBABILITIES
|
||||
preds = self._estimator.predict(
|
||||
x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[key],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
|
||||
return preds[head_lib.PredictionKey.PROBABILITIES]
|
||||
return _as_iterable(preds, output=key)
|
||||
return preds[key]
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def get_variable_names(self):
|
||||
|
@ -22,7 +22,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -41,7 +41,7 @@ class TensorSignature(collections.namedtuple(
|
||||
"""
|
||||
|
||||
def __new__(cls, tensor):
|
||||
if isinstance(tensor, ops.SparseTensor):
|
||||
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||
return super(TensorSignature, cls).__new__(
|
||||
cls, dtype=tensor.values.dtype, shape=None, is_sparse=True)
|
||||
return super(TensorSignature, cls).__new__(
|
||||
|
@ -40,6 +40,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
@ -77,7 +78,8 @@ def get_summary_writer(logdir):
|
||||
|
||||
|
||||
def _make_saver(graph, keep_checkpoint_max=5):
|
||||
vars_to_save = graph.get_collection(ops.GraphKeys.VARIABLES)
|
||||
vars_to_save = (graph.get_collection(ops.GraphKeys.VARIABLES) +
|
||||
graph.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
|
||||
if vars_to_save:
|
||||
return tf_saver.Saver(vars_to_save,
|
||||
sharded=True,
|
||||
@ -846,9 +848,11 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
|
||||
raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)
|
||||
|
||||
graph = contrib_ops.get_graph_from_inputs(output_dict.values())
|
||||
|
||||
with graph.as_default() as g:
|
||||
with tf_session.Session('') as session:
|
||||
session.run(
|
||||
resources.initialize_resources(resources.shared_resources() +
|
||||
resources.local_resources()))
|
||||
if restore_checkpoint_path:
|
||||
_restore_from_checkpoint(session, g, restore_checkpoint_path)
|
||||
else:
|
||||
|
@ -28,6 +28,8 @@ from tensorflow.contrib.learn.python import learn
|
||||
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_ops
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
@ -194,6 +196,19 @@ class GraphActionsTest(tf.test.TestCase):
|
||||
pass
|
||||
self.assertTrue(request_stop.called)
|
||||
|
||||
def test_run_feeds_iter_calls_resources_init(self):
|
||||
with tf.Graph().as_default() as g:
|
||||
in0, _, _ = self._build_inference_graph()
|
||||
handle = test_ops.stub_resource_handle_op(container='a', shared_name='b')
|
||||
resources.register_resource(
|
||||
handle=handle,
|
||||
create_op=test_ops.resource_create_op(handle),
|
||||
is_initialized_op=test_ops.resource_initialized_op(handle))
|
||||
|
||||
for _ in learn.graph_actions.run_feeds_iter({'in0': in0},
|
||||
feed_dicts=[{}]):
|
||||
self.assertTrue(test_ops.resource_initialized_op(handle).eval())
|
||||
|
||||
def test_infer_different_default_graph(self):
|
||||
with self.test_session():
|
||||
self._assert_ckpt(self._output_dir, False)
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -645,7 +646,7 @@ def queue_parsed_features(parsed_features,
|
||||
# directly.
|
||||
for key in sorted(parsed_features.keys()):
|
||||
tensor = parsed_features[key]
|
||||
if isinstance(tensor, ops.SparseTensor):
|
||||
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||
tensors_mapping.append((key, True))
|
||||
tensors_to_enqueue.extend([tensor.indices, tensor.values, tensor.shape])
|
||||
else:
|
||||
@ -704,7 +705,7 @@ def queue_parsed_features(parsed_features,
|
||||
for key, is_sparse_tensor in tensors_mapping:
|
||||
if is_sparse_tensor:
|
||||
# Three tensors are (indices, values, shape).
|
||||
dequeued_parsed_features[key] = ops.SparseTensor(
|
||||
dequeued_parsed_features[key] = sparse_tensor.SparseTensor(
|
||||
dequeued_tensors[index], dequeued_tensors[index + 1],
|
||||
dequeued_tensors[index + 2])
|
||||
index += 3
|
||||
|
@ -542,7 +542,8 @@ class CheckpointSaverTest(tf.test.TestCase):
|
||||
self.assertEqual(1, tf.contrib.framework.load_variable(
|
||||
self.model_dir, self.global_step.name))
|
||||
|
||||
def test_save_secs_saves_periodically(self):
|
||||
# TODO(gunan): Reenable this test after b/32446874 is fixed.
|
||||
def disabled_test_save_secs_saves_periodically(self):
|
||||
with self.graph.as_default():
|
||||
monitor = learn.monitors.CheckpointSaver(
|
||||
self.model_dir, save_secs=2, scaffold=self.scaffold)
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_data_flow_ops
|
||||
@ -166,7 +167,7 @@ class InitializableLookupTableBase(LookupInterface):
|
||||
name = "%s_lookup_table_find" % self._name
|
||||
|
||||
key_tensor = keys
|
||||
if isinstance(keys, ops.SparseTensor):
|
||||
if isinstance(keys, sparse_tensor.SparseTensor):
|
||||
key_tensor = keys.values
|
||||
|
||||
if keys.dtype != self._key_dtype:
|
||||
@ -181,8 +182,8 @@ class InitializableLookupTableBase(LookupInterface):
|
||||
# pylint: enable=protected-access
|
||||
|
||||
values.set_shape(key_tensor.get_shape())
|
||||
if isinstance(keys, ops.SparseTensor):
|
||||
return ops.SparseTensor(keys.indices, values, keys.shape)
|
||||
if isinstance(keys, sparse_tensor.SparseTensor):
|
||||
return sparse_tensor.SparseTensor(keys.indices, values, keys.shape)
|
||||
else:
|
||||
return values
|
||||
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import six
|
||||
import tensorflow as tf
|
||||
@ -296,7 +297,8 @@ class MutableHashTableOpTest(tf.test.TestCase):
|
||||
self.assertAllEqual([0, 1, 2], sorted_values)
|
||||
|
||||
def testSaveRestore(self):
|
||||
save_path = os.path.join(self.get_temp_dir(), "hash")
|
||||
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
||||
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||
|
||||
with self.test_session(graph=tf.Graph()) as sess:
|
||||
v0 = tf.Variable(10.0, name="v0")
|
||||
@ -867,7 +869,8 @@ class MutableDenseHashTableOpTest(tf.test.TestCase):
|
||||
[100, 0], [100, 0], [100, 0]], pairs)
|
||||
|
||||
def testSaveRestore(self):
|
||||
save_path = os.path.join(self.get_temp_dir(), "hash")
|
||||
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
||||
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||
|
||||
with self.test_session(graph=tf.Graph()) as sess:
|
||||
default_value = -1
|
||||
@ -922,7 +925,8 @@ class MutableDenseHashTableOpTest(tf.test.TestCase):
|
||||
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
|
||||
|
||||
def testVectorSaveRestore(self):
|
||||
save_path = os.path.join(self.get_temp_dir(), "hash")
|
||||
save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")
|
||||
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||
|
||||
with self.test_session(graph=tf.Graph()) as sess:
|
||||
empty_key = tf.constant([11, 13], tf.int64)
|
||||
|
@ -1,7 +1,8 @@
|
||||
### TensorFlow Makefile
|
||||
|
||||
The recommended way to build TensorFlow from source is using the Bazel
|
||||
open-source build system. Sometimes this isn't possible.
|
||||
open-source build system. Sometimes this isn't possible. For example,
|
||||
if you are building for iOS, you currently need to use the Makefile.
|
||||
|
||||
- The build system may not have the RAM or processing power to support Bazel.
|
||||
- Bazel or its dependencies may not be available.
|
||||
|
@ -43,6 +43,13 @@ tensorflow/core/kernels/sequence_ops.cc
|
||||
tensorflow/core/kernels/sendrecv_ops.cc
|
||||
tensorflow/core/kernels/scatter_op.cc
|
||||
tensorflow/core/kernels/scatter_functor.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
|
||||
tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
|
||||
tensorflow/core/kernels/scatter_nd_op.cc
|
||||
tensorflow/core/kernels/save_restore_tensor.cc
|
||||
tensorflow/core/kernels/save_restore_v2_ops.cc
|
||||
tensorflow/core/kernels/save_op.cc
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.framework import tensor_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
@ -102,7 +103,7 @@ def confusion_matrix(predictions, labels, num_classes=None, dtype=dtypes.int32,
|
||||
indices = array_ops.transpose(array_ops.pack([predictions, labels]))
|
||||
values = (array_ops.ones_like(predictions, dtype)
|
||||
if weights is None else weights)
|
||||
cm_sparse = ops.SparseTensor(
|
||||
cm_sparse = sparse_tensor.SparseTensor(
|
||||
indices=indices, values=values, shape=math_ops.to_int64(shape))
|
||||
zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtype)
|
||||
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
|
||||
from tensorflow.contrib.metrics.python.ops import set_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -762,7 +763,12 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
|
||||
computes the area under a discretized curve of precision versus recall values
|
||||
(computed using the aforementioned variables). The `num_thresholds` variable
|
||||
controls the degree of discretization with larger numbers of thresholds more
|
||||
closely approximating the true AUC.
|
||||
closely approximating the true AUC. The quality of the approximation may vary
|
||||
dramatically depending on `num_thresholds`.
|
||||
|
||||
For best results, `predictions` should be distributed approximately uniformly
|
||||
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
|
||||
approximation may be poor if this is not the case.
|
||||
|
||||
For estimation of the metric over a stream of data, the function creates an
|
||||
`update_op` operation that updates these variables and returns the `auc`.
|
||||
@ -1601,7 +1607,8 @@ def num_relevant(labels, k):
|
||||
raise ValueError('Invalid k=%s.' % k)
|
||||
with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
|
||||
# For SparseTensor, calculate separate count for each row.
|
||||
if isinstance(labels, (ops.SparseTensor, ops.SparseTensorValue)):
|
||||
if isinstance(
|
||||
labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
|
||||
labels_sizes = set_ops.set_size(labels)
|
||||
return math_ops.minimum(labels_sizes, k, name=scope)
|
||||
|
||||
@ -1637,9 +1644,9 @@ def expand_and_tile(tensor, multiple, dim=0, name=None):
|
||||
with ops.name_scope(
|
||||
name, 'expand_and_tile', (tensor, multiple, dim)) as scope:
|
||||
# Sparse.
|
||||
if isinstance(tensor, ops.SparseTensorValue):
|
||||
tensor = ops.SparseTensor.from_value(tensor)
|
||||
if isinstance(tensor, ops.SparseTensor):
|
||||
if isinstance(tensor, sparse_tensor.SparseTensorValue):
|
||||
tensor = sparse_tensor.SparseTensor.from_value(tensor)
|
||||
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||
if dim < 0:
|
||||
expand_dims = array_ops.reshape(
|
||||
array_ops.size(tensor.shape) + dim, [1])
|
||||
@ -1871,7 +1878,8 @@ def _select_class_id(ids, selected_id):
|
||||
`SparseTensor` of same dimensions as `ids`. This contains only the entries
|
||||
equal to `selected_id`.
|
||||
"""
|
||||
if isinstance(ids, (ops.SparseTensor, ops.SparseTensorValue)):
|
||||
if isinstance(
|
||||
ids, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
|
||||
return sparse_ops.sparse_retain(
|
||||
ids, math_ops.equal(ids.values, selected_id))
|
||||
|
||||
@ -1888,7 +1896,7 @@ def _select_class_id(ids, selected_id):
|
||||
filled_selected_id = array_ops.fill(
|
||||
filled_selected_id_shape, math_ops.to_int64(selected_id))
|
||||
result = set_ops.set_intersection(filled_selected_id, ids)
|
||||
return ops.SparseTensor(
|
||||
return sparse_tensor.SparseTensor(
|
||||
indices=result.indices, values=result.values, shape=ids_shape)
|
||||
|
||||
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
|
||||
@ -54,7 +55,7 @@ def set_size(a, validate_indices=True):
|
||||
TypeError: If `a` is an invalid types.
|
||||
"""
|
||||
a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
|
||||
if not isinstance(a, ops.SparseTensor):
|
||||
if not isinstance(a, sparse_tensor.SparseTensor):
|
||||
raise TypeError("Expected `SparseTensor`, got %s." % a)
|
||||
if a.values.dtype.base_dtype not in _VALID_DTYPES:
|
||||
raise TypeError("Invalid dtype %s." % a.values.dtype)
|
||||
@ -106,22 +107,22 @@ def _set_operation(a, b, set_operation, validate_indices=True):
|
||||
if b.dtype.base_dtype != a.dtype.base_dtype:
|
||||
raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(a, ops.SparseTensor):
|
||||
if isinstance(b, ops.SparseTensor):
|
||||
if isinstance(a, sparse_tensor.SparseTensor):
|
||||
if isinstance(b, sparse_tensor.SparseTensor):
|
||||
indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
|
||||
a.indices, a.values, a.shape, b.indices, b.values, b.shape,
|
||||
set_operation, validate_indices)
|
||||
else:
|
||||
raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
|
||||
"Please flip the order of your inputs.")
|
||||
elif isinstance(b, ops.SparseTensor):
|
||||
elif isinstance(b, sparse_tensor.SparseTensor):
|
||||
indices, values, shape = _set_ops.dense_to_sparse_set_operation(
|
||||
a, b.indices, b.values, b.shape, set_operation, validate_indices)
|
||||
else:
|
||||
indices, values, shape = _set_ops.dense_to_dense_set_operation(
|
||||
a, b, set_operation, validate_indices)
|
||||
# pylint: enable=protected-access
|
||||
return ops.SparseTensor(indices, values, shape)
|
||||
return sparse_tensor.SparseTensor(indices, values, shape)
|
||||
|
||||
|
||||
def set_intersection(a, b, validate_indices=True):
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import tempfile
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
@ -40,7 +41,9 @@ class MovingAverageOptimizerTest(tf.test.TestCase):
|
||||
tf.train.GradientDescentOptimizer(learning_rate=2.0),
|
||||
average_decay=0.5,
|
||||
sequential_update=sequential_update)
|
||||
save_path = os.path.join(self.get_temp_dir(), 'model')
|
||||
save_dir = tempfile.mkdtemp(
|
||||
prefix=os.path.join(self.get_temp_dir(), 'run_1'))
|
||||
save_path = os.path.join(save_dir, 'model')
|
||||
update = opt.apply_gradients(
|
||||
list(six.moves.zip([grads0, grads1], [var0, var1])))
|
||||
train_saver = opt.swapping_saver()
|
||||
|
@ -39,7 +39,7 @@ INCLUDES := \
|
||||
-I/usr/local/include \
|
||||
-I. \
|
||||
-I$(DOWNLOADSDIR) \
|
||||
-I$(DOWNLOADSDIR)/eigen-latest/ \
|
||||
-I$(DOWNLOADSDIR)/eigen/ \
|
||||
-I$(PROTOGENDIR) \
|
||||
-I$(PBTGENDIR)
|
||||
LIBS := \
|
||||
|
@ -39,7 +39,7 @@ INCLUDES := \
|
||||
-I/usr/local/include \
|
||||
-I. \
|
||||
-I$(DOWNLOADSDIR) \
|
||||
-I$(DOWNLOADSDIR)/eigen-latest/ \
|
||||
-I$(DOWNLOADSDIR)/eigen/ \
|
||||
-I$(PROTOGENDIR) \
|
||||
-I$(PBTGENDIR)
|
||||
LIBS := \
|
||||
|
@ -46,6 +46,7 @@ py_test(
|
||||
name = "learning_test",
|
||||
srcs = ["python/slim/learning_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/slim",
|
||||
|
@ -27,7 +27,7 @@ import abc
|
||||
|
||||
from tensorflow.contrib.slim.python.slim.data import data_decoder
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
@ -189,11 +189,11 @@ class Tensor(ItemHandler):
|
||||
shape_dims = []
|
||||
for k in self._shape_keys:
|
||||
shape_dim = keys_to_tensors[k]
|
||||
if isinstance(shape_dim, ops.SparseTensor):
|
||||
if isinstance(shape_dim, sparse_tensor.SparseTensor):
|
||||
shape_dim = sparse_ops.sparse_tensor_to_dense(shape_dim)
|
||||
shape_dims.append(shape_dim)
|
||||
shape = array_ops.reshape(array_ops.pack(shape_dims), [-1])
|
||||
if isinstance(tensor, ops.SparseTensor):
|
||||
if isinstance(tensor, sparse_tensor.SparseTensor):
|
||||
if shape is not None:
|
||||
tensor = sparse_ops.sparse_reshape(tensor, shape)
|
||||
tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
|
||||
@ -241,7 +241,7 @@ class SparseTensor(ItemHandler):
|
||||
values = keys_to_tensors[self._values_key]
|
||||
if self._shape_key:
|
||||
shape = keys_to_tensors[self._shape_key]
|
||||
if isinstance(shape, ops.SparseTensor):
|
||||
if isinstance(shape, sparse_tensor.SparseTensor):
|
||||
shape = sparse_ops.sparse_tensor_to_dense(shape)
|
||||
elif self._shape:
|
||||
shape = self._shape
|
||||
@ -255,7 +255,7 @@ class SparseTensor(ItemHandler):
|
||||
new_indices = array_ops.concat(1, [indices_columns_to_preserve,
|
||||
array_ops.reshape(ids, [-1, 1])])
|
||||
|
||||
tensor = ops.SparseTensor(new_indices, values.values, shape)
|
||||
tensor = sparse_tensor.SparseTensor(new_indices, values.values, shape)
|
||||
if self._densify:
|
||||
tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
|
||||
return tensor
|
||||
|
@ -132,11 +132,9 @@ REGISTER_OP("FinishedNodes")
|
||||
.Attr("num_split_after_samples: int")
|
||||
.Attr("min_split_samples: int")
|
||||
.Attr("dominate_fraction: float = 0.99")
|
||||
// TODO(thomaswc): Test out bootstrap on several datasets, confirm it
|
||||
// works well, make it the default.
|
||||
.Attr(
|
||||
"dominate_method:"
|
||||
" {'none', 'hoeffding', 'bootstrap', 'chebyshev'} = 'hoeffding'")
|
||||
" {'none', 'hoeffding', 'bootstrap', 'chebyshev'} = 'bootstrap'")
|
||||
.Attr("random_seed: int = 0")
|
||||
.Input("leaves: int32")
|
||||
.Input("node_to_accumulator: int32")
|
||||
|
@ -26,7 +26,7 @@ namespace tensorflow {
|
||||
|
||||
TEST(TrainingOpsTest, UpdateFertileSlots_ShapeFn) {
|
||||
ShapeInferenceTestOp op("UpdateFertileSlots");
|
||||
INFER_OK(op, "?;?;?;?;?;?;?", "[2,?];[2,?];[?];[?]");
|
||||
INFER_OK(op, "?;?;?;?;?;?;?;?", "[2,?];[2,?];[?];[?]");
|
||||
}
|
||||
|
||||
TEST(TrainingOpsTest, ScatterAddNdim_ShapeFn) {
|
||||
|
@ -55,23 +55,29 @@ T Sum(Tensor counts) {
|
||||
// is stored in index 0, individual feature types start at index 1.
|
||||
DataColumnTypes FeatureSpec(int32 input_feature, const Tensor& spec);
|
||||
|
||||
// Given an Eigen::Tensor type, calculate the Gini impurity, which we use
|
||||
// to determine the best split (lowest) and which nodes to allocate first
|
||||
// (highest).
|
||||
template<typename T>
|
||||
float WeightedGiniImpurity(const T& counts) {
|
||||
// Given an Eigen::Tensor type, calculate the Gini impurity.
|
||||
template <typename T>
|
||||
float RawWeightedGiniImpurity(const T& counts) {
|
||||
// Our split score is the Gini impurity times the number of examples
|
||||
// seen by the leaf. If c(i) denotes the i-th class count and c = sum_i c(i)
|
||||
// then
|
||||
// score = c * (1 - sum_i ( c(i) / c )^2 )
|
||||
// = c - sum_i c(i)^2 / c
|
||||
const auto smoothed = counts + counts.constant(1.0f);
|
||||
const auto sum = smoothed.sum();
|
||||
const auto sum2 = smoothed.square().sum();
|
||||
const auto sum = counts.sum();
|
||||
const auto sum2 = counts.square().sum();
|
||||
Eigen::Tensor<float, 0, Eigen::RowMajor> ret = sum - (sum2 / sum);
|
||||
return ret(0);
|
||||
}
|
||||
|
||||
// Given an Eigen::Tensor type, calculate the smoothed Gini impurity, which we
|
||||
// use to determine the best split (lowest) and which nodes to allocate first
|
||||
// (highest).
|
||||
template <typename T>
|
||||
float WeightedGiniImpurity(const T& counts) {
|
||||
const auto smoothed = counts + counts.constant(1.0f);
|
||||
return RawWeightedGiniImpurity(smoothed);
|
||||
}
|
||||
|
||||
template<typename T1, typename T2>
|
||||
float WeightedVariance(const T1& sums, const T2& squares, float count) {
|
||||
const auto e_x = sums / count;
|
||||
|
@ -48,6 +48,7 @@ REGISTER_OP("UpdateFertileSlots")
|
||||
.Input("accumulator_sums: float")
|
||||
.Input("node_to_accumulator: int32")
|
||||
.Input("stale_leaves: int32")
|
||||
.Input("node_sums: float")
|
||||
.Output("node_to_accumulator_map_updates: int32")
|
||||
.Output("accumulator_to_node_map_updates: int32")
|
||||
.Output("accumulators_cleared: int32")
|
||||
@ -84,6 +85,8 @@ node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
|
||||
fertile node i, or -1 if node i isn't fertile.
|
||||
stale_leaves:= A 1-d int32 tensor containing the indices of all leaves that
|
||||
have stopped accumulating statistics because they are too old.
|
||||
node_sums: `node_sums[n][c]` records how many
|
||||
training examples have class c and have ended up in node n.
|
||||
node_to_accumulator_map_updates:= A 2-d int32 tensor describing the changes
|
||||
that need to be applied to the node_to_accumulator map. Intended to be used
|
||||
with
|
||||
@ -121,6 +124,7 @@ class UpdateFertileSlots : public OpKernel {
|
||||
const Tensor& accumulator_sums = context->input(4);
|
||||
const Tensor& node_to_accumulator = context->input(5);
|
||||
const Tensor& stale_leaves = context->input(6);
|
||||
const Tensor& node_sums = context->input(7);
|
||||
|
||||
OP_REQUIRES(context, finished.shape().dims() == 1,
|
||||
errors::InvalidArgument(
|
||||
@ -204,6 +208,8 @@ class UpdateFertileSlots : public OpKernel {
|
||||
non_fertile_leaves, non_fertile_leaf_scores, eot, num_new_leaves,
|
||||
static_cast<int32>(accumulator_sums.shape().dim_size(1)), &leaf_heap);
|
||||
|
||||
const auto sums = node_sums.unaligned_flat<float>();
|
||||
const int32 num_columns = node_sums.shape().dim_size(1);
|
||||
// Allocate leaves.
|
||||
std::unique_ptr<HeapValuesType> values(
|
||||
leaf_heap.Extract());
|
||||
@ -218,6 +224,18 @@ class UpdateFertileSlots : public OpKernel {
|
||||
VLOG(1) << "No allocators left.";
|
||||
break;
|
||||
}
|
||||
// For classification, don't make a node fertile until it is unpure.
|
||||
if (!regression_) {
|
||||
// Add 1 here because index 0 contains the sum of the weights across
|
||||
// classes.
|
||||
Eigen::array<int, 1> offsets = {node.first * num_columns + 1};
|
||||
Eigen::array<int, 1> extents = {num_columns - 1};
|
||||
const auto node_counts = sums.slice(offsets, extents);
|
||||
// TODO(thomaswc): Implement a faster check for pure nodes.
|
||||
if (tensorforest::RawWeightedGiniImpurity(node_counts) == 0) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
VLOG(1) << "setting node " << node.first << " to accumulator "
|
||||
<< accumulator;
|
||||
++num_accumulators_allocated;
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import load_library
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
@ -77,7 +78,7 @@ def _ParseSparse(data):
|
||||
ValueError: If data contains non-string Tensors.
|
||||
"""
|
||||
for k in sorted(data.keys()):
|
||||
if not isinstance(data[k], ops.SparseTensor):
|
||||
if not isinstance(data[k], sparse_tensor.SparseTensor):
|
||||
raise NotImplementedError(
|
||||
'Features should be either all sparse or all dense. Use a '
|
||||
'feature engineering function to convert some of them.')
|
||||
@ -133,7 +134,7 @@ def ParseDataTensorOrDict(data):
|
||||
# If there's at least one sparse tensor, everything has to be sparse.
|
||||
is_sparse = False
|
||||
for v in data.values():
|
||||
if isinstance(v, ops.SparseTensor):
|
||||
if isinstance(v, sparse_tensor.SparseTensor):
|
||||
is_sparse = True
|
||||
break
|
||||
if is_sparse:
|
||||
@ -161,11 +162,11 @@ def ParseLabelTensorOrDict(labels):
|
||||
"""
|
||||
if isinstance(labels, dict):
|
||||
return math_ops.to_float(array_ops.concat(
|
||||
1, [sparse_ops.sparse_tensor_to_dense(labels[
|
||||
k], default_value=-1) if isinstance(labels, ops.SparseTensor) else
|
||||
labels[k] for k in sorted(labels.keys())]))
|
||||
1, [sparse_ops.sparse_tensor_to_dense(labels[k], default_value=-1)
|
||||
if isinstance(labels, sparse_tensor.SparseTensor)
|
||||
else labels[k] for k in sorted(labels.keys())]))
|
||||
else:
|
||||
if isinstance(labels, ops.SparseTensor):
|
||||
if isinstance(labels, sparse_tensor.SparseTensor):
|
||||
return math_ops.to_float(sparse_ops.sparse_tensor_to_dense(
|
||||
labels, default_value=-1))
|
||||
else:
|
||||
|
@ -40,6 +40,8 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
|
||||
self.total_counts = [[80., 40., 40.]]
|
||||
self.ops = training_ops.Load()
|
||||
self.stale_leaves = []
|
||||
self.node_sums = [[3, 1, 2], [4, 2, 2], [5, 2, 3], [6, 1, 5], [7, 5, 2],
|
||||
[8, 4, 4], [9, 7, 2]]
|
||||
|
||||
def testSimple(self):
|
||||
with self.test_session():
|
||||
@ -47,7 +49,7 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
|
||||
accumulators_allocated) = self.ops.update_fertile_slots(
|
||||
self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
|
||||
self.end_of_tree, self.total_counts, self.node_map,
|
||||
self.stale_leaves)
|
||||
self.stale_leaves, self.node_sums)
|
||||
|
||||
self.assertAllEqual([[2, 4], [-1, 0]], n2a_map_updates.eval())
|
||||
self.assertAllEqual([[0], [4]], a2n_map_updates.eval())
|
||||
@ -60,13 +62,27 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
|
||||
accumulators_allocated) = self.ops.update_fertile_slots(
|
||||
[], self.non_fertile_leaves, self.non_fertile_leaf_scores,
|
||||
self.end_of_tree, self.total_counts, self.node_map,
|
||||
self.stale_leaves)
|
||||
self.stale_leaves, self.node_sums)
|
||||
|
||||
self.assertAllEqual((2, 0), n2a_map_updates.eval().shape)
|
||||
self.assertAllEqual((2, 0), a2n_map_updates.eval().shape)
|
||||
self.assertAllEqual([], accumulators_cleared.eval())
|
||||
self.assertAllEqual([], accumulators_allocated.eval())
|
||||
|
||||
def testPureCounts(self):
|
||||
with self.test_session():
|
||||
self.node_sums[4] = [10, 0, 10]
|
||||
(n2a_map_updates, a2n_map_updates, accumulators_cleared,
|
||||
accumulators_allocated) = self.ops.update_fertile_slots(
|
||||
self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
|
||||
self.end_of_tree, self.total_counts, self.node_map,
|
||||
self.stale_leaves, self.node_sums)
|
||||
|
||||
self.assertAllEqual([[2, 3], [-1, 0]], n2a_map_updates.eval())
|
||||
self.assertAllEqual([[0], [3]], a2n_map_updates.eval())
|
||||
self.assertAllEqual([], accumulators_cleared.eval())
|
||||
self.assertAllEqual([0], accumulators_allocated.eval())
|
||||
|
||||
def testBadInput(self):
|
||||
del self.non_fertile_leaf_scores[-1]
|
||||
with self.test_session():
|
||||
@ -76,7 +92,7 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
|
||||
(n2a_map_updates, _, _, _) = self.ops.update_fertile_slots(
|
||||
self.finished, self.non_fertile_leaves,
|
||||
self.non_fertile_leaf_scores, self.end_of_tree, self.total_counts,
|
||||
self.node_map, self.stale_leaves)
|
||||
self.node_map, self.stale_leaves, self.node_sums)
|
||||
self.assertAllEqual((2, 0), n2a_map_updates.eval().shape)
|
||||
|
||||
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.contrib.tensor_forest.python.ops import training_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -629,7 +630,7 @@ class RandomTreeGraphs(object):
|
||||
sparse_indices = []
|
||||
sparse_values = []
|
||||
sparse_shape = []
|
||||
if isinstance(input_data, ops.SparseTensor):
|
||||
if isinstance(input_data, sparse_tensor.SparseTensor):
|
||||
sparse_indices = input_data.indices
|
||||
sparse_values = input_data.values
|
||||
sparse_shape = input_data.shape
|
||||
@ -780,6 +781,7 @@ class RandomTreeGraphs(object):
|
||||
self.variables.accumulator_sums,
|
||||
self.variables.node_to_accumulator_map,
|
||||
stale,
|
||||
self.variables.node_sums,
|
||||
regression=self.params.regression))
|
||||
|
||||
# Ensure end_of_tree doesn't get updated until UpdateFertileSlots has
|
||||
@ -881,7 +883,7 @@ class RandomTreeGraphs(object):
|
||||
sparse_indices = []
|
||||
sparse_values = []
|
||||
sparse_shape = []
|
||||
if isinstance(input_data, ops.SparseTensor):
|
||||
if isinstance(input_data, sparse_tensor.SparseTensor):
|
||||
sparse_indices = input_data.indices
|
||||
sparse_values = input_data.values
|
||||
sparse_shape = input_data.shape
|
||||
|
@ -15,9 +15,16 @@ py_library(
|
||||
"python/training/resample.py",
|
||||
"python/training/sampling_ops.py",
|
||||
"python/training/sequence_queueing_state_saver.py",
|
||||
"python/training/training.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -37,6 +44,7 @@ py_test(
|
||||
size = "medium",
|
||||
srcs = ["python/training/batch_sequences_with_states_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":training_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -73,7 +81,10 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/training/sampling_ops_threading_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["notsan"],
|
||||
tags = [
|
||||
"manual",
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
":training_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -86,6 +97,20 @@ py_test(
|
||||
size = "medium",
|
||||
srcs = ["python/training/bucket_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":training_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "training_test",
|
||||
size = "large",
|
||||
srcs = ["python/training/training_test.py"],
|
||||
shard_count = 3,
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":training_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
|
@ -70,6 +70,11 @@ from tensorflow.contrib.training.python.training.bucket_ops import *
|
||||
from tensorflow.contrib.training.python.training.resample import *
|
||||
from tensorflow.contrib.training.python.training.sampling_ops import *
|
||||
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
||||
from tensorflow.contrib.training.python.training.training import add_gradients_summaries
|
||||
from tensorflow.contrib.training.python.training.training import clip_gradient_norms
|
||||
from tensorflow.contrib.training.python.training.training import create_train_op
|
||||
from tensorflow.contrib.training.python.training.training import multiply_gradients
|
||||
from tensorflow.contrib.training.python.training.training import train
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
316
tensorflow/contrib/training/python/training/training.py
Normal file
316
tensorflow/contrib/training/python/training/training.py
Normal file
@ -0,0 +1,316 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains various routines and helper functions for training models.
|
||||
|
||||
TODO(nsilberman): Port documentation.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.python import summary
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import optimizer as tf_optimizer
|
||||
|
||||
# TODO(nsilberman): move add_gradients_summaries, clip_gradient_norms and
|
||||
# multiply_gradients into contrib/summaries and contrib/optimizers.py
|
||||
__all__ = [
|
||||
'add_gradients_summaries',
|
||||
'clip_gradient_norms',
|
||||
'create_train_op',
|
||||
'multiply_gradients',
|
||||
'train',
|
||||
]
|
||||
|
||||
|
||||
def add_gradients_summaries(grads_and_vars):
|
||||
"""Add summaries to gradients.
|
||||
|
||||
Args:
|
||||
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||
|
||||
Returns:
|
||||
The list of created summaries.
|
||||
"""
|
||||
summaries = []
|
||||
for grad, var in grads_and_vars:
|
||||
if grad is not None:
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
grad_values = grad.values
|
||||
else:
|
||||
grad_values = grad
|
||||
summaries.append(summary.histogram_summary(
|
||||
var.op.name + ':gradient', grad_values))
|
||||
summaries.append(summary.histogram_summary(
|
||||
var.op.name + ':gradient_norm', clip_ops.global_norm([grad_values])))
|
||||
else:
|
||||
logging.info('Var %s has no gradient', var.op.name)
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
def clip_gradient_norms(gradients_to_variables, max_norm):
|
||||
"""Clips the gradients by the given value.
|
||||
|
||||
Args:
|
||||
gradients_to_variables: A list of gradient to variable pairs (tuples).
|
||||
max_norm: the maximum norm value.
|
||||
|
||||
Returns:
|
||||
A list of clipped gradient to variable pairs.
|
||||
"""
|
||||
clipped_grads_and_vars = []
|
||||
for grad, var in gradients_to_variables:
|
||||
if grad is not None:
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
|
||||
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
|
||||
else:
|
||||
grad = clip_ops.clip_by_norm(grad, max_norm)
|
||||
clipped_grads_and_vars.append((grad, var))
|
||||
return clipped_grads_and_vars
|
||||
|
||||
|
||||
def multiply_gradients(grads_and_vars, gradient_multipliers):
|
||||
"""Multiply specified gradients.
|
||||
|
||||
Args:
|
||||
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||
gradient_multipliers: A map from either `Variables` or `Variable` op names
|
||||
to the coefficient by which the associated gradient should be scaled.
|
||||
|
||||
Returns:
|
||||
The updated list of gradient to variable pairs.
|
||||
|
||||
Raises:
|
||||
ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
|
||||
is empty or None or if `gradient_multipliers` is not a dictionary.
|
||||
"""
|
||||
if not isinstance(grads_and_vars, list):
|
||||
raise ValueError('`grads_and_vars` must be a list.')
|
||||
if not gradient_multipliers:
|
||||
raise ValueError('`gradient_multipliers` is empty.')
|
||||
if not isinstance(gradient_multipliers, dict):
|
||||
raise ValueError('`gradient_multipliers` must be a dict.')
|
||||
|
||||
multiplied_grads_and_vars = []
|
||||
for grad, var in grads_and_vars:
|
||||
if var in gradient_multipliers or var.op.name in gradient_multipliers:
|
||||
key = var if var in gradient_multipliers else var.op.name
|
||||
if grad is None:
|
||||
raise ValueError('Requested multiple of `None` gradient.')
|
||||
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
tmp = grad.values * constant_op.constant(
|
||||
gradient_multipliers[key], dtype=grad.dtype)
|
||||
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
|
||||
else:
|
||||
grad *= constant_op.constant(
|
||||
gradient_multipliers[key], dtype=grad.dtype)
|
||||
multiplied_grads_and_vars.append((grad, var))
|
||||
return multiplied_grads_and_vars
|
||||
|
||||
|
||||
def create_train_op(total_loss,
|
||||
optimizer,
|
||||
global_step=None,
|
||||
update_ops=None,
|
||||
variables_to_train=None,
|
||||
transform_grads_fn=None,
|
||||
summarize_gradients=False,
|
||||
gate_gradients=tf_optimizer.Optimizer.GATE_OP,
|
||||
aggregation_method=None,
|
||||
colocate_gradients_with_ops=False):
|
||||
"""Creates an `Operation` that evaluates the gradients and returns the loss.
|
||||
|
||||
Args:
|
||||
total_loss: A `Tensor` representing the total loss.
|
||||
optimizer: A tf.Optimizer to use for computing the gradients.
|
||||
global_step: A `Tensor` representing the global step variable. If left as
|
||||
`None`, then slim.variables.global_step() is used.
|
||||
update_ops: An optional list of updates to execute. If `update_ops` is
|
||||
`None`, then the update ops are set to the contents of the
|
||||
`tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
|
||||
it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
|
||||
a warning will be displayed.
|
||||
variables_to_train: an optional list of variables to train. If None, it will
|
||||
default to all tf.trainable_variables().
|
||||
transform_grads_fn: A function which takes a single argument, a list of
|
||||
gradient to variable pairs (tuples), performs any requested gradient
|
||||
updates, such as gradient clipping or multipliers, and returns the updated
|
||||
list.
|
||||
summarize_gradients: Whether or not add summaries for each gradient.
|
||||
gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
|
||||
aggregation_method: Specifies the method used to combine gradient terms.
|
||||
Valid values are defined in the class `AggregationMethod`.
|
||||
colocate_gradients_with_ops: Whether or not to try colocating the gradients
|
||||
with the ops that generated them.
|
||||
|
||||
Returns:
|
||||
A `Tensor` that when evaluated, computes the gradients and returns the total
|
||||
loss value.
|
||||
"""
|
||||
if global_step is None:
|
||||
global_step = variables.get_or_create_global_step()
|
||||
|
||||
# Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None.
|
||||
global_update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
|
||||
if update_ops is None:
|
||||
update_ops = global_update_ops
|
||||
else:
|
||||
update_ops = set(update_ops)
|
||||
if not global_update_ops.issubset(update_ops):
|
||||
logging.warning('update_ops in create_train_op does not contain all the '
|
||||
' update_ops in GraphKeys.UPDATE_OPS')
|
||||
|
||||
# Make sure update_ops are computed before total_loss.
|
||||
if update_ops:
|
||||
with ops.control_dependencies(update_ops):
|
||||
barrier = control_flow_ops.no_op(name='update_barrier')
|
||||
total_loss = control_flow_ops.with_dependencies([barrier], total_loss)
|
||||
|
||||
if variables_to_train is None:
|
||||
# Default to tf.trainable_variables()
|
||||
variables_to_train = tf_variables.trainable_variables()
|
||||
else:
|
||||
# Make sure that variables_to_train are in tf.trainable_variables()
|
||||
for v in variables_to_train:
|
||||
assert v in tf_variables.trainable_variables()
|
||||
|
||||
assert variables_to_train
|
||||
|
||||
# Create the gradients. Note that apply_gradients adds the gradient
|
||||
# computation to the current graph.
|
||||
grads = optimizer.compute_gradients(
|
||||
total_loss,
|
||||
variables_to_train,
|
||||
gate_gradients=gate_gradients,
|
||||
aggregation_method=aggregation_method,
|
||||
colocate_gradients_with_ops=colocate_gradients_with_ops)
|
||||
|
||||
if transform_grads_fn:
|
||||
grads = transform_grads_fn(grads)
|
||||
|
||||
# Summarize gradients.
|
||||
if summarize_gradients:
|
||||
with ops.name_scope('summarize_grads'):
|
||||
add_gradients_summaries(grads)
|
||||
|
||||
# Create gradient updates.
|
||||
grad_updates = optimizer.apply_gradients(grads, global_step=global_step)
|
||||
|
||||
with ops.name_scope('train_op'):
|
||||
# Make sure total_loss is valid.
|
||||
total_loss = array_ops.check_numerics(total_loss,
|
||||
'LossTensor is inf or nan')
|
||||
|
||||
# Ensure the train_tensor computes grad_updates.
|
||||
return control_flow_ops.with_dependencies([grad_updates], total_loss)
|
||||
|
||||
|
||||
def train(
|
||||
train_op,
|
||||
logdir,
|
||||
master='',
|
||||
is_chief=True,
|
||||
scaffold=None,
|
||||
hooks=None,
|
||||
chief_only_hooks=None,
|
||||
save_checkpoint_secs=600,
|
||||
save_summaries_steps=100,
|
||||
config=None):
|
||||
"""Runs the training loop.
|
||||
|
||||
Args:
|
||||
train_op: A `Tensor` that, when executed, will apply the gradients and
|
||||
return the loss value.
|
||||
logdir: The directory where the graph and checkpoints are saved.
|
||||
master: The URL of the master.
|
||||
is_chief: Specifies whether or not the training is being run by the primary
|
||||
replica during replica training.
|
||||
scaffold: An tf.train.Scaffold instance.
|
||||
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
|
||||
training loop.
|
||||
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
|
||||
inside the training loop for the chief trainer only.
|
||||
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
|
||||
using a default checkpoint saver. If `save_checkpoint_secs` is set to
|
||||
`None`, then the default checkpoint saver isn't used.
|
||||
save_summaries_steps: The frequency, in number of global steps, that the
|
||||
summaries are written to disk using a default summary saver. If
|
||||
`save_summaries_steps` is set to `None`, then the default summary saver
|
||||
isn't used.
|
||||
config: An instance of `tf.ConfigProto`.
|
||||
|
||||
Returns:
|
||||
the value of the loss function after training.
|
||||
|
||||
Raises:
|
||||
ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
|
||||
`save_summaries_steps` are `None.
|
||||
"""
|
||||
# TODO(nsilberman): move this logic into monitored_session.py
|
||||
scaffold = scaffold or monitored_session.Scaffold()
|
||||
|
||||
hooks = hooks or []
|
||||
|
||||
if is_chief:
|
||||
session_creator = monitored_session.ChiefSessionCreator(
|
||||
scaffold=scaffold,
|
||||
checkpoint_dir=logdir,
|
||||
master=master,
|
||||
config=config)
|
||||
|
||||
if chief_only_hooks:
|
||||
hooks.extend(chief_only_hooks)
|
||||
|
||||
hooks.append(basic_session_run_hooks.StepCounterHook(
|
||||
output_dir=logdir))
|
||||
|
||||
if save_summaries_steps:
|
||||
if logdir is None:
|
||||
raise ValueError(
|
||||
'logdir cannot be None when save_summaries_steps is None')
|
||||
hooks.append(basic_session_run_hooks.SummarySaverHook(
|
||||
scaffold=scaffold,
|
||||
save_steps=save_summaries_steps,
|
||||
output_dir=logdir))
|
||||
|
||||
if save_checkpoint_secs:
|
||||
if logdir is None:
|
||||
raise ValueError(
|
||||
'logdir cannot be None when save_checkpoint_secs is None')
|
||||
hooks.append(basic_session_run_hooks.CheckpointSaverHook(
|
||||
logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
|
||||
else:
|
||||
session_creator = monitored_session.WorkerSessionCreator(
|
||||
scaffold=scaffold, master=master, config=config)
|
||||
|
||||
with monitored_session.MonitoredSession(
|
||||
session_creator=session_creator, hooks=hooks) as session:
|
||||
loss = None
|
||||
while not session.should_stop():
|
||||
loss = session.run(train_op)
|
||||
return loss
|
514
tensorflow/contrib/training/python/training/training_test.py
Normal file
514
tensorflow/contrib/training/python/training/training_test.py
Normal file
@ -0,0 +1,514 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.training.training."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def logistic_classifier(inputs):
|
||||
return tf.contrib.layers.fully_connected(
|
||||
inputs, 1, activation_fn=tf.sigmoid)
|
||||
|
||||
|
||||
def batchnorm_classifier(inputs):
|
||||
inputs = tf.contrib.layers.batch_norm(inputs, decay=0.1)
|
||||
return tf.contrib.layers.fully_connected(inputs, 1, activation_fn=tf.sigmoid)
|
||||
|
||||
|
||||
class CreateTrainOpTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
np.random.seed(0)
|
||||
|
||||
# Create an easy training set:
|
||||
self._inputs = np.random.rand(16, 4).astype(np.float32)
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
|
||||
def testUseUpdateOps(self):
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
expected_mean = np.mean(self._inputs, axis=(0))
|
||||
expected_var = np.var(self._inputs, axis=(0))
|
||||
|
||||
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
|
||||
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
|
||||
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||
'moving_variance')[0]
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Initialize all variables
|
||||
sess.run(tf.initialize_all_variables())
|
||||
mean, variance = sess.run([moving_mean, moving_variance])
|
||||
# After initialization moving_mean == 0 and moving_variance == 1.
|
||||
self.assertAllClose(mean, [0] * 4)
|
||||
self.assertAllClose(variance, [1] * 4)
|
||||
|
||||
for _ in range(10):
|
||||
sess.run([train_op])
|
||||
mean = moving_mean.eval()
|
||||
variance = moving_variance.eval()
|
||||
# After 10 updates with decay 0.1 moving_mean == expected_mean and
|
||||
# moving_variance == expected_var.
|
||||
self.assertAllClose(mean, expected_mean)
|
||||
self.assertAllClose(variance, expected_var)
|
||||
|
||||
def testEmptyUpdateOps(self):
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, update_ops=[])
|
||||
|
||||
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
|
||||
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||
'moving_variance')[0]
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Initialize all variables
|
||||
sess.run(tf.initialize_all_variables())
|
||||
mean, variance = sess.run([moving_mean, moving_variance])
|
||||
# After initialization moving_mean == 0 and moving_variance == 1.
|
||||
self.assertAllClose(mean, [0] * 4)
|
||||
self.assertAllClose(variance, [1] * 4)
|
||||
|
||||
for _ in range(10):
|
||||
sess.run([train_op])
|
||||
mean = moving_mean.eval()
|
||||
variance = moving_variance.eval()
|
||||
|
||||
# Since we skip update_ops the moving_vars are not updated.
|
||||
self.assertAllClose(mean, [0] * 4)
|
||||
self.assertAllClose(variance, [1] * 4)
|
||||
|
||||
|
||||
class TrainBNClassifierTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create an easy training set:
|
||||
np.random.seed(0)
|
||||
|
||||
self._inputs = np.zeros((16, 4))
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
self._logdir = os.path.join(self.get_temp_dir(), 'tmp_bnlogs/')
|
||||
|
||||
for i in range(16):
|
||||
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||
self._inputs[i, j] = 1
|
||||
|
||||
def testTrainWithNoInitAssignCanAchieveZeroLoss(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, self._logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=300)
|
||||
])
|
||||
self.assertLess(loss, .1)
|
||||
|
||||
|
||||
class TrainTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create an easy training set:
|
||||
np.random.seed(0)
|
||||
|
||||
self._inputs = np.zeros((16, 4))
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
|
||||
for i in range(16):
|
||||
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||
self._inputs[i, j] = 1
|
||||
|
||||
def testCanAchieveZeroLoss(self):
|
||||
logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss')
|
||||
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=300)
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def testTrainWithLocalVariable(self):
|
||||
logdir = os.path.join(self.get_temp_dir(), 'train_with_local_variable')
|
||||
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
local_multiplier = tf.contrib.framework.local_variable(1.0)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs) * local_multiplier
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=300)
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def testResumeTrainAchievesRoughlyTheSameLoss(self):
|
||||
number_of_steps = [300, 1, 5]
|
||||
logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss')
|
||||
|
||||
for i in range(len(number_of_steps)):
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(i)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=number_of_steps[i]),
|
||||
tf.train.CheckpointSaverHook(
|
||||
logdir, save_steps=50, saver=saver),
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0):
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(
|
||||
learning_rate=learning_rate)
|
||||
|
||||
def transform_grads_fn(grads):
|
||||
if gradient_multiplier != 1.0:
|
||||
variables = tf.trainable_variables()
|
||||
gradient_multipliers = {var: gradient_multiplier for var in variables}
|
||||
|
||||
with tf.name_scope('multiply_grads'):
|
||||
return tf.contrib.training.multiply_gradients(
|
||||
grads, gradient_multipliers)
|
||||
else:
|
||||
return grads
|
||||
|
||||
return tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, transform_grads_fn=transform_grads_fn)
|
||||
|
||||
def testTrainWithInitFromCheckpoint(self):
|
||||
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
|
||||
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
|
||||
|
||||
if tf.gfile.Exists(logdir1): # For running on jenkins.
|
||||
tf.gfile.DeleteRecursively(logdir1)
|
||||
if tf.gfile.Exists(logdir2): # For running on jenkins.
|
||||
tf.gfile.DeleteRecursively(logdir2)
|
||||
|
||||
# First, train the model one step (make sure the error is high).
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
train_op = self.create_train_op()
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir1, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=1),
|
||||
], save_checkpoint_secs=None)
|
||||
self.assertGreater(loss, .5)
|
||||
|
||||
# Next, train the model to convergence.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(1)
|
||||
train_op = self.create_train_op()
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir1, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=300),
|
||||
], save_checkpoint_secs=None)
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .02)
|
||||
|
||||
# Finally, advance the model a single step and validate that the loss is
|
||||
# still low.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(2)
|
||||
train_op = self.create_train_op()
|
||||
|
||||
model_variables = tf.all_variables()
|
||||
model_path = os.path.join(logdir1, 'model.ckpt-300')
|
||||
|
||||
assign_fn = tf.contrib.framework.assign_from_checkpoint_fn(
|
||||
model_path, model_variables)
|
||||
def init_fn(_, session):
|
||||
assign_fn(session)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op,
|
||||
logdir2,
|
||||
scaffold=tf.train.Scaffold(init_fn=init_fn),
|
||||
hooks=[tf.train.StopAtStepHook(num_steps=1)])
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .02)
|
||||
|
||||
def ModelLoss(self):
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
return tf.contrib.losses.get_total_loss()
|
||||
|
||||
def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
|
||||
logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
|
||||
if tf.gfile.Exists(logdir): # For running on jenkins.
|
||||
tf.gfile.DeleteRecursively(logdir)
|
||||
|
||||
# First, train only the weights of the model.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
weights = tf.contrib.framework.get_variables_by_name('weights')
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss,
|
||||
optimizer,
|
||||
variables_to_train=weights)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=200),
|
||||
])
|
||||
self.assertGreater(loss, .015)
|
||||
self.assertLess(loss, .05)
|
||||
|
||||
# Next, train the biases of the model.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(1)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
biases = tf.contrib.framework.get_variables_by_name('biases')
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss,
|
||||
optimizer,
|
||||
variables_to_train=biases)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=300),
|
||||
])
|
||||
self.assertGreater(loss, .015)
|
||||
self.assertLess(loss, .05)
|
||||
|
||||
# Finally, train both weights and bias to get lower loss.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(2)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=400),
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
|
||||
# First, train only the weights of the model.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
weights, biases = tf.contrib.framework.get_variables()
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
train_weights = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, variables_to_train=[weights])
|
||||
train_biases = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, variables_to_train=[biases])
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
|
||||
# Get the intial weights and biases values.
|
||||
weights_values, biases_values = sess.run([weights, biases])
|
||||
self.assertGreater(np.linalg.norm(weights_values), 0)
|
||||
self.assertAlmostEqual(np.linalg.norm(biases_values), 0)
|
||||
|
||||
# Update weights and biases.
|
||||
loss = sess.run(train_op)
|
||||
self.assertGreater(loss, .5)
|
||||
new_weights, new_biases = sess.run([weights, biases])
|
||||
|
||||
# Check that the weights and biases have been updated.
|
||||
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
|
||||
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
|
||||
|
||||
weights_values, biases_values = new_weights, new_biases
|
||||
|
||||
# Update only weights.
|
||||
loss = sess.run(train_weights)
|
||||
self.assertGreater(loss, .5)
|
||||
new_weights, new_biases = sess.run([weights, biases])
|
||||
|
||||
# Check that the weights have been updated, but biases have not.
|
||||
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
|
||||
self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0)
|
||||
weights_values = new_weights
|
||||
|
||||
# Update only biases.
|
||||
loss = sess.run(train_biases)
|
||||
self.assertGreater(loss, .5)
|
||||
new_weights, new_biases = sess.run([weights, biases])
|
||||
|
||||
# Check that the biases have been updated, but weights have not.
|
||||
self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0)
|
||||
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
|
||||
|
||||
def testTrainWithAlteredGradients(self):
|
||||
# Use the same learning rate but different gradient multipliers
|
||||
# to train two models. Model with equivalently larger learning
|
||||
# rate (i.e., learning_rate * gradient_multiplier) has smaller
|
||||
# training loss.
|
||||
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs6/')
|
||||
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs7/')
|
||||
|
||||
if tf.gfile.Exists(logdir1):
|
||||
tf.gfile.DeleteRecursively(logdir1)
|
||||
if tf.gfile.Exists(logdir2):
|
||||
tf.gfile.DeleteRecursively(logdir2)
|
||||
|
||||
multipliers = [1., 1000.]
|
||||
number_of_steps = 10
|
||||
losses = []
|
||||
learning_rate = 0.001
|
||||
|
||||
# First, train the model with equivalently smaller learning rate.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
train_op = self.create_train_op(
|
||||
learning_rate=learning_rate,
|
||||
gradient_multiplier=multipliers[0])
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir1, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=number_of_steps),
|
||||
tf.train.CheckpointSaverHook(logdir1, save_steps=50, saver=saver),
|
||||
])
|
||||
|
||||
losses.append(loss)
|
||||
self.assertGreater(loss, .5)
|
||||
|
||||
# Second, train the model with equivalently larger learning rate.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
train_op = self.create_train_op(
|
||||
learning_rate=learning_rate,
|
||||
gradient_multiplier=multipliers[1])
|
||||
saver = tf.train.Saver()
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir2, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=number_of_steps),
|
||||
tf.train.CheckpointSaverHook(logdir2, save_steps=50, saver=saver),
|
||||
])
|
||||
|
||||
losses.append(loss)
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .5)
|
||||
|
||||
# The loss of the model trained with larger learning rate should
|
||||
# be smaller.
|
||||
self.assertGreater(losses[0], losses[1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user