Merge pull request #1 from tensorflow/master

Bring in changes from tf
This commit is contained in:
Yifei Feng 2016-08-08 10:00:04 -07:00 committed by GitHub
commit ee221cb625
634 changed files with 31985 additions and 8026 deletions

View File

@ -33,10 +33,10 @@ and discussion.**
People who are a little more adventurous can also try our nightly binaries:
* Linux CPU-only: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-cp27-none-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
* Linux GPU: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-cp27-none-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/140/artifact/pip_test/whl/tensorflow-0.8.0-cp35-cp35m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-py2-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/)) / [Python 3](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-py3-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/))
* Mac GPU: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-py2-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-py3-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
* Linux CPU-only: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp35-cp35m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
* Linux GPU: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/140/artifact/pip_test/whl/tensorflow-0.8.0-cp35-cp35m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py2-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/)) / [Python 3](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py3-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/))
* Mac GPU: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py2-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py3-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
* [Android](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
#### *Try your first TensorFlow program*

View File

@ -1,16 +1,40 @@
# Changes Since Last Release
## Features and Improvements
* Connectionist Temporal Classification ops are now "official" (see, e.g.,
`tf.nn.ctc_loss`)
* Preliminary graph-construction C API, for use by language bindings.
* Major revision to the graph-construction C++ API. Scoping mechanism to make op
naming, specifying control dependencies etc. more consistent. C++ values can
be used directly as operands, making op construction more concise.
# Release 0.10.0
## Breaking Changes to the API
* `env.h` replaces use of `New*File()` functions to use `std::unique_ptr`
return arguments, removing the old raw pointer returns.
## Major Features and Improvements
* Added support for C++ shape inference
* Added graph-construction C API
* Major revision to the graph-construction C++ API
* Support makefile build for iOS
* Added Mac GPU support
* Full version of TF-Slim available as `tf.contrib.slim`
* Added k-Means clustering and WALS matrix factorization
## Big Fixes and Other Changes
* Allow gradient computation for scalar values.
* Performance improvements for gRPC
* Improved support for fp16
* New high-level ops in tf.contrib.{layers,metrics}
* New features for TensorBoard, such as shape display, exponential smoothing
* Faster and more stable Google Cloud Storage (GCS) filesystem support
* Support for zlib compression and decompression for TFRecordReader and TFRecordWriter
* Support for reading (animated) GIFs
* Improved support for SparseTensor
* Added support for more probability distributions (Dirichlet, Beta, Bernoulli, etc.)
* Added Python interfaces to reset resource containers.
* Many bugfixes and performance improvements
* Many documentation fixes
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
Alex Rothberg, Andrew Royer, Austin Marshall, @BlackCoal, Bob Adolf, Brian Diesel, Charles-Emmanuel Dias, @chemelnucfin, Chris Lesniewski, Daeyun Shin, Daniel Rodriguez, Danijar Hafner, Darcy Liu, Kristinn R. Thórisson, Daniel Castro, Dmitry Savintsev, Kashif Rasul, Dylan Paiton, Emmanuel T. Odeke, Ernest Grzybowski, Gavin Sherry, Gideon Dresdner, Gregory King, Harold Cooper, @heinzbeinz, Henry Saputra, Huarong Huo, Huazuo Gao, Igor Babuschkin, Igor Macedo Quintanilha, Ivan Ukhov, James Fysh, Jan Wilken Dörrie, Jihun Choi, Johnny Lim, Jonathan Raiman, Justin Francis, @lilac, Li Yi, Marc Khoury, Marco Marchesi, Max Melnick, Micael Carvalho, @mikowals, Mostafa Gazar, Nico Galoppo, Nishant Agrawal, Petr Janda, Yuncheng Li, @raix852, Robert Rose, @Robin-des-Bois, Rohit Girdhar, Sam Abrahams, satok16, Sergey Kishchenko, Sharkd Tu, @shotat, Siddharth Agrawal, Simon Denel, @sono-bfio, SunYeop Lee, Thijs Vogels, @tobegit3hub, @Undo1, Wang Yang, Wenjian Huang, Yaroslav Bulatov, Yuan Tang, Yunfeng Wang, Ziming Dong
We are also grateful to all who filed issues or helped resolve them, asked and
answered questions, and were part of inspiring discussions.
# Release 0.9.0
@ -55,7 +79,7 @@
This release contains contributions from many people at Google, as well as:
Aaron Schumacher, Aidan Dang, Akihiko ITOH, Aki Sukegawa, Arbit Chen, Aziz Alto, Danijar Hafner, Erik Erwitt, Fabrizio Milo, Felix Maximilian Möller, Henry Saputra, Sung Kim, Igor Babuschkin, Jan Zikes, Jeremy Barnes, Jesper Steen Møller, Johannes Mayer, Justin Harris, Kashif Rasul, Kevin Robinson, Loo Rong Jie, Lucas Moura, Łukasz Bieniasz-Krzywiec, Mario Cho, Maxim Grechkin, Michael Heilman, Mostafa Rahmani, Mourad Mourafiq, @ninotoshi, Orion Reblitz-Richardson, Yuncheng Li, @raoqiyu, Robert DiPietro, Sam Abrahams, Sebastian Raschka, Siddharth Agrawal, @snakecharmer1024, Stephen Roller, Sung Kim, SunYeop Lee, Thijs Vogels, Till Hoffmann, Victor Melo, Ville Kallioniemi, Waleed Abdulla, Wenjian Huang, Yaroslav Bulatov, Yeison Rodriguez, Yuan (Terry) Tang, Yuxin Wu, @zhongzyd, Ziming Dong, Zohar Jackson
Aaron Schumacher, Aidan Dang, Akihiko ITOH, Aki Sukegawa, Arbit Chen, Aziz Alto, Danijar Hafner, Erik Erwitt, Fabrizio Milo, Felix Maximilian Möller, Henry Saputra, Sung Kim, Igor Babuschkin, Jan Zikes, Jeremy Barnes, Jesper Steen Møller, Johannes Mayer, Justin Harris, Kashif Rasul, Kevin Robinson, Loo Rong Jie, Lucas Moura, Łukasz Bieniasz-Krzywiec, Mario Cho, Maxim Grechkin, Michael Heilman, Mostafa Rahmani, Mourad Mourafiq, @ninotoshi, Orion Reblitz-Richardson, Yuncheng Li, @raoqiyu, Robert DiPietro, Sam Abrahams, Sebastian Raschka, Siddharth Agrawal, @snakecharmer1024, Stephen Roller, Sung Kim, SunYeop Lee, Thijs Vogels, Till Hoffmann, Victor Melo, Ville Kallioniemi, Waleed Abdulla, Wenjian Huang, Yaroslav Bulatov, Yeison Rodriguez, Yuan Tang, Yuxin Wu, @zhongzyd, Ziming Dong, Zohar Jackson
We are also grateful to all who filed issues or helped resolve them, asked and
answered questions, and were part of inspiring discussions.
@ -97,7 +121,7 @@ answered questions, and were part of inspiring discussions.
This release contains contributions from many people at Google, as well as:
Abhinav Upadhyay, Aggelos Avgerinos, Alan Wu, Alexander G. de G. Matthews, Aleksandr Yahnev, @amchercashin, Andy Kitchen, Aurelien Geron, Awni Hannun, @BanditCat, Bas Veeling, Cameron Chen, @cg31, Cheng-Lung Sung, Christopher Bonnett, Dan Becker, Dan Van Boxel, Daniel Golden, Danijar Hafner, Danny Goodman, Dave Decker, David Dao, David Kretch, Dongjoon Hyun, Dustin Dorroh, @e-lin, Eurico Doirado, Erik Erwitt, Fabrizio Milo, @gaohuazuo, Iblis Lin, Igor Babuschkin, Isaac Hodes, Isaac Turner, Iván Vallés, J Yegerlehner, Jack Zhang, James Wexler, Jan Zikes, Jay Young, Jeff Hodges, @jmtatsch, Johnny Lim, Jonas Meinertz Hansen, Kanit Wongsuphasawat, Kashif Rasul, Ken Shirriff, Kenneth Mitchner, Kenta Yonekura, Konrad Magnusson, Konstantin Lopuhin, @lahwran, @lekaha, @liyongsea, Lucas Adams, @makseq, Mandeep Singh, @manipopopo, Mark Amery, Memo Akten, Michael Heilman, Michael Peteuil, Nathan Daly, Nicolas Fauchereau, @ninotoshi, Olav Nymoen, @panmari, @papelita1234, Pedro Lopes, Pranav Sailesh Mani, RJ Ryan, Rob Culliton, Robert DiPietro, @ronrest, Sam Abrahams, Sarath Shekkizhar, Scott Graham, Sebastian Raschka, Sung Kim, Surya Bhupatiraju, Syed Ahmed, Till Hoffmann, @timsl, @urimend, @vesnica, Vlad Frolov, Vlad Zagorodniy, Wei-Ting Kuo, Wenjian Huang, William Dmitri Breaden Madden, Wladimir Schmidt, Yuwen Yan, Yuxin Wu, Yuya Kusakabe, @zhongzyd, @znah.
Abhinav Upadhyay, Aggelos Avgerinos, Alan Wu, Alexander G. de G. Matthews, Aleksandr Yahnev, @amchercashin, Andy Kitchen, Aurelien Geron, Awni Hannun, @BanditCat, Bas Veeling, Cameron Chen, @cg31, Cheng-Lung Sung, Christopher Bonnett, Dan Becker, Dan Van Boxel, Daniel Golden, Danijar Hafner, Danny Goodman, Dave Decker, David Dao, David Kretch, Dongjoon Hyun, Dustin Dorroh, @e-lin, Eurico Doirado, Erik Erwitt, Fabrizio Milo, @gaohuazuo, Iblis Lin, Igor Babuschkin, Isaac Hodes, Isaac Turner, Iván Vallés, J Yegerlehner, Jack Zhang, James Wexler, Jan Zikes, Jay Young, Jeff Hodges, @jmtatsch, Johnny Lim, Jonas Meinertz Hansen, Kanit Wongsuphasawat, Kashif Rasul, Ken Shirriff, Kenneth Mitchner, Kenta Yonekura, Konrad Magnusson, Konstantin Lopuhin, @lahwran, @lekaha, @liyongsea, Lucas Adams, @makseq, Mandeep Singh, @manipopopo, Mark Amery, Memo Akten, Michael Heilman, Michael Peteuil, Nathan Daly, Nicolas Fauchereau, @ninotoshi, Olav Nymoen, @panmari, @papelita1234, Pedro Lopes, Pranav Sailesh Mani, RJ Ryan, Rob Culliton, Robert DiPietro, @ronrest, Sam Abrahams, Sarath Shekkizhar, Scott Graham, Sebastian Raschka, Sung Kim, Surya Bhupatiraju, Syed Ahmed, Till Hoffmann, @timsl, @urimend, @vesnica, Vlad Frolov, Vlad Zagorodniy, Wei-Ting Kuo, Wenjian Huang, William Dmitri Breaden Madden, Wladimir Schmidt, Yuan Tang, Yuwen Yan, Yuxin Wu, Yuya Kusakabe, @zhongzyd, @znah.
We are also grateful to all who filed issues or helped resolve them, asked and
answered questions, and were part of inspiring discussions.

View File

@ -37,7 +37,10 @@ config_setting(
package_group(
name = "internal",
packages = ["//tensorflow/..."],
packages = [
"//learning/vis/...",
"//tensorflow/...",
],
)
sh_binary(
@ -71,6 +74,7 @@ filegroup(
name = "all_opensource_files",
data = [
":all_files",
"//tensorflow/c:all_files",
"//tensorflow/cc:all_files",
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/copy_graph:all_files",
@ -103,6 +107,7 @@ filegroup(
"//tensorflow/contrib/testing:all_files",
"//tensorflow/contrib/util:all_files",
"//tensorflow/core:all_files",
"//tensorflow/core/debug:all_files",
"//tensorflow/core/distributed_runtime:all_files",
"//tensorflow/core/distributed_runtime/rpc:all_files",
"//tensorflow/core/kernels:all_files",

95
tensorflow/c/BUILD Normal file
View File

@ -0,0 +1,95 @@
# Description:
# C API for TensorFlow, for use by client language bindings.
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_cuda_library",
)
# For platform specific build config
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_kernel_tests_linkstatic",
)
# -----------------------------------------------------------------------------
# Public targets
tf_cuda_library(
name = "c_api",
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
tf_cuda_library(
name = "tf_status_helper",
srcs = ["tf_status_helper.cc"],
hdrs = ["tf_status_helper.h"],
visibility = ["//visibility:public"],
deps = [
":c_api",
"//tensorflow/core:lib",
],
)
tf_cuda_library(
name = "checkpoint_reader",
srcs = ["checkpoint_reader.cc"],
hdrs = ["checkpoint_reader.h"],
visibility = ["//visibility:public"],
deps = [
":tf_status_helper",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
# -----------------------------------------------------------------------------
# Tests
tf_cc_test(
name = "c_api_test",
size = "small",
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:math",
"//third_party/eigen3",
],
)
# -----------------------------------------------------------------------------
# Google-internal targets.
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/public/tensor_c_api.h"
#include "tensorflow/c/c_api.h"
#include <memory>
#include <vector>
@ -482,7 +482,6 @@ static void TF_Run_Helper(
result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
}
if (!result.ok()) {
LOG(ERROR) << result.error_message();
status->status = result;
return;
}

View File

@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// TODO(jeff,sanjay): Rename to tensorflow/public/c_api.h
#ifndef TENSORFLOW_PUBLIC_TENSOR_C_API_H_
#define TENSORFLOW_PUBLIC_TENSOR_C_API_H_
#ifndef TENSORFLOW_C_C_API_H_
#define TENSORFLOW_C_C_API_H_
#include <stddef.h>
#include <stdint.h>
@ -699,4 +698,4 @@ extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);
} /* end extern "C" */
#endif
#endif // TENSORFLOW_PUBLIC_TENSOR_C_API_H_
#endif // TENSORFLOW_C_C_API_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/public/tensor_c_api.h"
#include "tensorflow/c/c_api.h"
#include <vector>
#include "tensorflow/core/framework/graph.pb_text.h"

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/checkpoint_reader.h"
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H
#define TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H
#ifndef TENSORFLOW_C_CHECKPOINT_READER_H
#define TENSORFLOW_C_CHECKPOINT_READER_H
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/tensor_slice_reader.h"
#include "tensorflow/core/util/tf_status_helper.h"
namespace tensorflow {
@ -60,4 +60,4 @@ class CheckpointReader {
} // namespace checkpoint
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H
#endif // TENSORFLOW_C_CHECKPOINT_READER_H

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/tf_status_helper.h"
#include "tensorflow/c/tf_status_helper.h"
namespace tensorflow {

View File

@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H
#define TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H
#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H
#define TENSORFLOW_C_TF_STATUS_HELPER_H
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/public/tensor_c_api.h"
namespace tensorflow {
@ -26,4 +26,4 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H
#endif // TENSORFLOW_C_TF_STATUS_HELPER_H

View File

@ -73,6 +73,46 @@ tf_cc_test(
],
)
cc_library(
name = "grad_op_registry",
srcs = ["framework/grad_op_registry.cc"],
hdrs = ["framework/grad_op_registry.h"],
deps = [
":ops",
":scope",
],
)
cc_library(
name = "math_grad",
srcs = ["gradients/math_grad.cc"],
deps = [
":cc_ops",
":grad_op_registry",
":ops",
":scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
],
)
tf_cc_test(
name = "gradients/math_grad_test",
deps = [
":cc_ops",
":grad_op_registry",
":math_grad",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_gen_op_wrappers_cc(
name = "cc_ops",
op_lib_names = [

View File

@ -0,0 +1,42 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/grad_op_registry.h"
namespace tensorflow {
namespace ops {
// static
GradOpRegistry* GradOpRegistry::Global() {
static GradOpRegistry* grad_op_registry = new GradOpRegistry;
return grad_op_registry;
}
bool GradOpRegistry::Register(const string& op, GradFunc func) {
CHECK(registry_.insert({op, func}).second) << "Existing gradient for " << op;
return true;
}
Status GradOpRegistry::Lookup(const string& op, GradFunc* func) {
auto iter = registry_.find(op);
if (iter == registry_.end()) {
return errors::NotFound("No gradient defined for op: ", op);
}
*func = iter->second;
return Status::OK();
}
} // end namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,75 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
#include <unordered_map>
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
namespace tensorflow {
namespace ops {
// GradFunc is the signature for all gradient functions in GradOpRegistry.
// Implementations should add operations to compute the gradient outputs of 'op'
// (returned in 'grad_outputs') using 'scope' and 'grad_inputs'.
typedef Status (*GradFunc)(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs);
// GradOpRegistry maintains a static registry of gradient functions.
// Gradient functions are indexed in the registry by the forward op name (i.e.
// "MatMul" -> MatMulGrad func).
class GradOpRegistry {
public:
// Registers 'func' as the the gradient function for 'op'.
// Returns true if registration was succesful, check fails otherwise.
bool Register(const string& op, GradFunc func);
// Sets 'func' to the gradient function for 'op' and returns Status OK if
// the gradient function for 'op' exists in the registry.
// Note that 'func' can be null for ops that have registered no-gradient with
// the registry.
// Returns error status otherwise.
Status Lookup(const string& op, GradFunc* func);
// Returns a pointer to the global gradient function registry.
static GradOpRegistry* Global();
private:
std::unordered_map<string, GradFunc> registry_;
};
} // namespace ops
// Macros used to define gradient functions for ops.
#define REGISTER_GRADIENT_OP(name, fn) \
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, fn)
#define REGISTER_NO_GRADIENT_OP(name) \
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, nullptr)
#define REGISTER_GRADIENT_OP_UNIQ_HELPER(ctr, name, fn) \
REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn)
#define REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn) \
static bool unused_ret_val_##ctr = \
::tensorflow::ops::GradOpRegistry::Global()->Register(name, fn)
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_

View File

@ -18,6 +18,44 @@ limitations under the License.
namespace tensorflow {
namespace ops {
Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {}
Output Operation::input(int i) const {
CHECK_NOTNULL(node_);
CHECK_GE(i, 0);
CHECK_LT(i, node_->num_inputs());
// Handle the case where the input was unknown at the time this
// Operation was constructed.
if (inputs_[i].first == nullptr && inputs_[i].second == -1) {
for (const Edge* e : node_->in_edges()) {
if (e->IsControlEdge()) continue;
if (e->dst_input() == i) {
return Output(e->src(), e->src_output());
}
}
}
return Output(inputs_[i].first, inputs_[i].second);
}
Output Operation::output(int i) const {
CHECK_NOTNULL(node_);
CHECK_GE(i, 0);
CHECK_LT(i, node_->num_outputs());
return Output(node_, i);
}
Operation::Inputs Operation::GetInputs(Node* node) {
Operation::Inputs inputs;
if (node != nullptr) {
inputs.resize(node->num_inputs(), {nullptr, -1});
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) continue;
inputs[e->dst_input()] = std::make_pair(e->src(), e->src_output());
}
}
return inputs;
}
Input::Initializer::Initializer(
const std::initializer_list<Input::Initializer>& v) {
if (v.size() < 1) {

View File

@ -27,17 +27,29 @@ limitations under the License.
namespace tensorflow {
namespace ops {
class Output;
// Represents a node in the computation graph.
class Operation {
public:
Operation() : node_(nullptr) {}
explicit Operation(Node* n) : node_(n) {}
explicit Operation(Node* n);
int num_inputs() const { return node_->num_inputs(); }
DataType input_type(int o) const { return node_->input_type(o); }
Output input(int i) const;
int num_outputs() const { return node_->num_outputs(); }
DataType output_type(int o) const { return node_->output_type(o); }
Output output(int i) const;
Node* node() const { return node_; }
private:
typedef std::vector<std::pair<Node*, int64>> Inputs;
static Inputs GetInputs(Node* node);
Inputs inputs_;
Node* node_;
};
@ -81,7 +93,7 @@ class Input {
tensor = t;
}
explicit Initializer(const Tensor& t) : tensor(t) {}
Initializer(const Tensor& t) : tensor(t) {} // NOLINT(runtime/explicit)
// Construct from a scalar value and an explicit shape
template <typename T, typename = typename std::enable_if<

View File

@ -0,0 +1,91 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
namespace tensorflow {
namespace ops {
namespace {
// TODO(andydavis) Move this to a more appropriate file.
REGISTER_NO_GRADIENT_OP("Const");
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const Output& x0, const bool adj_x0,
const Output& x1, const bool adj_x1, const Output& y0,
const bool adj_y0, const Output& y1, const bool adj_y1,
std::vector<Output>* grad_outputs) {
auto dx =
MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
grad_outputs->push_back(dx);
auto dy =
MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
grad_outputs->push_back(dy);
return Status::OK();
}
// MatMulGrad common used to read and check node attr state, and determine
// proper MatMul products for gradients based on input matrix transposition
// combinations.
// TODO(andydavis) Re-use this function for BatchMatMulGrad.
Status MatMulGradCommon(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
const string& attr_adj_x, const string& attr_adj_y,
std::vector<Output>* grad_outputs) {
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
return errors::Unimplemented(
"MatMul gradient for complex data type is not supported yet.");
}
bool ta;
bool tb;
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
if (!ta && !tb) {
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), true,
op.input(0), true, grad_inputs[0], false,
grad_outputs);
} else if (!ta && tb) {
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), false,
grad_inputs[0], true, op.input(0), false,
grad_outputs);
} else if (ta && !tb) {
return MatMulGradHelper(scope, op.input(1), false, grad_inputs[0], true,
op.input(0), false, grad_inputs[0], false,
grad_outputs);
}
return MatMulGradHelper(scope, op.input(1), true, grad_inputs[0], true,
grad_inputs[0], true, op.input(0), true,
grad_outputs);
}
Status MatMulGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
return MatMulGradCommon(scope, op, grad_inputs, "transpose_a", "transpose_b",
grad_outputs);
}
REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
} // anonymous namespace
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,183 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
namespace {
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
// to a testutil library.
class MathGradTest : public ::testing::Test {
protected:
MathGradTest() : root_(Scope::NewRootScope()) {}
void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y,
const bool t_y, const Output& dz,
std::vector<Tensor>* out) {
// Compute forward MatMul: z = MatMul(x, y).
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
TF_EXPECT_OK(root_.status());
CHECK_NOTNULL(z.node());
std::vector<Output> grad_outputs;
// Call MatMulGrad which populates 'grad_outputs'.
CallGradFunction(Operation(z.node()), {dz}, &grad_outputs);
EXPECT_EQ(2, grad_outputs.size());
// Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
}
void CallGradFunction(const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
GradFunc grad_fn;
TF_EXPECT_OK(GradOpRegistry::Global()->Lookup(op.node()->name(), &grad_fn));
TF_EXPECT_OK(grad_fn(root_, op, grad_inputs, grad_outputs));
TF_EXPECT_OK(root_.status());
}
Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y,
const bool t_y) {
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
TF_EXPECT_OK(root_.status());
Tensor out;
GetTensor(root_, z, &out);
return out;
}
void RandMatMulGradData(const bool tx, const bool ty,
std::vector<Tensor>* data) {
// z = MatMul(x, y)
const int m = Rand();
const int k = Rand();
const int n = Rand();
// x.shape = [m, k]
const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
data->emplace_back(DT_FLOAT, x_shape);
RandTensor(&data->back());
// y.shape = [k, n]
const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
data->emplace_back(DT_FLOAT, y_shape);
RandTensor(&data->back());
// z.shape = [m, n]
data->emplace_back(DT_FLOAT, TensorShape({m, n}));
RandTensor(&data->back());
}
void RandTensor(Tensor* t) {
test::FillFn<float>(
t, [this](const int i) { return static_cast<float>(Rand()); });
}
int Rand() { return 1 + (random::New64() % 10); }
// TODO(andydavis) Move 'GetTensors/GetTensor' to some testutil class.
// Note: they should be moved to a general/non-grad specific testutil class.
void GetTensors(const Scope& scope, OutputList tensors,
std::vector<Tensor>* out) {
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
GraphDef def;
scope.graph()->ToGraphDef(&def);
graph::SetDefaultDevice("/cpu:0", &def);
TF_CHECK_OK(session->Create(def));
std::vector<string> names;
for (const auto& t : tensors) {
names.push_back(strings::StrCat(t.node()->name(), ":", t.index()));
}
TF_CHECK_OK(session->Run({}, names, {}, out));
TF_CHECK_OK(session->Close());
}
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
std::vector<Tensor> outputs;
GetTensors(scope, {tensor}, &outputs);
*out = outputs[0];
}
Scope root_;
};
TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
std::vector<Tensor> data;
RandMatMulGradData(false, false, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true));
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false));
}
TEST_F(MathGradTest, MatMulGrad_TransposeX) {
std::vector<Tensor> data;
RandMatMulGradData(true, false, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true));
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false));
}
TEST_F(MathGradTest, MatMulGrad_TransposeY) {
std::vector<Tensor> data;
RandMatMulGradData(false, true, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false));
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false));
}
TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
std::vector<Tensor> data;
RandMatMulGradData(true, true, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true));
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true));
}
} // namespace
} // namespace tensorflow

View File

@ -39,6 +39,7 @@ set (DOWNLOAD_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/downloads"
mark_as_advanced(DOWNLOAD_LOCATION)
# External dependencies
include(gif)
include(png)
include(jpeg)
include(re2)

View File

@ -0,0 +1,38 @@
include (ExternalProject)
set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive)
set(gif_URL http://ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1)
set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install)
set(gif_STATIC_LIBRARIES ${gif_INSTALL}/lib/libgif.a)
set(gif_HEADERS
"${gif_INSTALL}/include/gif_lib.h"
)
ExternalProject_Add(gif
PREFIX gif
URL ${gif_URL}
URL_HASH ${gif_HASH}
INSTALL_DIR ${gif_INSTALL}
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
BUILD_COMMAND $(MAKE)
INSTALL_COMMAND $(MAKE) install
CONFIGURE_COMMAND
${CMAKE_CURRENT_BINARY_DIR}/gif/src/gif/configure
--prefix=${gif_INSTALL}
--enable-shared=yes
)
# put gif includes in the directory where they are expected
add_custom_target(gif_create_destination_dir
COMMAND ${CMAKE_COMMAND} -E make_directory ${gif_INCLUDE_DIR}/giflib-5.1.4/lib
DEPENDS gif)
add_custom_target(gif_copy_headers_to_destination
DEPENDS gif_create_destination_dir)
foreach(header_file ${gif_HEADERS})
add_custom_command(TARGET gif_copy_headers_to_destination PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${header_file} ${gif_INCLUDE_DIR}/giflib-5.1.4/lib/)
endforeach()

View File

@ -1,10 +1,39 @@
########################################################
# tf_cc_framework library
########################################################
set(tf_cc_framework_srcs
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.cc"
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.cc"
)
add_library(tf_cc_framework OBJECT ${tf_cc_framework_srcs})
add_dependencies(tf_cc_framework tf_core_framework)
target_include_directories(tf_cc_framework PRIVATE
${tensorflow_source_dir}
${eigen_INCLUDE_DIRS}
)
target_compile_options(tf_cc_framework PRIVATE
-fno-exceptions
-DEIGEN_AVOID_STL_ARRAY
)
# C++11
target_compile_features(tf_cc_framework PRIVATE
cxx_rvalue_references
)
########################################################
# tf_cc_op_gen_main library
########################################################
set(tf_cc_op_gen_main_srcs
"${tensorflow_source_dir}/tensorflow/cc/ops/cc_op_gen.cc"
"${tensorflow_source_dir}/tensorflow/cc/ops/cc_op_gen_main.cc"
"${tensorflow_source_dir}/tensorflow/cc/ops/cc_op_gen.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/cc_op_gen.cc"
"${tensorflow_source_dir}/tensorflow/cc/framework/cc_op_gen_main.cc"
"${tensorflow_source_dir}/tensorflow/cc/framework/cc_op_gen.h"
)
add_library(tf_cc_op_gen_main OBJECT ${tf_cc_op_gen_main_srcs})
@ -120,6 +149,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names})
${PROTOBUF_LIBRARIES}
tf_protos_cc
re2_lib
${gif_STATIC_LIBRARIES}
${jpeg_STATIC_LIBRARIES}
${png_STATIC_LIBRARIES}
${ZLIB_LIBRARIES}

View File

@ -4,8 +4,17 @@
file(GLOB tf_core_direct_session_srcs
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.cc"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.h"
"${tensorflow_source_dir}/tensorflow/core/debug/*.h"
"${tensorflow_source_dir}/tensorflow/core/debug/*.cc"
)
file(GLOB_RECURSE tf_core_direct_session_test_srcs
"${tensorflow_source_dir}/tensorflow/core/debug/*test*.h"
"${tensorflow_source_dir}/tensorflow/core/debug/*test*.cc"
)
list(REMOVE_ITEM tf_core_direct_session_srcs ${tf_core_direct_session_test_srcs})
add_library(tf_core_direct_session OBJECT ${tf_core_direct_session_srcs})
add_dependencies(tf_core_direct_session tf_core_cpu)

View File

@ -150,6 +150,7 @@ list(REMOVE_ITEM tf_core_lib_srcs ${tf_core_lib_test_srcs})
add_library(tf_core_lib OBJECT ${tf_core_lib_srcs})
target_include_directories(tf_core_lib PUBLIC
${tensorflow_source_dir}
${gif_INCLUDE_DIR}
${jpeg_INCLUDE_DIR}
${png_INCLUDE_DIR}
${eigen_INCLUDE_DIRS}
@ -168,6 +169,7 @@ target_compile_features(tf_core_lib PRIVATE
)
add_dependencies(tf_core_lib
gif_copy_headers_to_destination
jpeg_copy_headers_to_destination
png_copy_headers_to_destination
re2_copy_headers_to_destination

View File

@ -71,7 +71,7 @@ target_include_directories(tf_models_word2vec_kernels PRIVATE
${re2_INCLUDES}
)
add_dependencies(tf_models_word2vec_ops
add_dependencies(tf_models_word2vec_kernels
tf_core_cpu
)

View File

@ -22,6 +22,7 @@ target_link_libraries(${proto_text} PUBLIC
${PROTOBUF_LIBRARIES}
# tf_protos_cc
# re2_lib
${gif_STATIC_LIBRARIES}
${jpeg_STATIC_LIBRARIES}
${png_STATIC_LIBRARIES}
${ZLIB_LIBRARIES}

View File

@ -23,6 +23,7 @@ add_executable(tf_tutorials_example_trainer
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>
$<TARGET_OBJECTS:tf_core_kernels>
$<TARGET_OBJECTS:tf_cc_framework>
$<TARGET_OBJECTS:tf_cc_ops>
$<TARGET_OBJECTS:tf_core_ops>
$<TARGET_OBJECTS:tf_core_direct_session>
@ -40,6 +41,7 @@ target_link_libraries(tf_tutorials_example_trainer PUBLIC
re2_lib
${boringssl_STATIC_LIBRARIES}
${farmhash_STATIC_LIBRARIES}
${gif_STATIC_LIBRARIES}
${jpeg_STATIC_LIBRARIES}
${jsoncpp_STATIC_LIBRARIES}
${png_STATIC_LIBRARIES}

View File

@ -54,6 +54,29 @@ cuda_py_tests(
],
)
cuda_py_tests(
name = "operator_pd_identity_test",
size = "small",
srcs = ["python/kernel_tests/operator_pd_identity_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "operator_pd_vdvt_update_test",
size = "large",
srcs = ["python/kernel_tests/operator_pd_vdvt_update_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
tags = ["notap"], # http://b/30441813
)
py_library(
name = "distributions_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
@ -76,7 +99,16 @@ cuda_py_tests(
srcs = ["python/kernel_tests/beta_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "binomial_test",
size = "small",
srcs = ["python/kernel_tests/binomial_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:platform_test",
],
tags = ["notsan"],
@ -156,9 +188,8 @@ cuda_py_tests(
)
cuda_py_tests(
name = "kullback_leibler_test",
size = "small",
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
name = "laplace_test",
srcs = ["python/kernel_tests/laplace_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
@ -167,13 +198,14 @@ cuda_py_tests(
)
cuda_py_tests(
name = "laplace_test",
srcs = ["python/kernel_tests/laplace_test.py"],
name = "multinomial_test",
srcs = ["python/kernel_tests/multinomial_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
tags = ["notsan"],
)
cuda_py_tests(
@ -216,6 +248,15 @@ cuda_py_tests(
srcs = ["python/kernel_tests/uniform_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
],
)
cuda_py_tests(
name = "kullback_leibler_test",
size = "small",
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
additional_deps = [
"//tensorflow/python:platform_test",
],
)
@ -240,6 +281,28 @@ cuda_py_tests(
],
)
cuda_py_tests(
name = "shape_test",
size = "small",
srcs = ["python/kernel_tests/shape_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "bijector_test",
size = "small",
srcs = ["python/kernel_tests/bijector_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -25,6 +25,7 @@ initialized with parameters that define the distributions.
### Univariate (scalar) distributions
@@Binomial
@@Bernoulli
@@Beta
@@Categorical
@ -50,6 +51,7 @@ initialized with parameters that define the distributions.
@@Dirichlet
@@DirichletMultinomial
@@Multinomial
### Transformed distributions
@ -79,6 +81,7 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.bernoulli import *
from tensorflow.contrib.distributions.python.ops.beta import *
from tensorflow.contrib.distributions.python.ops.binomial import *
from tensorflow.contrib.distributions.python.ops.categorical import *
from tensorflow.contrib.distributions.python.ops.chi2 import *
from tensorflow.contrib.distributions.python.ops.dirichlet import *
@ -89,6 +92,7 @@ from tensorflow.contrib.distributions.python.ops.gamma import *
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
from tensorflow.contrib.distributions.python.ops.laplace import *
from tensorflow.contrib.distributions.python.ops.multinomial import *
from tensorflow.contrib.distributions.python.ops.mvn import *
from tensorflow.contrib.distributions.python.ops.normal import *
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *

View File

@ -57,10 +57,17 @@ class BernoulliTest(tf.test.TestCase):
self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
def testInvalidP(self):
invalid_ps = [1.01, -0.01, 2., -3.]
invalid_ps = [1.01, 2.]
for p in invalid_ps:
with self.test_session():
with self.assertRaisesOpError("x <= y"):
with self.assertRaisesOpError("p has components greater than 1"):
dist = tf.contrib.distributions.Bernoulli(p=p)
dist.p.eval()
invalid_ps = [-0.01, -3.]
for p in invalid_ps:
with self.test_session():
with self.assertRaisesOpError("Condition x >= 0"):
dist = tf.contrib.distributions.Bernoulli(p=p)
dist.p.eval()

View File

@ -0,0 +1,67 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Bijector."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
from tensorflow.contrib.distributions.python.ops.bijector import _Exp # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops.bijector import _Identity # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops.shape import _ShapeUtil # pylint: disable=line-too-long
class IdentityBijectorTest(tf.test.TestCase):
"""Tests the correctness of the Y = g(X) = X transformation."""
def testBijector(self):
with self.test_session():
bijector = _Identity(_ShapeUtil(batch_ndims=1, event_ndims=1))
self.assertEqual(bijector.name, 'Identity')
x = [[[0.], [1]]]
self.assertAllEqual(bijector.forward(x).eval(), x)
self.assertAllEqual(bijector.inverse(x).eval(), x)
self.assertAllEqual(bijector.inverse_log_det_jacobian(x).eval(),
[[0., 0]])
rev, jac = bijector.inverse_and_inverse_log_det_jacobian(x)
self.assertAllEqual(rev.eval(), x)
self.assertAllEqual(jac.eval(), [[0., 0]])
class ExpBijectorTest(tf.test.TestCase):
"""Tests the correctness of the Y = g(X) = exp(X) transformation."""
def testBijector(self):
with self.test_session():
bijector = _Exp(_ShapeUtil(batch_ndims=1, event_ndims=1))
self.assertEqual(bijector.name, 'Exp')
x = [[[1.], [2]]]
self.assertAllClose(bijector.forward(x).eval(),
[[[math.exp(1.)], [math.exp(2.)]]])
self.assertAllClose(bijector.inverse(x).eval(),
[[[math.log(1.)], [math.log(2.)]]])
self.assertAllClose(bijector.inverse_log_det_jacobian(x).eval(),
[[0., -math.log(2.)]])
rev, jac = bijector.inverse_and_inverse_log_det_jacobian(x)
self.assertAllClose(rev.eval(), [[[math.log(1.)], [math.log(2.)]]])
self.assertAllClose(jac.eval(), [[0., -math.log(2.)]])
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,173 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from scipy import stats
import tensorflow as tf
class BinomialTest(tf.test.TestCase):
def testSimpleShapes(self):
with self.test_session():
p = np.float32(np.random.beta(1, 1))
binom = tf.contrib.distributions.Binomial(n=1., p=p)
self.assertAllEqual([], binom.event_shape().eval())
self.assertAllEqual([], binom.batch_shape().eval())
self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
self.assertEqual(tf.TensorShape([]), binom.get_batch_shape())
def testComplexShapes(self):
with self.test_session():
p = np.random.beta(1, 1, size=(3, 2)).astype(np.float32)
n = [[3., 2], [4, 5], [6, 7]]
binom = tf.contrib.distributions.Binomial(n=n, p=p)
self.assertAllEqual([], binom.event_shape().eval())
self.assertAllEqual([3, 2], binom.batch_shape().eval())
self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
self.assertEqual(tf.TensorShape([3, 2]), binom.get_batch_shape())
def testNProperty(self):
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
n = [[3.], [4]]
with self.test_session():
binom = tf.contrib.distributions.Binomial(n=n, p=p)
self.assertEqual((2, 1), binom.n.get_shape())
self.assertAllClose(n, binom.n.eval())
def testPProperty(self):
p = [[0.1, 0.2, 0.7]]
with self.test_session():
binom = tf.contrib.distributions.Binomial(n=3., p=p)
self.assertEqual((1, 3), binom.p.get_shape())
self.assertEqual((1, 3), binom.logits.get_shape())
self.assertAllClose(p, binom.p.eval())
def testLogitsProperty(self):
logits = [[0., 9., -0.5]]
with self.test_session():
binom = tf.contrib.distributions.Binomial(n=3., logits=logits)
self.assertEqual((1, 3), binom.p.get_shape())
self.assertEqual((1, 3), binom.logits.get_shape())
self.assertAllClose(logits, binom.logits.eval())
def testPmfNandCountsAgree(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
with self.test_session():
binom = tf.contrib.distributions.Binomial(n=n, p=p)
binom.pmf([2., 3, 2]).eval()
binom.pmf([3., 1, 2]).eval()
with self.assertRaisesOpError('Condition x >= 0.*'):
binom.pmf([-1., 4, 2]).eval()
with self.assertRaisesOpError('Condition x <= y.*'):
binom.pmf([7., 3, 0]).eval()
def testPmf_non_integer_counts(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
with self.test_session():
# No errors with integer n.
binom = tf.contrib.distributions.Binomial(n=n, p=p)
binom.pmf([2., 3, 2]).eval()
binom.pmf([3., 1, 2]).eval()
# Both equality and integer checking fail.
with self.assertRaisesOpError('Condition x == y.*'):
binom.pmf([1.0, 2.5, 1.5]).eval()
binom = tf.contrib.distributions.Binomial(n=n, p=p, validate_args=False)
binom.pmf([1., 2., 3.]).eval()
# Non-integer arguments work.
binom.pmf([1.0, 2.5, 1.5]).eval()
def testPmfBothZeroBatches(self):
with self.test_session():
# Both zero-batches. No broadcast
p = 0.5
counts = 1.
pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
self.assertAllClose(0.5, pmf.eval())
self.assertEqual((), pmf.get_shape())
def testPmfBothZeroBatchesNontrivialN(self):
with self.test_session():
# Both zero-batches. No broadcast
p = 0.1
counts = 3.
binom = tf.contrib.distributions.Binomial(n=5., p=p)
pmf = binom.pmf(counts)
self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval())
self.assertEqual((), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenSameRank(self):
with self.test_session():
p = [[0.1, 0.9]]
counts = [[1., 2.]]
pmf = tf.contrib.distributions.Binomial(n=3., p=p).pmf(counts)
self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval())
self.assertEqual((1, 2), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
p = [0.1, 0.4]
counts = [[1.], [0.]]
pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval())
self.assertEqual((2, 2), pmf.get_shape())
def testBinomialMean(self):
with self.test_session():
n = 5.
p = [0.1, 0.2, 0.7]
binom = tf.contrib.distributions.Binomial(n=n, p=p)
expected_means = stats.binom.mean(n, p)
self.assertEqual((3,), binom.mean().get_shape())
self.assertAllClose(expected_means, binom.mean().eval())
def testBinomialVariance(self):
with self.test_session():
n = 5.
p = [0.1, 0.2, 0.7]
binom = tf.contrib.distributions.Binomial(n=n, p=p)
expected_variances = stats.binom.var(n, p)
self.assertEqual((3,), binom.variance().get_shape())
self.assertAllClose(expected_variances, binom.variance().eval())
def testBinomialMode(self):
with self.test_session():
n = 5.
p = [0.1, 0.2, 0.7]
binom = tf.contrib.distributions.Binomial(n=n, p=p)
expected_modes = [0., 1, 4]
self.assertEqual((3,), binom.mode().get_shape())
self.assertAllClose(expected_modes, binom.mode().eval())
def testBinomialMultipleMode(self):
with self.test_session():
n = 9.
p = [0.1, 0.2, 0.7]
binom = tf.contrib.distributions.Binomial(n=n, p=p)
# For the case where (n + 1) * p is an integer, the modes are:
# (n + 1) * p and (n + 1) * p - 1. In this case, we get back
# the larger of the two modes.
expected_modes = [1., 2, 7]
self.assertEqual((3,), binom.mode().get_shape())
self.assertAllClose(expected_modes, binom.mode().eval())
if __name__ == '__main__':
tf.test.main()

View File

@ -61,14 +61,14 @@ class DirichletMultinomialTest(tf.test.TestCase):
n = [[5.]]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(n, alpha)
dist.pmf([2, 3, 0]).eval()
dist.pmf([3, 0, 2]).eval()
dist.pmf([2., 3, 0]).eval()
dist.pmf([3., 0, 2]).eval()
with self.assertRaisesOpError('Condition x >= 0.*'):
dist.pmf([-1, 4, 2]).eval()
with self.assertRaisesOpError('Condition x == y.*'):
dist.pmf([3, 3, 0]).eval()
dist.pmf([-1., 4, 2]).eval()
with self.assertRaisesOpError('counts do not sum to n'):
dist.pmf([3., 3, 0]).eval()
def testPmfArbitraryCounts(self):
def testPmf_non_integer_counts(self):
alpha = [[1., 2, 3]]
n = [[5.]]
with self.test_session():
@ -80,8 +80,10 @@ class DirichletMultinomialTest(tf.test.TestCase):
with self.assertRaisesOpError('Condition x == y.*'):
dist.pmf([1.0, 2.5, 1.5]).eval()
dist = tf.contrib.distributions.DirichletMultinomial(
n, alpha, allow_arbitrary_counts=True)
dist.pmf(np.array([1.0, 2.5, 1.5])).eval()
n, alpha, validate_args=False)
dist.pmf([1., 2., 3.]).eval()
# Non-integer arguments work.
dist.pmf([1.0, 2.5, 1.5]).eval()
def testPmfBothZeroBatches(self):
# The probabilities of one vote falling into class k is the mean for class
@ -90,7 +92,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
# Both zero-batches. No broadcast
alpha = [1., 2]
counts = [1., 0]
dist = tf.contrib.distributions.DirichletMultinomial(1, alpha)
dist = tf.contrib.distributions.DirichletMultinomial(1., alpha)
pmf = dist.pmf(counts)
self.assertAllClose(1 / 3., pmf.eval())
self.assertEqual((), pmf.get_shape())
@ -102,7 +104,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
# Both zero-batches. No broadcast
alpha = [1., 2]
counts = [3., 2]
dist = tf.contrib.distributions.DirichletMultinomial(5, alpha)
dist = tf.contrib.distributions.DirichletMultinomial(5., alpha)
pmf = dist.pmf(counts)
self.assertAllClose(1 / 7., pmf.eval())
self.assertEqual((), pmf.get_shape())
@ -113,7 +115,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
with self.test_session():
alpha = [1., 2]
counts = [3., 2]
n = np.full([4, 3], 5.)
n = np.full([4, 3], 5., dtype=np.float32)
dist = tf.contrib.distributions.DirichletMultinomial(n, alpha)
pmf = dist.pmf(counts)
self.assertAllClose([[1 / 7., 1 / 7., 1 / 7.]] * 4, pmf.eval())
@ -125,7 +127,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
with self.test_session():
alpha = [[1., 2]]
counts = [[1., 0], [0., 1]]
dist = tf.contrib.distributions.DirichletMultinomial([1], alpha)
dist = tf.contrib.distributions.DirichletMultinomial([1.], alpha)
pmf = dist.pmf(counts)
self.assertAllClose([1 / 3., 2 / 3.], pmf.eval())
self.assertEqual((2), pmf.get_shape())
@ -231,12 +233,12 @@ class DirichletMultinomialTest(tf.test.TestCase):
def testVariance_n_alpha_broadcast(self):
alpha_v = [1., 2, 3]
alpha_0 = np.sum(alpha_v)
alpha_0 = 6.
# Shape [4, 3]
alpha = np.array(4 * [alpha_v])
alpha = np.array(4 * [alpha_v], dtype=np.float32)
# Shape [4, 1]
ns = np.array([[2.], [3.], [4.], [5.]])
ns = np.array([[2.], [3.], [4.], [5.]], dtype=np.float32)
variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
covariance_entry = lambda a, b, a_sum: -a * b/ a_sum**2
@ -250,7 +252,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
covariance_entry(alpha_v[1], alpha_v[2], alpha_0)],
[covariance_entry(alpha_v[2], alpha_v[0], alpha_0),
covariance_entry(alpha_v[2], alpha_v[1], alpha_0),
variance_entry(alpha_v[2], alpha_0)]]])
variance_entry(alpha_v[2], alpha_0)]]], dtype=np.float32)
with self.test_session():
# ns is shape [4, 1], and alpha is shape [4, 3].
@ -263,11 +265,11 @@ class DirichletMultinomialTest(tf.test.TestCase):
self.assertAllClose(expected_variance, variance.eval())
def testVariance_multidimensional(self):
alpha = np.random.rand(3, 5, 4)
alpha2 = np.random.rand(6, 3, 3)
# Ensure n > 0.
ns = np.random.geometric(p=0.8, size=[3, 5, 1]) + 1
ns2 = np.random.geometric(p=0.8, size=[6, 1, 1]) + 1
alpha = np.random.rand(3, 5, 4).astype(np.float32)
alpha2 = np.random.rand(6, 3, 3).astype(np.float32)
ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32)
ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32)
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(ns, alpha)
@ -297,7 +299,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
# One (three sided) coin flip. Prob[coin 3] = 0.8.
# Note that since it was one flip, value of tau didn't matter.
counts = [0, 0, 1]
counts = [0., 0, 1]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(1., alpha)
pmf = dist.pmf(counts)
@ -305,9 +307,9 @@ class DirichletMultinomialTest(tf.test.TestCase):
self.assertEqual((), pmf.get_shape())
# Two (three sided) coin flips. Prob[coin 3] = 0.8.
counts = [0, 0, 2]
counts = [0., 0, 2]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(2, alpha)
dist = tf.contrib.distributions.DirichletMultinomial(2., alpha)
pmf = dist.pmf(counts)
self.assertAllClose(0.8**2, pmf.eval(), atol=1e-2)
self.assertEqual((), pmf.get_shape())
@ -315,7 +317,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
# Three (three sided) coin flips.
counts = [1., 0, 2]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(3, alpha)
dist = tf.contrib.distributions.DirichletMultinomial(3., alpha)
pmf = dist.pmf(counts)
self.assertAllClose(3 * 0.1 * 0.8 * 0.8, pmf.eval(), atol=1e-2)
self.assertEqual((), pmf.get_shape())
@ -336,10 +338,10 @@ class DirichletMultinomialTest(tf.test.TestCase):
self.assertEqual((), pmf.get_shape())
# If there are two draws, it is much more likely that they are the same.
counts_same = [2, 0]
counts_same = [2., 0]
counts_different = [1, 1.]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(2, alpha)
dist = tf.contrib.distributions.DirichletMultinomial(2., alpha)
pmf_same = dist.pmf(counts_same)
pmf_different = dist.pmf(counts_different)
self.assertLess(5 * pmf_different.eval(), pmf_same.eval())

View File

@ -0,0 +1,226 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class MultinomialTest(tf.test.TestCase):
def testSimpleShapes(self):
with self.test_session():
p = [.1, .3, .6]
dist = tf.contrib.distributions.Multinomial(n=1., p=p)
self.assertEqual(3, dist.event_shape().eval())
self.assertAllEqual([], dist.batch_shape().eval())
self.assertEqual(tf.TensorShape([3]), dist.get_event_shape())
self.assertEqual(tf.TensorShape([]), dist.get_batch_shape())
def testComplexShapes(self):
with self.test_session():
p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
n = [[3., 2], [4, 5], [6, 7]]
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
self.assertEqual(2, dist.event_shape().eval())
self.assertAllEqual([3, 2], dist.batch_shape().eval())
self.assertEqual(tf.TensorShape([2]), dist.get_event_shape())
self.assertEqual(tf.TensorShape([3, 2]), dist.get_batch_shape())
def testNProperty(self):
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
n = [[3.], [4]]
with self.test_session():
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
self.assertEqual((2, 1), dist.n.get_shape())
self.assertAllClose(n, dist.n.eval())
def testPProperty(self):
p = [[0.1, 0.2, 0.7]]
with self.test_session():
dist = tf.contrib.distributions.Multinomial(n=3., p=p)
self.assertEqual((1, 3), dist.p.get_shape())
self.assertEqual((1, 3), dist.logits.get_shape())
self.assertAllClose(p, dist.p.eval())
def testLogitsProperty(self):
logits = [[0., 9., -0.5]]
with self.test_session():
multinom = tf.contrib.distributions.Multinomial(n=3., logits=logits)
self.assertEqual((1, 3), multinom.p.get_shape())
self.assertEqual((1, 3), multinom.logits.get_shape())
self.assertAllClose(logits, multinom.logits.eval())
def testPmfNandCountsAgree(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
with self.test_session():
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
dist.pmf([2., 3, 0]).eval()
dist.pmf([3., 0, 2]).eval()
with self.assertRaisesOpError('Condition x >= 0.*'):
dist.pmf([-1., 4, 2]).eval()
with self.assertRaisesOpError('counts do not sum to n'):
dist.pmf([3., 3, 0]).eval()
def testPmf_non_integer_counts(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
with self.test_session():
# No errors with integer n.
multinom = tf.contrib.distributions.Multinomial(n=n, p=p)
multinom.pmf([2., 1, 2]).eval()
multinom.pmf([3., 0, 2]).eval()
# Counts don't sum to n.
with self.assertRaisesOpError('counts do not sum to n'):
multinom.pmf([2., 3, 2]).eval()
# Counts are non-integers.
with self.assertRaisesOpError('Condition x == y.*'):
multinom.pmf([1.0, 2.5, 1.5]).eval()
multinom = tf.contrib.distributions.Multinomial(
n=n, p=p, validate_args=False)
multinom.pmf([1., 2., 2.]).eval()
# Non-integer arguments work.
multinom.pmf([1.0, 2.5, 1.5]).eval()
def testPmfBothZeroBatches(self):
with self.test_session():
# Both zero-batches. No broadcast
p = [0.5, 0.5]
counts = [1., 0]
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
self.assertAllClose(0.5, pmf.eval())
self.assertEqual((), pmf.get_shape())
def testPmfBothZeroBatchesNontrivialN(self):
with self.test_session():
# Both zero-batches. No broadcast
p = [0.1, 0.9]
counts = [3., 2]
dist = tf.contrib.distributions.Multinomial(n=5., p=p)
pmf = dist.pmf(counts)
# 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000.
self.assertAllClose(81./10000, pmf.eval())
self.assertEqual((), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenSameRank(self):
with self.test_session():
p = [[0.1, 0.9]]
counts = [[1., 0], [0, 1]]
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
self.assertAllClose([0.1, 0.9], pmf.eval())
self.assertEqual((2), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
p = [0.1, 0.9]
counts = [[1., 0], [0, 1]]
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
self.assertAllClose([0.1, 0.9], pmf.eval())
self.assertEqual((2), pmf.get_shape())
def testPmfCountsStretchedInBroadcastWhenSameRank(self):
with self.test_session():
p = [[0.1, 0.9], [0.7, 0.3]]
counts = [[1., 0]]
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
self.assertAllClose(pmf.eval(), [0.1, 0.7])
self.assertEqual((2), pmf.get_shape())
def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
p = [[0.1, 0.9], [0.7, 0.3]]
counts = [1., 0]
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
self.assertAllClose(pmf.eval(), [0.1, 0.7])
self.assertEqual(pmf.get_shape(), (2))
def testPmfShapeCountsStretched_N(self):
with self.test_session():
# [2, 2, 2]
p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]]
# [2, 2]
n = [[3., 3], [3, 3]]
# [2]
counts = [2., 1]
pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
pmf.eval()
self.assertEqual(pmf.get_shape(), (2, 2))
def testPmfShapeCountsPStretched_N(self):
with self.test_session():
p = [0.1, 0.9]
counts = [3., 2]
n = np.full([4, 3], 5., dtype=np.float32)
pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
pmf.eval()
self.assertEqual((4, 3), pmf.get_shape())
def testMultinomialMean(self):
with self.test_session():
n = 5.
p = [0.1, 0.2, 0.7]
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
expected_means = 5 * np.array(p, dtype=np.float32)
self.assertEqual((3,), dist.mean().get_shape())
self.assertAllClose(expected_means, dist.mean().eval())
def testMultinomialVariance(self):
with self.test_session():
n = 5.
p = [0.1, 0.2, 0.7]
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
expected_variances = [
[9./20, -1/10, -7/20], [-1/10, 4/5, -7/10], [-7/20, -7/10, 21/20]]
self.assertEqual((3, 3), dist.variance().get_shape())
self.assertAllClose(expected_variances, dist.variance().eval())
def testMultinomialVariance_batch(self):
with self.test_session():
# Shape [2]
n = [5.] * 2
# Shape [4, 1, 2]
p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
# Shape [2, 2]
inner_var = [[9./20, -9/20], [-9/20, 9/20]]
# Shape [4, 2, 2, 2]
expected_variances = [[inner_var, inner_var]] * 4
self.assertEqual((4, 2, 2, 2), dist.variance().get_shape())
self.assertAllClose(expected_variances, dist.variance().eval())
def testVariance_multidimensional(self):
# Shape [3, 5, 4]
p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
# Shape [6, 3, 3]
p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32)
ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32)
ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
with self.test_session():
dist = tf.contrib.distributions.Multinomial(ns, p)
dist2 = tf.contrib.distributions.Multinomial(ns2, p2)
variance = dist.variance()
variance2 = dist2.variance()
self.assertEqual((3, 5, 4, 4), variance.get_shape())
self.assertEqual((6, 3, 3, 3), variance2.get_shape())
if __name__ == '__main__':
tf.test.main()

View File

@ -117,6 +117,61 @@ class MultivariateNormalDiagTest(tf.test.TestCase):
self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1)
class MultivariateNormalDiagPlusVDVTTest(tf.test.TestCase):
"""Well tested because this is a simple override of the base class."""
def setUp(self):
self._rng = np.random.RandomState(42)
def testMean(self):
mu = [-1.0, 1.0]
diag_large = [1.0, 5.0]
v = [[2.0], [3.0]]
diag_small = [3.0]
with self.test_session():
dist = distributions.MultivariateNormalDiagPlusVDVT(
mu, diag_large, v, diag_small=diag_small)
self.assertAllEqual(mu, dist.mean().eval())
def testNonmatchingMuAndSigmaDimensionFailsStatic(self):
mu = self._rng.rand(2)
# With this diag_large and v, the covariance is 3 x 3
diag_large = self._rng.rand(3)
v = self._rng.rand(3, 2) # v works with diag_large.
with self.test_session():
with self.assertRaisesRegexp(ValueError, "shape.*should match"):
distributions.MultivariateNormalDiagPlusVDVT(
mu, diag_large, v)
def testNonmatchingMuDiagDimensionsFailsDynamic(self):
mu = self._rng.rand(2)
# With this diag_large and v, the covariance is 3 x 3
diag_large = self._rng.rand(3)
v = self._rng.rand(3, 2) # v works with diag_large.
with self.test_session():
mu_ph = tf.placeholder(tf.float32, name="mu_ph")
v_ph = tf.placeholder(tf.float32, name="v_ph")
diag_ph = tf.placeholder(tf.float32, name="diag_ph")
dist = distributions.MultivariateNormalDiagPlusVDVT(
mu_ph, diag_ph, v_ph)
with self.assertRaisesOpError("mu.*cov.*shape"):
dist.mean().eval(feed_dict={mu_ph: mu, diag_ph: diag_large, v_ph: v})
def testSample(self):
mu = [-1.0, 1.0]
diag_large = [1.0, 0.5]
v = [[0.2], [0.3]]
with self.test_session():
dist = distributions.MultivariateNormalDiagPlusVDVT(mu, diag_large, v)
samps = dist.sample_n(1000, seed=0).eval()
cov_mat = dist.sigma.eval()
self.assertAllClose(mu, samps.mean(axis=0), atol=0.1)
self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1)
class MultivariateNormalCholeskyTest(tf.test.TestCase):
def setUp(self):
@ -314,5 +369,87 @@ class MultivariateNormalCholeskyTest(tf.test.TestCase):
self.assertEqual((3, 5), tuple(mvn.batch_shape().eval()))
class MultivariateNormalFullTest(tf.test.TestCase):
def setUp(self):
self._rng = np.random.RandomState(42)
def _random_mu_and_sigma(self, batch_shape, event_shape):
# This ensures sigma is positive def.
mat_shape = batch_shape + event_shape + event_shape
mat = self._rng.randn(*mat_shape)
sigma = tf.batch_matmul(mat, mat, adj_y=True).eval()
mu_shape = batch_shape + event_shape
mu = self._rng.randn(*mu_shape)
return mu, sigma
def testKLNonBatch(self):
batch_shape = ()
event_shape = (2,)
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
mvn_b = distributions.MultivariateNormalFull(mu_b, sigma_b)
kl = distributions.kl(mvn_a, mvn_b)
self.assertEqual(batch_shape, kl.get_shape())
kl_v = kl.eval()
expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b)
self.assertAllClose(expected_kl, kl_v)
def testKLBatch(self):
batch_shape = (2,)
event_shape = (3,)
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
mvn_b = distributions.MultivariateNormalFull(mu_b, sigma_b)
kl = distributions.kl(mvn_a, mvn_b)
self.assertEqual(batch_shape, kl.get_shape())
kl_v = kl.eval()
expected_kl_0 = _compute_non_batch_kl(
mu_a[0, :], sigma_a[0, :, :], mu_b[0, :], sigma_b[0, :])
expected_kl_1 = _compute_non_batch_kl(
mu_a[1, :], sigma_a[1, :, :], mu_b[1, :], sigma_b[1, :])
self.assertAllClose(expected_kl_0, kl_v[0])
self.assertAllClose(expected_kl_1, kl_v[1])
def testKLTwoIdenticalDistributionsIsZero(self):
batch_shape = (2,)
event_shape = (3,)
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
# Should be zero since KL(p || p) = =.
kl = distributions.kl(mvn_a, mvn_a)
self.assertEqual(batch_shape, kl.get_shape())
kl_v = kl.eval()
self.assertAllClose(np.zeros(*batch_shape), kl_v)
def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b):
"""Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b)."""
# Check using numpy operations
# This mostly repeats the tensorflow code _kl_mvn_mvn(), but in numpy.
# So it is important to also check that KL(mvn, mvn) = 0.
sigma_b_inv = np.linalg.inv(sigma_b)
t = np.trace(sigma_b_inv.dot(sigma_a))
q = (mu_b - mu_a).dot(sigma_b_inv).dot(mu_b - mu_a)
k = mu_a.shape[0]
l = np.log(np.linalg.det(sigma_b) / np.linalg.det(sigma_a))
return 0.5 * (t + q - k + l)
if __name__ == "__main__":
tf.test.main()

View File

@ -17,14 +17,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import numpy as np
import six
import tensorflow as tf
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
from tensorflow.contrib.distributions.python.ops import operator_test_util
class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
@six.add_metaclass(abc.ABCMeta)
class OperatorPDDiagBaseTest(object):
def setUp(self):
self._rng = np.random.RandomState(42)
@ -32,8 +35,14 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
def _random_pd_diag(self, diag_shape):
return self._rng.rand(*diag_shape) + 0.1
@abc.abstractmethod
def _diag_to_matrix(self, diag):
return tf.batch_matrix_diag(diag**2).eval()
pass
@abc.abstractproperty
def operator_class(self):
# Return the operator class that this tests.
pass
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
# Create a diagonal matrix explicitly.
@ -46,7 +55,7 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
# The diag is the square root.
diag = self._random_pd_diag(diag_shape).astype(dtype)
mat = self._diag_to_matrix(diag).astype(dtype)
operator = operator_pd_diag.OperatorPDSqrtDiag(diag)
operator = self.operator_class(diag)
return operator, mat
@ -66,5 +75,29 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
operator.to_dense().eval() # Should not raise
class OperatorPDDiagTest(
OperatorPDDiagBaseTest, operator_test_util.OperatorPDDerivedClassTest):
"""Most tests done in the base classes."""
def _diag_to_matrix(self, diag):
return tf.batch_matrix_diag(diag).eval()
@property
def operator_class(self):
return operator_pd_diag.OperatorPDDiag
class OperatorPDSqrtDiagTest(
OperatorPDDiagBaseTest, operator_test_util.OperatorPDDerivedClassTest):
"""Most tests done in the base classes."""
def _diag_to_matrix(self, diag):
return tf.batch_matrix_diag(diag**2).eval()
@property
def operator_class(self):
return operator_pd_diag.OperatorPDSqrtDiag
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,115 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.distributions.python.ops import operator_pd_identity
from tensorflow.contrib.distributions.python.ops import operator_test_util
distributions = tf.contrib.distributions
class OperatorPDIdentityTest(operator_test_util.OperatorPDDerivedClassTest):
"""Most tests done in the base class."""
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
# Build an identity matrix with right shape and dtype.
# Build an operator that should act the same way.
batch_shape = list(batch_shape)
diag_shape = batch_shape + [k]
matrix_shape = batch_shape + [k, k]
diag = tf.ones(diag_shape, dtype=dtype)
identity_matrix = tf.batch_matrix_diag(diag)
operator = operator_pd_identity.OperatorPDIdentity(matrix_shape, dtype)
return operator, identity_matrix.eval()
def test_bad_dtype_args_raise(self):
dtype = np.float32
batch_shape = [2, 3]
k = 4
with self.test_session():
operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)
x_good_shape = batch_shape + [k, 5]
x_good = self._rng.randn(*x_good_shape).astype(dtype)
x_bad = x_good.astype(np.float64)
operator.matmul(x_good).eval() # Should not raise.
with self.assertRaisesRegexp(TypeError, 'dtype'):
operator.matmul(x_bad)
with self.assertRaisesRegexp(TypeError, 'dtype'):
operator.solve(x_bad)
with self.assertRaisesRegexp(TypeError, 'dtype'):
operator.sqrt_solve(x_bad)
def test_bad_rank_args_raise(self):
# Prepend a singleton dimension, changing the rank of 'x', but not the size.
dtype = np.float32
batch_shape = [2, 3]
k = 4
with self.test_session():
operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)
x_good_shape = batch_shape + [k, 5]
x_good = self._rng.randn(*x_good_shape).astype(dtype)
x_bad = x_good.reshape(1, 2, 3, 4, 5)
operator.matmul(x_good).eval() # Should not raise.
with self.assertRaisesRegexp(ValueError, 'tensor rank'):
operator.matmul(x_bad)
with self.assertRaisesRegexp(ValueError, 'tensor rank'):
operator.solve(x_bad)
with self.assertRaisesRegexp(ValueError, 'tensor rank'):
operator.sqrt_solve(x_bad)
def test_incompatible_shape_args_raise(self):
# Test shapes that are the same rank but incompatible for matrix
# multiplication.
dtype = np.float32
batch_shape = [2, 3]
k = 4
with self.test_session():
operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)
x_good_shape = batch_shape + [k, 5]
x_good = self._rng.randn(*x_good_shape).astype(dtype)
x_bad_shape = batch_shape + [5, k]
x_bad = x_good.reshape(*x_bad_shape)
operator.matmul(x_good).eval() # Should not raise.
with self.assertRaisesRegexp(ValueError, 'Incompatible'):
operator.matmul(x_bad)
with self.assertRaisesRegexp(ValueError, 'Incompatible'):
operator.solve(x_bad)
with self.assertRaisesRegexp(ValueError, 'Incompatible'):
operator.sqrt_solve(x_bad)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,273 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.distributions.python.ops import operator_pd_full
from tensorflow.contrib.distributions.python.ops import operator_pd_vdvt_update
from tensorflow.contrib.distributions.python.ops import operator_test_util
distributions = tf.contrib.distributions
class OperatorPDSqrtVDVTUpdateTest(
operator_test_util.OperatorPDDerivedClassTest):
"""Most tests done in the base class."""
_diag_is_none = False
def setUp(self):
self._rng = np.random.RandomState(42)
def _random_pd_matrix(self, shape):
# With probability 1 this is positive definite.
sqrt = self._rng.randn(*shape)
mat = tf.batch_matmul(sqrt, sqrt, adj_y=True)
return mat.eval()
def _random_v_and_diag(self, mat_shape, v_matrix_rank):
# Get the necessary elements to make the sqrt update.
mat_shape = list(mat_shape)
batch_shape = mat_shape[:-2]
diag_shape = mat_shape[:-2] + [v_matrix_rank]
k = mat_shape[-1]
assert k == mat_shape[-2], 'Must be a square matrix'
v_shape = batch_shape + [k, v_matrix_rank]
v = self._rng.randn(*v_shape) # anything goes with "v"!
if self._diag_is_none:
diag = None
else:
diag = self._rng.rand(*diag_shape) + 0.1 # Positive diag!
return v, diag
def _updated_mat(self, mat, v, diag):
# Get dense matrix defined by its square root, which is an update of `mat`:
# A = (mat + v D v^T) (mat + v D v^T)^T
# D is the diagonal matrix with `diag` on the diagonal.
# If diag is None, then it defaults to the identity matrix, so DV^T = V^T
if diag is None:
diag_vt = tf.batch_matrix_transpose(v)
else:
diag_mat = tf.batch_matrix_diag(diag)
diag_vt = tf.batch_matmul(diag_mat, v, adj_y=True)
v_diag_vt = tf.batch_matmul(v, diag_vt)
sqrt = mat + v_diag_vt
a = tf.batch_matmul(sqrt, sqrt, adj_y=True)
return a.eval()
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
"""This method is called by base class, enabling many standard tests."""
# Create a matrix then explicitly update it with v and diag.
# Create an OperatorPDSqrtVDVTUpdate from the matrix and v and diag
# The operator should have the same behavior.
#
# The low-rank matrix V will have rank 1/2 of k, unless k is 1, in which
# case it will be 1 as well.
if k == 1:
v_matrix_rank = k
else:
v_matrix_rank = k // 2
mat_shape = list(batch_shape) + [k, k]
mat = self._random_pd_matrix(mat_shape)
v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank)
# Set dtypes
mat = mat.astype(dtype)
v = v.astype(dtype)
if diag is not None:
diag = diag.astype(dtype)
# The matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T
# Our final updated operator should behave like this.
updated_mat = self._updated_mat(mat, v, diag)
# Represents the matrix: `mat`, before updating.
# This is the Operator that we will update.
o_made_with_mat = operator_pd_full.OperatorPDFull(mat)
# Represents the matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T,
# achieved by updating the operator "o_made_with_mat".
# This is the operator we're testing.
operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
o_made_with_mat, v, diag)
return operator, updated_mat
def test_to_dense_placeholder(self):
# Test simple functionality when the inputs are placeholders.
mat_shape = [3, 3]
v_matrix_rank = 2
with self.test_session():
# Make an OperatorPDFull with a matrix placeholder.
mat_ph = tf.placeholder(tf.float64, name='mat_ph')
mat = self._random_pd_matrix(mat_shape)
o_made_with_mat = operator_pd_full.OperatorPDFull(mat_ph)
# Make the placeholders and arrays for the updated operator.
v_ph = tf.placeholder(tf.float64, name='v_ph')
v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank)
if self._diag_is_none:
diag_ph = None
feed_dict = {v_ph: v, mat_ph: mat}
else:
diag_ph = tf.placeholder(tf.float64, name='diag_ph')
feed_dict = {v_ph: v, diag_ph: diag, mat_ph: mat}
# Make the OperatorPDSqrtVDVTUpdate with v and diag placeholders.
operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
o_made_with_mat, v_ph, diag=diag_ph)
# Should not fail
operator.to_dense().eval(feed_dict=feed_dict)
operator.log_det().eval(feed_dict=feed_dict)
def test_operator_not_subclass_of_operator_pd_raises(self):
# We enforce that `operator` is an `OperatorPDBase`.
with self.test_session():
v, diag = self._random_v_and_diag((3, 3), 2)
operator_m = 'I am not a subclass of OperatorPDBase'
with self.assertRaisesRegexp(TypeError, 'not instance'):
operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
def test_non_pos_def_diag_raises(self):
if self._diag_is_none:
return
# We enforce that the diag is positive definite.
with self.test_session():
matrix_shape = (3, 3)
v_rank = 2
v, diag = self._random_v_and_diag(matrix_shape, v_rank)
mat = self._random_pd_matrix(matrix_shape)
diag[0] = 0.0
operator_m = operator_pd_full.OperatorPDFull(mat)
operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
operator_m, v, diag)
with self.assertRaisesOpError('positive'):
operator.to_dense().eval()
def test_non_pos_def_diag_doesnt_raise_if_verify_pd_false(self):
# We enforce that the diag is positive definite.
if self._diag_is_none:
return
with self.test_session():
matrix_shape = (3, 3)
v_rank = 2
v, diag = self._random_v_and_diag(matrix_shape, v_rank)
mat = self._random_pd_matrix(matrix_shape)
diag[0] = 0.0
operator_m = operator_pd_full.OperatorPDFull(mat)
operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
operator_m, v, diag, verify_pd=False)
operator.to_dense().eval() # Should not raise.
def test_event_shape_mismatch_v_and_diag_raises_static(self):
v = self._rng.rand(4, 3, 2)
diag = self._rng.rand(4, 1) # Should be shape (4, 2,) to match v.
with self.test_session():
mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
operator_m = operator_pd_full.OperatorPDFull(mat)
with self.assertRaisesRegexp(ValueError, 'diag.*v.*last dimension'):
operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
def test_batch_shape_mismatch_v_and_diag_raises_static(self):
v = self._rng.rand(4, 3, 2)
diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v.
with self.test_session():
mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
operator_m = operator_pd_full.OperatorPDFull(mat)
with self.assertRaisesRegexp(ValueError, 'diag.*batch shape'):
operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
def test_tensor_rank_shape_mismatch_v_and_diag_raises_static(self):
v = self._rng.rand(1, 2, 2, 2)
diag = self._rng.rand(5, 1) # Should have rank 1 less than v.
with self.test_session():
mat = self._random_pd_matrix((1, 2, 2, 2)) # mat and v match
operator_m = operator_pd_full.OperatorPDFull(mat)
with self.assertRaisesRegexp(ValueError, 'diag.*rank'):
operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
def test_event_shape_mismatch_v_and_diag_raises_dynamic(self):
with self.test_session():
v = self._rng.rand(4, 3, 2)
diag = self._rng.rand(4, 1) # Should be shape (4, 2,) to match v.
mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
v_ph = tf.placeholder(tf.float32, name='v_ph')
diag_ph = tf.placeholder(tf.float32, name='diag_ph')
mat_ph = tf.placeholder(tf.float32, name='mat_ph')
operator_m = operator_pd_full.OperatorPDFull(mat_ph)
updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
operator_m, v_ph, diag_ph)
with self.assertRaisesOpError('x == y'):
updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat})
def test_batch_shape_mismatch_v_and_diag_raises_dynamic(self):
with self.test_session():
v = self._rng.rand(4, 3, 2)
diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v.
mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
v_ph = tf.placeholder(tf.float32, name='v_ph')
diag_ph = tf.placeholder(tf.float32, name='diag_ph')
mat_ph = tf.placeholder(tf.float32, name='mat_ph')
operator_m = operator_pd_full.OperatorPDFull(mat_ph)
updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
operator_m, v_ph, diag_ph)
with self.assertRaisesOpError('x == y'):
updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat})
def test_tensor_rank_shape_mismatch_v_and_diag_raises_dynamic(self):
with self.test_session():
v = self._rng.rand(2, 2, 2, 2)
diag = self._rng.rand(2, 2) # Should have rank 1 less than v.
mat = self._random_pd_matrix((2, 2, 2, 2)) # mat and v match
v_ph = tf.placeholder(tf.float32, name='v_ph')
diag_ph = tf.placeholder(tf.float32, name='diag_ph')
mat_ph = tf.placeholder(tf.float32, name='mat_ph')
operator_m = operator_pd_full.OperatorPDFull(mat_ph)
updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
operator_m, v_ph, diag_ph)
with self.assertRaisesOpError('rank'):
updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat})
class OperatorPDSqrtVDVTUpdateNoneDiagTest(OperatorPDSqrtVDVTUpdateTest):
_diag_is_none = True
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,165 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ShapeUtil."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.distributions.python.ops.shape import _ShapeUtil # pylint: disable=line-too-long
class ShapeUtilTest(tf.test.TestCase):
def testShapeUtilGetNdims(self):
with self.test_session():
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
x = 1
self.assertEqual(shaper.get_sample_ndims(x), 0)
self.assertEqual(shaper.batch_ndims, 0)
self.assertEqual(shaper.event_ndims, 0)
shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
x = [[[0., 1, 2], [3, 4, 5]]]
self.assertAllEqual(shaper.get_ndims(x), 3)
self.assertEqual(shaper.get_sample_ndims(x), 1)
self.assertEqual(shaper.batch_ndims, 1)
self.assertEqual(shaper.event_ndims, 1)
x += [[[6, 7, 8], [9, 10, 11]]]
self.assertAllEqual(shaper.get_ndims(x), 3)
self.assertEqual(shaper.get_sample_ndims(x), 1)
self.assertEqual(shaper.batch_ndims, 1)
self.assertEqual(shaper.event_ndims, 1)
# Test ndims functions work, even despite unfed Tensors.
y = tf.placeholder(tf.float32, shape=(1024, None, 1024))
self.assertAllEqual(shaper.get_ndims(y), 3)
self.assertEqual(shaper.get_sample_ndims(y), 1)
self.assertEqual(shaper.batch_ndims, 1)
self.assertEqual(shaper.event_ndims, 1)
with self.assertRaises(ValueError):
y = tf.placeholder(tf.float32)
shaper.get_ndims(y)
def testShapeUtilGetDims(self):
with self.test_session():
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
with self.assertRaises(ValueError):
y = tf.placeholder(tf.float32)
shaper.get_sample_dims(y)
with self.assertRaises(ValueError):
y = tf.placeholder(tf.float32)
shaper.get_batch_dims(y)
with self.assertRaises(ValueError):
y = tf.placeholder(tf.float32)
shaper.get_event_dims(y)
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
x = 1
self.assertAllEqual(shaper.get_sample_dims(x), [])
self.assertAllEqual(shaper.get_batch_dims(x), [])
self.assertAllEqual(shaper.get_event_dims(x), [])
self.assertAllEqual(shaper.get_dims(x, sample=False), [])
shaper = _ShapeUtil(batch_ndims=1, event_ndims=2)
x = [[[[0., 1], [2, 4]]]]
self.assertAllEqual(shaper.get_sample_dims(x), [0])
self.assertAllEqual(shaper.get_batch_dims(x), [1])
self.assertAllEqual(shaper.get_event_dims(x), [2, 3])
self.assertAllEqual(shaper.get_dims(x, sample=False), [1, 2, 3])
x += x
self.assertAllEqual(shaper.get_sample_dims(x), [0])
self.assertAllEqual(shaper.get_batch_dims(x), [1])
self.assertAllEqual(shaper.get_event_dims(x), [2, 3])
self.assertAllEqual(shaper.get_dims(x, sample=False), [1, 2, 3])
# Test dims functions work, despite unfed Tensors.
y = tf.placeholder(tf.float32, shape=(1024, None, 5, 5))
self.assertAllEqual(shaper.get_sample_dims(y), [0])
self.assertAllEqual(shaper.get_batch_dims(y), [1])
self.assertAllEqual(shaper.get_event_dims(y), [2, 3])
def testShapeUtilGetShape(self):
with self.test_session() as sess:
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
with self.assertRaises(ValueError):
y = tf.placeholder(tf.float32)
shaper.get_sample_shape(y)
with self.assertRaises(ValueError):
y = tf.placeholder(tf.float32)
shaper.get_batch_shape(y)
with self.assertRaises(ValueError):
y = tf.placeholder(tf.float32)
shaper.get_event_shape(y)
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
x = 1
self.assertAllEqual(shaper.get_sample_shape(x), [])
self.assertAllEqual(shaper.get_batch_shape(x), [])
self.assertAllEqual(shaper.get_event_shape(x), [])
self.assertAllEqual(shaper.get_shape(x, batch=False), [])
shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
x = [[[0., 1, 2], [3, 4, 5]]]
self.assertAllEqual(shaper.get_sample_shape(x), [1])
self.assertAllEqual(shaper.get_batch_shape(x), [2])
self.assertAllEqual(shaper.get_event_shape(x), [3])
self.assertAllEqual(shaper.get_shape(x, batch=False), [1, 3])
x += [[[6, 7, 8], [9, 10, 11]]]
self.assertAllEqual(shaper.get_sample_shape(x), [2])
self.assertAllEqual(shaper.get_batch_shape(x), [2])
self.assertAllEqual(shaper.get_event_shape(x), [3])
self.assertAllEqual(shaper.get_shape(x, batch=False), [2, 3])
shaper = _ShapeUtil(batch_ndims=0, event_ndims=1)
x = tf.ones((3, 2))
self.assertAllEqual(shaper.get_shape(x, sample=False), (2,))
def feed_eval(fun, build_shape=(None, None, 2), graph_shape=(3, 4, 2)):
"""Helper to use a deferred-shape tensor eval'ed at graph runtime."""
y = tf.placeholder(tf.int32, shape=build_shape)
y_value = np.ones(graph_shape, dtype=y.dtype.as_numpy_dtype())
return sess.run(fun(y),
feed_dict={y: y_value})
shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
self.assertAllEqual(feed_eval(shaper.get_sample_shape), [3])
self.assertAllEqual(feed_eval(shaper.get_batch_shape), [4])
self.assertAllEqual(feed_eval(shaper.get_event_shape), [2])
self.assertAllEqual(
feed_eval(lambda y: shaper.get_shape(y, batch=False)),
[3, 2])
shaper = _ShapeUtil(batch_ndims=0, event_ndims=1)
self.assertAllEqual(
feed_eval(lambda y: shaper.get_shape(y, batch=False),
(None, None),
(3, 2)),
[3, 2])
self.assertAllEqual(
feed_eval(lambda y: shaper.get_shape(y, sample=False),
(None, None),
(3, 2)),
[2])
if __name__ == "__main__":
tf.test.main()

View File

@ -19,13 +19,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
@ -36,10 +36,6 @@ class Bernoulli(distribution.Distribution):
The Bernoulli distribution is parameterized by p, the probability of a
positive event.
Note, the following methods of the base class aren't implemented:
* cdf
* log_cdf
"""
def __init__(self,
@ -62,10 +58,10 @@ class Bernoulli(distribution.Distribution):
dtype: dtype for samples.
validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args,
`log_pmf` may return nans.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: A name for this distribution.
Raises:
@ -75,25 +71,8 @@ class Bernoulli(distribution.Distribution):
self._name = name
self._dtype = dtype
self._validate_args = validate_args
check_op = check_ops.assert_less_equal
if p is None and logits is None:
raise ValueError("Must pass p or logits.")
elif p is not None and logits is not None:
raise ValueError("Must pass either p or logits, not both.")
elif p is None:
with ops.op_scope([logits], name):
self._logits = array_ops.identity(logits, name="logits")
with ops.name_scope(name):
with ops.name_scope("p"):
self._p = math_ops.sigmoid(self._logits)
elif logits is None:
with ops.name_scope(name):
with ops.name_scope("p"):
with ops.control_dependencies([check_op(p, 1.), check_op(0., p)] if
validate_args else []):
self._p = array_ops.identity(p)
with ops.name_scope("logits"):
self._logits = math_ops.log(self._p) - math_ops.log(1. - self._p)
self._logits, self._p = distribution_util.get_logits_and_prob(
name=name, logits=logits, p=p, validate_args=validate_args)
with ops.name_scope(name):
with ops.name_scope("q"):
self._q = 1. - self._p
@ -180,8 +159,12 @@ class Bernoulli(distribution.Distribution):
event = ops.convert_to_tensor(event, name="event")
event = math_ops.cast(event, self.logits.dtype)
logits = self.logits
if ((event.get_shape().ndims is not None) or
(logits.get_shape().ndims is not None) or
# sigmoid_cross_entropy_with_logits doesn't broadcast shape,
# so we do this here.
# TODO(b/30637701): Check dynamic shape, and don't broadcast if the
# dynamic shapes are the same.
if (not event.get_shape().is_fully_defined() or
not logits.get_shape().is_fully_defined() or
event.get_shape() != logits.get_shape()):
logits = array_ops.ones_like(event) * logits
event = array_ops.ones_like(logits) * event
@ -202,8 +185,7 @@ class Bernoulli(distribution.Distribution):
with ops.name_scope(self.name):
with ops.op_scope([self.p, n], name):
n = ops.convert_to_tensor(n, name="n")
new_shape = array_ops.concat(
0, [array_ops.expand_dims(n, 0), self.batch_shape()])
new_shape = array_ops.concat(0, ([n], self.batch_shape()))
uniform = random_ops.random_uniform(
new_shape, seed=seed, dtype=dtypes.float32)
sample = math_ops.less(uniform, self.p)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""The Beta distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -95,6 +96,7 @@ class Beta(distribution.Distribution):
x = [.2, .3, .9]
dist.pdf(x) # Shape [2]
```
"""
def __init__(self, a, b, validate_args=True, allow_nan_stats=False,
@ -102,20 +104,20 @@ class Beta(distribution.Distribution):
"""Initialize a batch of Beta distributions.
Args:
a: Positive `float` or `double` tensor with shape broadcastable to
a: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
different Beta distributions. This also defines the
dtype of the distribution.
b: Positive `float` or `double` tensor with shape broadcastable to
b: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
different Beta distributions.
validate_args: Whether to assert valid values for parameters `a` and `b`,
and `x` in `prob` and `log_prob`. If False, correct behavior is not
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class.
Examples:
@ -127,6 +129,7 @@ class Beta(distribution.Distribution):
# Define a 2-batch.
dist = Beta([1.0, 2.0], [4.0, 5.0])
```
"""
with ops.op_scope([a, b], name):
with ops.control_dependencies([
@ -276,8 +279,14 @@ class Beta(distribution.Distribution):
array_ops.ones_like(a_b_sum, dtype=self.dtype)))
else:
return control_flow_ops.with_dependencies([
check_ops.assert_less(one, a),
check_ops.assert_less(one, b)], mode)
check_ops.assert_less(
one, a,
message="mode not defined for components of a <= 1"
),
check_ops.assert_less(
one, b,
message="mode not defined for components of b <= 1"
)], mode)
def entropy(self, name="entropy"):
"""Entropy of the distribution in nats."""
@ -306,7 +315,7 @@ class Beta(distribution.Distribution):
"""`Log(P[counts])`, computed for every batch member.
Args:
x: Non-negative `float` or `double`, tensor whose shape can
x: Non-negative floating point tensor whose shape can
be broadcast with `self.a` and `self.b`. For fixed leading
dimensions, the last dimension represents counts for the corresponding
Beta distribution in `self.a` and `self.b`. `x` is only legal if
@ -334,7 +343,7 @@ class Beta(distribution.Distribution):
"""`P[x]`, computed for every batch member.
Args:
x: Non-negative `float`, `double` tensor whose shape can
x: Non-negative floating point tensor whose shape can
be broadcast with `self.a` and `self.b`. For fixed leading
dimensions, the last dimension represents x for the corresponding Beta
distribution in `self.a` and `self.b`. `x` is only legal if is

View File

@ -0,0 +1,350 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An API for reversible (bijective) transformations of random variables."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
class _Bijector(object):
"""An interface for transforming random variable(s).
A bijector is characterized by three operations:
1) Forward Evaluation
Useful for turning one random outcome into another random outcome from a
different distribution.
2) Inverse Evaluation
Useful for "reversing" a transformation to compute one probability in terms
of another.
3) (log o det o Jacobian o inverse)(x)
"The log of the determinant of the matrix of all first-order partial
derivatives of the inverse function."
Useful for inverting a transformation to compute one probability in terms
of another. Geometrically, the det(Jacobian) is the volume of the
transformation and is used to scale the probability.
By convention, transformations of random variables are named in terms of the
forward transformation. The forward transformation creates samples, the
inverse is useful for computing probabilities.
Example transformations:
"Exponential"
```
Y = g(X) = exp(X)
X ~ Normal(0, 1) # Univariate.
```
Implies:
```
g^{-1}(Y) = log(Y)
|Jacobian(g^{-1})(y)| = 1 / y
Y ~ LogNormal(0, 1), i.e.,
prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
= (1 / y) Normal(log(y); 0, 1)
```
"ShiftAndScale"
```
Y = g(X) = sqrtSigma * X + mu
X ~ MultivariateNormal(0, I_d)
```
Implies:
```
g^{-1}(Y) = inv(sqrtSigma) * (Y - mu)
|Jacobian(g^{-1})(y)| = det(inv(sqrtSigma))
Y ~ MultivariateNormal(mu, sqrtSigma) , i.e.,
prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
= det(sqrtSigma)^(-d) *
MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d)
```
Example use:
Basic properties:
```python
x = ... # A tensor.
# Evaluate forward transformation.
fwd_x = my_bijector.forward(x)
x != my_bijector.forward(fwd_x) # Not equal because g(x) != g(g(x)).
x == my_bijector.inverse(fwd_x)
```
Computing a log-likelihood:
```python
def transformed_log_pdf(bijector, log_pdf, x):
return (bijector.inverse_log_det_jacobian(x) +
log_pdf(bijector.inverse(x)))
```
Transforming a random outcome:
```python
def transformed_sample(bijector, x):
return bijector.forward(x)
```
"""
# TODO(b/30476956): Try to remove constructor dependence on shape util.
def __init__(self, shaper=None, name=None):
"""Constructs Bijector.
A bijector transforms random variables into new random variables. Managing
shape is typically an important piece of this so a Bijector is usually
composed of ShapeUtil. The ShapeUtil object handles input shape checks as
well as reshaping/transposing for easier linear algebra operations.
Example:
```python
# Create the Y = g(X) = X transform which operates on 4-Tensors of vectors.
identity = Identity(ShapeUtil(batch_ndims=4, event_ndims=1))
# Create the Y = g(X) = exp(X) transform which operates on matrices.
exp = Exp(ShapeUtil(batch_ndims=0, event_ndims=2))
```
See Bijector subclass doc for more details and examples.
Args:
shaper: object used for managing and manipulating shape, typically an
instance of ShapeUtil.
name: The name to give Ops created by the initializer.
"""
self._shaper = shaper
self._name = name or type(self).__name__
@property
def shaper(self):
"""Returns shape object used to manage shape constraints."""
return self._shaper
@property
def name(self):
"""Returns the string name of this bijector."""
return self._name
def forward(self, x, name='forward'):
"""Returns the forward bijector evaluation, i.e., X = g(Y).
Args:
x: `Tensor`. The input to the "forward" evaluation.
name: The name to give this op.
Returns:
`Tensor`.
"""
with ops.name_scope(self.name):
with ops.op_scope([x], name):
x = ops.convert_to_tensor(x)
return self._forward(x)
def inverse(self, x, name='inverse'):
"""Returns the inverse bijector evaluation, i.e., X = g^{-1}(Y).
Args:
x: `Tensor`. The input to the "inverse" evaluation.
name: The name to give this op.
Returns:
`Tensor`.
"""
with ops.name_scope(self.name):
with ops.op_scope([x], name):
x = ops.convert_to_tensor(x)
try:
return self._inverse(x)
except NotImplementedError:
return self._inverse_and_inverse_log_det_jacobian(x)[0]
def inverse_log_det_jacobian(self, x, name='inverse_log_det_jacobian'):
"""Returns the (log o det o Jacobian o inverse)(x).
Mathematically, returns: log(det(dY/dX g^{-1}))(Y).
Args:
x: `Tensor`. The input to the "inverse" Jacobian evaluation.
name: The name to give this op.
Returns:
`Tensor`.
"""
with ops.name_scope(self.name):
with ops.op_scope([x], name):
x = ops.convert_to_tensor(x)
try:
return self._inverse_log_det_jacobian(x)
except NotImplementedError:
return self._inverse_and_inverse_log_det_jacobian(x)[1]
def inverse_and_inverse_log_det_jacobian(
self, x, name='inverse_and_inverse_log_det_jacobian'):
"""Returns both the inverse evaluation and inverse_log_det_jacobian.
Enables possibly more efficient calculation when both inverse and
corresponding Jacobian are needed.
See `inverse()`, `inverse_log_det_jacobian()` for more details.
Args:
x: `Tensor`. The input to the "inverse" Jacobian evaluation.
name: The name to give this op.
Returns:
`Tensor`.
"""
with ops.name_scope(self.name):
with ops.op_scope([x], name):
x = ops.convert_to_tensor(x)
try:
return self._inverse_and_inverse_log_det_jacobian(x)
except NotImplementedError:
return self._inverse(x), self._inverse_log_det_jacobian(x)
# Subclass interface.
def _forward(self, x):
"""Subclass implementation of forward().
Args:
x: `Tensor`. The input to the "forward" evaluation.
Raises:
`NotImplementedError`: if subclass implementation not provided
Returns:
`Tensor`.
"""
raise NotImplementedError('_forward not implemented')
def _inverse(self, x):
"""Subclass implementation of inverse().
Args:
x: `Tensor`. The input to the "inverse" evaluation.
Raises:
`NotImplementedError`: if subclass implementation not provided
Returns:
`Tensor`.
"""
raise NotImplementedError('_inverse not implemented')
def _inverse_log_det_jacobian(self, x):
"""Subclass implementation of inverse_log_det_jacobian().
Args:
x: `Tensor`. The input to the "inverse" Jacobian evaluation.
Raises:
`NotImplementedError`: if subclass implementation not provided
Returns:
`Tensor`.
"""
raise NotImplementedError('_inverse_log_det_jacobian not implemented')
def _inverse_and_inverse_log_det_jacobian(self, x):
"""Subclass implementation of inverse_and_inverse_log_det_jacobian().
Args:
x: `Tensor`. The input to the "inverse" evaluation.
Returns:
List of two `Tensor` items, inverse and inverse_log_det_jacobian.
"""
raise NotImplementedError(
'_inverse_and_inverse_log_det_jacobian not implemented')
class _Identity(_Bijector):
"""Bijector which computes Y = g(X) = X.
Example Use:
```python
# Create the Y=g(X)=X transform which works only on Tensors with 1 batch
# ndims and 1 event ndim (i.e., vector of vectors).
identity = Identity(ShapeUtil(batch_ndims=1, event_ndims=1))
x = [[1., 2],
[3, 4]]
x == identity.forward(x) == identity.inverse(x)
```
"""
# TODO(b/30476956): Try to remove constructor dependence on shape util.
def __init__(self, shaper=None, name='Identity'):
super(_Identity, self).__init__(shaper, name)
def _forward(self, x):
return x
def _inverse(self, x):
return x
def _inverse_log_det_jacobian(self, x):
result_shape = self.shaper.get_shape(
x, sample=True, batch=True, event=False)
return array_ops.zeros(result_shape, dtype=x.dtype)
class _Exp(_Bijector):
"""Bijector which computes Y = g(X) = exp(X).
Example Use:
```python
# Create the Y=g(X)=exp(X) transform which works only on Tensors with 1
# batch ndims and 2 event ndim (i.e., vector of matrices).
exp = Exp(ShapeUtil(batch_ndims=1, event_ndims=2))
x = [[[1., 2],
[3, 4]],
[[5, 6],
[7, 8]]]
exp(x) == exp.forward(x)
log(x) == exp.inverse(x)
```
"""
# TODO(b/30476956): Try to remove constructor dependence on shape util.
def __init__(self, shaper=None, name='Exp'):
super(_Exp, self).__init__(shaper, name)
def _forward(self, x):
return math_ops.exp(x)
def _inverse(self, x):
return math_ops.log(x)
def _inverse_log_det_jacobian(self, x):
d = self.shaper.get_event_dims(x)
return -math_ops.reduce_sum(math_ops.log(x), d)
def _inverse_and_inverse_log_det_jacobian(self, x):
y = math_ops.log(x)
d = self.shaper.get_event_dims(x)
return y, -math_ops.reduce_sum(y, d)

View File

@ -0,0 +1,340 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Binomial distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
# pylint: enable=line-too-long
class Binomial(distribution.Distribution):
"""Binomial distribution.
This distribution is parameterized by a vector `p` of probabilities and `n`,
the total counts.
#### Mathematical details
The Binomial is a distribution over the number of successes in `n` independent
trials, with each trial having the same probability of success `p`.
The probability mass function (pmf):
```pmf(k) = n! / (k! * (n - k)!) * (p)^k * (1 - p)^(n - k)```
#### Examples
Create a single distribution, corresponding to 5 coin flips.
```python
dist = Binomial(n=5., p=.5)
```
Create a single distribution (using logits), corresponding to 5 coin flips.
```python
dist = Binomial(n=5., logits=0.)
```
Creates 3 distributions with the third distribution most likely to have
successes.
```python
p = [.2, .3, .8]
# n will be broadcast to [4., 4., 4.], to match p.
dist = Binomial(n=4., p=p)
```
The distribution functions can be evaluated on counts.
```python
# counts same shape as p.
counts = [1., 2, 3]
dist.prob(counts) # Shape [3]
# p will be broadcast to [[.2, .3, .8], [.2, .3, .8]] to match counts.
counts = [[1., 2, 1], [2, 2, 4]]
dist.prob(counts) # Shape [2, 3]
# p will be broadcast to shape [5, 7, 3] to match counts.
counts = [[...]] # Shape [5, 7, 3]
dist.prob(counts) # Shape [5, 7, 3]
```
"""
def __init__(self,
n,
logits=None,
p=None,
validate_args=True,
allow_nan_stats=False,
name="Binomial"):
"""Initialize a batch of Binomial distributions.
Args:
n: Non-negative floating point tensor with shape broadcastable to
`[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`.
Defines this as a batch of `N1 x ... x Nm` different Binomial
distributions. Its components should be equal to integer values.
logits: Floating point tensor representing the log-odds of a
positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
the same dtype as `n`. Each entry represents logits for the probability
of success for independent Binomial distributions.
p: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the
probability of success for independent Binomial distributions.
validate_args: Whether to assert valid values for parameters `n` and `p`,
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
guaranteed.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class.
Examples:
```python
# Define 1-batch of a binomial distribution.
dist = Binomial(n=2., p=.9)
# Define a 2-batch.
dist = Binomial(n=[4., 5], p=[.1, .3])
```
"""
self._logits, self._p = distribution_util.get_logits_and_prob(
name=name, logits=logits, p=p, validate_args=validate_args)
with ops.op_scope([n], name):
with ops.control_dependencies([
check_ops.assert_non_negative(
n, message="n has negative components."),
distribution_util.assert_integer_form(
n, message="n has non-integer components."
)] if validate_args else []):
self._n = array_ops.identity(n, name="convert_n")
self._name = name
self._validate_args = validate_args
self._allow_nan_stats = allow_nan_stats
self._mean = self._n * self._p
self._get_batch_shape = self._mean.get_shape()
self._get_event_shape = tensor_shape.TensorShape([])
@property
def name(self):
"""Name to prepend to all ops."""
return self._name
@property
def dtype(self):
"""dtype of samples from this distribution."""
return self._p.dtype
@property
def validate_args(self):
"""Boolean describing behavior on invalid input."""
return self._validate_args
@property
def allow_nan_stats(self):
"""Boolean describing behavior when a stat is undefined for batch member."""
return self._allow_nan_stats
def batch_shape(self, name="batch_shape"):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
The product of the dimensions of the `batch_shape` is the number of
independent distributions of this kind the instance represents.
Args:
name: name to give to the op
Returns:
`Tensor` `batch_shape`
"""
return array_ops.shape(self._mean)
def get_batch_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `batch_shape`. May be only partially defined.
Returns:
batch shape
"""
return self._get_batch_shape
def event_shape(self, name="event_shape"):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
Args:
name: name to give to the op
Returns:
`Tensor` `event_shape`
"""
with ops.name_scope(self.name):
with ops.op_scope([], name):
return constant_op.constant([], name=name, dtype=dtypes.int32)
def get_event_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `event_shape`. May be only partially defined.
Returns:
event shape
"""
return self._get_event_shape
@property
def n(self):
"""Number of trials."""
return self._n
@property
def logits(self):
"""Log-odds."""
return self._logits
@property
def p(self):
"""Probability of success."""
return self._p
def mean(self, name="mean"):
"""Mean of the distribution."""
with ops.name_scope(self.name):
return array_ops.identity(self._mean, name=name)
def variance(self, name="variance"):
"""Variance of the distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._n, self._p], name):
return self._n * self._p * (1 - self._p)
def std(self, name="std"):
"""Standard deviation of the distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._n, self._p], name):
return math_ops.sqrt(self.variance())
def mode(self, name="mode"):
"""Mode of the distribution.
Note that when `(n + 1) * p` is an integer, there are actually two modes.
Namely, `(n + 1) * p` and `(n + 1) * p - 1` are both modes. Here we return
only the larger of the two modes.
Args:
name: The name for this op.
Returns:
The mode of the Binomial distribution.
"""
with ops.name_scope(self.name):
with ops.op_scope([self._n, self._p], name):
return math_ops.floor((self._n + 1) * self._p)
def log_prob(self, counts, name="log_prob"):
"""`Log(P[counts])`, computed for every batch member.
For each batch member of counts `k`, `P[counts]` is the probability that
after sampling `n` draws from this Binomial distribution, the number of
successes is `k`. Note that different sequences of draws can result in the
same counts, thus the probability includes a combinatorial coefficient.
Args:
counts: Non-negative tensor with dtype `dtype` and whose shape can be
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
less than or equal to `n` and its components are equal to integer
values.
name: Name to give this Op, defaults to "log_prob".
Returns:
Log probabilities for each record, shape `[N1,...,Nm]`.
"""
n = self._n
p = self._p
with ops.name_scope(self.name):
with ops.op_scope([self._n, self._p, counts], name):
counts = self._check_counts(counts)
prob_prob = counts * math_ops.log(p) + (
n - counts) * math_ops.log(1 - p)
combinations = math_ops.lgamma(n + 1) - math_ops.lgamma(
counts + 1) - math_ops.lgamma(n - counts + 1)
log_prob = prob_prob + combinations
return log_prob
def prob(self, counts, name="prob"):
"""`P[counts]`, computed for every batch member.
For each batch member of counts `k`, `P[counts]` is the probability that
after sampling `n` draws from this Binomial distribution, the number of
successes is `k`. Note that different sequences of draws can result in the
same counts, thus the probability includes a combinatorial coefficient.
Args:
counts: Non-negative tensor with dtype `dtype` and whose shape can be
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
less than or equal to `n` and its components are equal to integer
values.
name: Name to give this Op, defaults to "prob".
Returns:
Probabilities for each record, shape `[N1,...,Nm]`.
"""
return super(Binomial, self).prob(counts, name=name)
@property
def is_continuous(self):
return False
@property
def is_reparameterized(self):
return False
def _check_counts(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
if not self.validate_args:
return counts
return control_flow_ops.with_dependencies([
check_ops.assert_non_negative(
counts, message="counts has negative components."),
check_ops.assert_less_equal(
counts, self._n, message="counts are not less than or equal to n."),
distribution_util.assert_integer_form(
counts, message="counts have non-integer components.")], counts)

View File

@ -34,11 +34,6 @@ class Categorical(distribution.Distribution):
The categorical distribution is parameterized by the log-probabilities
of a set of classes.
Note, the following methods of the base class aren't implemented:
* mean
* cdf
* log_cdf
"""
def __init__(
@ -57,10 +52,10 @@ class Categorical(distribution.Distribution):
indexes into the classes.
dtype: The type of the event samples (default: int32).
validate_args: Unused in this distribution.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: A name for this distribution (optional).
"""
self._allow_nan_stats = allow_nan_stats
@ -177,8 +172,7 @@ class Categorical(distribution.Distribution):
samples = math_ops.cast(samples, self._dtype)
ret = array_ops.reshape(
array_ops.transpose(samples),
array_ops.concat(
0, [array_ops.expand_dims(n, 0), self.batch_shape()]))
array_ops.concat(0, ([n], self.batch_shape())))
ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
.concatenate(self.get_batch_shape()))
return ret

View File

@ -42,15 +42,15 @@ class Chi2(gamma.Gamma):
"""Construct Chi2 distributions with parameter `df`.
Args:
df: `float` or `double` tensor, the degrees of freedom of the
df: Floating point tensor, the degrees of freedom of the
distribution(s). `df` must contain only positive values.
validate_args: Whether to assert that `df > 0`, and that `x > 0` in the
methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prepend to all ops created by this distribution.
"""
# Even though all stats of chi2 are defined for valid parameters, this is

View File

@ -19,9 +19,8 @@ from __future__ import print_function
# pylint: disable=line-too-long
import numpy as np
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -29,7 +28,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
@ -37,24 +35,6 @@ from tensorflow.python.ops import special_math_ops
# pylint: enable=line-too-long
def _assert_close(x, y, data=None, summarize=None, name=None):
if x.dtype.is_integer:
return check_ops.assert_equal(
x, y, data=data, summarize=summarize, name=name)
with ops.op_scope([x, y, data], name, "assert_close"):
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y")
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
if data is None:
data = [
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
y.name, y
]
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
return logging_ops.Assert(condition, data, summarize=summarize)
class Dirichlet(distribution.Distribution):
"""Dirichlet distribution.
@ -117,6 +97,7 @@ class Dirichlet(distribution.Distribution):
x = [.2, .3, .5]
dist.prob(x) # Shape [2]
```
"""
def __init__(self,
@ -127,16 +108,16 @@ class Dirichlet(distribution.Distribution):
"""Initialize a batch of Dirichlet distributions.
Args:
alpha: Positive `float` or `double` tensor with shape broadcastable to
alpha: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
different `k` class Dirichlet distributions.
validate_args: Whether to assert valid values for parameters `alpha` and
`x` in `prob` and `log_prob`. If False, correct behavior is not
`x` in `prob` and `log_prob`. If `False`, correct behavior is not
guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class.
Examples:
@ -149,6 +130,7 @@ class Dirichlet(distribution.Distribution):
# Define a 2-batch of 3-class distributions.
dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
```
"""
with ops.op_scope([alpha], name):
alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps")
@ -302,7 +284,9 @@ class Dirichlet(distribution.Distribution):
array_ops.ones_like(self._alpha, dtype=self.dtype)))
else:
return control_flow_ops.with_dependencies([
check_ops.assert_less(one, self._alpha)
check_ops.assert_less(
one, self._alpha,
message="mode not defined for components of alpha <= 1")
], mode)
def entropy(self, name="entropy"):
@ -334,7 +318,7 @@ class Dirichlet(distribution.Distribution):
"""`Log(P[counts])`, computed for every batch member.
Args:
x: Non-negative `float` or `double`, tensor whose shape can
x: Non-negative tensor with dtype `dtype` and whose shape can
be broadcast with `self.alpha`. For fixed leading dimensions, the last
dimension represents counts for the corresponding Dirichlet distribution
in `self.alpha`. `x` is only legal if it sums up to one.
@ -359,7 +343,7 @@ class Dirichlet(distribution.Distribution):
"""`P[x]`, computed for every batch member.
Args:
x: Non-negative `float`, `double` tensor whose shape can
x: Non-negative tensor with dtype `dtype` and whose shape can
be broadcast with `self.alpha`. For fixed leading dimensions, the last
dimension represents x for the corresponding Dirichlet distribution in
`self.alpha` and `self.beta`. `x` is only legal if it sums up to one.
@ -407,7 +391,8 @@ class Dirichlet(distribution.Distribution):
x = ops.convert_to_tensor(x, name="x_before_deps")
candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1])
one = constant_op.constant(1., self.dtype)
dependencies = [check_ops.assert_positive(x), check_ops.assert_less(x, one),
_assert_close(one, candidate_one)
dependencies = [check_ops.assert_positive(x), check_ops.assert_less(
x, one, message="x has components greater than or equal to 1"),
distribution_util.assert_close(one, candidate_one)
] if self.validate_args else []
return control_flow_ops.with_dependencies(dependencies, x)

View File

@ -13,13 +13,15 @@
# limitations under the License.
# ==============================================================================
"""The Dirichlet Multinomial distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@ -30,34 +32,6 @@ from tensorflow.python.ops import special_math_ops
# pylint: enable=line-too-long
def _assert_integer_form(x):
"""Check x for integer components (or floats that are equal to integers)."""
x = ops.convert_to_tensor(x, name='x')
casted_x = math_ops.to_int64(x)
return check_ops.assert_equal(x, math_ops.cast(
math_ops.round(casted_x), x.dtype))
def _log_combinations(n, counts, name='log_combinations'):
"""Log number of ways counts could have come in."""
# First a bit about the number of ways counts could have come in:
# E.g. if counts = [1, 2], then this is 3 choose 2.
# In general, this is (sum counts)! / sum(counts!)
# The sum should be along the last dimension of counts. This is the
# "distribution" dimension. Here n a priori represents the sum of counts.
with ops.op_scope([counts], name):
# To compute factorials, use the fact that Gamma(n + 1) = n!
# Compute two terms, each a sum over counts. Compute each for each
# batch member.
# Log Gamma((sum counts) + 1) = Log((sum counts)!)
total_permutations = math_ops.lgamma(n + 1)
# sum(Log Gamma(counts + 1)) = Log sum(counts!)
counts_factorial = math_ops.lgamma(counts + 1)
redundant_permutations = math_ops.reduce_sum(counts_factorial,
reduction_indices=[-1])
return total_permutations - redundant_permutations
class DirichletMultinomial(distribution.Distribution):
"""DirichletMultinomial mixture distribution.
@ -126,38 +100,35 @@ class DirichletMultinomial(distribution.Distribution):
counts = [2, 1, 0]
dist.pmf(counts) # Shape [2]
```
"""
# TODO(b/27419586) Change docstring for dtype of alpha once int allowed.
def __init__(self,
n,
alpha,
allow_arbitrary_counts=False,
validate_args=True,
allow_nan_stats=False,
name='DirichletMultinomial'):
name="DirichletMultinomial"):
"""Initialize a batch of DirichletMultinomial distributions.
Args:
n: Non-negative `float` or `double` tensor with shape
broadcastable to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch
of `N1 x ... x Nm` different Dirichlet multinomial distributions. Its
components should be equal to integral values.
alpha: Positive `float` or `double` tensor with shape broadcastable to
`[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
different `k` class Dirichlet multinomial distributions.
allow_arbitrary_counts: Boolean. This represents whether the pmf/cdf
allows for the `counts` tensor to be non-integral values.
The pmf/cdf are functions that can be evaluated at non-integral values,
but are only a distribution over non-negative integers. If
`validate_args` is `False`, this assertion is turned off.
n: Non-negative floating point tensor, whose dtype is the same as
`alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`.
Defines this as a batch of `N1 x ... x Nm` different Dirichlet
multinomial distributions. Its components should be equal to integer
values.
alpha: Positive floating point tensor, whose dtype is the same as
`n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines
this as a batch of `N1 x ... x Nm` different `k` class Dirichlet
multinomial distributions.
validate_args: Whether to assert valid values for parameters `alpha` and
`n`, and `x` in `prob` and `log_prob`. If False, correct behavior is
`n`, and `x` in `prob` and `log_prob`. If `False`, correct behavior is
not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class.
Examples:
@ -170,11 +141,11 @@ class DirichletMultinomial(distribution.Distribution):
# Define a 2-batch of 3-class distributions.
dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
```
"""
self._allow_nan_stats = allow_nan_stats
self._validate_args = validate_args
self._name = name
self._allow_arbitrary_counts = allow_arbitrary_counts
with ops.op_scope([n, alpha], name):
# Broadcasting works because:
# * The broadcasting convention is to prepend dimensions of size [1], and
@ -186,8 +157,7 @@ class DirichletMultinomial(distribution.Distribution):
# * All calls involving `counts` eventually require a broadcast between
# `counts` and alpha.
self._alpha = self._check_alpha(alpha)
n = self._check_n(n)
self._n = math_ops.cast(n, self._alpha.dtype)
self._n = self._check_n(n)
self._alpha_sum = math_ops.reduce_sum(
self._alpha, reduction_indices=[-1], keep_dims=False)
@ -227,7 +197,7 @@ class DirichletMultinomial(distribution.Distribution):
"""dtype of samples from this distribution."""
return self._alpha.dtype
def mean(self, name='mean'):
def mean(self, name="mean"):
"""Class means for every batch member."""
alpha = self._alpha
alpha_sum = self._alpha_sum
@ -237,7 +207,7 @@ class DirichletMultinomial(distribution.Distribution):
mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1)
return array_ops.expand_dims(n, -1) * mean_no_n
def variance(self, name='mean'):
def variance(self, name="mean"):
"""Class variances for every batch member.
The variance for each batch member is defined as the following:
@ -279,7 +249,7 @@ class DirichletMultinomial(distribution.Distribution):
variance *= array_ops.expand_dims(shared_factor, -1)
return variance
def batch_shape(self, name='batch_shape'):
def batch_shape(self, name="batch_shape"):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
The product of the dimensions of the `batch_shape` is the number of
@ -305,7 +275,7 @@ class DirichletMultinomial(distribution.Distribution):
"""
return self._get_batch_shape
def event_shape(self, name='event_shape'):
def event_shape(self, name="event_shape"):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
Args:
@ -328,15 +298,15 @@ class DirichletMultinomial(distribution.Distribution):
"""
return self._get_event_shape
def cdf(self, x, name='cdf'):
def cdf(self, x, name="cdf"):
raise NotImplementedError(
'DirichletMultinomial does not have a well-defined cdf.')
"DirichletMultinomial does not have a well-defined cdf.")
def log_cdf(self, x, name='log_cdf'):
def log_cdf(self, x, name="log_cdf"):
raise NotImplementedError(
'DirichletMultinomial does not have a well-defined cdf.')
"DirichletMultinomial does not have a well-defined cdf.")
def log_prob(self, counts, name='log_prob'):
def log_prob(self, counts, name="log_prob"):
"""`Log(P[counts])`, computed for every batch member.
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
@ -346,12 +316,11 @@ class DirichletMultinomial(distribution.Distribution):
probability includes a combinatorial coefficient.
Args:
counts: Non-negative `float` or `double` tensor whose shape can
be broadcast with `self.alpha`. For fixed leading dimensions, the last
counts: Non-negative tensor with dtype `dtype` and whose shape can be
broadcast with `self.alpha`. For fixed leading dimensions, the last
dimension represents counts for the corresponding Dirichlet Multinomial
distribution in `self.alpha`. `counts` is only legal if it sums up to
`n` and its components are equal to integral values. The second
condition is relaxed if `allow_arbitrary_counts` is set.
`n` and its components are equal to integer values.
name: Name to give this Op, defaults to "log_prob".
Returns:
@ -362,25 +331,14 @@ class DirichletMultinomial(distribution.Distribution):
with ops.name_scope(self.name):
with ops.op_scope([n, alpha, counts], name):
counts = self._check_counts(counts)
# Use the same dtype as alpha for computations.
counts = math_ops.cast(counts, self.dtype)
ordered_prob = (special_math_ops.lbeta(alpha + counts) -
special_math_ops.lbeta(alpha))
log_prob = ordered_prob + _log_combinations(n, counts)
# If alpha = counts = [[]], ordered_prob carries the right shape, which
# is []. However, since reduce_sum([[]]) = [0], log_combinations = [0],
# which is not correct. Luckily, [] + [0] = [], so the sum is fine, but
# shape must be inferred from ordered_prob. We must also make this
# broadcastable with n, so this is multiplied by n to ensure the shape
# is correctly inferred.
# Note also that tf.constant([]).get_shape() =
# TensorShape([Dimension(0)])
broadcasted_tensor = ordered_prob * n
log_prob.set_shape(broadcasted_tensor.get_shape())
log_prob = ordered_prob + distribution_util.log_combinations(
n, counts)
return log_prob
def prob(self, counts, name='prob'):
def prob(self, counts, name="prob"):
"""`P[counts]`, computed for every batch member.
For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability
@ -390,12 +348,11 @@ class DirichletMultinomial(distribution.Distribution):
probability includes a combinatorial coefficient.
Args:
counts: Non-negative `float`, `double` tensor whose shape can
be broadcast with `self.alpha`. For fixed leading dimensions, the last
counts: Non-negative tensor with dtype `dtype` and whose shape can be
broadcast with `self.alpha`. For fixed leading dimensions, the last
dimension represents counts for the corresponding Dirichlet Multinomial
distribution in `self.alpha`. `counts` is only legal if it sums up to
`n` and its components are equal to integral values. The second
condition is relaxed if `allow_arbitrary_counts` is set.
`n` and its components are equal to integer values.
name: Name to give this Op, defaults to "prob".
Returns:
@ -405,21 +362,21 @@ class DirichletMultinomial(distribution.Distribution):
def _check_counts(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
counts = ops.convert_to_tensor(counts, name='counts')
counts = ops.convert_to_tensor(counts, name="counts")
if not self.validate_args:
return counts
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
dependencies = [check_ops.assert_non_negative(counts),
check_ops.assert_equal(self._n,
math_ops.cast(candidate_n,
self._n.dtype))]
if not self._allow_arbitrary_counts:
dependencies += [_assert_integer_form(counts)]
return control_flow_ops.with_dependencies(dependencies, counts)
return control_flow_ops.with_dependencies([
check_ops.assert_non_negative(counts),
check_ops.assert_equal(
self._n, candidate_n,
message="counts do not sum to n"
),
distribution_util.assert_integer_form(counts)], counts)
def _check_alpha(self, alpha):
alpha = ops.convert_to_tensor(alpha, name='alpha')
alpha = ops.convert_to_tensor(alpha, name="alpha")
if not self.validate_args:
return alpha
return control_flow_ops.with_dependencies(
@ -427,11 +384,12 @@ class DirichletMultinomial(distribution.Distribution):
check_ops.assert_positive(alpha)], alpha)
def _check_n(self, n):
n = ops.convert_to_tensor(n, name='n')
n = ops.convert_to_tensor(n, name="n")
if not self.validate_args:
return n
return control_flow_ops.with_dependencies(
[check_ops.assert_non_negative(n), _assert_integer_form(n)], n)
[check_ops.assert_non_negative(n),
distribution_util.assert_integer_form(n)], n)
@property
def is_continuous(self):

View File

@ -0,0 +1,177 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for probability distributions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
def assert_close(
x, y, data=None, summarize=None, message=None, name="assert_close"):
"""Assert that that x and y are within machine epsilon of each other.
Args:
x: Numeric `Tensor`
y: Numeric `Tensor`
data: The tensors to print out if the condition is `False`. Defaults to
error message and first few entries of `x` and `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Returns:
Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
"""
message = message or ""
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y")
if x.dtype.is_integer:
return check_ops.assert_equal(
x, y, data=data, summarize=summarize, message=message, name=name)
with ops.op_scope([x, y, data], name, "assert_close"):
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
if data is None:
data = [
message,
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
y.name, y
]
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
return logging_ops.Assert(
condition, data, summarize=summarize)
def assert_integer_form(
x, data=None, summarize=None, message=None, name="assert_integer_form"):
"""Assert that x has integer components (or floats equal to integers).
Args:
x: Numeric `Tensor`
data: The tensors to print out if the condition is `False`. Defaults to
error message and first few entries of `x` and `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Returns:
Op raising `InvalidArgumentError` if round(x) != x.
"""
message = message or "x has non-integer components"
x = ops.convert_to_tensor(x, name="x")
casted_x = math_ops.to_int64(x)
return check_ops.assert_equal(
x, math_ops.cast(math_ops.round(casted_x), x.dtype),
data=data, summarize=summarize, message=message, name=name)
def get_logits_and_prob(
logits=None, p=None, multidimensional=False, validate_args=True, name=None):
"""Converts logits to probabilities and vice-versa, and returns both.
Args:
logits: Numeric `Tensor` representing log-odds.
p: Numeric `Tensor` representing probabilities.
multidimensional: Given `p` a [N1, N2, ... k] dimensional tensor,
whether the last dimension represents the probability between k classes.
This will additionally assert that the values in the last dimension
sum to one. If `False`, will instead assert that each value is in
`[0, 1]`.
validate_args: Whether to assert `0 <= p <= 1` if multidimensional is
`False`, otherwise that the last dimension of `p` sums to one.
name: A name for this operation (optional).
Returns:
Tuple with `logits` and `p`. If `p` has an entry that is `0` or `1`, then
the corresponding entry in the returned logits will be `-Inf` and `Inf`
respectively.
Raises:
ValueError: if neither `p` nor `logits` were passed in, or both were.
"""
if p is None and logits is None:
raise ValueError("Must pass p or logits.")
elif p is not None and logits is not None:
raise ValueError("Must pass either p or logits, not both.")
elif p is None:
with ops.op_scope([logits], name):
logits = array_ops.identity(logits, name="logits")
with ops.name_scope(name):
with ops.name_scope("p"):
p = math_ops.sigmoid(logits)
elif logits is None:
with ops.name_scope(name):
with ops.name_scope("p"):
p = array_ops.identity(p)
if validate_args:
one = constant_op.constant(1., p.dtype)
dependencies = [check_ops.assert_non_negative(p)]
if multidimensional:
dependencies += [assert_close(
math_ops.reduce_sum(p, reduction_indices=[-1]),
one, message="p does not sum to 1.")]
else:
dependencies += [check_ops.assert_less_equal(
p, one, message="p has components greater than 1.")]
p = control_flow_ops.with_dependencies(dependencies, p)
with ops.name_scope("logits"):
logits = math_ops.log(p) - math_ops.log(1. - p)
return (logits, p)
def log_combinations(n, counts, name="log_combinations"):
"""Multinomial coefficient.
Given `n` and `counts`, where `counts` has last dimension `k`, we compute
the multinomial coefficient as:
```n! / sum_i n_i!```
where `i` runs over all `k` classes.
Args:
n: Numeric `Tensor` broadcastable with `counts`. This represents `n`
outcomes.
counts: Numeric `Tensor` broadcastable with `n`. This represents counts
in `k` classes, where `k` is the last dimension of the tensor.
name: A name for this operation (optional).
Returns:
`Tensor` representing the multinomial coefficient between `n` and `counts`.
"""
# First a bit about the number of ways counts could have come in:
# E.g. if counts = [1, 2], then this is 3 choose 2.
# In general, this is (sum counts)! / sum(counts!)
# The sum should be along the last dimension of counts. This is the
# "distribution" dimension. Here n a priori represents the sum of counts.
with ops.op_scope([n, counts], name):
total_permutations = math_ops.lgamma(n + 1)
counts_factorial = math_ops.lgamma(counts + 1)
redundant_permutations = math_ops.reduce_sum(counts_factorial,
reduction_indices=[-1])
return total_permutations - redundant_permutations

View File

@ -46,15 +46,15 @@ class Exponential(gamma.Gamma):
"""Construct Exponential distribution with parameter `lam`.
Args:
lam: `float` or `double` tensor, the rate of the distribution(s).
lam: Floating point tensor, the rate of the distribution(s).
`lam` must contain only positive values.
validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the
methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prepend to all ops created by this distribution.
"""
# Even though all statistics of are defined for valid inputs, this is not
@ -95,13 +95,13 @@ class Exponential(gamma.Gamma):
broadcast_shape = self._lam.get_shape()
with ops.op_scope([self.lam, n], name, "ExponentialSample"):
n = ops.convert_to_tensor(n, name="n")
shape = array_ops.concat(
0, [array_ops.pack([n]), array_ops.shape(self._lam)])
shape = array_ops.concat(0, ([n], array_ops.shape(self._lam)))
# Sample uniformly-at-random from the open-interval (0, 1).
sampled = random_ops.random_uniform(
shape, minval=np.nextafter(
self.dtype.as_numpy_dtype(0.), self.dtype.as_numpy_dtype(1.)),
maxval=constant_op.constant(1.0, dtype=self.dtype),
seed=seed,
dtype=self.dtype)
n_val = tensor_util.constant_value(n)

View File

@ -69,19 +69,19 @@ class Gamma(distribution.Distribution):
broadcasting (e.g. `alpha + beta` is a valid operation).
Args:
alpha: `float` or `double` tensor, the shape params of the
alpha: Floating point tensor, the shape params of the
distribution(s).
alpha must contain only positive values.
beta: `float` or `double` tensor, the inverse scale params of the
beta: Floating point tensor, the inverse scale params of the
distribution(s).
beta must contain only positive values.
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prepend to all ops created by this distribution.
Raises:
@ -213,9 +213,12 @@ class Gamma(distribution.Distribution):
nan = np.nan * self._ones()
return math_ops.select(alpha_ge_1, mode_if_defined, nan)
else:
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
one = constant_op.constant(1.0, dtype=self.dtype)
return control_flow_ops.with_dependencies(
[check_ops.assert_less(one, alpha)], mode_if_defined)
[check_ops.assert_less(
one, alpha,
message="mode not defined for components of alpha <= 1"
)], mode_if_defined)
def variance(self, name="variance"):
"""Variance of each batch member."""

View File

@ -69,18 +69,18 @@ class InverseGamma(distribution.Distribution):
broadcasting (e.g. `alpha + beta` is a valid operation).
Args:
alpha: `float` or `double` tensor, the shape params of the
alpha: Floating point tensor, the shape params of the
distribution(s).
alpha must contain only positive values.
beta: `float` or `double` tensor, the scale params of the distribution(s).
beta: Floating point tensor, the scale params of the distribution(s).
beta must contain only positive values.
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prepend to all ops created by this distribution.
Raises:
@ -206,9 +206,12 @@ class InverseGamma(distribution.Distribution):
nan = np.nan * self._ones()
return math_ops.select(alpha_gt_1, mean_if_defined, nan)
else:
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
one = constant_op.constant(1.0, dtype=self.dtype)
return control_flow_ops.with_dependencies(
[check_ops.assert_less(one, alpha)], mean_if_defined)
[check_ops.assert_less(
one, alpha,
message="mean not defined for components of alpha <= 1")],
mean_if_defined)
def mode(self, name="mode"):
"""Mode of each batch member.
@ -250,9 +253,12 @@ class InverseGamma(distribution.Distribution):
nan = np.nan * self._ones()
return math_ops.select(alpha_gt_2, var_if_defined, nan)
else:
two = ops.convert_to_tensor(2.0, dtype=self.dtype)
two = constant_op.constant(2.0, dtype=self.dtype)
return control_flow_ops.with_dependencies(
[check_ops.assert_less(two, alpha)], var_if_defined)
[check_ops.assert_less(
two, alpha,
message="variance not defined for components of alpha <= 2")],
var_if_defined)
def log_prob(self, x, name="log_prob"):
"""Log prob of observations in `x` under these InverseGamma distribution(s).

View File

@ -34,9 +34,9 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
Args:
dist_a: instance of distributions.Distribution.
dist_b: instance of distributions.Distribution.
allow_nan: If False (default), a runtime error is raised
allow_nan: If `False` (default), a runtime error is raised
if the KL returns NaN values for any batch entry of the given
distributions. If True, the KL may return a NaN for the given entry.
distributions. If `True`, the KL may return a NaN for the given entry.
name: (optional) Name scope to use for created operations.
Returns:

View File

@ -60,17 +60,17 @@ class Laplace(distribution.Distribution):
broadcasting (e.g., `loc / scale` is a valid operation).
Args:
loc: `float` or `double` tensor which characterizes the location (center)
loc: Floating point tensor which characterizes the location (center)
of the distribution.
scale: `float` or `double`, positive-valued tensor which characterzes the
spread of the distribution.
scale: Positive floating point tensor which characterizes the spread of
the distribution.
validate_args: Whether to validate input with asserts. If `validate_args`
is `False`, and the inputs are invalid, correct behavior is not
guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
@ -294,8 +294,7 @@ class Laplace(distribution.Distribution):
with ops.op_scope([self._loc, self._scale, n], name):
n = ops.convert_to_tensor(n)
n_val = tensor_util.constant_value(n)
shape = array_ops.concat(
0, [array_ops.pack([n]), self.batch_shape()])
shape = array_ops.concat(0, ([n], self.batch_shape()))
# Sample uniformly-at-random from the open-interval (-1, 1).
uniform_samples = random_ops.random_uniform(
shape=shape,

View File

@ -0,0 +1,343 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Multinomial distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
# pylint: enable=line-too-long
class Multinomial(distribution.Distribution):
"""Multinomial distribution.
This distribution is parameterized by a vector `p` of probability
parameters for `k` classes and `n`, the counts per each class..
#### Mathematical details
The Multinomial is a distribution over k-class count data, meaning
for each k-tuple of non-negative integer `counts = [n_1,...,n_k]`, we have a
probability of these draws being made from the distribution. The distribution
has hyperparameters `p = (p_1,...,p_k)`, and probability mass
function (pmf):
```pmf(counts) = n! / (n_1!...n_k!) * (p_1)^n_1*(p_2)^n_2*...(p_k)^n_k```
where above `n = sum_j n_j`, `n!` is `n` factorial.
#### Examples
Create a 3-class distribution, with the 3rd class is most likely to be drawn,
using logits..
```python
logits = [-50., -43, 0]
dist = Multinomial(n=4., logits=logits)
```
Create a 3-class distribution, with the 3rd class is most likely to be drawn.
```python
p = [.2, .3, .5]
dist = Multinomial(n=4., p=p)
```
The distribution functions can be evaluated on counts.
```python
# counts same shape as p.
counts = [1., 0, 3]
dist.prob(counts) # Shape []
# p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
counts = [[1., 2, 1], [2, 2, 0]]
dist.prob(counts) # Shape [2]
# p will be broadcast to shape [5, 7, 3] to match counts.
counts = [[...]] # Shape [5, 7, 3]
dist.prob(counts) # Shape [5, 7]
```
Create a 2-batch of 3-class distributions.
```python
p = [[.1, .2, .7], [.3, .3, .4]] # Shape [2, 3]
dist = Multinomial(n=[4., 5], p=p)
counts = [[2., 1, 1], [3, 1, 1]]
dist.prob(counts) # Shape [2]
```
"""
def __init__(self,
n,
logits=None,
p=None,
validate_args=True,
allow_nan_stats=False,
name="Multinomial"):
"""Initialize a batch of Multinomial distributions.
Args:
n: Non-negative floating point tensor with shape broadcastable to
`[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
`N1 x ... x Nm` different Multinomial distributions. Its components
should be equal to integer values.
logits: Floating point tensor representing the log-odds of a
positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm`
different `k` class Multinomial distributions.
p: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm, k]` `m >= 0` and same dtype as `n`. Defines this as
a batch of `N1 x ... x Nm` different `k` class Multinomial
distributions. `p`'s components in the last portion of its shape should
sum up to 1.
validate_args: Whether to assert valid values for parameters `n` and `p`,
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
guaranteed.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class.
Examples:
```python
# Define 1-batch of 2-class multinomial distribution,
# also known as a Binomial distribution.
dist = Multinomial(n=2., p=[.1, .9])
# Define a 2-batch of 3-class distributions.
dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]])
```
"""
self._logits, self._p = distribution_util.get_logits_and_prob(
name=name, logits=logits, p=p, validate_args=validate_args,
multidimensional=True)
with ops.op_scope([n, self._p], name):
with ops.control_dependencies([
check_ops.assert_non_negative(
n, message="n has negative components."),
distribution_util.assert_integer_form(
n, message="n has non-integer components."
)] if validate_args else []):
self._n = array_ops.identity(n, name="convert_n")
self._name = name
self._validate_args = validate_args
self._allow_nan_stats = allow_nan_stats
self._mean = array_ops.expand_dims(n, -1) * self._p
# Only used for inferring shape.
self._broadcast_shape = math_ops.reduce_sum(self._mean,
reduction_indices=[-1],
keep_dims=False)
self._get_batch_shape = self._broadcast_shape.get_shape()
self._get_event_shape = (
self._mean.get_shape().with_rank_at_least(1)[-1:])
@property
def n(self):
"""Number of trials."""
return self._n
@property
def p(self):
"""Event probabilities."""
return self._p
@property
def logits(self):
"""Log-odds."""
return self._logits
@property
def name(self):
"""Name to prepend to all ops."""
return self._name
@property
def dtype(self):
"""dtype of samples from this distribution."""
return self._p.dtype
@property
def validate_args(self):
"""Boolean describing behavior on invalid input."""
return self._validate_args
@property
def allow_nan_stats(self):
"""Boolean describing behavior when a stat is undefined for batch member."""
return self._allow_nan_stats
def batch_shape(self, name="batch_shape"):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
The product of the dimensions of the `batch_shape` is the number of
independent distributions of this kind the instance represents.
Args:
name: name to give to the op
Returns:
`Tensor` `batch_shape`
"""
with ops.name_scope(self.name):
with ops.op_scope([self._broadcast_shape], name):
return array_ops.shape(self._broadcast_shape)
def get_batch_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `batch_shape`. May be only partially defined.
Returns:
batch shape
"""
return self._get_batch_shape
def event_shape(self, name="event_shape"):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
Args:
name: name to give to the op
Returns:
`Tensor` `event_shape`
"""
with ops.name_scope(self.name):
with ops.op_scope([self._mean], name):
return array_ops.gather(array_ops.shape(self._mean),
[array_ops.rank(self._mean) - 1])
def get_event_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `event_shape`. May be only partially defined.
Returns:
event shape
"""
return self._get_event_shape
def mean(self, name="mean"):
"""Mean of the distribution."""
with ops.name_scope(self.name):
return array_ops.identity(self._mean, name=name)
def variance(self, name="variance"):
"""Variance of the distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._n, self._p, self._mean], name):
p = array_ops.expand_dims(
self._p * array_ops.expand_dims(
array_ops.ones_like(self._n), -1), -1)
variance = -math_ops.batch_matmul(
array_ops.expand_dims(self._mean, -1), p, adj_y=True)
variance += array_ops.batch_matrix_diag(self._mean)
return variance
def log_prob(self, counts, name="log_prob"):
"""`Log(P[counts])`, computed for every batch member.
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
that after sampling `n` draws from this Multinomial distribution, the
number of draws falling in class `j` is `n_j`. Note that different
sequences of draws can result in the same counts, thus the probability
includes a combinatorial coefficient.
Args:
counts: Non-negative tensor with dtype `dtype` and whose shape can
be broadcast with `self.p` and `self.n`. For fixed leading dimensions,
the last dimension represents counts for the corresponding Multinomial
distribution in `self.p`. `counts` is only legal if it sums up to `n`
and its components are equal to integer values.
name: Name to give this Op, defaults to "log_prob".
Returns:
Log probabilities for each record, shape `[N1,...,Nm]`.
"""
n = self._n
p = self._p
with ops.name_scope(self.name):
with ops.op_scope([n, p, counts], name):
counts = self._check_counts(counts)
prob_prob = math_ops.reduce_sum(counts * math_ops.log(self._p),
reduction_indices=[-1])
log_prob = prob_prob + distribution_util.log_combinations(
n, counts)
return log_prob
def prob(self, counts, name="prob"):
"""`P[counts]`, computed for every batch member.
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
that after sampling `n` draws from this Multinomial distribution, the
number of draws falling in class `j` is `n_j`. Note that different
sequences of draws can result in the same counts, thus the probability
includes a combinatorial coefficient.
Args:
counts: Non-negative tensor with dtype `dtype` and whose shape can
be broadcast with `self.p` and `self.n`. For fixed leading dimensions,
the last dimension represents counts for the corresponding Multinomial
distribution in `self.p`. `counts` is only legal if it sums up to `n`
and its components are equal to integer values.
name: Name to give this Op, defaults to "prob".
Returns:
Probabilities for each record, shape `[N1,...,Nm]`.
"""
return super(Multinomial, self).prob(counts, name=name)
@property
def is_continuous(self):
return False
@property
def is_reparameterized(self):
return False
def _check_counts(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
if not self.validate_args:
return counts
return control_flow_ops.with_dependencies([
check_ops.assert_non_negative(
counts, message="counts has negative components."),
check_ops.assert_equal(
self._n, candidate_n, message="counts do not sum to n."),
distribution_util.assert_integer_form(
counts, message="counts have non-integer components.")], counts)

View File

@ -21,9 +21,11 @@ from __future__ import print_function
import math
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
from tensorflow.contrib.distributions.python.ops import operator_pd_full
from tensorflow.contrib.distributions.python.ops import operator_pd_vdvt_update
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
@ -40,6 +42,7 @@ __all__ = [
"MultivariateNormalDiag",
"MultivariateNormalCholesky",
"MultivariateNormalFull",
"MultivariateNormalDiagPlusVDVT",
]
@ -52,14 +55,13 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
#### Mathematical details
The PDF of this distribution is:
With `C` the covariance matrix represented by the operator, the PDF of this
distribution is:
```
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
```
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
@ -103,16 +105,16 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
which determines the covariance.
Args:
mu: `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
cov: `float` or `double` instance of `OperatorPDBase` with same `dtype`
as `mu` and shape `[N1,...,Nb, k, k]`.
mu: Floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
cov: Instance of `OperatorPDBase` with same `dtype` as `mu` and shape
`[N1,...,Nb, k, k]`.
validate_args: Whether to validate input with asserts. If `validate_args`
is `False`, and the inputs are invalid, correct behavior is not
guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
@ -148,7 +150,7 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
else:
return mu
# Static checks could not be run, so possibly do dyamic checks.
# Static checks could not be run, so possibly do dynamic checks.
if not self.validate_args:
return mu
else:
@ -170,12 +172,12 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
@property
def validate_args(self):
"""Boolean describing behavior on invalid input."""
"""`Boolean` describing behavior on invalid input."""
return self._validate_args
@property
def allow_nan_stats(self):
"""Boolean describing behavior when a stat is undefined for batch member."""
"""`Boolean` describing behavior when stats are undefined."""
return self._allow_nan_stats
@property
@ -417,7 +419,7 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
determined by `diag_stdev`: `C_{ii} = diag_stdev[i]**2`.
```
f(x) = (2*pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 * (x - mu)^T C^{-1} (x - mu))
f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
```
#### Examples
@ -464,17 +466,17 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stdev[i]`.
Args:
mu: Rank `N + 1` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
mu: Rank `N + 1` floating point tensor with shape `[N1,...,Nb, k]`,
`b >= 0`.
diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`,
representing the standard deviations.
representing the standard deviations. Must be positive.
validate_args: Whether to validate input with asserts. If `validate_args`
is `False`,
and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
@ -487,6 +489,125 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
name=name)
class MultivariateNormalDiagPlusVDVT(MultivariateNormalOperatorPD):
"""The multivariate normal distribution on `R^k`.
Every batch member of this distribution is defined by a mean and a lightweight
covariance matrix `C`.
#### Mathematical details
The PDF of this distribution in terms of the mean `mu` and covariance `C` is:
```
f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
```
For every batch member, this distribution represents `k` random variables
`(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix
`C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]`
The user initializes this class by providing the mean `mu`, and a lightweight
definition of `C`:
```
C = SS^T = SS = (M + V D V^T) (M + V D V^T)
M is diagonal (k x k)
V = is shape (k x r), typically r << k
D = is diagonal (r x r), optional (defaults to identity).
```
This allows for `O(kr + r^3)` pdf evaluation and determinant, and `O(kr)`
sampling and storage (per batch member).
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
of length `k`, and square root of the covariance `S = M + V D V^T`. Extra
leading dimensions, if provided, allow for batches.
```python
# Initialize a single 3-variate Gaussian with covariance square root
# S = M + V D V^T, where V D V^T is a matrix-rank 2 update.
mu = [1, 2, 3.]
diag_large = [1.1, 2.2, 3.3]
v = ... # shape 3 x 2
diag_small = [4., 5.]
dist = tf.contrib.distributions.MultivariateNormalDiagPlusVDVT(
mu, diag_large, v, diag_small=diag_small)
# Evaluate this on an observation in R^3, returning a scalar.
dist.pdf([-1, 0, 1])
# Initialize a batch of two 3-variate Gaussians. This time, don't provide
# diag_small. This means S = M + V V^T.
mu = [[1, 2, 3], [11, 22, 33]] # shape 2 x 3
diag_large = ... # shape 2 x 3
v = ... # shape 2 x 3 x 1, a matrix-rank 1 update.
dist = tf.contrib.distributions.MultivariateNormalDiagPlusVDVT(
mu, diag_large, v)
# Evaluate this on a two observations, each in R^3, returning a length two
# tensor.
x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3.
dist.pdf(x)
```
"""
def __init__(
self,
mu,
diag_large,
v,
diag_small=None,
validate_args=True,
allow_nan_stats=False,
name="MultivariateNormalDiagPlusVDVT"):
"""Multivariate Normal distributions on `R^k`.
For every batch member, this distribution represents `k` random variables
`(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix
`C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]`
The user initializes this class by providing the mean `mu`, and a
lightweight definition of `C`:
```
C = SS^T = SS = (M + V D V^T) (M + V D V^T)
M is diagonal (k x k)
V = is shape (k x r), typically r << k
D = is diagonal (r x r), optional (defaults to identity).
```
Args:
mu: Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`,
`n >= 0`. The means.
diag_large: Optional rank `n + 1` floating point tensor, shape
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`.
v: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]`
`n >= 0`. Defines the matrix `V`.
diag_small: Rank `n + 1` floating point tensor, shape
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default
is `None`, which means `D` will be the identity matrix.
validate_args: Whether to validate input with asserts. If `validate_args`
is `False`,
and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
"""
m = operator_pd_diag.OperatorPDDiag(diag_large, verify_pd=validate_args)
cov = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
m, v, diag=diag_small, verify_pd=validate_args,
verify_shapes=validate_args)
super(MultivariateNormalDiagPlusVDVT, self).__init__(
mu, cov, allow_nan_stats=allow_nan_stats, validate_args=validate_args,
name=name)
class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
"""The multivariate normal distribution on `R^k`.
@ -496,14 +617,14 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
#### Mathematical details
The PDF of this distribution is:
The Cholesky factor `chol` defines the covariance matrix: `C = chol chol^T`.
The PDF of this distribution is then:
```
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
```
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
@ -546,20 +667,21 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
"""Multivariate Normal distributions on `R^k`.
User must provide means `mu` and `chol` which holds the (batch) Cholesky
factors `S`, such that the covariance of each batch member is `S S^*`.
factors, such that the covariance of each batch member is `chol chol^T`.
Args:
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
`b >= 0`.
chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
`[N1,...,Nb, k, k]`.
`[N1,...,Nb, k, k]`. The upper triangular part is ignored (treated as
though it is zero), and the diagonal must be positive.
validate_args: Whether to validate input with asserts. If `validate_args`
is `False`,
and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
is `False`, and the inputs are invalid, correct behavior is not
guaranteed.
allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
@ -582,14 +704,12 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
#### Mathematical details
The PDF of this distribution is:
With `C = sigma`, the PDF of this distribution is:
```
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
```
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
@ -630,17 +750,17 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
User must provide means `mu` and `sigma`, the mean and covariance.
Args:
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
`b >= 0`.
sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
`[N1,...,Nb, k, k]`.
`[N1,...,Nb, k, k]`. Each batch member must be positive definite.
validate_args: Whether to validate input with asserts. If `validate_args`
is `False`, and the inputs are invalid, correct behavior is not
guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
@ -653,3 +773,72 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
allow_nan_stats=allow_nan_stats,
validate_args=validate_args,
name=name)
def _kl_mvn_mvn_brute_force(mvn_a, mvn_b, name=None):
"""Batched KL divergence `KL(mvn_a || mvn_b)` for multivariate normals.
With `X`, `Y` both multivariate normals in `R^k` with means `mu_x`, `mu_y` and
covariance `C_x`, `C_y` respectively,
```
KL(X || Y) = 0.5 * ( T + Q + - k + L ),
T := trace(C_b^{-1} C_a),
Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
L := Log[Det(C_b)] - Log[Det(C_a)]
```
This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
methods for solving systems with `C_b` may be available, a dense version of
(the square root of) `C_a` is used, so performance is `O(B s k^2)` where `B`
is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
and `y`.
Args:
mvn_a: Instance of subclass of `MultivariateNormalOperatorPD`.
mvn_b: Instance of subclass of `MultivariateNormalOperatorPD`.
name: (optional) name to use for created ops. Default "kl_mvn_mvn".
Returns:
Batchwise `KL(mvn_a || mvn_b)`.
"""
# Access the "private" OperatorPD that each mvn is built from.
cov_a = mvn_a._cov # pylint: disable=protected-access
cov_b = mvn_b._cov # pylint: disable=protected-access
mu_a = mvn_a.mu
mu_b = mvn_b.mu
inputs = [mu_a, mu_b] + cov_a.inputs + cov_b.inputs
with ops.op_scope(inputs, name, "kl_mvn_mvn"):
# If Ca = AA', Cb = BB', then
# tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
# = tr[inv(B) A A' inv(B)']
# = tr[(inv(B) A) (inv(B) A)']
# = sum_{ik} (inv(B) A)_{ik}^2
# The second equality follows from the cyclic permutation property.
b_inv_a = cov_b.sqrt_solve(cov_a.sqrt_to_dense())
t = math_ops.reduce_sum(
math_ops.square(b_inv_a),
reduction_indices=[-1, -2])
q = cov_b.inv_quadratic_form_on_vectors(mu_b - mu_a)
k = math_ops.cast(cov_a.vector_space_dimension(), mvn_a.dtype)
one_half_l = cov_b.sqrt_log_det() - cov_a.sqrt_log_det()
return 0.5 * (t + q - k) + one_half_l
# Register KL divergences.
kl_classes = [
MultivariateNormalFull,
MultivariateNormalCholesky,
MultivariateNormalDiag,
MultivariateNormalDiagPlusVDVT,
]
for mvn_aa in kl_classes:
# Register when they are the same here, and do not register when they are the
# same below because that would result in a repeated registration.
kullback_leibler.RegisterKL(mvn_aa, mvn_aa)(_kl_mvn_mvn_brute_force)
for mvn_bb in kl_classes:
if mvn_bb != mvn_aa:
kullback_leibler.RegisterKL(mvn_aa, mvn_bb)(_kl_mvn_mvn_brute_force)

View File

@ -92,15 +92,15 @@ class Normal(distribution.Distribution):
broadcasting (e.g. `mu + sigma` is a valid operation).
Args:
mu: `float` or `double` tensor, the means of the distribution(s).
sigma: `float` or `double` tensor, the stddevs of the distribution(s).
mu: Floating point tensor, the means of the distribution(s).
sigma: Floating point tensor, the stddevs of the distribution(s).
sigma must contain only positive values.
validate_args: Whether to assert that `sigma > 0`. If `validate_args` is
False, correct output is not guaranteed when input is invalid.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
`False`, correct output is not guaranteed when input is invalid.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
@ -321,8 +321,7 @@ class Normal(distribution.Distribution):
with ops.op_scope([self._mu, self._sigma, n], name):
broadcast_shape = (self._mu + self._sigma).get_shape()
n = ops.convert_to_tensor(n)
shape = array_ops.concat(
0, [array_ops.pack([n]), array_ops.shape(self.mean())])
shape = array_ops.concat(0, ([n], array_ops.shape(self.mean())))
sampled = random_ops.random_normal(
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.contrib.distributions.python.ops import operator_pd
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -26,11 +29,190 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
class OperatorPDSqrtDiag(operator_pd.OperatorPDBase):
@six.add_metaclass(abc.ABCMeta)
class OperatorPDDiagBase(operator_pd.OperatorPDBase):
"""Base class for diagonal operators."""
def __init__(self, diag, verify_pd=True, name='OperatorPDDiagBase'):
self._verify_pd = verify_pd
self._name = name
with ops.name_scope(name):
with ops.op_scope([diag], 'init'):
self._diag = self._check_diag(diag)
def _check_diag(self, diag):
"""Verify that `diag` is positive."""
diag = ops.convert_to_tensor(diag, name='diag')
if not self.verify_pd:
return diag
deps = [check_ops.assert_positive(diag)]
return control_flow_ops.with_dependencies(deps, diag)
@property
def name(self):
"""String name identifying this `Operator`."""
return self._name
@property
def verify_pd(self):
"""Whether to verify that this `Operator` is positive definite."""
return self._verify_pd
@property
def dtype(self):
"""Data type of matrix elements of `A`."""
return self._diag.dtype
@property
def inputs(self):
"""Initialization arguments."""
return [self._diag]
def get_shape(self):
"""`TensorShape` giving static shape."""
# If d_shape = [5, 3], we return [5, 3, 3].
d_shape = self._diag.get_shape()
return d_shape.concatenate(d_shape[-1:])
def _shape(self):
d_shape = array_ops.shape(self._diag)
k = array_ops.gather(d_shape, array_ops.size(d_shape) - 1)
return array_ops.concat(0, (d_shape, [k]))
@abc.abstractmethod
def _batch_log_det(self):
pass
@abc.abstractmethod
def _inv_quadratic_form_on_vectors(self, x):
pass
@abc.abstractmethod
def _batch_matmul(self, x, transpose_x=False):
pass
@abc.abstractmethod
def _batch_sqrt_matmul(self, x, transpose_x=False):
pass
@abc.abstractmethod
def _batch_solve(self, rhs):
pass
@abc.abstractmethod
def _batch_sqrt_solve(self, rhs):
pass
@abc.abstractmethod
def _to_dense(self):
pass
@abc.abstractmethod
def _sqrt_to_dense(self):
pass
@abc.abstractmethod
def _add_to_tensor(self, mat):
pass
class OperatorPDDiag(OperatorPDDiagBase):
"""Class representing a (batch) of positive definite matrices `A`.
This class provides access to functions of a batch of symmetric positive
definite (PD) matrices `A` in `R^{k x k}` defined by their their square root,
definite (PD) matrices `A` in `R^{k x k}`.
In this case, `A` is diagonal and is defined by a provided tensor `diag`,
`A_{ii} = diag[i]`.
Determinants, solves, and storage are `O(k)`.
In practice, this operator represents a (batch) matrix `A` with shape
`[N1,...,Nn, k, k]` for some `n >= 0`. The first `n` indices designate a
batch member. For every batch member `(i1,...,ib)`, `A[i1,...,ib, : :]` is
a `k x k` matrix.
For example,
```python
distributions = tf.contrib.distributions
diag = [1.0, 2.0]
operator = OperatorPDDiag(diag)
operator.det() # ==> (1 * 2)
# Compute the quadratic form x^T A^{-1} x for vector x.
x = [1.0, 2.0]
operator.inv_quadratic_form_on_vectors(x)
# Matrix multiplication by the square root, S w, with A = S S^T.
# Recall A is diagonal, and so then is S, with S_{ij} = sqrt(A_{ij}).
# If w is iid normal, S w has covariance A.
w = [[1.0],
[2.0]]
operator.sqrt_matmul(w)
```
The above three methods, `log_det`, `inv_quadratic_form_on_vectors`, and
`sqrt_matmul` provide "all" that is necessary to use a covariance matrix
in a multi-variate normal distribution. See the class
`MultivariateNormalDiag`.
"""
def __init__(self, diag, verify_pd=True, name='OperatorPDDiag'):
"""Initialize an OperatorPDDiag.
Args:
diag: Shape `[N1,...,Nn, k]` positive tensor with `n >= 0`, `k >= 1`.
verify_pd: Whether to check `diag` is positive.
name: A name to prepend to all ops created by this class.
"""
super(OperatorPDDiag, self).__init__(
diag, verify_pd=verify_pd, name=name)
def _batch_log_det(self):
return math_ops.reduce_sum(
math_ops.log(self._diag), reduction_indices=[-1])
def _inv_quadratic_form_on_vectors(self, x):
return self._iqfov_via_solve(x)
def _batch_matmul(self, x, transpose_x=False):
if transpose_x:
x = array_ops.batch_matrix_transpose(x)
diag_mat = array_ops.expand_dims(self._diag, -1)
return diag_mat * x
def _batch_sqrt_matmul(self, x, transpose_x=False):
if transpose_x:
x = array_ops.batch_matrix_transpose(x)
diag_mat = array_ops.expand_dims(self._diag, -1)
return math_ops.sqrt(diag_mat) * x
def _batch_solve(self, rhs):
diag_mat = array_ops.expand_dims(self._diag, -1)
return rhs / diag_mat
def _batch_sqrt_solve(self, rhs):
diag_mat = array_ops.expand_dims(self._diag, -1)
return rhs / math_ops.sqrt(diag_mat)
def _to_dense(self):
return array_ops.batch_matrix_diag(self._diag)
def _sqrt_to_dense(self):
return array_ops.batch_matrix_diag(math_ops.sqrt(self._diag))
def _add_to_tensor(self, mat):
mat_diag = array_ops.batch_matrix_diag_part(mat)
new_diag = self._diag + mat_diag
return array_ops.batch_matrix_set_diag(mat, new_diag)
class OperatorPDSqrtDiag(OperatorPDDiagBase):
"""Class representing a (batch) of positive definite matrices `A`.
This class provides access to functions of a batch of symmetric positive
definite (PD) matrices `A` in `R^{k x k}` defined by their square root,
`S`, such that `A = SS^T`.
In this case, `S` is diagonal and is defined by a provided tensor `diag`,
@ -75,58 +257,17 @@ class OperatorPDSqrtDiag(operator_pd.OperatorPDBase):
verify_pd: Whether to check `diag` is positive.
name: A name to prepend to all ops created by this class.
"""
self._verify_pd = verify_pd
self._name = name
with ops.name_scope(name):
with ops.op_scope([diag], 'init'):
self._diag = self._check_diag(diag)
def _check_diag(self, diag):
"""Verify that `diag` is positive."""
diag = ops.convert_to_tensor(diag, name='diag')
if not self.verify_pd:
return diag
deps = [check_ops.assert_positive(diag)]
return control_flow_ops.with_dependencies(deps, diag)
@property
def name(self):
"""String name identifying this `Operator`."""
return self._name
@property
def verify_pd(self):
"""Whether to verify that this `Operator` is positive definite."""
return self._verify_pd
@property
def dtype(self):
"""Data type of matrix elements of `A`."""
return self._diag.dtype
super(OperatorPDSqrtDiag, self).__init__(
diag, verify_pd=verify_pd, name=name)
def _batch_log_det(self):
return 2 * math_ops.reduce_sum(
math_ops.log(self._diag), reduction_indices=[-1])
@property
def inputs(self):
"""List of tensors that were provided as initialization inputs."""
return [self._diag]
def _inv_quadratic_form_on_vectors(self, x):
# This Operator is defined in terms of diagonal entries of the sqrt.
return self._iqfov_via_sqrt_solve(x)
def get_shape(self):
"""`TensorShape` giving static shape."""
d_shape = self._diag.get_shape()
return d_shape.concatenate(d_shape[-1:])
def _shape(self):
d_shape = array_ops.shape(self._diag)
k = array_ops.gather(d_shape, array_ops.size(d_shape) - 1)
return array_ops.concat(0, (d_shape, [k]))
def _batch_matmul(self, x, transpose_x=False):
if transpose_x:
x = array_ops.batch_matrix_transpose(x)

View File

@ -0,0 +1,207 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Identity operator in `R^k`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import operator_pd
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
class OperatorPDIdentity(operator_pd.OperatorPDBase):
"""Identity operator in `R^k`: `Ax = x`.
This provides an efficient implementation of the identity as an `OperatorPD`.
Storage, solves, and matmul are all `O(1)`, independent of batch size.
In order to be a drop-in replacement for other operators, shape and dtype
of arguments (e.g. to `matmul`) are checked statically as though this operator
was an instantiated matrix.
Dynamic shape checks of arguments are not done since that could impede
performance.
"""
def __init__(self, shape, dtype, verify_pd=True, name='OperatorPDIdentity'):
"""Initialize an `OperatorPDIdentity`.
Args:
shape: `int32` rank 1 `Tensor` of length at least 2, and with the last
two entries equal (since this is a square matrix).
dtype: Data type of the matrix that this operator represents.
verify_pd: `Boolean`, if `True`, asserts are added to the initialization
args to ensure they define this operator as a square (batch) matrix.
name: Name to prepend to `Ops`.
"""
# Grab static shape if available now.
with ops.name_scope(name):
with ops.op_scope([shape], 'init'):
self._dtype = dtypes.as_dtype(dtype)
self._verify_pd = verify_pd
self._name = name
# Store the static shape (if possible) right now before adding the
# asserts, since the asserts prevent .constant_value from working.
shape = ops.convert_to_tensor(shape, name='shape')
self._get_shape = tensor_shape.TensorShape(
tensor_util.constant_value(shape))
self._shape_arg = self._check_shape(shape)
def _check_shape(self, shape):
"""Check that the init arg `shape` defines a valid operator."""
shape = ops.convert_to_tensor(shape, name='shape')
if not self._verify_pd:
return shape
# Further checks are equivalent to verification that this is positive
# definite. Why? Because the further checks simply check that this is a
# square matrix, and combining the fact that this is square (and thus maps
# a vector space R^k onto itself), with the behavior of .matmul(), this must
# be the identity operator.
rank = array_ops.size(shape)
assert_matrix = check_ops.assert_less_equal(2, rank)
with ops.control_dependencies([assert_matrix]):
last_dim = array_ops.gather(shape, rank - 1)
second_to_last_dim = array_ops.gather(shape, rank - 2)
assert_square = check_ops.assert_equal(last_dim, second_to_last_dim)
return control_flow_ops.with_dependencies([assert_matrix, assert_square],
shape)
def _check_x(self, x):
"""Static check that the argument `x` is proper `shape`, `dtype`."""
# x is a typical argument e.g. to matmul or solve. In both cases, x should
# have the same type/shape since this is a square matrix. These checks are
# ususally not needed since we ususally have some tensor backing this
# distribution, and the calls to tf.matmul do a shape/type check.
#
# Static checks only for efficiency, the identity should be fast.
#
# Why check at all? Because we want this operator to be swappable for a
# real Operator.
if self.dtype != x.dtype:
raise TypeError(
'Expected argument "x" to have same dtype as this operator (%s). '
'Found: %s' % (self.dtype, x.dtype))
x_shape = x.get_shape()
self_shape = self.get_shape()
found_msg = (
'Found: operator.shape = %s, x.shape = %s' % (self_shape, x_shape))
if x_shape.ndims is not None and self_shape.ndims is not None:
if x_shape.ndims != self_shape.ndims:
raise ValueError(
'Expected argument "x" to have same tensor rank as this operator. '
+ found_msg)
if x_shape.is_fully_defined() and self_shape.is_fully_defined():
if x_shape[-2] != self_shape[-1]:
raise ValueError(
'Incompatible shapes for matrix-matrix operation. ' + found_msg)
@property
def name(self):
"""String name identifying this `Operator`."""
return self._name
@property
def verify_pd(self):
"""Whether to verify that this `Operator` is positive definite."""
return self._verify_pd
@property
def dtype(self):
"""Data type of matrix elements of `A`."""
return self._dtype
def _add_to_tensor(self, mat):
# Add to a tensor in O(k) time!
mat_diag = array_ops.batch_matrix_diag_part(mat)
new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag
return array_ops.batch_matrix_set_diag(mat, new_diag)
def _inv_quadratic_form_on_vectors(self, x):
self._check_x(x)
return self._iqfov_via_sqrt_solve(x)
@property
def inputs(self):
"""List of tensors that were provided as initialization inputs."""
return [self._shape]
def get_shape(self):
"""Static `TensorShape` of entire operator.
If this operator represents the batch matrix `A` with
`A.shape = [N1,...,Nn, k, k]`, then this returns
`TensorShape([N1,...,Nn, k, k])`
Returns:
`TensorShape`, statically determined, may be undefined.
"""
return self._get_shape
def _shape(self):
return self._shape_arg
def _det(self):
det = array_ops.ones(self.batch_shape(), dtype=self.dtype)
det.set_shape(self.get_batch_shape())
return det
def _batch_log_det(self):
log_det = array_ops.zeros(self.batch_shape(), dtype=self.dtype)
log_det.set_shape(self.get_batch_shape())
return log_det
def _batch_sqrt_log_det(self):
s_log_det = array_ops.zeros(self.batch_shape(), dtype=self.dtype)
s_log_det.set_shape(self.get_batch_shape())
return s_log_det
def _batch_matmul(self, x, transpose_x=False):
if transpose_x:
x = array_ops.batch_matrix_transpose(x)
self._check_x(x)
return x
def _batch_sqrt_matmul(self, x, transpose_x=False):
return self._batch_matmul(x, transpose_x=transpose_x)
def _batch_solve(self, rhs):
self._check_x(rhs)
return rhs
def _batch_sqrt_solve(self, rhs):
self._check_x(rhs)
return rhs
def _to_dense(self):
diag = array_ops.ones(self.vector_shape(), dtype=self.dtype)
dense = array_ops.batch_matrix_diag(diag)
dense.set_shape(self.get_shape())
return dense
def _sqrt_to_dense(self):
return self.to_dense()

View File

@ -0,0 +1,475 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operator defined: `A = SS^T` where `S = M + VDV^T`, for `OperatorPD` `M`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import operator_pd
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
from tensorflow.contrib.distributions.python.ops import operator_pd_identity
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
class OperatorPDSqrtVDVTUpdate(operator_pd.OperatorPDBase):
r"""Operator defined by `A=SS^T`, where `S = M + VDV^T` for `OperatorPD` `M`.
This provides efficient low-rank updates of arbitrary `OperatorPD`.
Some math:
Given positive definite operator representing positive definite (batch) matrix
`M` in `R^{k x k}`, diagonal matrix `D` in `R^{r x r}`, and low rank `V` in
`R^{k x r}` this class represents the batch matrix `A`, defined by its square
root `S` as follows:
```
A = SS^T, where
S := M + VDV^T
```
Defining an operator in terms of its square root means that
`A_{ij} = S_i S_j^T`, where `S_i` is the ith row of `S`. The update
`VDV^T` has `ij` coordinate equal to `sum_k V_{ik} D_{kk} V_{jk}`.
Computational efficiency:
Defining `A` via its square root eliminates the need to compute the square
root.
Performance depends on the operator representing `M`, the batch size `B`, and
the width of the matrix being multiplied, or systems being solved `L`.
Since `V` is rank `r`, the update adds
* `O(B L k r)` to matmul, which requires a call to `M.matmul`.
* `O(B L r^3)` to solves, which require a call to `M.solve` as well as the
solution to a batch of rank `r` systems.
* `O(B r^3)` to determinants, which require a call to `M.solve` as well as the
solution to a batch of rank `r` systems.
The rank `r` solve and determinant are both done through a Cholesky
factorization, thus some computation is shared.
See
https://en.wikipedia.org/wiki/Woodbury_matrix_identity
https://en.wikipedia.org/wiki/Matrix_determinant_lemma
"""
# Note that diag must be nonsingular to use Woodbury lemma, and must be
# positive def to use a Cholesky factorization, so we enforce that here.
def __init__(self,
operator,
v,
diag=None,
verify_pd=True,
verify_shapes=True,
name='OperatorPDSqrtVDVTUpdate'):
"""Initialize an `OperatorPDSqrtVDVTUpdate`.
Args:
operator: Subclass of `OperatorPDBase`. Represents the (batch) positive
definite matrix `M` in `R^{k x k}`.
v: `Tensor` defining batch matrix of same `dtype` and `batch_shape` as
`operator`, and last two dimensions of shape `(k, r)`.
diag: Optional `Tensor` defining batch vector of same `dtype` and
`batch_shape` as `operator`, and last dimension of size `r`. If `None`,
the update becomes `VV^T` rather than `VDV^T`.
verify_pd: `Boolean`. If `True`, add asserts that `diag > 0`, which,
along with the positive definiteness of `operator`, is sufficient to
make the resulting operator positive definite.
verify_shapes: `Boolean`. If `True`, check that `operator`, `v`, and
`diag` have compatible shapes.
name: A name to prepend to `Op` names.
"""
if not isinstance(operator, operator_pd.OperatorPDBase):
raise TypeError('operator was not instance of OperatorPDBase.')
with ops.name_scope(name):
with ops.op_scope(operator.inputs + [v, diag], 'init'):
self._operator = operator
self._v = ops.convert_to_tensor(v, name='v')
self._verify_pd = verify_pd
self._verify_shapes = verify_shapes
self._name = name
# This operator will be PD so long as the diag is PSD, but Woodbury
# and determinant lemmas require diag to be PD. So require diag PD
# whenever we ask to "verify_pd".
if diag is not None:
self._diag = ops.convert_to_tensor(diag, name='diag')
self._diag_operator = operator_pd_diag.OperatorPDDiag(
diag, verify_pd=self.verify_pd)
# No need to verify that the inverse of a PD is PD.
self._diag_inv_operator = operator_pd_diag.OperatorPDDiag(
1 / self._diag, verify_pd=False)
else:
self._diag = None
self._diag_operator = self._get_identity_operator(self._v)
self._diag_inv_operator = self._diag_operator
self._check_types(operator, self._v, self._diag)
# Always check static.
checked = self._check_shapes_static(operator, self._v, self._diag)
if not checked and self._verify_shapes:
self._v, self._diag = self._check_shapes_dynamic(
operator, self._v, self._diag)
def _get_identity_operator(self, v):
"""Get an `OperatorPDIdentity` to play the role of `D` in `VDV^T`."""
with ops.op_scope([v], 'get_identity_operator'):
if v.get_shape().is_fully_defined():
v_shape = v.get_shape().as_list()
v_batch_shape = v_shape[:-2]
r = v_shape[-1]
id_shape = v_batch_shape + [r, r]
else:
v_shape = array_ops.shape(v)
v_rank = array_ops.rank(v)
v_batch_shape = array_ops.slice(v_shape, [0], [v_rank - 2])
r = array_ops.gather(v_shape, v_rank - 1) # Last dim of v
id_shape = array_ops.concat(0, (v_batch_shape, [r, r]))
return operator_pd_identity.OperatorPDIdentity(
id_shape, v.dtype, verify_pd=self._verify_pd)
def _check_types(self, operator, v, diag):
def msg():
string = (
'dtypes must match: Found operator.dtype = %s, v.dtype = %s'
% (operator.dtype, v.dtype))
return string
if operator.dtype != v.dtype:
raise TypeError(msg())
if diag is not None:
if diag.dtype != v.dtype:
raise TypeError('%s, diag.dtype = %s' % (msg(), diag.dtype))
def _check_shapes_static(self, operator, v, diag):
"""True if they are compatible. Raise if not. False if could not check."""
def msg():
# Error message when shapes don't match.
string = ' Found: operator.shape = %s, v.shape = %s' % (s_op, s_v)
if diag is not None:
string += ', diag.shape = ' % s_d
return string
s_op = operator.get_shape()
s_v = v.get_shape()
# If everything is not fully defined, return False because we couldn't check
if not (s_op.is_fully_defined() and s_v.is_fully_defined()):
return False
if diag is not None:
s_d = diag.get_shape()
if not s_d.is_fully_defined():
return False
# Now perform the checks, raising ValueError if they fail.
# Check tensor rank.
if s_v.ndims != s_op.ndims:
raise ValueError('v should have same rank as operator' + msg())
if diag is not None:
if s_d.ndims != s_op.ndims - 1:
raise ValueError('diag should have rank 1 less than operator' + msg())
# Check batch shape
if s_v[:-2] != s_op[:-2]:
raise ValueError('v and operator should have same batch shape' + msg())
if diag is not None:
if s_d[:-1] != s_op[:-2]:
raise ValueError(
'diag and operator should have same batch shape' + msg())
# Check event shape
if s_v[-2] != s_op[-1]:
raise ValueError(
'v and operator should be compatible for matmul' + msg())
if diag is not None:
if s_d[-1] != s_v[-1]:
raise ValueError('diag and v should have same last dimension' + msg())
return True
def _check_shapes_dynamic(self, operator, v, diag):
"""Return (v, diag) with Assert dependencies, which check shape."""
checks = []
with ops.op_scope([operator, v, diag], 'check_shapes'):
s_v = array_ops.shape(v)
r_op = operator.rank()
r_v = array_ops.rank(v)
if diag is not None:
s_d = array_ops.shape(diag)
r_d = array_ops.rank(diag)
# Check tensor rank.
checks.append(check_ops.assert_rank(v, r_op))
if diag is not None:
checks.append(check_ops.assert_rank(diag, r_op - 1))
# Check batch shape
checks.append(check_ops.assert_equal(
operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2])))
if diag is not None:
checks.append(check_ops.assert_equal(
operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1])))
# Check event shape
checks.append(check_ops.assert_equal(
operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2)))
if diag is not None:
checks.append(check_ops.assert_equal(
array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1)))
v = control_flow_ops.with_dependencies(checks, v)
if diag is not None:
diag = control_flow_ops.with_dependencies(checks, diag)
return v, diag
@property
def name(self):
"""String name identifying this `Operator`."""
return self._name
@property
def verify_pd(self):
"""Whether to verify that this `Operator` is positive definite."""
return self._verify_pd
@property
def dtype(self):
"""Data type of matrix elements of `A`."""
return self._v.dtype
def _inv_quadratic_form_on_vectors(self, x):
return self._iqfov_via_sqrt_solve(x)
@property
def inputs(self):
"""List of tensors that were provided as initialization inputs."""
return self._operator.inputs + self._diag_operator.inputs + [self._v]
def get_shape(self):
"""Static `TensorShape` of entire operator.
If this operator represents the batch matrix `A` with
`A.shape = [N1,...,Nn, k, k]`, then this returns
`TensorShape([N1,...,Nn, k, k])`
Returns:
`TensorShape`, statically determined, may be undefined.
"""
return self._operator.get_shape()
def _shape(self):
return self._operator.shape()
def _det(self):
return math_ops.exp(self.log_det())
def _batch_log_det(self):
return 2 * self._batch_sqrt_log_det()
def _log_det(self):
return 2 * self._sqrt_log_det()
def _sqrt_log_det(self):
# The matrix determinant lemma states:
# det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M)
# = det(C) * det(D) * det(M)
#
# Here we compute the Cholesky factor of "C", then pass the result on.
diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance(
batch_mode=False))
return self._sqrt_log_det_core(diag_chol_c)
def _batch_sqrt_log_det(self):
# Here we compute the Cholesky factor of "C", then pass the result on.
diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance(
batch_mode=True))
return self._sqrt_log_det_core(diag_chol_c)
def _chol_capacitance(self, batch_mode):
"""Cholesky factorization of the capacitance term."""
# Cholesky factor for (D^{-1} + V^T M^{-1} V), which is sometimes
# known as the "capacitance" matrix.
# self._operator will use batch if need be. Automatically. We cannot force
# that here.
# M^{-1} V
minv_v = self._operator.solve(self._v)
# V^T M^{-1} V
if batch_mode:
vt_minv_v = math_ops.batch_matmul(self._v, minv_v, adj_x=True)
else:
vt_minv_v = math_ops.matmul(self._v, minv_v, transpose_a=True)
# D^{-1} + V^T M^{-1} V
capacitance = self._diag_inv_operator.add_to_tensor(vt_minv_v)
# Cholesky[D^{-1} + V^T M^{-1} V]
if batch_mode:
return linalg_ops.batch_cholesky(capacitance)
else:
return linalg_ops.cholesky(capacitance)
def _sqrt_log_det_core(self, diag_chol_c):
"""Finish computation of Sqrt[Log[Det]]."""
# Complete computation of ._log_det and ._batch_log_det, after the initial
# Cholesky factor has been taken with the appropriate batch/non-batch method
# det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M)
# = det(C) * det(D) * det(M)
# Multiply by 2 here because this is the log-det of the Cholesky factor of C
log_det_c = 2 * math_ops.reduce_sum(
math_ops.log(diag_chol_c),
reduction_indices=[-1])
# Add together to get Log[det(M + VDV^T)], the Log-det of the updated square
# root.
log_det_updated_sqrt = (
log_det_c + self._diag_operator.log_det() + self._operator.log_det())
return log_det_updated_sqrt
def _batch_matmul(self, x, transpose_x=False):
# Since the square root is PD, it is symmetric, and so A = SS^T = SS.
s_x = self._batch_sqrt_matmul(x, transpose_x=transpose_x)
return self._batch_sqrt_matmul(s_x)
def _matmul(self, x, transpose_x=False):
# Since the square root is PD, it is symmetric, and so A = SS^T = SS.
s_x = self._sqrt_matmul(x, transpose_x=transpose_x)
return self._sqrt_matmul(s_x)
def _batch_sqrt_matmul(self, x, transpose_x=False):
v = self._v
m = self._operator
d = self._diag_operator
# The operators call the appropriate matmul/batch_matmul automatically. We
# cannot override.
# batch_matmul is defined as: x * y, so adj_x and adj_y are the ways to
# transpose the left and right.
mx = m.matmul(x, transpose_x=transpose_x)
vt_x = math_ops.batch_matmul(v, x, adj_x=True, adj_y=transpose_x)
d_vt_x = d.matmul(vt_x)
v_d_vt_x = math_ops.batch_matmul(v, d_vt_x)
return mx + v_d_vt_x
def _sqrt_matmul(self, x, transpose_x=False):
v = self._v
m = self._operator
d = self._diag_operator
# The operators call the appropriate matmul/batch_matmul automatically. We
# cannot override.
# matmul is defined as: a * b, so transpose_a, transpose_b are used.
# transpose the left and right.
mx = m.matmul(x, transpose_x=transpose_x)
vt_x = math_ops.matmul(v, x, transpose_a=True, transpose_b=transpose_x)
d_vt_x = d.matmul(vt_x)
v_d_vt_x = math_ops.matmul(v, d_vt_x)
return mx + v_d_vt_x
def _solve(self, rhs):
# This operator represents A = SS^T, but S is symmetric, so A = SS,
# which means A^{-1} = S^{-1}S^{-2}
# S^{-1} rhs
sqrtinv_rhs = self._sqrt_solve(rhs)
return self._sqrt_solve(sqrtinv_rhs)
def _batch_solve(self, rhs):
sqrtinv_rhs = self._batch_sqrt_solve(rhs)
return self._batch_sqrt_solve(sqrtinv_rhs)
def _sqrt_solve(self, rhs):
# Recall the square root of this operator is M + VDV^T.
# The Woodbury formula gives:
# (M + VDV^T)^{-1}
# = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1}
# = M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
# where C is the capacitance matrix.
# TODO(jvdillon) Determine if recursively applying rank-1 updates is more
# efficient. May not be possible because a general n x n matrix can be
# represeneted as n rank-1 updates, and solving with this matrix is always
# done in O(n^3) time.
m = self._operator
v = self._v
cchol = self._chol_capacitance(batch_mode=False)
# The operators will use batch/singleton mode automatically. We don't
# override.
# M^{-1} rhs
minv_rhs = m.solve(rhs)
# V^T M^{-1} rhs
vt_minv_rhs = math_ops.matmul(v, minv_rhs, transpose_a=True)
# C^{-1} V^T M^{-1} rhs
cinv_vt_minv_rhs = linalg_ops.cholesky_solve(cchol, vt_minv_rhs)
# V C^{-1} V^T M^{-1} rhs
v_cinv_vt_minv_rhs = math_ops.matmul(v, cinv_vt_minv_rhs)
# M^{-1} V C^{-1} V^T M^{-1} rhs
minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs)
# M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
return minv_rhs - minv_v_cinv_vt_minv_rhs
def _batch_sqrt_solve(self, rhs):
# Recall the square root of this operator is M + VDV^T.
# The Woodbury formula gives:
# (M + VDV^T)^{-1}
# = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1}
# = M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
# where C is the capacitance matrix.
m = self._operator
v = self._v
cchol = self._chol_capacitance(batch_mode=True)
# The operators will use batch/singleton mode automatically. We don't
# override.
# M^{-1} rhs
minv_rhs = m.solve(rhs)
# V^T M^{-1} rhs
vt_minv_rhs = math_ops.batch_matmul(v, minv_rhs, adj_x=True)
# C^{-1} V^T M^{-1} rhs
cinv_vt_minv_rhs = linalg_ops.batch_cholesky_solve(cchol, vt_minv_rhs)
# V C^{-1} V^T M^{-1} rhs
v_cinv_vt_minv_rhs = math_ops.batch_matmul(v, cinv_vt_minv_rhs)
# M^{-1} V C^{-1} V^T M^{-1} rhs
minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs)
# M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
return minv_rhs - minv_v_cinv_vt_minv_rhs
def _to_dense(self):
sqrt = self.sqrt_to_dense()
return math_ops.batch_matmul(sqrt, sqrt, adj_y=True)
def _sqrt_to_dense(self):
v = self._v
d = self._diag_operator
m = self._operator
d_vt = d.matmul(v, transpose_x=True)
# Batch op won't be efficient for singletons. Currently we don't break
# to_dense into batch/singleton methods.
v_d_vt = math_ops.batch_matmul(v, d_vt)
m_plus_v_d_vt = m.to_dense() + v_d_vt
return m_plus_v_d_vt

View File

@ -0,0 +1,396 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A helper class for inferring Distribution shape."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
class _ShapeUtil(object):
"""Class which helps infer/identify subsets of tensor dimensions.
Terminology:
Recall that a `Tensor` has:
shape: sizes of tensor dimensions,
ndims: size of shape; number of tensor dimensions,
dims: indexes into shape; useful for transpose, reduce.
Tensors sampled from a `Distribution` can be partitioned by:
sample dims: indexes independent, identically distributed (iid) draws,
batch dims: indexes non-identical draws,
event dims: indexes coordinates of a single draw.
The sample, batch, and event dimensions constitute the entirety of a
`Tensor` shape. The dimensions are always in sample, batch, event order.
Assumptions:
We assume that batch_ndims and event_ndims are statically known for both
creating this object and for inputs to its functions.
TODO(jvdillon): Relax this assumption and support fully unknown shape.
We also assume that the `Tensor` rank is static, i.e., `x.get_shape().ndims
is not None`.
Possible use-cases:
~ Sample dimensions:
Computing summary statistics, i.e., the average is a reduction over sample
dimensions.
~ Batch dimensions:
Log-likelihood under model predicted location:
```python
mu = ... # vector of predictions, one for each covariate.
neg_log_likelihood = -tf.reduce_mean(
Normal(loc=mu, scale=1).log_pdf(x),
reduce_dims=[0])
```
Monte Carlo estimation of a marginal probability:
Average over batch dimensions where batch dimensions are associated with
random draws of a prior.
E.g., suppose we want to find the Monte Carlo estimate of the marginal
distribution of a Normal with a random Laplace location:
```
P(X=x) = integral P(X=x|y) P(Y=y) dy
~= 1/n sum_{i=1}^n P(X=x|y_i), y_i ~iid Laplace(0,1)
= tf.reduce_mean(Normal(loc=Laplace(0, 1).sample_n(n=1000),
scale=tf.ones([1000, 1])).pdf(x),
reduce_dims=[0])
```
The `Laplace` distribution generates a tensor of shape [1000, 1]. When fed
to a `Normal`, this is interpreted as 1000 different locations, i.e.,
1000 non-identical Normals. Therefore a single call to pdf(x) yields 1000
probabilities, one for every location. The average over this batch yields
the marginal.
~ Event dimensions:
Computing the determinant of the Jacobian of a function of a random
variable involves a reduction over event dimensions.
Examples:
Write S, B, E for sample shape, batch shape, and event shape (resp.).
```python
x.get_shape() == S + B + E # For statically known x shape.
# 100 iid samples from one multivariate Normal with two
# degrees of freedom (DF).
mu = [0., 0]
sigma = [[1., 0],
[0, 1]]
X = MultivariateNormal(loc=mu, scale=sigma).sample_n(n=100)
# S = [100]
# B = []
# E = [2]
# 100 iid samples from one Wishart with 2x2 DF.
sigma = [[1., 0],
[0, 1]]
X = Wishart(scale=sigma).sample_n(n=100)
# S = [100]
# B = []
# E = [2, 2]
# 100 iid samples (with shape [2, 50]) from two, non-identical bivariate
# Normal distributions.
mu = ... # shape(2, 2)
sigma = ... # shape(2, 2, 2)
X = MultivariateNormal(loc=mu, scale=sigma).sample(shape=[2, 50])
# S = [2, 50]
# B = [2]
# E = [2]
```
"""
def __init__(self, batch_ndims=None, event_ndims=None, name='ShapeUtil'):
"""Construct ShapeUtil with known sample, batch, and/or event ndims.
Typically, batch_ndims and event_ndims are fixed throughout the lifetime of
a Distribution.
Args:
batch_ndims: number of dims (rank) of the batch portion of indexes of a
`Tensor`. A "batch" is a non-identical distribution, i.e, Normal with
different parameters.
event_ndims: number of dims (rank) of the event portion of indexes of a
`Tensor`. An "event" is what is sampled from a distribution, i.e., a
trivariate Normal has an event shape of [3] and a 4 dimensional Wishart
has an event shape of [4, 4].
name: `String`. The name to give Ops created by this class.
Raises:
ValueError: if batch_ndims or event_ndims are invalid.
"""
if batch_ndims < 0:
raise ValueError('must specify non-negative batch_ndims(%d)', batch_ndims)
if batch_ndims > 0 and event_ndims < 1:
raise ValueError('must specify positive event_ndims(%d) when '
'batch_ndims(%d) is positive', event_ndims, batch_ndims)
# TODO(jvdillon): Support batches of scalars.
self._name = name
self._batch_ndims = batch_ndims
self._event_ndims = event_ndims
@property
def name(self):
"""Name given to ops created by this class."""
return self._name
@property
def batch_ndims(self):
"""Returns number of dimensions corresponding to non-identical draws."""
return self._batch_ndims
@property
def event_ndims(self):
"""Returns number of dimensions needed to index a sample's coordinates."""
return self._event_ndims
def get_ndims(self, x, name='get_ndims'):
"""Get tensor ndims (rank).
Args:
x: `Tensor`.
name: `String`. The name to give this op.
Raises:
ValueError: if ndims is not statically known.
Returns:
`Scalar` number of dimensions associated with a `Tensor`.
"""
if x is None:
raise ValueError('Input was None which does not have known ndims.')
with ops.name_scope(self.name):
with ops.op_scope([x], name):
ndims = ops.convert_to_tensor(x).get_shape().ndims
if ndims is None:
raise ValueError('ShapeUtil assumes static number of '
'dimensions(%d)', ndims)
return ndims
def get_sample_ndims(self, x):
"""Returns number of dimensions corresponding to iid draws.
Args:
x: `Tensor`.
Raises:
ValueError: if batch_ndims or event_ndims are not statically known.
ValueError: if static sample_ndims does not match inferred
Returns:
Scalar number of dimensions associated with a sample.
"""
ndims = self.get_ndims(x)
sample_ndims = ndims - self.batch_ndims - self.event_ndims
if sample_ndims < 0:
raise ValueError('expected batch_ndims(%d) + event_ndims(%d) < ndims(%d)',
self.batch_ndims, self.event_ndims, ndims)
return sample_ndims
def get_dims(self, x, sample=True, batch=True, event=True):
"""Returns subset of tensor's dimension indexes (indexes into shape).
Args:
x: `Tensor`.
sample: `Boolean`. Include sample dimensions or not.
batch: `Boolean`. Include batch dimensions or not.
event: `Boolean`. Include event dimensions or not.
Raises:
ValueError: if `x.get_shape().ndims` is `None`
Returns:
List enumerating requested dimensions.
"""
ndims = self.get_ndims(x)
if sample and batch and event:
return list(range(ndims))
sample_start = 0
batch_start = self.get_sample_ndims(x)
event_start = batch_start + self.batch_ndims
sample_shape = list(range(sample_start, batch_start)) if sample else []
batch_shape = list(range(batch_start, event_start)) if batch else []
event_shape = list(range(event_start, ndims)) if event else []
return sample_shape + batch_shape + event_shape
def get_shape(self, x, sample=True, batch=True, event=True, name='get_shape'):
"""Returns subset of tensor's shape (size of dimensions).
Args:
x: `Tensor`.
sample: `Boolean`. Include sample shape or not.
batch: `Boolean`. Include batch shape or not.
event: `Boolean`. Include event shape or not.
name: `String`. The name to give this op.
Raises:
ValueError: if `x.get_shape().ndims` is `None`
Returns:
List describing event shape if known statically, `Tensor` otherwise.
"""
if not sample and not batch and not event:
return []
with ops.name_scope(self._name):
with ops.op_scope([x], name):
x = ops.convert_to_tensor(x)
shape = (x.get_shape().as_list()
if x.get_shape().is_fully_defined()
else array_ops.shape(x))
if sample and batch and event:
return shape
sample_start = 0
batch_start = self.get_sample_ndims(x)
event_start = batch_start + self.batch_ndims
sample_shape = shape[sample_start:batch_start] if sample else []
batch_shape = shape[batch_start:event_start] if batch else []
event_shape = shape[event_start:] if event else []
if not batch and not event:
return sample_shape
if not sample and not event:
return batch_shape
if not sample and not batch:
return event_shape
if x.get_shape().is_fully_defined():
return sample_shape + batch_shape + event_shape
else:
return array_ops.concat(0, [sample_shape, batch_shape, event_shape])
def get_sample_dims(self, x):
"""Returns dimension indexes corresponding to sample.
Convenience function; identical to:
```python
get_dims(x, sample=True, batch=False, event=False)
```
Args:
x: `Tensor`.
Raises:
ValueError: if `x.get_shape().ndims` is `None`
Returns:
List enumerating sample dimensions.
"""
return self.get_dims(x, sample=True, batch=False, event=False)
def get_batch_dims(self, x):
"""Returns dimension indexes corresponding to batch.
Convenience function; identical to:
```python
get_dims(x, sample=False, batch=True, event=False)
```
Args:
x: `Tensor`.
Raises:
ValueError: if `x.get_shape().ndims` is `None`
Returns:
List enumerating batch dimensions.
"""
return self.get_dims(x, sample=False, batch=True, event=False)
def get_event_dims(self, x):
"""Returns dimension indexes corresponding to event.
Convenience function; identical to:
```python
get_dims(x, sample=False, batch=False, event=True)
```
Args:
x: `Tensor`.
Raises:
ValueError: if `x.get_shape().ndims` is `None`
Returns:
List enumerating event dimensions.
"""
return self.get_dims(x, sample=False, batch=False, event=True)
def get_sample_shape(self, x):
"""Returns shape corresponding to sample.
Convenience function; identical to:
```python
get_shape(x, sample=True, batch=False, event=False)
```
Args:
x: `Tensor`.
Returns:
List describing sample shape if known statically, `Tensor` otherwise.
"""
return self.get_shape(x, sample=True, batch=False, event=False)
def get_batch_shape(self, x):
"""Returns shape corresponding to batch.
Convenience function; identical to:
```python
get_shape(x, sample=False, batch=True, event=False)
```
Args:
x: `Tensor`.
Returns:
List describing batch shape if known statically, `Tensor` otherwise.
"""
return self.get_shape(x, sample=False, batch=True, event=False)
def get_event_shape(self, x):
"""Returns shape corresponding to event.
Convenience function; identical to:
```python
get_shape(x, sample=False, batch=False, event=True)
```
Args:
x: `Tensor`.
Returns:
List describing event shape if known statically, `Tensor` otherwise.
"""
return self.get_shape(x, sample=False, batch=False, event=True)

View File

@ -82,6 +82,7 @@ class StudentT(distribution.Distribution):
# returning a length 2 tensor.
dist.pdf(3.0)
```
"""
def __init__(self,
@ -99,19 +100,19 @@ class StudentT(distribution.Distribution):
broadcasting (e.g. `df + mu + sigma` is a valid operation).
Args:
df: `float` or `double` tensor, the degrees of freedom of the
df: Floating point tensor, the degrees of freedom of the
distribution(s). `df` must contain only positive values.
mu: `float` or `double` tensor, the means of the distribution(s).
sigma: `float` or `double` tensor, the scaling factor for the
mu: Floating point tensor, the means of the distribution(s).
sigma: Floating point tensor, the scaling factor for the
distribution(s). `sigma` must contain only positive values.
Note that `sigma` is not the standard deviation of this distribution.
validate_args: Whether to assert that `df > 0, sigma > 0`. If
`validate_args` is False and inputs are invalid, correct behavior is not
guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
`validate_args` is `False` and inputs are invalid, correct behavior is
not guaranteed.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
@ -185,9 +186,12 @@ class StudentT(distribution.Distribution):
nan = np.nan + self._zeros()
return math_ops.select(df_gt_1, result_if_defined, nan)
else:
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
one = constant_op.constant(1.0, dtype=self.dtype)
return control_flow_ops.with_dependencies(
[check_ops.assert_less(one, self._df)], result_if_defined)
[check_ops.assert_less(
one, self._df,
message="mean not defined for components of df <= 1"
)], result_if_defined)
def mode(self, name="mode"):
with ops.name_scope(self.name):
@ -232,9 +236,12 @@ class StudentT(distribution.Distribution):
result_where_defined,
self._zeros() + np.nan)
else:
one = ops.convert_to_tensor(1.0, self.dtype)
one = constant_op.constant(1.0, dtype=self.dtype)
return control_flow_ops.with_dependencies(
[check_ops.assert_less(one, self._df)], result_where_defined)
[check_ops.assert_less(
one, self._df,
message="variance not defined for components of df <= 1"
)], result_where_defined)
def std(self, name="std"):
with ops.name_scope(self.name):
@ -348,8 +355,7 @@ class StudentT(distribution.Distribution):
# Let X = R*cos(theta), and let Y = R*sin(theta).
# Then X ~ t_df and Y ~ t_df.
# The variates X and Y are not independent.
shape = array_ops.concat(0, [array_ops.pack([2, n]),
self.batch_shape()])
shape = array_ops.concat(0, ([2, n], self.batch_shape()))
uniform = random_ops.random_uniform(shape=shape,
dtype=self.dtype,
seed=seed)

View File

@ -57,6 +57,7 @@ class TransformedDistribution(distribution.Distribution):
name="LogitNormalTransformedDistribution"
)
```
"""
def __init__(self,

View File

@ -67,14 +67,14 @@ class Uniform(distribution.Distribution):
```
Args:
a: `float` or `double` tensor, the minimum endpoint.
b: `float` or `double` tensor, the maximum endpoint. Must be > `a`.
validate_args: Whether to assert that `a > b`. If `validate_args` is False
and inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
a: Floating point tensor, the minimum endpoint.
b: Floating point tensor, the maximum endpoint. Must be > `a`.
validate_args: Whether to assert that `a > b`. If `validate_args` is
`False` and inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class.
Raises:
@ -83,8 +83,9 @@ class Uniform(distribution.Distribution):
self._allow_nan_stats = allow_nan_stats
self._validate_args = validate_args
with ops.op_scope([a, b], name):
with ops.control_dependencies([check_ops.assert_less(a, b)] if
validate_args else []):
with ops.control_dependencies([check_ops.assert_less(
a, b, message="uniform not defined when a > b.")] if validate_args
else []):
a = array_ops.identity(a, name="a")
b = array_ops.identity(b, name="b")
@ -228,7 +229,7 @@ class Uniform(distribution.Distribution):
n = ops.convert_to_tensor(n, name="n")
n_val = tensor_util.constant_value(n)
shape = array_ops.concat(0, [array_ops.pack([n]), self.batch_shape()])
shape = array_ops.concat(0, ([n], self.batch_shape()))
samples = random_ops.random_uniform(shape=shape,
dtype=self.dtype,
seed=seed)

View File

@ -94,6 +94,30 @@ tf_py_test(
],
)
tf_py_test(
name = "gmm_test",
srcs = [
"python/ops/gmm_test.py",
],
additional_deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
tf_py_test(
name = "gmm_ops_test",
srcs = [
"python/ops/gmm_ops_test.py",
],
additional_deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
tf_py_test(
name = "factorization_ops_test",
srcs = ["python/ops/factorization_ops_test.py"],

View File

@ -304,7 +304,7 @@ class WalsModelTest(tf.test.TestCase):
col_factors2 = [x.eval() for x in wals_model.col_factors]
for c1, c2 in zip(col_factors1, col_factors2):
self.assertAllClose(c1, c2, atol=1e-3)
self.assertAllClose(c1, c2, rtol=5e-3, atol=1e-2)
def test_als_transposed(self):
with self.test_session():
@ -383,7 +383,7 @@ class WalsModelTest(tf.test.TestCase):
regularization=1e-5,
row_weights=None,
col_weights=None)
self.simple_train(model, inp, 15)
self.simple_train(model, inp, 25)
row_factor = model.row_factors[0].eval()
col_factor = model.col_factors[0].eval()
self.assertAllClose(data,
@ -407,7 +407,7 @@ class WalsModelTest(tf.test.TestCase):
regularization=1e-5,
row_weights=[0] * rows,
col_weights=[0] * cols)
self.simple_train(model, inp, 15)
self.simple_train(model, inp, 25)
row_factor = model.row_factors[0].eval()
col_factor = model.col_factors[0].eval()
self.assertAllClose(data,
@ -438,7 +438,7 @@ class WalsModelTest(tf.test.TestCase):
regularization=0.001,
row_weights=row_wts,
col_weights=col_wts)
self.simple_train(model, inp, 10)
self.simple_train(model, inp, 25)
row_factor = model.row_factors[0].eval()
col_factor = model.col_factors[0].eval()
out = np.dot(row_factor, np.transpose(col_factor))
@ -446,7 +446,7 @@ class WalsModelTest(tf.test.TestCase):
for j in xrange(cols):
if keep_index([i, j]):
self.assertNear(data[i][j], out[i][j],
err=0.2, msg="%d, %d" % (i, j))
err=0.4, msg="%d, %d" % (i, j))
else:
self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j))

View File

@ -0,0 +1,211 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Gaussian mixture model (GMM) clustering.
This goes on top of skflow API.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.factorization.python.ops import gmm_ops
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
from tensorflow.contrib.learn.python.learn.utils import checkpoints
from tensorflow.python.ops.control_flow_ops import with_dependencies
class GMM(estimator.Estimator, TransformerMixin):
"""GMM clustering."""
SCORES = 'scores'
ASSIGNMENTS = 'assignments'
ALL_SCORES = 'all_scores'
def __init__(self,
num_clusters,
model_dir=None,
random_seed=0,
params='wmc',
initial_clusters='random',
covariance_type='full',
batch_size=128,
steps=10,
continue_training=False,
config=None,
verbose=1):
"""Creates a model for running GMM training and inference.
Args:
num_clusters: number of clusters to train.
model_dir: the directory to save the model results and log files.
random_seed: Python integer. Seed for PRNG used to initialize centers.
params: Controls which parameters are updated in the training process.
Can contain any combination of "w" for weights, "m" for means,
and "c" for covars.
initial_clusters: specifies how to initialize the clusters for training.
See gmm_ops.gmm for the possible values.
covariance_type: one of "full", "diag".
batch_size: See TensorFlowEstimator
steps: See TensorFlowEstimator
continue_training: See TensorFlowEstimator
config: See TensorFlowEstimator
verbose: See TensorFlowEstimator
"""
super(GMM, self).__init__(
model_dir=model_dir,
config=config)
self.batch_size = batch_size
self.steps = steps
self.continue_training = continue_training
self.verbose = verbose
self._num_clusters = num_clusters
self._params = params
self._training_initial_clusters = initial_clusters
self._covariance_type = covariance_type
self._training_graph = None
self._random_seed = random_seed
def fit(self, x, y=None, monitors=None, logdir=None, steps=None):
"""Trains a GMM clustering on x.
Note: See TensorFlowEstimator for logic for continuous training and graph
construction across multiple calls to fit.
Args:
x: training input matrix of shape [n_samples, n_features].
y: labels. Should be None.
monitors: List of `Monitor` objects to print training progress and
invoke early stopping.
logdir: the directory to save the log file that can be used for optional
visualization.
steps: number of training steps. If not None, overrides the value passed
in constructor.
Returns:
Returns self.
"""
if logdir is not None:
self._model_dir = logdir
self._data_feeder = data_feeder.setup_train_data_feeder(
x, None, self._num_clusters, self.batch_size)
self._train_model(input_fn=self._data_feeder.input_builder,
feed_fn=self._data_feeder.get_feed_dict_fn(),
steps=steps or self.steps,
monitors=monitors,
init_feed_fn=self._data_feeder.get_feed_dict_fn())
return self
def predict(self, x, batch_size=None):
"""Predict cluster id for each element in x.
Args:
x: 2-D matrix or iterator.
batch_size: size to use for batching up x for querying the model.
Returns:
Array with same number of rows as x, containing cluster ids.
"""
return super(GMM, self).predict(x=x, batch_size=batch_size)[GMM.ASSIGNMENTS]
def score(self, x, batch_size=None):
"""Predict total sum of distances to nearest clusters.
Args:
x: 2-D matrix or iterator.
batch_size: size to use for batching up x for querying the model.
Returns:
Total score.
"""
return np.sum(self.evaluate(x=x, batch_size=batch_size)[GMM.SCORES])
def transform(self, x, batch_size=None):
"""Transforms each element in x to distances to cluster centers.
Args:
x: 2-D matrix or iterator.
batch_size: size to use for batching up x for querying the model.
Returns:
Array with same number of rows as x, and num_clusters columns, containing
distances to the cluster centers.
"""
return super(GMM, self).predict(x=x, batch_size=batch_size)[GMM.ALL_SCORES]
def clusters(self):
"""Returns cluster centers."""
clusters = checkpoints.load_variable(self.model_dir,
gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
return np.squeeze(clusters, 1)
def covariances(self):
"""Returns the covariances."""
return checkpoints.load_variable(
self.model_dir,
gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
def _get_train_ops(self, features, _):
(_,
_,
losses,
training_op) = gmm_ops.gmm(
features,
self._training_initial_clusters,
self._num_clusters,
self._random_seed,
self._covariance_type,
self._params)
incr_step = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
loss = tf.reduce_sum(losses)
training_op = with_dependencies([training_op, incr_step], loss)
return training_op, loss
def _get_predict_ops(self, features):
(all_scores,
model_predictions,
_,
_) = gmm_ops.gmm(
features,
self._training_initial_clusters,
self._num_clusters,
self._random_seed,
self._covariance_type,
self._params)
return {
GMM.ALL_SCORES: all_scores[0],
GMM.ASSIGNMENTS: model_predictions[0]
}
def _get_eval_ops(self, features, _, unused_metrics):
(_,
_,
losses,
_) = gmm_ops.gmm(
features,
self._training_initial_clusters,
self._num_clusters,
self._random_seed,
self._covariance_type,
self._params)
return {
GMM.SCORES: tf.reduce_sum(losses),
}

View File

@ -0,0 +1,461 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gaussian mixture models Operations."""
# TODO(xavigonzalvo): Factor out covariance matrix operations to make
# code reusable for different types (e.g. diag).
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.ops.embedding_ops import embedding_lookup
# Machine epsilon.
MEPS = np.finfo(float).eps
FULL_COVARIANCE = 'full'
DIAG_COVARIANCE = 'diag'
def _covariance(x, diag):
"""Defines the covariance operation of a matrix.
Args:
x: a matrix Tensor. Dimension 0 should contain the number of examples.
diag: if True, it computes the diagonal covariance.
Returns:
A Tensor representing the covariance of x. In the case of
diagonal matrix just the diagonal is returned.
"""
num_points = tf.to_float(tf.shape(x)[0])
x -= tf.reduce_mean(x, 0, keep_dims=True)
if diag:
cov = tf.reduce_sum(
tf.square(x), 0, keep_dims=True) / (num_points - 1)
else:
cov = tf.matmul(x, x, transpose_a=True) / (num_points - 1)
return cov
def _init_clusters_random(data, num_clusters, random_seed):
"""Does random initialization of clusters.
Args:
data: a list of Tensors with a matrix of data, each row is an example.
num_clusters: an integer with the number of clusters.
random_seed: Seed for PRNG used to initialize seeds.
Returns:
A Tensor with num_clusters random rows of data.
"""
assert isinstance(data, list)
num_data = tf.add_n([tf.shape(inp)[0] for inp in data])
with tf.control_dependencies([tf.assert_less_equal(num_clusters, num_data)]):
indices = tf.random_uniform([num_clusters],
minval=0,
maxval=tf.cast(num_data, tf.int64),
seed=random_seed,
dtype=tf.int64)
indices = tf.cast(indices, tf.int32) % num_data
clusters_init = embedding_lookup(data, indices, partition_strategy='div')
return clusters_init
class GmmAlgorithm(object):
"""Tensorflow Gaussian mixture model clustering class."""
CLUSTERS_VARIABLE = 'clusters'
CLUSTERS_COVS_VARIABLE = 'clusters_covs'
def __init__(self, data, num_classes, initial_means=None, params='wmc',
covariance_type=FULL_COVARIANCE, random_seed=0):
"""Constructor.
Args:
data: a list of Tensors with data, each row is a new example.
num_classes: number of clusters.
initial_means: a Tensor with a matrix of means. If None, means are
computed by sampling randomly.
params: Controls which parameters are updated in the training
process. Can contain any combination of "w" for weights, "m" for
means, and "c" for covariances.
covariance_type: one of "full", "diag".
random_seed: Seed for PRNG used to initialize seeds.
Raises:
Exception if covariance type is unknown.
"""
self._params = params
self._random_seed = random_seed
self._covariance_type = covariance_type
if self._covariance_type not in [DIAG_COVARIANCE, FULL_COVARIANCE]:
raise Exception( # pylint: disable=g-doc-exception
'programmer error: Invalid covariance type: %s' %
self._covariance_type)
# Create sharded variables for multiple shards. The following
# lists are indexed by shard.
# Probability per example in a class.
num_shards = len(data)
self._probs = [None] * num_shards
# Prior probability.
self._prior_probs = [None] * num_shards
# Membership weights w_{ik} where "i" is the i-th example and "k"
# is the k-th mixture.
self._w = [None] * num_shards
# Number of examples in a class.
self._points_in_k = [None] * num_shards
first_shard = data[0]
self._dimensions = tf.shape(first_shard)[1]
self._num_classes = num_classes
# Small value to guarantee that covariances are invertible.
self._min_var = tf.diag(tf.ones(tf.pack([self._dimensions]))) * 1e-3
self._create_variables(data, initial_means)
# Operations of partial statistics for the computation of the means.
self._w_mul_x = []
# Operations of partial statistics for the computation of the covariances.
self._w_mul_x2 = []
self._define_graph(data)
def _create_variables(self, data, initial_means=None):
"""Initializes GMM algorithm.
Args:
data: a list of Tensors with data, each row is a new example.
initial_means: a Tensor with a matrix of means.
"""
first_shard = data[0]
# Initialize means: num_classes X 1 X dimensions.
if initial_means is not None:
self._means = tf.Variable(tf.expand_dims(initial_means, 1),
name=self.CLUSTERS_VARIABLE,
validate_shape=False, dtype=tf.float32)
else:
# Sample data randomly
self._means = tf.Variable(tf.expand_dims(
_init_clusters_random(data, self._num_classes, self._random_seed), 1),
name=self.CLUSTERS_VARIABLE,
validate_shape=False)
# Initialize covariances.
if self._covariance_type == FULL_COVARIANCE:
cov = _covariance(first_shard, False) + self._min_var
# A matrix per class, num_classes X dimensions X dimensions
covs = tf.tile(
tf.expand_dims(cov, 0), [self._num_classes, 1, 1])
elif self._covariance_type == DIAG_COVARIANCE:
cov = _covariance(first_shard, True) + self._min_var
# A diagonal per row, num_classes X dimensions.
covs = tf.tile(tf.expand_dims(tf.diag_part(cov), 0),
[self._num_classes, 1])
self._covs = tf.Variable(covs, name='clusters_covs', validate_shape=False)
# Mixture weights, representing the probability that a randomly
# selected unobservable data (in EM terms) was generated by component k.
self._alpha = tf.Variable(tf.tile([1.0 / self._num_classes],
[self._num_classes]))
def training_ops(self):
"""Returns the training operation."""
return self._train_ops
def alphas(self):
return self._alpha
def clusters(self):
"""Returns the clusters with dimensions num_classes X 1 X num_dimensions."""
return self._means
def covariances(self):
"""Returns the covariances matrices."""
return self._covs
def assignments(self):
"""Returns a list of Tensors with the matrix of assignments per shard."""
ret = []
for w in self._w:
ret.append(tf.argmax(w, 1))
return ret
def scores(self):
"""Returns the distances to each class.
Returns:
A tuple with two Tensors. The first contains the distance to
each class. The second contains the distance to the assigned
class.
"""
return (self._all_scores, self._scores)
def _define_graph(self, data):
"""Define graph for a single iteration.
Args:
data: a list of Tensors defining the training data.
"""
for shard_id, shard in enumerate(data):
self._num_examples = tf.shape(shard)[0]
shard = tf.expand_dims(shard, 0)
self._define_log_prob_operation(shard_id, shard)
self._define_prior_log_prob_operation(shard_id)
self._define_expectation_operation(shard_id)
self._define_partial_maximization_operation(shard_id, shard)
self._define_maximization_operation(len(data))
self._define_distance_to_clusters(data)
def _define_full_covariance_probs(self, shard_id, shard):
"""Defines the full covariance probabilties per example in a class.
Updates a matrix with dimension num_examples X num_classes.
Args:
shard_id: id of the current shard.
shard: current data shard, 1 X num_examples X dimensions.
"""
diff = shard - self._means
cholesky = tf.batch_cholesky(self._covs + self._min_var)
log_det_covs = 2.0 * tf.reduce_sum(tf.log(
tf.batch_matrix_diag_part(cholesky)), 1)
x_mu_cov = tf.square(tf.batch_matrix_triangular_solve(
cholesky, tf.transpose(diff, perm=[0, 2, 1]),
lower=True))
diag_m = tf.transpose(tf.reduce_sum(x_mu_cov, 1))
self._probs[shard_id] = -0.5 * (
diag_m + tf.to_float(self._dimensions) * tf.log(2 * np.pi) +
log_det_covs)
def _define_diag_covariance_probs(self, shard_id, shard):
"""Defines the diagonal covariance probabilities per example in a class.
Args:
shard_id: id of the current shard.
shard: current data shard, 1 X num_examples X dimensions.
Returns a matrix num_examples * num_classes.
"""
# num_classes X 1
# TODO(xavigonzalvo): look into alternatives to log for
# reparametrization of variance parameters.
det_expanded = tf.reduce_sum(tf.log(self._covs + 1e-3),
1, keep_dims=True)
diff = shard - self._means
x2 = tf.square(diff)
cov_expanded = tf.expand_dims(1.0 / (self._covs + 1e-3), 2)
# num_classes X num_examples
x2_cov = tf.batch_matmul(x2, cov_expanded)
x2_cov = tf.transpose(tf.squeeze(x2_cov, [2]))
self._probs[shard_id] = -0.5 * (
tf.to_float(self._dimensions) * tf.log(2.0 * np.pi) +
tf.transpose(det_expanded) + x2_cov)
def _define_log_prob_operation(self, shard_id, shard):
"""Probability per example in a class.
Updates a matrix with dimension num_examples X num_classes.
Args:
shard_id: id of the current shard.
shard: current data shard, 1 X num_examples X dimensions.
"""
# TODO(xavigonzalvo): Use the pdf defined in
# third_party/tensorflow/contrib/distributions/python/ops/gaussian.py
if self._covariance_type == FULL_COVARIANCE:
self._define_full_covariance_probs(shard_id, shard)
elif self._covariance_type == DIAG_COVARIANCE:
self._define_diag_covariance_probs(shard_id, shard)
self._probs[shard_id] += tf.log(self._alpha)
def _define_prior_log_prob_operation(self, shard_id):
"""Computes the prior probability of all samples.
Updates a vector where each item is the prior probabibility of an
input example.
Args:
shard_id: id of current shard_id.
"""
self._prior_probs[shard_id] = tf.log(
tf.reduce_sum(tf.exp(self._probs[shard_id]), 1, keep_dims=True))
def _define_expectation_operation(self, shard_id):
# Shape broadcasting.
probs = tf.expand_dims(self._probs[shard_id], 0)
# Membership weights are computed as:
# w_{ik} = \frac{\alpha_k f(\mathbf{y_i}|\mathbf{\theta}_k)}
# {\sum_{m=1}^{K}\alpha_mf(\mathbf{y_i}|\mathbf{\theta}_m)}
# where "i" is the i-th example, "k" is the k-th mixture, theta are
# the model parameters and y_i the observations.
# These are defined for each shard.
self._w[shard_id] = tf.reshape(
tf.exp(probs - self._prior_probs[shard_id]),
tf.pack([self._num_examples, self._num_classes]))
def _define_partial_maximization_operation(self, shard_id, shard):
"""Computes the partial statistics of the means and covariances.
Args:
shard_id: current shard id.
shard: current data shard, 1 X num_examples X dimensions.
"""
# Soft assignment of each data point to each of the two clusters.
self._points_in_k[shard_id] = tf.reduce_sum(self._w[shard_id], 0,
keep_dims=True)
# Partial means.
w_mul_x = tf.expand_dims(
tf.matmul(self._w[shard_id],
tf.squeeze(shard, [0]), transpose_a=True), 1)
self._w_mul_x.append(w_mul_x)
# Partial covariances.
x = tf.concat(0, [shard for _ in range(self._num_classes)])
x_trans = tf.transpose(x, perm=[0, 2, 1])
x_mul_w = tf.concat(0, [
tf.expand_dims(x_trans[k, :, :] * self._w[shard_id][:, k], 0)
for k in range(self._num_classes)])
self._w_mul_x2.append(tf.batch_matmul(x_mul_w, x))
def _define_maximization_operation(self, num_batches):
"""Maximization operations."""
# TODO(xavigonzalvo): some of these operations could be moved to C++.
# Compute the effective number of data points assigned to component k.
with tf.control_dependencies(self._w):
points_in_k = tf.squeeze(tf.add_n(self._points_in_k), squeeze_dims=[0])
# Update alpha.
if 'w' in self._params:
final_points_in_k = points_in_k / num_batches
num_examples = tf.to_float(tf.reduce_sum(final_points_in_k))
self._alpha_op = self._alpha.assign(
final_points_in_k / (num_examples + MEPS))
else:
self._alpha_op = tf.no_op()
self._train_ops = [self._alpha_op]
# Update means.
points_in_k_expanded = tf.reshape(points_in_k,
[self._num_classes, 1, 1])
if 'm' in self._params:
self._means_op = self._means.assign(
tf.div(tf.add_n(self._w_mul_x), points_in_k_expanded + MEPS))
else:
self._means_op = tf.no_op()
# means are (num_classes x 1 x dims)
# Update covariances.
with tf.control_dependencies([self._means_op]):
b = tf.add_n(self._w_mul_x2) / (points_in_k_expanded + MEPS)
new_covs = []
for k in range(self._num_classes):
mean = self._means.ref()[k, :, :]
square_mean = tf.matmul(mean, mean, transpose_a=True)
new_cov = b[k, :, :] - square_mean + self._min_var
if self._covariance_type == FULL_COVARIANCE:
new_covs.append(tf.expand_dims(new_cov, 0))
elif self._covariance_type == DIAG_COVARIANCE:
new_covs.append(tf.expand_dims(tf.diag_part(new_cov), 0))
new_covs = tf.concat(0, new_covs)
if 'c' in self._params:
# Train operations don't need to take care of the means
# because covariances already depend on it.
with tf.control_dependencies([self._means_op, new_covs]):
self._train_ops.append(
tf.assign(self._covs, new_covs, validate_shape=False))
def _define_distance_to_clusters(self, data):
"""Defines the Mahalanobis distance to the assigned Gaussian."""
# TODO(xavigonzalvo): reuse (input - mean) * cov^-1 * (input -
# mean) from log probability function.
self._all_scores = []
for shard in data:
all_scores = []
shard = tf.expand_dims(shard, 0)
for c in xrange(self._num_classes):
if self._covariance_type == FULL_COVARIANCE:
cov = self._covs[c, :, :]
elif self._covariance_type == DIAG_COVARIANCE:
cov = tf.diag(self._covs[c, :])
inverse = tf.matrix_inverse(cov + self._min_var)
inv_cov = tf.tile(
tf.expand_dims(inverse, 0),
tf.pack([self._num_examples, 1, 1]))
diff = tf.transpose(shard - self._means[c, :, :], perm=[1, 0, 2])
m_left = tf.batch_matmul(diff, inv_cov)
all_scores.append(tf.sqrt(tf.batch_matmul(
m_left, tf.transpose(diff, perm=[0, 2, 1])
)))
self._all_scores.append(tf.reshape(
tf.concat(1, all_scores),
tf.pack([self._num_examples, self._num_classes])))
# Distance to the associated class.
self._all_scores = tf.concat(0, self._all_scores)
assignments = tf.concat(0, self.assignments())
rows = tf.to_int64(tf.range(0, self._num_examples))
indices = tf.concat(1, [tf.expand_dims(rows, 1),
tf.expand_dims(assignments, 1)])
self._scores = tf.gather_nd(self._all_scores, indices)
def _define_loglikelihood_operation(self):
"""Defines the total log-likelihood of current iteration."""
self._ll_op = []
for prior_probs in self._prior_probs:
self._ll_op.append(tf.reduce_sum(tf.log(prior_probs)))
tf.scalar_summary('ll', tf.reduce_sum(self._ll_op))
def gmm(inp, initial_clusters, num_clusters, random_seed,
covariance_type=FULL_COVARIANCE, params='wmc'):
"""Creates the graph for Gaussian mixture model (GMM) clustering.
Args:
inp: An input tensor or list of input tensors
initial_clusters: Specifies the clusters used during
initialization. Can be a tensor or numpy array, or a function
that generates the clusters. Can also be "random" to specify
that clusters should be chosen randomly from input data. Note: type
is diverse to be consistent with skflow.
num_clusters: number of clusters.
random_seed: Python integer. Seed for PRNG used to initialize centers.
covariance_type: one of "diag", "full".
params: Controls which parameters are updated in the training
process. Can contain any combination of "w" for weights, "m" for
means, and "c" for covars.
Returns:
Note: tuple of lists returned to be consistent with skflow
A tuple consisting of:
all_scores: A matrix (or list of matrices) of dimensions (num_input,
num_clusters) where the value is the distance of an input vector and a
cluster center.
assignments: A vector (or list of vectors). Each element in the vector
corresponds to an input row in 'inp' and specifies the cluster id
corresponding to the input.
scores: Similar to assignments but specifies the distance to the
assigned cluster instead.
training_op: an op that runs an iteration of training.
"""
initial_means = None
if initial_clusters != 'random' and not isinstance(
initial_clusters, tf.Tensor):
initial_means = tf.constant(initial_clusters, dtype=tf.float32)
# Implementation of GMM.
inp = inp if isinstance(inp, list) else [inp]
gmm_tool = GmmAlgorithm(inp, num_clusters, initial_means, params,
covariance_type, random_seed)
training_ops = gmm_tool.training_ops()
assignments = gmm_tool.assignments()
all_scores, scores = gmm_tool.scores()
return [all_scores], [assignments], [scores], tf.group(*training_ops)

View File

@ -0,0 +1,198 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gmm_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.factorization.python.ops import gmm_ops
from tensorflow.python.platform import tf_logging as logging
class GmmOpsTest(tf.test.TestCase):
def setUp(self):
self.num_examples = 1000
self.iterations = 40
self.seed = 4
tf.set_random_seed(self.seed)
np.random.seed(self.seed * 2)
self.data, self.true_assignments = self.make_data(self.num_examples)
# Generate more complicated data.
self.centers = [[1, 1], [-1, 0.5], [2, 1]]
self.more_data, self.more_true_assignments = self.make_data_from_centers(
self.num_examples, self.centers)
@staticmethod
def make_data(num_vectors):
"""Generates 2-dimensional data centered on (2,2), (-1,-1).
Args:
num_vectors: number of training examples.
Returns:
A tuple containing the data as a numpy array and the cluster ids.
"""
vectors = []
classes = []
for _ in xrange(num_vectors):
if np.random.random() > 0.5:
vectors.append([np.random.normal(2.0, 0.6),
np.random.normal(2.0, 0.9)])
classes.append(0)
else:
vectors.append([np.random.normal(-1.0, 0.4),
np.random.normal(-1.0, 0.5)])
classes.append(1)
return np.asarray(vectors), classes
@staticmethod
def make_data_from_centers(num_vectors, centers):
"""Generates 2-dimensional data with random centers.
Args:
num_vectors: number of training examples.
centers: a list of random 2-dimensional centers.
Returns:
A tuple containing the data as a numpy array and the cluster ids.
"""
vectors = []
classes = []
for _ in xrange(num_vectors):
current_class = np.random.random_integers(0, len(centers) - 1)
vectors.append([np.random.normal(centers[current_class][0],
np.random.random_sample()),
np.random.normal(centers[current_class][1],
np.random.random_sample())])
classes.append(current_class)
return np.asarray(vectors), len(centers)
def test_covariance(self):
start_time = time.time()
data = self.data.T
np_cov = np.cov(data)
logging.info('Numpy took %f', time.time() - start_time)
start_time = time.time()
with self.test_session() as sess:
op = gmm_ops._covariance(
tf.constant(data.T, dtype=tf.float32),
False)
op_diag = gmm_ops._covariance(
tf.constant(data.T, dtype=tf.float32),
True)
tf.initialize_all_variables().run()
tf_cov = sess.run(op)
np.testing.assert_array_almost_equal(np_cov, tf_cov)
logging.info('Tensorflow took %f', time.time() - start_time)
tf_cov = sess.run(op_diag)
np.testing.assert_array_almost_equal(
np.diag(np_cov), np.ravel(tf_cov), decimal=5)
def test_simple_cluster(self):
"""Tests that the clusters are correct."""
num_classes = 2
graph = tf.Graph()
with graph.as_default() as g:
g.seed = 5
with self.test_session() as sess:
data = tf.constant(self.data, dtype=tf.float32)
_, assignments, _, training_op = gmm_ops.gmm(data, 'random',
num_classes,
random_seed=self.seed)
tf.initialize_all_variables().run()
for _ in xrange(self.iterations):
sess.run(training_op)
assignments = sess.run(assignments)
accuracy = np.mean(
np.asarray(self.true_assignments) == np.squeeze(assignments))
logging.info('Accuracy: %f', accuracy)
self.assertGreater(accuracy, 0.98)
def testParams(self):
"""Tests that the params work as intended."""
num_classes = 2
with self.test_session() as sess:
# Experiment 1. Update weights only.
data = tf.constant(self.data, dtype=tf.float32)
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
[[3.0, 3.0], [0.0, 0.0]], 'w')
training_ops = gmm_tool.training_ops()
tf.initialize_all_variables().run()
for _ in xrange(self.iterations):
sess.run(training_ops)
# Only the probability to each class is updated.
alphas = sess.run(gmm_tool.alphas())
self.assertGreater(alphas[1], 0.6)
means = sess.run(gmm_tool.clusters())
np.testing.assert_almost_equal(
np.expand_dims([[3.0, 3.0], [0.0, 0.0]], 1), means)
covs = sess.run(gmm_tool.covariances())
np.testing.assert_almost_equal(covs[0], covs[1])
# Experiment 2. Update means and covariances.
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
[[3.0, 3.0], [0.0, 0.0]], 'mc')
training_ops = gmm_tool.training_ops()
tf.initialize_all_variables().run()
for _ in xrange(self.iterations):
sess.run(training_ops)
alphas = sess.run(gmm_tool.alphas())
self.assertAlmostEqual(alphas[0], alphas[1])
means = sess.run(gmm_tool.clusters())
np.testing.assert_almost_equal(
np.expand_dims([[2.0, 2.0], [-1.0, -1.0]], 1), means, decimal=1)
covs = sess.run(gmm_tool.covariances())
np.testing.assert_almost_equal(
[[0.371111, -0.0050774], [-0.0050774, 0.8651744]],
covs[0], decimal=4)
np.testing.assert_almost_equal(
[[0.146976, 0.0259463], [0.0259463, 0.2543971]],
covs[1], decimal=4)
# Experiment 3. Update covariances only.
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
[[-1.0, -1.0], [1.0, 1.0]], 'c')
training_ops = gmm_tool.training_ops()
tf.initialize_all_variables().run()
for _ in xrange(self.iterations):
sess.run(training_ops)
alphas = sess.run(gmm_tool.alphas())
self.assertAlmostEqual(alphas[0], alphas[1])
means = sess.run(gmm_tool.clusters())
np.testing.assert_almost_equal(
np.expand_dims([[-1.0, -1.0], [1.0, 1.0]], 1), means)
covs = sess.run(gmm_tool.covariances())
np.testing.assert_almost_equal(
[[0.1299582, 0.0435872], [0.0435872, 0.2558578]],
covs[0], decimal=5)
np.testing.assert_almost_equal(
[[3.195385, 2.6989155], [2.6989155, 3.3881593]],
covs[1], decimal=5)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,172 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ops.gmm."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.factorization.python.ops.gmm import GMM
from tensorflow.contrib.factorization.python.ops.kmeans import KMeansClustering as KMeans
from tensorflow.contrib.learn.python.learn.estimators import run_config
FLAGS = tf.app.flags.FLAGS
class GMMTest(tf.test.TestCase):
def setUp(self):
np.random.seed(3)
tf.set_random_seed(2)
self.num_centers = 2
self.num_dims = 2
self.num_points = 4000
self.batch_size = 100
self.true_centers = self.make_random_centers(self.num_centers,
self.num_dims)
self.points, self.assignments, self.scores = self.make_random_points(
self.true_centers,
self.num_points)
self.true_score = np.add.reduce(self.scores)
# Use initial means from kmeans (just like scikit-learn does).
clusterer = KMeans(num_clusters=self.num_centers)
clusterer.fit(self.points, steps=30)
self.initial_means = clusterer.clusters()
@staticmethod
def make_random_centers(num_centers, num_dims):
return np.round(np.random.rand(num_centers,
num_dims).astype(np.float32) * 500)
@staticmethod
def make_random_points(centers, num_points):
num_centers, num_dims = centers.shape
assignments = np.random.choice(num_centers, num_points)
offsets = np.round(np.random.randn(num_points,
num_dims).astype(np.float32) * 20)
points = centers[assignments] + offsets
means = [np.mean(points[assignments == center], axis=0)
for center in xrange(num_centers)]
covs = [np.cov(points[assignments == center].T)
for center in xrange(num_centers)]
scores = []
for r in xrange(num_points):
scores.append(np.sqrt(np.dot(
np.dot(points[r, :] - means[assignments[r]],
np.linalg.inv(covs[assignments[r]])),
points[r, :] - means[assignments[r]])))
return (points, assignments, scores)
def test_clusters(self):
"""Tests the shape of the clusters."""
gmm = GMM(self.num_centers,
initial_clusters=self.initial_means,
batch_size=self.batch_size,
steps=40,
continue_training=True,
random_seed=4,
config=run_config.RunConfig(tf_random_seed=2))
gmm.fit(x=self.points, steps=0)
clusters = gmm.clusters()
self.assertAllEqual(list(clusters.shape),
[self.num_centers, self.num_dims])
def test_fit(self):
gmm = GMM(self.num_centers,
initial_clusters='random',
batch_size=self.batch_size,
random_seed=4,
config=run_config.RunConfig(tf_random_seed=2))
gmm.fit(x=self.points, steps=1)
score1 = gmm.score(x=self.points)
gmm = GMM(self.num_centers,
initial_clusters='random',
batch_size=self.batch_size,
random_seed=4,
config=run_config.RunConfig(tf_random_seed=2))
gmm.fit(x=self.points, steps=10)
score2 = gmm.score(x=self.points)
self.assertGreater(score1, score2)
self.assertNear(self.true_score, score2, self.true_score * 0.15)
def test_infer(self):
gmm = GMM(self.num_centers,
initial_clusters=self.initial_means,
batch_size=self.batch_size,
steps=40,
continue_training=True,
random_seed=4,
config=run_config.RunConfig(tf_random_seed=2))
gmm.fit(x=self.points, steps=60)
clusters = gmm.clusters()
# Make a small test set
points, true_assignments, true_offsets = (
self.make_random_points(clusters, 40))
assignments = np.ravel(gmm.predict(points))
self.assertAllEqual(true_assignments, assignments)
# Test score
score = gmm.score(points)
self.assertNear(score, np.sum(true_offsets), 4.05)
def _compare_with_sklearn(self, cov_type):
# sklearn version.
iterations = 40
np.random.seed(5)
sklearn_assignments = np.asarray([0, 0, 1, 0, 0, 0, 1, 0, 0, 1])
sklearn_means = np.asarray([[144.83417719, 254.20130341],
[274.38754816, 353.16074346]])
sklearn_covs = np.asarray([[[395.0081194, -4.50389512],
[-4.50389512, 408.27543989]],
[[385.17484203, -31.27834935],
[-31.27834935, 391.74249925]]])
# skflow version.
gmm = GMM(self.num_centers,
initial_clusters=self.initial_means,
covariance_type=cov_type,
batch_size=self.num_points,
steps=iterations,
continue_training=True,
config=run_config.RunConfig(tf_random_seed=2))
gmm.fit(self.points)
skflow_assignments = gmm.predict(self.points[:10, :]).astype(int)
self.assertAllClose(sklearn_assignments,
np.ravel(skflow_assignments))
self.assertAllClose(sklearn_means, gmm.clusters())
if cov_type == 'full':
self.assertAllClose(sklearn_covs, gmm.covariances(), rtol=0.01)
else:
for d in [0, 1]:
self.assertAllClose(np.diag(sklearn_covs[d]),
gmm.covariances()[d, :], rtol=0.01)
def test_compare_full(self):
self._compare_with_sklearn('full')
def test_compare_diag(self):
self._compare_with_sklearn('diag')
if __name__ == '__main__':
tf.test.main()

View File

@ -153,9 +153,11 @@ class KMeansTest(tf.test.TestCase):
def test_fit_with_cosine_distance(self):
# Create points on y=x and y=1.5x lines to check the cosine similarity.
# Note that euclidean distance will give different results in this case.
points = np.array([[9, 9], [0.5, 0.5], [10, 15], [0.4, 0.6]])
points = np.array(
[[9, 9], [0.5, 0.5], [10, 15], [0.4, 0.6]], dtype=np.float32)
# true centers are the unit vectors on lines y=x and y=1.5x
true_centers = np.array([[0.70710678, 0.70710678], [0.5547002, 0.83205029]])
true_centers = np.array(
[[0.70710678, 0.70710678], [0.5547002, 0.83205029]], dtype=np.float32)
kmeans = KMeans(2,
initial_clusters=kmeans_ops.RANDOM_INIT,
distance_metric=kmeans_ops.COSINE_DISTANCE,
@ -168,8 +170,9 @@ class KMeansTest(tf.test.TestCase):
np.sort(true_centers, axis=0))
def test_transform_with_cosine_distance(self):
points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18],
[-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]])
points = np.array(
[[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2],
[0.1, 2.5], [0.2, 2], [0.1, 3], [0.2, 4]], dtype=np.float32)
true_centers = [normalize(np.mean(normalize(points)[4:, :], axis=0,
keepdims=True))[0],
@ -180,8 +183,8 @@ class KMeansTest(tf.test.TestCase):
initial_clusters=kmeans_ops.RANDOM_INIT,
distance_metric=kmeans_ops.COSINE_DISTANCE,
use_mini_batch=self.use_mini_batch,
config=self.config(3))
kmeans.fit(x=points, steps=30, batch_size=8)
config=self.config(5))
kmeans.fit(x=points, steps=50, batch_size=8)
centers = normalize(kmeans.clusters())
self.assertAllClose(np.sort(centers, axis=0),
@ -193,16 +196,16 @@ class KMeansTest(tf.test.TestCase):
self.assertAllClose(transform, true_transform, atol=1e-3)
def test_predict_with_cosine_distance(self):
points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18],
[-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]]).astype(
np.float32)
points = np.array(
[[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2],
[0.1, 2.5], [0.2, 2], [0.1, 3], [0.2, 4]], dtype=np.float32)
true_centers = np.array(
[normalize(np.mean(normalize(points)[0:4, :],
axis=0,
keepdims=True))[0],
normalize(np.mean(normalize(points)[4:, :],
axis=0,
keepdims=True))[0]])
keepdims=True))[0]], dtype=np.float32)
true_assignments = [0] * 4 + [1] * 4
true_score = len(points) - np.tensordot(normalize(points),
true_centers[true_assignments])
@ -230,14 +233,14 @@ class KMeansTest(tf.test.TestCase):
# the less populated centers.
points = np.array([[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3],
[-3.1, -3.2], [-2.8, -3.], [-2.9, -3.1], [-3., -3.1],
[-3., -3.1], [-3.2, -3.], [-3., -3.]]).astype(np.float32)
[-3., -3.1], [-3.2, -3.], [-3., -3.]], dtype=np.float32)
true_centers = np.array(
[normalize(np.mean(normalize(points)[0:2, :], axis=0,
keepdims=True))[0],
normalize(np.mean(normalize(points)[2:4, :], axis=0,
keepdims=True))[0],
normalize(np.mean(normalize(points)[4:, :], axis=0,
keepdims=True))[0]])
keepdims=True))[0]], dtype=np.float32)
true_assignments = [0] * 2 + [1] * 2 + [2] * 8
true_score = len(points) - np.tensordot(normalize(points),
true_centers[true_assignments])
@ -262,7 +265,7 @@ class KMeansTest(tf.test.TestCase):
self.assertAllClose(score, true_score, atol=1e-2)
def test_fit_raise_if_num_clusters_larger_than_num_points_random_init(self):
points = np.array([[2.0, 3.0], [1.6, 8.2]])
points = np.array([[2.0, 3.0], [1.6, 8.2]], dtype=np.float32)
with self.assertRaisesOpError('less'):
kmeans = KMeans(num_clusters=3, initial_clusters=kmeans_ops.RANDOM_INIT)
@ -270,7 +273,7 @@ class KMeansTest(tf.test.TestCase):
def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus(
self):
points = np.array([[2.0, 3.0], [1.6, 8.2]])
points = np.array([[2.0, 3.0], [1.6, 8.2]], dtype=np.float32)
with self.assertRaisesOpError(AssertionError):
kmeans = KMeans(num_clusters=3,

View File

@ -21,10 +21,12 @@
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace ffmpeg {
@ -62,13 +64,11 @@ class FileDeleter {
class DecodeAudioOp : public OpKernel {
public:
explicit DecodeAudioOp(OpKernelConstruction* context)
: OpKernel(context) {
explicit DecodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
file_format_ = str_util::Lowercase(file_format_);
const std::set<string> valid_file_formats(
kValidFileFormats,
kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1,
errors::InvalidArgument(
"file_format arg must be in {",
@ -79,8 +79,7 @@ class DecodeAudioOp : public OpKernel {
OP_REQUIRES(context, samples_per_second_ > 0,
errors::InvalidArgument("samples_per_second must be > 0."));
OP_REQUIRES_OK(
context, context->GetAttr("channel_count", &channel_count_));
OP_REQUIRES_OK(context, context->GetAttr("channel_count", &channel_count_));
OP_REQUIRES(context, channel_count_ > 0,
errors::InvalidArgument("channel_count must be > 0."));
}
@ -112,12 +111,18 @@ class DecodeAudioOp : public OpKernel {
context, result.ok(),
errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg "
"can be found at http://www.ffmpeg.org."));
} else if (result.code() == error::UNKNOWN) {
LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message()
<< "'. Returning empty tensor.";
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({0, 0}), &output));
return;
} else {
OP_REQUIRES_OK(context, result);
}
OP_REQUIRES(
context, !output_samples.empty(),
errors::Unknown("No output created by FFmpeg."));
OP_REQUIRES(context, !output_samples.empty(),
errors::Unknown("No output created by FFmpeg."));
OP_REQUIRES(
context, output_samples.size() % channel_count_ == 0,
errors::Unknown("FFmpeg created non-integer number of audio frames."));
@ -125,9 +130,9 @@ class DecodeAudioOp : public OpKernel {
// Copy the output data to the output Tensor.
Tensor* output = nullptr;
const int64 frame_count = output_samples.size() / channel_count_;
OP_REQUIRES_OK(
context, context->allocate_output(
0, TensorShape({frame_count, channel_count_}), &output));
OP_REQUIRES_OK(context,
context->allocate_output(
0, TensorShape({frame_count, channel_count_}), &output));
auto matrix = output->tensor<float, 2>();
for (int32 frame = 0; frame < frame_count; ++frame) {
for (int32 channel = 0; channel < channel_count_; ++channel) {
@ -151,6 +156,15 @@ REGISTER_OP("DecodeAudio")
.Attr("file_format: string")
.Attr("samples_per_second: int")
.Attr("channel_count: int")
.SetShapeFn([](shape_inference::InferenceContext* c) {
int64 channels;
if (c->GetAttr("channel_count", &channels).ok()) {
c->set_output(0, c->Matrix(c->UnknownDim(), channels));
} else {
c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
}
return Status::OK();
})
.Doc(R"doc(
Processes the contents of an audio file into a tensor using FFmpeg to decode
the file.
@ -162,7 +176,8 @@ different from the contents of the file, channels will be merged or created.
contents: The binary audio file contents.
sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0
is time and dimension 1 is the channel.
is time and dimension 1 is the channel. If ffmpeg fails to decode the audio
then an empty tensor will be returned.
file_format: A string describing the audio file format. This can be "wav" or
"mp3".
samples_per_second: The number of samples per second that the audio should have.

View File

@ -72,6 +72,14 @@ class DecodeAudioOpTest(tf.test.TestCase):
def testOgg(self):
self._loadFileAndTest('mono_10khz.ogg', 'ogg', 0.57, 10000, 1)
def testInvalidFile(self):
with self.test_session():
contents = 'invalid file'
audio_op = ffmpeg.decode_audio(contents, file_format='wav',
samples_per_second=10000, channel_count=2)
audio = audio_op.eval()
self.assertEqual(audio.shape, (0, 0))
if __name__ == '__main__':
tf.test.main()

View File

@ -38,7 +38,6 @@ namespace {
const char kFfmpegExecutable[] = "ffmpeg";
const int32 kDefaultProbeSize = 5000000; // 5MB
std::vector<string> FfmpegCommandLine(const string& input_filename,
const string& output_filename,
const string& input_format_id,
@ -63,6 +62,39 @@ std::vector<string> FfmpegCommandLine(const string& input_filename,
};
}
// Is a named binary installed and executable by the current process?
// Note that this is harder than it seems like it should be...
bool IsBinaryInstalled(const string& binary_name) {
string path = ::getenv("PATH");
for (const string& dir : str_util::Split(path, ':')) {
const string binary_path = io::JoinPath(dir, binary_name);
char absolute_path[PATH_MAX + 1];
::realpath(binary_path.c_str(), absolute_path);
struct stat statinfo;
int result = ::stat(absolute_path, &statinfo);
if (result < 0) {
continue;
}
if (!S_ISREG(statinfo.st_mode)) {
continue;
}
// Is the current user able to execute the file?
if (statinfo.st_uid == ::geteuid() && statinfo.st_mode & S_IXUSR) {
return true;
}
// Is the current group able to execute the file?
if (statinfo.st_uid == ::getegid() && statinfo.st_mode & S_IXGRP) {
return true;
}
// Is anyone able to execute the file?
if (statinfo.st_mode & S_IXOTH) {
return true;
}
}
return false;
}
[[noreturn]] int ExecuteFfmpeg(const std::vector<string>& args) {
std::vector<char*> args_chars;
std::transform(args.begin(), args.end(), std::back_inserter(args_chars),
@ -191,6 +223,14 @@ Status ReadAudioFile(const string& filename,
FfmpegCommandLine(filename, output_filename, audio_format_id,
samples_per_second, channel_count);
// Unfortunately, it's impossible to differentiate an exec failure due to the
// binary being missing and an error from the binary's execution. Therefore,
// check to see if the binary *should* be available. If not, return an error
// that will be converted into a helpful error message by the TensorFlow op.
if (!IsBinaryInstalled(kFfmpegExecutable)) {
return Status(error::Code::NOT_FOUND, StrCat("FFmpeg could not be found."));
}
// Execute ffmpeg and report errors.
pid_t child_pid = ::fork();
if (child_pid < 0) {
@ -202,7 +242,7 @@ Status ReadAudioFile(const string& filename,
int status_code;
::waitpid(child_pid, &status_code, 0);
if (status_code) {
return Status(error::Code::NOT_FOUND,
return Status(error::Code::UNKNOWN,
StrCat("FFmpeg execution failed: ", status_code));
}
*output_samples = ReadPcmFile(output_filename);

View File

@ -16,6 +16,7 @@
#include <limits>
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -24,8 +25,7 @@ namespace ffmpeg {
class EncodeAudioOp : public OpKernel {
public:
explicit EncodeAudioOp(OpKernelConstruction* context)
: OpKernel(context) {
explicit EncodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
file_format_ = str_util::Lowercase(file_format_);
OP_REQUIRES(context, file_format_ == "wav",
@ -35,15 +35,15 @@ class EncodeAudioOp : public OpKernel {
context, context->GetAttr("samples_per_second", &samples_per_second_));
OP_REQUIRES(context, samples_per_second_ > 0,
errors::InvalidArgument("samples_per_second must be > 0."));
OP_REQUIRES_OK(
context, context->GetAttr("bits_per_second", &bits_per_second_));
OP_REQUIRES_OK(context,
context->GetAttr("bits_per_second", &bits_per_second_));
}
void Compute(OpKernelContext* context) override {
// Get and verify the input data.
OP_REQUIRES(context, context->num_inputs() == 1,
errors::InvalidArgument(
"EncodeAudio requires exactly one input."));
OP_REQUIRES(
context, context->num_inputs() == 1,
errors::InvalidArgument("EncodeAudio requires exactly one input."));
const Tensor& contents = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()),
errors::InvalidArgument(
@ -88,6 +88,7 @@ REGISTER_OP("EncodeAudio")
.Attr("file_format: string")
.Attr("samples_per_second: int")
.Attr("bits_per_second: int = 192000")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Processes a `Tensor` containing sampled audio with the number of channels
and length of the audio specified by the dimensions of the `Tensor`. The

View File

@ -67,7 +67,8 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
Returns:
A rank 2 tensor that has time along dimension 0 and channels along
dimension 1. Dimension 0 will be `samples_per_second * length` wide, and
dimension 1 will be `channel_count` wide.
dimension 1 will be `channel_count` wide. If ffmpeg fails to decode the
audio then an empty tensor will be returned.
"""
return gen_decode_audio_op_py.decode_audio(
contents, file_format=file_format, samples_per_second=samples_per_second,

View File

@ -14,6 +14,7 @@ py_library(
srcs = [
"__init__.py",
"python/framework/__init__.py",
"python/framework/checkpoint_utils.py",
"python/framework/deprecation.py",
"python/framework/tensor_util.py",
"python/ops/__init__.py",
@ -35,10 +36,19 @@ py_test(
deps = ["//tensorflow:tensorflow_py"],
)
py_test(
name = "checkpoint_utils_test",
size = "small",
srcs = ["python/framework/checkpoint_utils_test.py"],
srcs_version = "PY2AND3",
tags = ["manual"], # http://b/30468735
deps = ["//tensorflow:tensorflow_py"],
)
py_test(
name = "ops_test",
size = "small",
srcs = glob(["python/ops/ops_test.py"]),
srcs = ["python/ops/ops_test.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
@ -51,9 +61,16 @@ py_test(
deps = ["//tensorflow:tensorflow_py"],
)
py_test(
name = "deprecation_test",
srcs = ["python/framework/deprecation_test.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
py_test(
name = "tensor_util_test",
srcs = glob(["python/framework/tensor_util_test.py"]),
srcs = ["python/framework/tensor_util_test.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
@ -61,7 +78,7 @@ py_test(
py_test(
name = "variables_test",
size = "small",
srcs = glob(["python/ops/variables_test.py"]),
srcs = ["python/ops/variables_test.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
@ -74,6 +91,15 @@ py_test(
deps = ["//tensorflow:tensorflow_py"],
)
py_test(
name = "sampling_ops_threading_test",
size = "small",
srcs = ["python/ops/sampling_ops_threading_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
deps = ["//tensorflow:tensorflow_py"],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -30,6 +30,7 @@
## Deprecation
@@deprecated
@@deprecated_arg_values
## Arg_Scope
@@arg_scope

View File

@ -19,5 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
from tensorflow.contrib.framework.python.framework.deprecation import deprecated
from tensorflow.contrib.framework.python.framework.deprecation import deprecated_arg_values
from tensorflow.contrib.framework.python.framework.tensor_util import *

View File

@ -0,0 +1,288 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools to work with checkpoints."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver
from tensorflow.python.training import training as train
__all__ = [
"load_checkpoint",
"load_variable",
"list_variables",
"init_from_checkpoint"]
def _get_checkpoint_filename(filepattern):
"""Returns checkpoint filename given directory or specific filepattern."""
if gfile.IsDirectory(filepattern):
return saver.latest_checkpoint(filepattern)
return filepattern
def load_checkpoint(filepattern):
"""Returns CheckpointReader for latest checkpoint.
Args:
filepattern: Directory with checkpoints file or path to checkpoint.
Returns:
`CheckpointReader` object.
Raises:
ValueError: if checkpoint_dir doesn't have 'checkpoint' file or checkpoints.
"""
filename = _get_checkpoint_filename(filepattern)
if filename is None:
raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
"given directory %s" % filepattern)
return train.NewCheckpointReader(filename)
def load_variable(checkpoint_dir, name):
"""Returns a Tensor with the contents of the given variable in the checkpoint.
Args:
checkpoint_dir: Directory with checkpoints file or path to checkpoint.
name: Name of the tensor to return.
Returns:
`Tensor` object.
"""
# TODO(b/29227106): Fix this in the right place and remove this.
if name.endswith(":0"):
name = name[:-2]
reader = load_checkpoint(checkpoint_dir)
return reader.get_tensor(name)
def list_variables(checkpoint_dir):
"""Returns list of all variables in the latest checkpoint.
Args:
checkpoint_dir: Directory with checkpoints file or path to checkpoint.
Returns:
List of tuples `(name, shape)`.
"""
reader = load_checkpoint(checkpoint_dir)
variable_map = reader.get_variable_to_shape_map()
names = sorted(variable_map.keys())
result = []
for name in names:
result.append((name, variable_map[name]))
return result
# pylint: disable=protected-access
# Currently variable_scope doesn't provide very good APIs to access
# all variables under scope and retrieve and check existing scopes.
# TODO(ipolosukhin): Refactor variable_scope module to provide nicer APIs.
def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec,
name="checkpoint_initializer"):
"""Sets variable initializer to assign op form value in checkpoint's tensor.
Args:
variable: `Variable` object.
file_pattern: string, where to load checkpoints from.
tensor_name: Name of the `Tensor` to load from checkpoint reader.
slice_spec: Slice specification for loading partitioned variables.
name: Name of the operation.
"""
base_type = variable.dtype.base_dtype
restore_op = gen_io_ops._restore_slice(
file_pattern,
tensor_name,
slice_spec,
base_type,
preferred_shard=-1,
name=name)
variable._initializer_op = state_ops.assign(variable, restore_op)
def _set_variable_or_list_initializer(variable_or_list, file_pattern,
tensor_name):
if isinstance(variable_or_list, (list, tuple)):
# A set of slices.
slice_name = None
for v in variable_or_list:
if slice_name is None:
slice_name = v._save_slice_info.full_name
elif slice_name != v._save_slice_info.full_name:
raise ValueError("Slices must all be from the same tensor: %s != %s" %
(slice_name, v._save_slice_info.full_name))
_set_checkpoint_initializer(v, file_pattern, tensor_name,
v._save_slice_info.spec)
else:
_set_checkpoint_initializer(variable_or_list, file_pattern, tensor_name, "")
def init_from_checkpoint(checkpoint_dir, assignment_map):
"""Using assingment map initializes current variables with loaded tensors.
Note: This overrides default initialization ops of specified variables and
redefines dtype.
Assignment map supports following syntax:
`'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
current `scope_name` from `checkpoint_scope_name` with matching variable
names.
`'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
will initalize `scope_name/variable_name` variable
from `checkpoint_scope_name/some_other_variable`.
`'scope_variable_name': variable` - will initialize given `tf.Variable`
object with variable from the checkpoint.
`'scope_variable_name': list(variable)` - will initialize list of
partitioned variables with variable from the checkpoint.
`'scope_name/': '/'` - will load all variables in current `scope_name` from
checkpoint's root (e.g. no scope).
Supports loading into partitioned variables, which are represented as
'<variable>/part_<part #>'.
Example:
```python
# Create variables.
with tf.variable_scope('test'):
m = tf.get_variable('my_var')
with tf.variable_scope('test2'):
var2 = tf.get_variable('my_var')
var3 = tf.get_variable(name="my1", shape=[100, 100],
partitioner=lambda shape, dtype: [5, 1])
...
# Specify which variables to intialize from checkpoint.
init_from_checkpoint(checkpoint_dir, {
'some_var': 'test/my_var',
'some_scope/': 'test2/'})
...
# Or use `Variable` objects to identify what to initialize.
init_from_checkpoint(checkpoint_dir, {
'some_scope/var2': var2,
})
# Initialize partitioned variables
init_from_checkpoint(checkpoint_dir, {
'some_var_from_ckpt': 'part_var',
})
# Or specifying the list of `Variable` objects.
init_from_checkpoint(checkpoint_dir, {
'some_var_from_ckpt': var3._get_variable_list(),
})
...
# Initialize variables as usual.
session.run(tf.get_all_variables())
```
Args:
checkpoint_dir: Directory with checkpoints file or path to checkpoint.
assignment_map: Dict, where keys are names of the variables in the
checkpoint and values are current variables or names of current variables
(in default graph).
Raises:
tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
ValueError: If missing variables in current graph.
"""
filepattern = _get_checkpoint_filename(checkpoint_dir)
reader = load_checkpoint(checkpoint_dir)
variable_map = reader.get_variable_to_shape_map()
for tensor_name_in_ckpt, current_var_or_name in six.iteritems(assignment_map):
var = None
# Check if this is Variable object or list of Variable objects (in case of
# partitioned variables).
is_var = lambda x: isinstance(x, variables.Variable)
if is_var(current_var_or_name) or (
isinstance(current_var_or_name, list)
and all(is_var(v) for v in current_var_or_name)):
var = current_var_or_name
else:
var_scope = vs._get_default_variable_store()
# Check if this variable is in var_store.
var = var_scope._vars.get(current_var_or_name, None)
# Also check if variable is partitioned as list.
if var is None:
if current_var_or_name + "/part_0" in var_scope._vars:
var = []
i = 0
while current_var_or_name + "/part_%d" % i in var_scope._vars:
var.append(var_scope._vars[current_var_or_name + "/part_%d" % i])
i += 1
if var is not None:
# If 1 to 1 mapping was provided, find variable in the checkpoint.
if tensor_name_in_ckpt not in variable_map:
raise ValueError("Tensor %s is not found in %s checkpoint" % (
tensor_name_in_ckpt, checkpoint_dir
))
if is_var(var):
# Additional at-call-time checks.
if not var.get_shape().is_compatible_with(
variable_map[tensor_name_in_ckpt]):
raise ValueError(
"Shape of variable %s (%s) doesn't match with shape of "
"tensor %s (%s) from checkpoint reader." % (
var.name, str(var.get_shape()),
tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
))
var_name = var.name
else:
var_name = ",".join([v.name for v in var])
_set_variable_or_list_initializer(var, filepattern, tensor_name_in_ckpt)
logging.info("Initialize variable %s from checkpoint %s with %s" % (
var_name, checkpoint_dir, tensor_name_in_ckpt
))
else:
scopes = ""
# TODO(vihanjain): Support list of 'current_var_or_name' here.
if "/" in current_var_or_name:
scopes = current_var_or_name[:current_var_or_name.rindex("/")]
if not tensor_name_in_ckpt.endswith("/"):
raise ValueError(
"Assignment map with scope only name {} should map to scope only "
"{}. Should be 'scope/': 'other_scope/'.".format(
scopes, tensor_name_in_ckpt))
# If scope to scope mapping was provided, find all variables in the scope.
for var_name in var_scope._vars:
if var_name.startswith(scopes):
# Lookup name with specified prefix and suffix from current variable.
# If tensor_name given is '/' (root), don't use it for full name.
if tensor_name_in_ckpt != "/":
full_tensor_name = tensor_name_in_ckpt + var_name[len(scopes) + 1:]
else:
full_tensor_name = var_name[len(scopes) + 1:]
if full_tensor_name not in variable_map:
raise ValueError(
"Tensor %s (%s in %s) is not found in %s checkpoint" % (
full_tensor_name, var_name[len(scopes) + 1:],
tensor_name_in_ckpt, checkpoint_dir
))
var = var_scope._vars[var_name]
_set_variable_or_list_initializer(var, filepattern, full_tensor_name)
logging.info("Initialize variable %s from checkpoint %s with %s" % (
var_name, checkpoint_dir, full_tensor_name
))
# pylint: enable=protected-access

View File

@ -23,8 +23,6 @@ import os
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.utils import checkpoints
def _create_checkpoints(sess, checkpoint_dir):
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
@ -45,15 +43,14 @@ def _create_checkpoints(sess, checkpoint_dir):
def _create_partition_checkpoints(sess, checkpoint_dir):
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
checkpoint_state_name = "checkpoint"
# TODO(ipolosukhin): Enable this when get_variable partitioning works.
# v1 = tf.get_variable("var1", [100, 100],
# partitioner=tf.variable_axis_size_partitioner(axis=0,
# max_shard_bytes=512))
v1 = tf.create_partitioned_variables(
shape=[100, 100], slicing=[5, 1], name="var1",
initializer=tf.truncated_normal_initializer(0.5))
v1 = tf.get_variable(
name="var1",
shape=[100, 100],
initializer=tf.truncated_normal_initializer(0.5),
partitioner=tf.min_max_variable_partitioner(max_partitions=5, axis=0,
min_slice_size=8 << 10))
sess.run(tf.initialize_all_variables())
v1_value = sess.run(v1)
v1_value = sess.run(v1._get_variable_list())
saver = tf.train.Saver()
saver.save(sess, checkpoint_prefix, global_step=0,
latest_filename=checkpoint_state_name)
@ -65,30 +62,36 @@ class CheckpointsTest(tf.test.TestCase):
def testNoCheckpoints(self):
checkpoint_dir = self.get_temp_dir() + "/no_checkpoints"
with self.assertRaises(tf.errors.OpError):
self.assertAllEqual(checkpoints.load_variable(checkpoint_dir, "var1"), [])
self.assertAllEqual(tf.contrib.framework.load_variable(
checkpoint_dir, "var1"), [])
def testNoTensor(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
with self.assertRaises(tf.errors.OpError):
self.assertAllEqual(checkpoints.load_variable(checkpoint_dir, "var5"), [])
self.assertAllEqual(tf.contrib.framework.load_variable(
checkpoint_dir, "var5"), [])
def testGetTensor(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
self.assertAllEqual(checkpoints.load_variable(checkpoint_dir, "var1"), v1)
self.assertAllEqual(checkpoints.load_variable(checkpoint_dir, "var2"), v2)
self.assertAllEqual(checkpoints.load_variable(checkpoint_dir, "var3"), v3)
self.assertAllEqual(tf.contrib.framework.load_variable(
checkpoint_dir, "var1"), v1)
self.assertAllEqual(tf.contrib.framework.load_variable(
checkpoint_dir, "var2"), v2)
self.assertAllEqual(tf.contrib.framework.load_variable(
checkpoint_dir, "var3"), v3)
self.assertAllEqual(
checkpoints.load_variable(checkpoint_dir, "useful_scope/var4"), v4)
tf.contrib.framework.load_variable(
checkpoint_dir, "useful_scope/var4"), v4)
def testGetAllVariables(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
_create_checkpoints(session, checkpoint_dir)
self.assertEqual(checkpoints.list_variables(checkpoint_dir),
self.assertEqual(tf.contrib.framework.list_variables(checkpoint_dir),
[("useful_scope/var4", [9, 9]),
("var1", [1, 10]),
("var2", [10, 10]),
@ -110,13 +113,13 @@ class CheckpointsTest(tf.test.TestCase):
my4 = tf.get_variable("var4", [9, 9])
my3 = tf.get_variable("my3", [100, 100])
checkpoints.init_from_checkpoint(checkpoint_dir, {
"some_scope/my1": "var1",
"some_scope/some_other_scope/other_useful_scope/": "useful_scope/",
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
"var1": "some_scope/my1",
"useful_scope/": "some_scope/some_other_scope/other_useful_scope/",
})
checkpoints.init_from_checkpoint(checkpoint_dir, {
"some_scope/some_other_scope/my2": "var2",
my3: "var3",
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
"var2": "some_scope/some_other_scope/my2",
"var3": my3,
})
session.run(tf.initialize_all_variables())
@ -143,8 +146,8 @@ class CheckpointsTest(tf.test.TestCase):
with tf.variable_scope("useful_scope"):
my4 = tf.get_variable("var4", [9, 9])
checkpoints.init_from_checkpoint(checkpoint_dir, {
"some_scope/": "/",
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
"/": "some_scope/",
})
session.run(tf.initialize_all_variables())
@ -162,23 +165,40 @@ class CheckpointsTest(tf.test.TestCase):
with tf.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with tf.variable_scope("some_scope"):
# TODO(ipolosukhin): Enable this when get_variable partitioning works.
# Currently get_variable with partitioner doesn't return Variable,
# but returns a concat op.
# my1 = tf.get_variable(
# "my1", [100, 100],
# partitioner=tf.variable_axis_size_partitioner(axis=0,
# max_shard_bytes=100))
my1 = tf.create_partitioned_variables(
shape=[100, 100], slicing=[5, 1], name="my1",
initializer=tf.truncated_normal_initializer(0.5))
my1 = tf.get_variable(
name="my1",
shape=[100, 100],
initializer=tf.truncated_normal_initializer(0.5),
partitioner=tf.min_max_variable_partitioner(
max_partitions=5, axis=0, min_slice_size=8 << 10))
my1_var_list = my1._get_variable_list()
checkpoints.init_from_checkpoint(checkpoint_dir, {
"some_scope/my1": "var1",
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
"var1": "some_scope/my1",
})
session.run(tf.initialize_all_variables())
my1_values = session.run(my1)
my1_values = session.run(my1_var_list)
self.assertAllEqual(my1_values, v1)
# New graph and session.
with tf.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with tf.variable_scope("some_scope"):
my1 = tf.get_variable(
name="my1",
shape=[100, 100],
initializer=tf.truncated_normal_initializer(0.5),
partitioner=tf.min_max_variable_partitioner(
max_partitions=5, axis=0, min_slice_size=8 << 10))
my1_var_list = my1._get_variable_list()
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
"var1": my1_var_list,
})
session.run(tf.initialize_all_variables())
my1_values = session.run(my1_var_list)
self.assertAllEqual(my1_values, v1)
def testInitFromCheckpointMissing(self):
@ -196,33 +216,33 @@ class CheckpointsTest(tf.test.TestCase):
# No directory.
with self.assertRaises(tf.errors.OpError):
checkpoints.init_from_checkpoint("no_dir", {
"some_scope/my1": "var1"})
tf.contrib.framework.init_from_checkpoint("no_dir", {
"var1": "some_scope/my1"})
# No variable in checkpoint.
with self.assertRaises(ValueError):
checkpoints.init_from_checkpoint(checkpoint_dir, {
"some_scope/my1": "no_var"})
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
"no_var": "some_scope/my1"})
# No variable in the graph.
with self.assertRaises(ValueError):
checkpoints.init_from_checkpoint(checkpoint_dir, {
"some_scope/no_var": "var3"})
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
"var3": "some_scope/no_var"})
# Shape mismatch.
with self.assertRaises(ValueError):
checkpoints.init_from_checkpoint(checkpoint_dir, {
"some_scope/my1": "var1"})
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
"var1": "some_scope/my1"})
# Variable 'my1' and 'my2' are missing in given checkpoint scope.
with self.assertRaises(ValueError):
checkpoints.init_from_checkpoint(checkpoint_dir, {
"some_scope/": "useful_scope/"})
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
"useful_scope/": "some_scope/"})
# Mapping is not to scope name.
with self.assertRaises(ValueError):
checkpoints.init_from_checkpoint(checkpoint_dir, {
"some_scope/": "useful_scope"})
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
"useful_scope": "some_scope/"})
if __name__ == "__main__":
tf.test.main()

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import inspect
import re
from tensorflow.python.platform import tf_logging as logging
@ -34,43 +36,77 @@ def _get_qualified_name(function):
return function.__name__
def _add_deprecation_to_docstring(doc, date, instructions):
def _add_deprecation_to_docstring(
doc, instructions, no_doc_str, suffix_str, notice):
"""Adds a deprecation notice to a docstring."""
lines = doc.splitlines()
if not doc:
lines = [no_doc_str]
else:
lines = doc.splitlines()
lines[0] += ' ' + suffix_str
lines[0] += ' (deprecated)'
notice = [
'',
'THIS FUNCTION IS DEPRECATED. It will be removed after %s.' % date,
'Instructions for updating:',
'%s' % instructions,
]
notice = [''] + notice + [instructions]
if len(lines) > 1:
# Make sure that we keep our distance from the main body
if lines[1].strip():
notice += ['']
notice.append('')
lines = [lines[0]] + notice + lines[1:]
lines[1:1] = notice
else:
lines += notice
return '\n'.join(lines)
def _add_deprecated_function_notice_to_docstring(doc, date, instructions):
"""Adds a deprecation notice to a docstring for deprecated functions."""
return _add_deprecation_to_docstring(
doc, instructions,
'DEPRECATED FUNCTION',
'(deprecated)', [
'THIS FUNCTION IS DEPRECATED. It will be removed after %s.' % date,
'Instructions for updating:'])
def _add_deprecated_arg_notice_to_docstring(doc, date, instructions):
"""Adds a deprecation notice to a docstring for deprecated arguments."""
return _add_deprecation_to_docstring(
doc, instructions,
'DEPRECATED FUNCTION ARGUMENTS',
'(deprecated arguments)', [
'SOME ARGUMENTS ARE DEPRECATED. '
'They will be removed after %s.' % date,
'Instructions for updating:'])
def _validate_deprecation_args(date, instructions):
if not date:
raise ValueError('Tell us what date this will be deprecated!')
if not re.match(r'20\d\d-[01]\d-[0123]\d', date):
raise ValueError('Date must be YYYY-MM-DD.')
if not instructions:
raise ValueError('Don\'t deprecate things without conversion instructions!')
def _validate_callable(func, decorator_name):
if not hasattr(func, '__call__'):
raise ValueError(
'%s is not a function. If this is a property, '
'apply @%s after @property.' % (func, decorator_name))
def deprecated(date, instructions):
"""Decorator for marking functions or methods deprecated.
This decorator adds a deprecation warning to a function's docstring. It has
the following format:
This decorator logs a deprecation warning whenever the decorated function is
called. It has the following format:
<function> (from <module>) is deprecated and will be removed after <date>.
Instructions for updating:
<instructions>
whenever the decorated function is called. <function> will include the class
name if it is a method.
<function> will include the class name if it is a method.
It also edits the docstring of the function: ' (deprecated)' is appended
to the first line of the docstring and a deprecation notice is prepended
@ -88,24 +124,73 @@ def deprecated(date, instructions):
Raises:
ValueError: If date is not in ISO 8601 format, or instructions are empty.
"""
if not date:
raise ValueError('Tell us what date this will be deprecated!')
if not re.match(r'20\d\d-[01]\d-[0123]\d', date):
raise ValueError('Date must be YYYY-MM-DD.')
if not instructions:
raise ValueError('Don\'t deprecate things without conversion instructions!')
_validate_deprecation_args(date, instructions)
def deprecated_wrapper(func):
"""Deprecation wrapper."""
_validate_callable(func, 'deprecated')
@functools.wraps(func)
def new_func(*args, **kwargs):
logging.warn('%s (from %s) is deprecated and will be removed after %s.\n'
'Instructions for updating:\n%s',
_get_qualified_name(func), func.__module__,
date, instructions)
logging.warning(
'%s (from %s) is deprecated and will be removed after %s.\n'
'Instructions for updating:\n%s',
_get_qualified_name(func), func.__module__, date, instructions)
return func(*args, **kwargs)
new_func.__name__ = func.__name__
new_func.__doc__ = _add_deprecation_to_docstring(func.__doc__, date,
instructions)
new_func.__dict__.update(func.__dict__)
new_func.__doc__ = _add_deprecated_function_notice_to_docstring(
func.__doc__, date, instructions)
return new_func
return deprecated_wrapper
def deprecated_arg_values(date, instructions, **deprecated_kwargs):
"""Decorator for marking specific function argument values as deprecated.
This decorator logs a deprecation warning whenever the decorated function is
called with the deprecated argument values. It has the following format:
Calling <function> (from <module>) with <arg>=<value> is deprecated and
will be removed after <date>. Instructions for updating:
<instructions>
<function> will include the class name if it is a method.
It also edits the docstring of the function: ' (deprecated arguments)' is
appended to the first line of the docstring and a deprecation notice is
prepended to the rest of the docstring.
Args:
date: String. The date the function is scheduled to be removed. Must be
ISO 8601 (YYYY-MM-DD).
instructions: String. Instructions on how to update code using the
deprecated function.
**deprecated_kwargs: The deprecated argument values.
Returns:
Decorated function or method.
Raises:
ValueError: If date is not in ISO 8601 format, or instructions are empty.
"""
_validate_deprecation_args(date, instructions)
if not deprecated_kwargs:
raise ValueError('Specify which argument values are deprecated.')
def deprecated_wrapper(func):
"""Deprecation decorator."""
_validate_callable(func, 'deprecated_arg_values')
@functools.wraps(func)
def new_func(*args, **kwargs):
"""Deprecation wrapper."""
named_args = inspect.getcallargs(func, *args, **kwargs)
for arg_name, arg_value in deprecated_kwargs.items():
if arg_name in named_args and named_args[arg_name] == arg_value:
logging.warning(
'Calling %s (from %s) with %s=%s is deprecated and will be '
'removed after %s.\nInstructions for updating:\n%s',
_get_qualified_name(func), func.__module__,
arg_name, arg_value, date, instructions)
return func(*args, **kwargs)
new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
func.__doc__, date, instructions)
return new_func
return deprecated_wrapper

View File

@ -0,0 +1,488 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""tensor_util tests."""
# pylint: disable=unused-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.framework.python.framework import deprecation
from tensorflow.python.platform import tf_logging as logging
class DeprecationTest(tf.test.TestCase):
def _assert_subset(self, expected_subset, actual_set):
self.assertTrue(
actual_set.issuperset(expected_subset),
msg="%s is not a superset of %s." % (actual_set, expected_subset))
def test_deprecated_illegal_args(self):
instructions = "This is how you update..."
with self.assertRaisesRegexp(ValueError, "date"):
deprecation.deprecated(None, instructions)
with self.assertRaisesRegexp(ValueError, "date"):
deprecation.deprecated("", instructions)
with self.assertRaisesRegexp(ValueError, "YYYY-MM-DD"):
deprecation.deprecated("07-04-2016", instructions)
date = "2016-07-04"
with self.assertRaisesRegexp(ValueError, "instructions"):
deprecation.deprecated(date, None)
with self.assertRaisesRegexp(ValueError, "instructions"):
deprecation.deprecated(date, "")
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_static_fn_with_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated(date, instructions)
def _fn(arg0, arg1):
"""fn doc.
Args:
arg0: Arg 0.
arg1: Arg 1.
Returns:
Sum of args.
"""
return arg0 + arg1
# Assert function docs are properly updated.
self.assertEqual("_fn", _fn.__name__)
self.assertEqual(
"fn doc. (deprecated)"
"\n"
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
"\nInstructions for updating:\n%s"
"\n"
"\n Args:"
"\n arg0: Arg 0."
"\n arg1: Arg 1."
"\n"
"\n Returns:"
"\n Sum of args."
"\n " % (date, instructions),
_fn.__doc__)
# Assert calling new fn issues log warning.
self.assertEqual(3, _fn(1, 2))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_static_fn_with_one_line_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated(date, instructions)
def _fn(arg0, arg1):
"""fn doc."""
return arg0 + arg1
# Assert function docs are properly updated.
self.assertEqual("_fn", _fn.__name__)
self.assertEqual(
"fn doc. (deprecated)"
"\n"
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
"\nInstructions for updating:\n%s" % (date, instructions),
_fn.__doc__)
# Assert calling new fn issues log warning.
self.assertEqual(3, _fn(1, 2))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_static_fn_no_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated(date, instructions)
def _fn(arg0, arg1):
return arg0 + arg1
# Assert function docs are properly updated.
self.assertEqual("_fn", _fn.__name__)
self.assertEqual(
"DEPRECATED FUNCTION"
"\n"
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
"\nInstructions for updating:"
"\n%s" % (date, instructions),
_fn.__doc__)
# Assert calling new fn issues log warning.
self.assertEqual(3, _fn(1, 2))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_instance_fn_with_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
class _Object(object):
def __init(self):
pass
@deprecation.deprecated(date, instructions)
def _fn(self, arg0, arg1):
"""fn doc.
Args:
arg0: Arg 0.
arg1: Arg 1.
Returns:
Sum of args.
"""
return arg0 + arg1
# Assert function docs are properly updated.
self.assertEqual(
"fn doc. (deprecated)"
"\n"
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
"\nInstructions for updating:\n%s"
"\n"
"\n Args:"
"\n arg0: Arg 0."
"\n arg1: Arg 1."
"\n"
"\n Returns:"
"\n Sum of args."
"\n " % (date, instructions),
getattr(_Object, "_fn").__doc__)
# Assert calling new fn issues log warning.
self.assertEqual(3, _Object()._fn(1, 2))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_instance_fn_with_one_line_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
class _Object(object):
def __init(self):
pass
@deprecation.deprecated(date, instructions)
def _fn(self, arg0, arg1):
"""fn doc."""
return arg0 + arg1
# Assert function docs are properly updated.
self.assertEqual(
"fn doc. (deprecated)"
"\n"
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
"\nInstructions for updating:\n%s" % (date, instructions),
getattr(_Object, "_fn").__doc__)
# Assert calling new fn issues log warning.
self.assertEqual(3, _Object()._fn(1, 2))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_instance_fn_no_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
class _Object(object):
def __init(self):
pass
@deprecation.deprecated(date, instructions)
def _fn(self, arg0, arg1):
return arg0 + arg1
# Assert function docs are properly updated.
self.assertEqual(
"DEPRECATED FUNCTION"
"\n"
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
"\nInstructions for updating:"
"\n%s" % (date, instructions),
getattr(_Object, "_fn").__doc__)
# Assert calling new fn issues log warning.
self.assertEqual(3, _Object()._fn(1, 2))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_prop_wrong_order(self, mock_warning):
with self.assertRaisesRegexp(
ValueError, "apply @deprecated after @property"):
# pylint: disable=unused-variable
class _Object(object):
def __init(self):
pass
@deprecation.deprecated("2016-07-04", "Instructions.")
@property
def _prop(self):
return "prop_wrong_order"
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_prop_with_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
class _Object(object):
def __init(self):
pass
@property
@deprecation.deprecated(date, instructions)
def _prop(self):
"""prop doc.
Returns:
String.
"""
return "prop_with_doc"
# Assert function docs are properly updated.
self.assertEqual(
"prop doc. (deprecated)"
"\n"
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
"\nInstructions for updating:"
"\n%s"
"\n"
"\n Returns:"
"\n String."
"\n " % (date, instructions),
getattr(_Object, "_prop").__doc__)
# Assert calling new fn issues log warning.
self.assertEqual("prop_with_doc", _Object()._prop)
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_prop_no_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
class _Object(object):
def __init(self):
pass
@property
@deprecation.deprecated(date, instructions)
def _prop(self):
return "prop_no_doc"
# Assert function docs are properly updated.
self.assertEqual(
"DEPRECATED FUNCTION"
"\n"
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
"\nInstructions for updating:"
"\n%s" % (date, instructions),
getattr(_Object, "_prop").__doc__)
# Assert calling new fn issues log warning.
self.assertEqual("prop_no_doc", _Object()._prop)
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
class DeprecatedArgsTest(tf.test.TestCase):
def _assert_subset(self, expected_subset, actual_set):
self.assertTrue(
actual_set.issuperset(expected_subset),
msg="%s is not a superset of %s." % (actual_set, expected_subset))
def test_deprecated_illegal_args(self):
instructions = "This is how you update..."
with self.assertRaisesRegexp(ValueError, "date"):
deprecation.deprecated_arg_values(
None, instructions, deprecated=True)
with self.assertRaisesRegexp(ValueError, "date"):
deprecation.deprecated_arg_values(
"", instructions, deprecated=True)
with self.assertRaisesRegexp(ValueError, "YYYY-MM-DD"):
deprecation.deprecated_arg_values(
"07-04-2016", instructions, deprecated=True)
date = "2016-07-04"
with self.assertRaisesRegexp(ValueError, "instructions"):
deprecation.deprecated_arg_values(
date, None, deprecated=True)
with self.assertRaisesRegexp(ValueError, "instructions"):
deprecation.deprecated_arg_values(
date, "", deprecated=True)
with self.assertRaisesRegexp(ValueError, "argument", deprecated=True):
deprecation.deprecated_arg_values(
date, instructions)
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_static_fn_with_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
def _fn(arg0, arg1, deprecated=True):
"""fn doc.
Args:
arg0: Arg 0.
arg1: Arg 1.
deprecated: Deprecated!
Returns:
Sum of args.
"""
return arg0 + arg1 if deprecated else arg1 + arg0
# Assert function docs are properly updated.
self.assertEqual("_fn", _fn.__name__)
self.assertEqual(
"fn doc. (deprecated arguments)"
"\n"
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
"\nInstructions for updating:\n%s"
"\n"
"\n Args:"
"\n arg0: Arg 0."
"\n arg1: Arg 1."
"\n deprecated: Deprecated!"
"\n"
"\n Returns:"
"\n Sum of args."
"\n " % (date, instructions),
_fn.__doc__)
# Assert calling new fn with non-deprecated value logs nothing.
self.assertEqual(3, _fn(1, 2, deprecated=False))
self.assertEqual(0, mock_warning.call_count)
# Assert calling new fn with deprecated value issues log warning.
self.assertEqual(3, _fn(1, 2, deprecated=True))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
# Assert calling new fn with default deprecated value issues log warning.
self.assertEqual(3, _fn(1, 2))
self.assertEqual(2, mock_warning.call_count)
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_static_fn_with_one_line_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
def _fn(arg0, arg1, deprecated=True):
"""fn doc."""
return arg0 + arg1 if deprecated else arg1 + arg0
# Assert function docs are properly updated.
self.assertEqual("_fn", _fn.__name__)
self.assertEqual(
"fn doc. (deprecated arguments)"
"\n"
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
"\nInstructions for updating:\n%s" % (date, instructions),
_fn.__doc__)
# Assert calling new fn with non-deprecated value logs nothing.
self.assertEqual(3, _fn(1, 2, deprecated=False))
self.assertEqual(0, mock_warning.call_count)
# Assert calling new fn with deprecated value issues log warning.
self.assertEqual(3, _fn(1, 2, deprecated=True))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
# Assert calling new fn with default deprecated value issues log warning.
self.assertEqual(3, _fn(1, 2))
self.assertEqual(2, mock_warning.call_count)
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_static_fn_no_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
def _fn(arg0, arg1, deprecated=True):
return arg0 + arg1 if deprecated else arg1 + arg0
# Assert function docs are properly updated.
self.assertEqual("_fn", _fn.__name__)
self.assertEqual(
"DEPRECATED FUNCTION ARGUMENTS"
"\n"
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
"\nInstructions for updating:"
"\n%s" % (date, instructions),
_fn.__doc__)
# Assert calling new fn with non-deprecated value logs nothing.
self.assertEqual(3, _fn(1, 2, deprecated=False))
self.assertEqual(0, mock_warning.call_count)
# Assert calling new fn issues log warning.
self.assertEqual(3, _fn(1, 2, deprecated=True))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
# Assert calling new fn with default deprecated value issues log warning.
self.assertEqual(3, _fn(1, 2))
self.assertEqual(2, mock_warning.call_count)
if __name__ == "__main__":
tf.test.main()

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import input as input_ops
from tensorflow.python.training import queue_runner
@ -34,10 +35,8 @@ __all__ = ['stratified_sample',
'stratified_sample_unknown_dist',]
# TODO(joelshor): Use an exponential-moving-average to estimate the initial
# class distribution and remove the requirement that it be provided.
def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
enqueue_many=False, queue_capacity=16,
def stratified_sample(tensors, labels, target_probs, batch_size,
init_probs=None, enqueue_many=False, queue_capacity=16,
threads_per_queue=1, name=None):
"""Stochastically creates batches based on per-class probabilities.
@ -52,11 +51,12 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
batch, according to enqueue_many.
labels: Tensor for label of data. Label is a single integer or a batch,
depending on enqueue_many. It is not a one-hot vector.
init_probs: Class proportions in the data. An object whose type has a
registered Tensor conversion function.
target_probs: Target class proportions in batch. An object whose type has a
registered Tensor conversion function.
batch_size: Size of batch to be returned.
init_probs: Class proportions in the data. An object whose type has a
registered Tensor conversion function, or `None` for estimating the
initial distribution.
enqueue_many: Bool. If true, interpret input tensors as having a batch
dimension.
queue_capacity: Capacity of the large queue that holds input examples.
@ -81,10 +81,9 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
data, label = data_provider.Get(['data', 'label'])
# Get stratified batch according to per-class probabilities.
init_probs = [1.0/NUM_CLASSES for _ in range(NUM_CLASSES)]
target_probs = [...distribution you want...]
[data_batch], labels = tf.contrib.framework.sampling_ops.stratified_sample(
[data], label, init_probs, target_probs)
[data], label, target_probs)
# Run batch through network.
...
@ -92,22 +91,34 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
with ops.op_scope(tensors + [labels], name, 'stratified_sample'):
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
labels = ops.convert_to_tensor(labels)
init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)
target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
# Reduce the case of a single example to that of a batch of size 1.
if not enqueue_many:
tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
labels = array_ops.expand_dims(labels, 0)
# If `init_probs` is `None`, set up online estimation of data distribution.
if init_probs is None:
# We use `target_probs` to get the number of classes, so its shape must be
# fully defined at graph construction time.
target_probs.get_shape().assert_is_fully_defined()
init_probs = _estimate_data_distribution(
labels, target_probs.get_shape().num_elements())
else:
init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)
# Validate that input is consistent.
tensor_list, labels, [init_probs, target_probs] = _verify_input(
tensor_list, labels, [init_probs, target_probs])
# Check that all zero initial probabilities also have zero target
# probabilities.
assert_op = logging_ops.Assert(math_ops.reduce_all(math_ops.logical_or(
math_ops.not_equal(init_probs, 0),
math_ops.equal(target_probs, 0))), [init_probs, target_probs])
assert_op = logging_ops.Assert(
math_ops.reduce_all(math_ops.logical_or(
math_ops.not_equal(init_probs, 0),
math_ops.equal(target_probs, 0))),
['All classes with zero initial probability must also have zero target '
'probability: ', init_probs, target_probs])
init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)
# Calculate acceptance sampling probabilities.
@ -212,6 +223,40 @@ def stratified_sample_unknown_dist(tensors, labels, probs, batch_size,
per_class_queues, probs, batch_size)
def _estimate_data_distribution(labels, num_classes):
"""Estimate data distribution as labels are seen."""
# Variable to track running count of classes. Add 1 to avoid division-by-zero,
# and to guarantee that calculation of acceptance probabilities is (mostly)
# correct.
num_examples_per_class_seen = variables.Variable(
initial_value=[1] * num_classes, trainable=False, name='class_count',
dtype=dtypes.int64)
# Update the class-count based on what labels are seen in batch.
num_examples_per_class_seen = num_examples_per_class_seen.assign_add(
math_ops.reduce_sum(array_ops.one_hot(labels, num_classes,
dtype=dtypes.int64), 0))
# Normalize count into a probability.
# NOTE: Without the `+= 0` line below, the test
# `testMultiThreadedEstimateDataDistribution` fails. The reason is that
# before this line, `num_examples_per_class_seen` is a Tensor that shares a
# buffer with an underlying `ref` object. When the `ref` is changed by another
# thread, `num_examples_per_class_seen` changes as well. Since this can happen
# in the middle of the normalization computation, we get probabilities that
# are very far from summing to one. Adding `+= 0` copies the contents of the
# tensor to a new buffer, which will be consistent from the start to the end
# of the normalization computation.
num_examples_per_class_seen += 0
init_prob_estimate = math_ops.truediv(
num_examples_per_class_seen,
math_ops.reduce_sum(num_examples_per_class_seen))
# Must return float32 (not float64) to agree with downstream `_verify_input`
# checks.
return math_ops.cast(init_prob_estimate, dtypes.float32)
def _verify_input(tensor_list, labels, probs_list):
"""Verify that batched inputs are well-formed."""
checked_probs_list = []

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
class SamplingOpsTest(tf.test.TestCase):
@ -33,15 +34,22 @@ class SamplingOpsTest(tf.test.TestCase):
# Curry the rejection sampler so we can easily run the same tests on both
# stratified_sample and stratified_sample_unknown_dist.
def curried_sampler(val, lbls, probs, batch, enqueue_many=True):
def curried_sampler(tensors, labels, probs, batch_size, enqueue_many=True):
return tf.contrib.framework.sampling_ops.stratified_sample(
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
tensors=tensors,
labels=labels,
target_probs=probs,
batch_size=batch_size,
init_probs=initial_p,
enqueue_many=enqueue_many)
samplers = [
tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist,
curried_sampler,
]
for sampler in samplers:
logging.info('Now testing `%s`', sampler.__class__.__name__)
# Label must have only batch dimension if enqueue_many is True.
with self.assertRaises(ValueError):
sampler(val, tf.zeros([]), probs, batch_size, enqueue_many=True)
@ -70,20 +78,21 @@ class SamplingOpsTest(tf.test.TestCase):
# Probabilities shape must be fully defined.
with self.assertRaises(ValueError):
sampler(val, label, tf.placeholder(tf.float32, shape=[None]),
batch_size)
sampler(
val, label, tf.placeholder(
tf.float32, shape=[None]), batch_size)
# In the rejection sampling case, make sure that probability lengths are
# the same.
with self.assertRaises(ValueError):
tf.contrib.framework.sampling_ops.stratified_sample(
val, label, [.2] * 5, [.1] * 10, batch_size)
val, label, [.1] * 10, batch_size, init_probs=[.2] * 5)
# In the rejection sampling case, make sure that zero initial probability
# classes also have zero target probability.
with self.assertRaises(ValueError):
tf.contrib.framework.sampling_ops.stratified_sample(
val, label, [0, .5, .5], [.2, .4, .4], batch_size)
val, label, [.2, .4, .4], batch_size, init_probs=[0, .5, .5])
# Probabilities must be 1D.
with self.assertRaises(ValueError):
@ -116,15 +125,17 @@ class SamplingOpsTest(tf.test.TestCase):
# Run session that should fail.
with self.test_session() as sess:
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run([val_tf, lbl_tf], feed_dict={label_ph: illegal_label,
probs_ph: valid_probs})
sess.run([val_tf, lbl_tf],
feed_dict={label_ph: illegal_label,
probs_ph: valid_probs})
for illegal_prob in illegal_probs:
# Run session that should fail.
with self.test_session() as sess:
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run([prob_tf], feed_dict={label_ph: valid_labels,
probs_ph: illegal_prob})
sess.run([prob_tf],
feed_dict={label_ph: valid_labels,
probs_ph: illegal_prob})
def batchingBehaviorHelper(self, sampler):
batch_size = 20
@ -152,15 +163,14 @@ class SamplingOpsTest(tf.test.TestCase):
lbl_input_batch = tf.ones([], dtype=tf.int32)
probs = np.array([0, 1, 0, 0, 0])
batches = tf.contrib.framework.sampling_ops.stratified_sample(
val_input_batch, lbl_input_batch, probs, probs, batch_size)
val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
batches += tf.contrib.framework.sampling_ops.stratified_sample(
val_input_batch, lbl_input_batch, probs, probs, batch_size)
val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
val_input_batch, lbl_input_batch, probs, batch_size)
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
val_input_batch, lbl_input_batch, probs, batch_size)
summary_op = tf.merge_summary(tf.get_collection(
tf.GraphKeys.SUMMARIES))
summary_op = tf.merge_summary(tf.get_collection(tf.GraphKeys.SUMMARIES))
with self.test_session() as sess:
coord = tf.train.Coordinator()
@ -177,9 +187,15 @@ class SamplingOpsTest(tf.test.TestCase):
def testRejectionBatchingBehavior(self):
initial_p = [0, .3, 0, .7, 0]
def curried_sampler(val, lbls, probs, batch, enqueue_many=True):
return tf.contrib.framework.sampling_ops.stratified_sample(
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
val,
lbls,
probs,
batch,
init_probs=initial_p,
enqueue_many=enqueue_many)
self.batchingBehaviorHelper(curried_sampler)
@ -190,8 +206,7 @@ class SamplingOpsTest(tf.test.TestCase):
lbl2 = 3
# This cond allows the necessary class queues to be populated.
label = tf.cond(
tf.greater(.5, tf.random_uniform([])),
lambda: tf.constant(lbl1),
tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(lbl1),
lambda: tf.constant(lbl2))
val = [np.array([1, 4]) * label]
probs = tf.placeholder(tf.float32, shape=[5])
@ -225,7 +240,7 @@ class SamplingOpsTest(tf.test.TestCase):
def testBatchDimensionNotRequired(self):
classes = 5
# Probs must be a tensor, since we pass it directly to _verify_input.
probs = tf.constant([1.0/classes] * classes)
probs = tf.constant([1.0 / classes] * classes)
# Make sure that these vals/labels pairs don't throw any runtime exceptions.
legal_input_pairs = [
@ -243,16 +258,17 @@ class SamplingOpsTest(tf.test.TestCase):
# Run graph to make sure there are no shape-related runtime errors.
for vals, labels in legal_input_pairs:
with self.test_session() as sess:
sess.run([val_tf, labels_tf], feed_dict={vals_ph: vals,
labels_ph: labels})
sess.run([val_tf, labels_tf],
feed_dict={vals_ph: vals,
labels_ph: labels})
def dataListHelper(self, sampler):
batch_size = 20
val_input_batch = [tf.zeros([2, 3, 4]), tf.ones([2, 4]), tf.ones(2) * 3]
lbl_input_batch = tf.ones([], dtype=tf.int32)
probs = np.array([0, 1, 0, 0, 0])
val_list, lbls = sampler(
val_input_batch, lbl_input_batch, probs, batch_size)
val_list, lbls = sampler(val_input_batch, lbl_input_batch, probs,
batch_size)
# Check output shapes.
self.assertTrue(isinstance(val_list, list))
@ -277,9 +293,16 @@ class SamplingOpsTest(tf.test.TestCase):
def testRejectionDataListInput(self):
initial_p = [0, 1, 0, 0, 0]
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
return tf.contrib.framework.sampling_ops.stratified_sample(
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
val,
lbls,
probs,
batch,
init_probs=initial_p,
enqueue_many=enqueue_many)
self.dataListHelper(curried_sampler)
def normalBehaviorHelper(self, sampler):
@ -289,8 +312,7 @@ class SamplingOpsTest(tf.test.TestCase):
lbl2 = 3
# This cond allows the necessary class queues to be populated.
label = tf.cond(
tf.greater(.5, tf.random_uniform([])),
lambda: tf.constant(lbl1),
tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(lbl1),
lambda: tf.constant(lbl2))
val = [np.array([1, 4]) * label]
probs = np.array([.8, 0, 0, .2, 0])
@ -302,6 +324,9 @@ class SamplingOpsTest(tf.test.TestCase):
data_l = []
label_l = []
with self.test_session() as sess:
# Need to initialize variables that keep running total of classes seen.
tf.initialize_all_variables().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
@ -329,7 +354,7 @@ class SamplingOpsTest(tf.test.TestCase):
# is fixed, for a given implementation, this test will pass or fail 100% of
# the time. This use of assertNear is to cover cases where someone changes
# an implementation detail, which would cause the random behavior to differ.
self.assertNear(actual_lbl, expected_label, 3*lbl_std_dev_of_mean)
self.assertNear(actual_lbl, expected_label, 3 * lbl_std_dev_of_mean)
def testNormalBehavior(self):
self.normalBehaviorHelper(
@ -337,10 +362,26 @@ class SamplingOpsTest(tf.test.TestCase):
def testRejectionNormalBehavior(self):
initial_p = [.7, 0, 0, .3, 0]
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
return tf.contrib.framework.sampling_ops.stratified_sample(
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
val,
lbls,
probs,
batch,
init_probs=initial_p,
enqueue_many=enqueue_many)
self.normalBehaviorHelper(curried_sampler)
def testRejectionNormalBehaviorWithOnlineInitPEstimate(self):
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
return tf.contrib.framework.sampling_ops.stratified_sample(
val, lbls, probs, batch, init_probs=None, enqueue_many=enqueue_many)
self.normalBehaviorHelper(curried_sampler)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,65 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unused-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class SamplingOpsThreadingTest(tf.test.TestCase):
def testMultiThreadedEstimateDataDistribution(self):
num_classes = 10
# Set up graph.
tf.set_random_seed(1234)
label = tf.cast(tf.round(tf.random_uniform([1]) * num_classes), tf.int32)
prob_estimate = tf.contrib.framework.sampling_ops._estimate_data_distribution( # pylint: disable=line-too-long
label, num_classes)
# Check that prob_estimate is well-behaved in a multithreaded context.
_, _, [prob_estimate] = tf.contrib.framework.sampling_ops._verify_input(
[], label, [prob_estimate])
# Use queues to run multiple threads over the graph, each of which
# fetches `prob_estimate`.
queue = tf.FIFOQueue(
capacity=25,
dtypes=[prob_estimate.dtype],
shapes=[prob_estimate.get_shape()])
enqueue_op = queue.enqueue([prob_estimate])
tf.train.add_queue_runner(tf.train.QueueRunner(queue, [enqueue_op] * 25))
out_tensor = queue.dequeue()
# Run the multi-threaded session.
with self.test_session() as sess:
# Need to initialize variables that keep running total of classes seen.
tf.initialize_all_variables().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for _ in range(25):
sess.run([out_tensor])
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
tf.test.main()

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("SparseFeatureCross")
@ -31,6 +32,12 @@ REGISTER_OP("SparseFeatureCross")
.Attr("dense_types: list({int64, string}) >= 0")
.Attr("out_type: {int64, string}")
.Attr("internal_type: {int64, string}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Matrix(c->UnknownDim(), 2));
c->set_output(1, c->Vector(c->UnknownDim()));
c->set_output(2, c->Vector(2));
return Status::OK();
})
.Doc(R"doc(
Generates sparse cross form a list of sparse tensors.

View File

@ -75,6 +75,7 @@ import abc
import collections
import math
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import embedding_ops
from tensorflow.contrib.layers.python.ops import bucketization_op
@ -149,6 +150,7 @@ class _FeatureColumn(object):
raise ValueError("Calling an abstract method.")
# TODO(b/30410315): Support warm starting in all feature columns.
class _SparseColumn(_FeatureColumn,
collections.namedtuple("_SparseColumn",
["column_name", "is_integerized",
@ -191,35 +193,36 @@ class _SparseColumn(_FeatureColumn,
combiner="sum",
dtype=dtypes.string):
if is_integerized and bucket_size is None:
raise ValueError("bucket_size should be set if is_integerized=True. "
raise ValueError("bucket_size must be set if is_integerized is True. "
"column_name: {}".format(column_name))
if is_integerized and not dtype.is_integer:
raise ValueError("dtype should be an integer if is_integerized is True. "
"Column {}.".format(column_name))
raise ValueError("dtype must be an integer if is_integerized is True. "
"dtype: {}, column_name: {}.".format(dtype, column_name))
if bucket_size is None and lookup_config is None:
raise ValueError("one of bucket_size or lookup_config should be "
"set. column_name: {}".format(column_name))
raise ValueError("one of bucket_size or lookup_config must be set. "
"column_name: {}".format(column_name))
if bucket_size is not None and lookup_config:
raise ValueError("one and only one of bucket_size or lookup_config "
"should be set. column_name: {}".format(column_name))
"must be set. column_name: {}".format(column_name))
if bucket_size is not None and bucket_size < 2:
raise ValueError("bucket_size should be at least 2. "
"column_name: {}".format(column_name))
raise ValueError("bucket_size must be at least 2. "
"bucket_size: {}, column_name: {}".format(bucket_size,
column_name))
if ((lookup_config) and
(not isinstance(lookup_config, _SparseIdLookupConfig))):
raise TypeError(
"lookup_config should be an instance of _SparseIdLookupConfig. "
"lookup_config must be an instance of _SparseIdLookupConfig. "
"Given one is in type {} for column_name {}".format(
type(lookup_config), column_name))
if (lookup_config and lookup_config.vocabulary_file and
lookup_config.vocab_size is None):
raise ValueError("vocab_size should be defined. "
raise ValueError("vocab_size must be defined. "
"column_name: {}".format(column_name))
return super(_SparseColumn, cls).__new__(cls, column_name, is_integerized,
@ -260,8 +263,8 @@ class _SparseColumn(_FeatureColumn,
input_tensor,
weight_collections=None,
trainable=True):
raise ValueError("Column {} is not supported in DNN. "
"Please use embedding_column.".format(self))
raise ValueError("SparseColumn is not supported in DNN. "
"Please use embedding_column. column: {}".format(self))
def to_weighted_sum(self,
input_tensor,
@ -277,7 +280,7 @@ class _SparseColumn(_FeatureColumn,
initializer=init_ops.zeros_initializer,
combiner=self.combiner,
trainable=trainable,
name=self.name + "_weights")
name=self.name)
class _SparseColumnIntegerized(_SparseColumn):
@ -289,8 +292,8 @@ class _SparseColumnIntegerized(_SparseColumn):
combiner="sum",
dtype=dtypes.int64):
if not dtype.is_integer:
raise ValueError("dtype should be an integer. Given {}".format(
column_name))
raise ValueError("dtype must be an integer. "
"dtype: {}, column_name: {}".format(dtype, column_name))
return super(_SparseColumnIntegerized, cls).__new__(cls,
column_name,
@ -505,8 +508,8 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
input_tensor,
weight_collections=None,
trainable=True):
raise ValueError("Column {} is not supported in DNN. "
"Please use embedding_column.".format(self))
raise ValueError("WeightedSparseColumn is not supported in DNN. "
"Please use embedding_column. column: {}".format(self))
def to_weighted_sum(self,
input_tensor,
@ -522,7 +525,7 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
initializer=init_ops.zeros_initializer,
combiner=self.sparse_id_column.combiner,
trainable=trainable,
name=self.name + "_weights")
name=self.name)
def weighted_sparse_column(sparse_id_column,
@ -568,7 +571,8 @@ def weighted_sparse_column(sparse_id_column,
class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
"_EmbeddingColumn",
["sparse_id_column", "dimension", "combiner", "initializer"])):
["sparse_id_column", "dimension", "combiner", "initializer",
"ckpt_to_load_from", "tensor_name_in_ckpt"])):
"""Represents an embedding column.
Args:
@ -586,15 +590,33 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
variable initialization. If not specified, defaults to
`tf.truncated_normal_initializer` with mean 0.0 and standard deviation
1/sqrt(sparse_id_column.length).
ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
to restore the column weights. Required if `tensor_name_in_ckpt` is not
None.
tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
checkpoint from which to restore the column weights. Required if
`ckpt_to_load_from` is not None.
Raises:
ValueError: if `initializer` is specified and is not callable. Also,
if only one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified.
"""
def __new__(cls,
sparse_id_column,
dimension,
combiner="mean",
initializer=None):
initializer=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None):
if initializer is not None and not callable(initializer):
raise ValueError("initializer must be callable if specified.")
raise ValueError("initializer must be callable if specified. "
"Embedding of column_name: {}".format(
sparse_id_column.name))
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
raise ValueError("Must specify both `ckpt_to_load_from` and "
"`tensor_name_in_ckpt` or none of them.")
if initializer is None:
stddev = 1 / math.sqrt(sparse_id_column.length)
# TODO(b/25671353): Better initial value?
@ -602,7 +624,8 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
stddev=stddev)
return super(_EmbeddingColumn, cls).__new__(cls, sparse_id_column,
dimension, combiner,
initializer)
initializer, ckpt_to_load_from,
tensor_name_in_ckpt)
@property
def name(self):
@ -645,7 +668,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
input_tensor,
weight_collections=None,
trainable=True):
output, _ = _create_embedding_lookup(
output, embedding_weights = _create_embedding_lookup(
input_tensor=self.sparse_id_column.id_tensor(input_tensor),
weight_tensor=self.sparse_id_column.weight_tensor(input_tensor),
vocab_size=self.length,
@ -654,7 +677,14 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
initializer=self.initializer,
combiner=self.combiner,
trainable=trainable,
name=self.name + "_weights")
name=self.name)
if self.ckpt_to_load_from is not None:
weights_to_restore = embedding_weights
if len(embedding_weights) == 1:
weights_to_restore = embedding_weights[0]
checkpoint_utils.init_from_checkpoint(
self.ckpt_to_load_from,
{self.tensor_name_in_ckpt: weights_to_restore})
return output
# pylint: disable=unused-argument
@ -663,19 +693,22 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
num_outputs=1,
weight_collections=None,
trainable=True):
raise ValueError("Column {} is not supported in linear models. "
"Please use sparse_column.".format(self))
raise ValueError("EmbeddingColumn is not supported in linear models. "
"Please use sparse_column. column: {}".format(self))
def embedding_column(sparse_id_column,
dimension,
combiner="mean",
initializer=None):
initializer=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None):
"""Creates an _EmbeddingColumn.
Args:
sparse_id_column: A _SparseColumn which is created by `sparse_column_with_*`
functions. Note that `combiner` defined in `sparse_id_column` is ignored.
or crossed_column functions. Note that `combiner` defined in
`sparse_id_column` is ignored.
dimension: An integer specifying dimension of the embedding.
combiner: A string specifying how to reduce if there are multiple entries
in a single row. Currently "mean", "sqrtn" and "sum" are supported. Each
@ -688,11 +721,18 @@ def embedding_column(sparse_id_column,
variable initialization. If not specified, defaults to
`tf.truncated_normal_initializer` with mean 0.0 and standard deviation
1/sqrt(sparse_id_column.length).
ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
to restore the column weights. Required if `tensor_name_in_ckpt` is not
None.
tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
checkpoint from which to restore the column weights. Required if
`ckpt_to_load_from` is not None.
Returns:
An _EmbeddingColumn.
"""
return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer)
return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer,
ckpt_to_load_from, tensor_name_in_ckpt)
class _HashedEmbeddingColumn(collections.namedtuple(
@ -707,7 +747,8 @@ class _HashedEmbeddingColumn(collections.namedtuple(
combiner="mean",
initializer=None):
if initializer is not None and not callable(initializer):
raise ValueError("initializer must be callable if specified.")
raise ValueError("initializer must be callable if specified. "
"column_name: {}".format(column_name))
if initializer is None:
stddev = 0.1
# TODO(b/25671353): Better initial value?
@ -733,7 +774,7 @@ class _HashedEmbeddingColumn(collections.namedtuple(
weight_collections=None,
trainable=True):
embeddings = _create_embeddings(
name=self.name + "_weights",
name=self.name,
shape=[self.size],
initializer=self.initializer,
dtype=dtypes.float32,
@ -778,10 +819,14 @@ def hashed_embedding_column(column_name,
"""
if (dimension < 1) or (size < 1):
raise ValueError("Dimension and size must be greater than 0.")
raise ValueError("Dimension and size must be greater than 0. "
"dimension: {}, size: {}, column_name: {}".format(
dimension, size, column_name))
if combiner not in ("mean", "sqrtn", "sum"):
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'. "
"combiner: {}, column_name: {}".format(
combiner, column_name))
return _HashedEmbeddingColumn(column_name, size, dimension, combiner,
initializer)
@ -892,14 +937,18 @@ def real_valued_column(column_name,
"""
if not isinstance(dimension, int):
raise TypeError("dimension must be an integer")
raise TypeError("dimension must be an integer. "
"dimension: {}, column_name: {}".format(dimension,
column_name))
if dimension < 1:
raise ValueError("dimension must be greater than 0")
raise ValueError("dimension must be greater than 0. "
"dimension: {}, column_name: {}".format(dimension,
column_name))
if not (dtype.is_integer or dtype.is_floating):
raise ValueError("dtype is not convertible to tf.float32. Given {}".format(
dtype))
raise ValueError("dtype must be convertible to float. "
"dtype: {}, column_name: {}".format(dtype, column_name))
if default_value is None:
return _RealValuedColumn(column_name, dimension, default_value, dtype)
@ -920,9 +969,10 @@ def real_valued_column(column_name,
if isinstance(default_value, list):
if len(default_value) != dimension:
raise ValueError("The length of default_value is not equal to the "
"value of dimension. default_value is {}.".format(
default_value))
raise ValueError(
"The length of default_value must be equal to dimension. "
"default_value: {}, dimension: {}, column_name: {}".format(
default_value, dimension, column_name))
# Check if the values in the list are all integers or are convertible to
# floats.
is_list_all_int = True
@ -943,8 +993,9 @@ def real_valued_column(column_name,
default_value = [float(v) for v in default_value]
return _RealValuedColumn(column_name, dimension, default_value, dtype)
raise TypeError("default_value is not compatible with dtype. "
"default_value is {}.".format(default_value))
raise TypeError("default_value must be compatible with dtype. "
"default_value: {}, dtype: {}, column_name: {}".format(
default_value, dtype, column_name))
class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
@ -971,10 +1022,12 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
def __new__(cls, source_column, boundaries):
if not isinstance(source_column, _RealValuedColumn):
raise TypeError(
"source_column should be an instance of _RealValuedColumn.")
"source_column must be an instance of _RealValuedColumn. "
"source_column: {}".format(source_column))
if not isinstance(boundaries, list) or not boundaries:
raise ValueError("boundaries must be a list and it should not be empty.")
raise ValueError("boundaries must be a non-empty list. "
"boundaries: {}".format(boundaries))
# We allow bucket boundaries to be monotonically increasing
# (ie a[i+1] >= a[i]). When two bucket boundaries are the same, we
@ -986,7 +1039,8 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
elif boundaries[i] < boundaries[i + 1]:
sanitized_boundaries.append(boundaries[i])
else:
raise ValueError("boundaries must be a sorted list")
raise ValueError("boundaries must be a sorted list. "
"boundaries: {}".format(boundaries))
sanitized_boundaries.append(boundaries[len(boundaries) - 1])
return super(_BucketizedColumn, cls).__new__(cls, source_column,
@ -1067,7 +1121,7 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
initializer=init_ops.zeros_initializer,
combiner="sum",
trainable=trainable,
name=self.name + "_weights")
name=self.name)
def bucketized_column(source_column, boundaries):
@ -1087,7 +1141,8 @@ def bucketized_column(source_column, boundaries):
class _CrossedColumn(_FeatureColumn, collections.namedtuple(
"_CrossedColumn", ["columns", "hash_bucket_size", "combiner"])):
"_CrossedColumn", ["columns", "hash_bucket_size", "combiner",
"ckpt_to_load_from", "tensor_name_in_ckpt"])):
"""Represents a cross transformation also known as composition or union.
Instances of this class are immutable. It crosses given `columns`. Crossed
@ -1124,13 +1179,19 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
* "mean": do l1 normalization
* "sqrtn": do l2 normalization
For more information: `tf.embedding_lookup_sparse`.
ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
to restore the column weights. Required if `tensor_name_in_ckpt` is not
None.
tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
checkpoint from which to restore the column weights. Required if
`ckpt_to_load_from` is not None.
Raises:
TypeError: if all items in columns are not an instance of _SparseColumn,
_CrossedColumn, or _BucketizedColumn or
hash_bucket_size is not an int.
ValueError: if hash_bucket_size is not > 1 or
len(columns) is not > 1.
ValueError: if hash_bucket_size is not > 1 or len(columns) is not > 1. Also,
if only one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified.
"""
@staticmethod
@ -1138,26 +1199,36 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
return isinstance(column,
(_SparseColumn, _CrossedColumn, _BucketizedColumn))
def __new__(cls, columns, hash_bucket_size, combiner="sum"):
def __new__(cls, columns, hash_bucket_size, combiner="sum",
ckpt_to_load_from=None, tensor_name_in_ckpt=None):
for column in columns:
if not _CrossedColumn._is_crossable(column):
raise TypeError("columns should be a set of "
"_SparseColumn, _CrossedColumn, or _BucketizedColumn. "
"Column is {}".format(column))
raise TypeError("columns must be a set of _SparseColumn, "
"_CrossedColumn, or _BucketizedColumn instances. "
"column: {}".format(column))
if len(columns) < 2:
raise ValueError("columns should contain at least 2 elements.")
raise ValueError("columns must contain at least 2 elements. "
"columns: {}".format(columns))
if not isinstance(hash_bucket_size, int):
raise TypeError("hash_bucket_size should be an int.")
raise TypeError("hash_bucket_size must be an int. "
"hash_bucket_size: {}".format(hash_bucket_size))
if hash_bucket_size < 2:
raise ValueError("hash_bucket_size should be at least 2.")
raise ValueError("hash_bucket_size must be at least 2. "
"hash_bucket_size: {}".format(hash_bucket_size))
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
raise ValueError("Must specify both `ckpt_to_load_from` and "
"`tensor_name_in_ckpt` or none of them.")
sorted_columns = sorted([column for column in columns],
key=lambda column: column.name)
return super(_CrossedColumn, cls).__new__(cls, tuple(sorted_columns),
hash_bucket_size, combiner)
hash_bucket_size, combiner,
ckpt_to_load_from,
tensor_name_in_ckpt)
@property
def name(self):
@ -1181,6 +1252,15 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
"""Returns a string which will be used as a key when we do sorting."""
return "{}".format(self)
def id_tensor(self, input_tensor):
"""Returns the id tensor from the given transformed input_tensor."""
return input_tensor
# pylint: disable=unused-argument
def weight_tensor(self, input_tensor):
"""Returns the weight tensor from the given transformed input_tensor."""
return None
def insert_transformed_feature(self, columns_to_tensors):
"""Handles cross transformation."""
@ -1215,15 +1295,15 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
input_tensor,
weight_collections=None,
trainable=True):
raise ValueError("Column {} is not supported in DNN. "
"Please use embedding_column.".format(self))
raise ValueError("CrossedColumn is not supported in DNN. "
"Please use embedding_column. column: {}".format(self))
def to_weighted_sum(self,
input_tensor,
num_outputs=1,
weight_collections=None,
trainable=True):
return _create_embedding_lookup(
output, embedding_weights = _create_embedding_lookup(
input_tensor=input_tensor,
weight_tensor=None,
vocab_size=self.length,
@ -1232,10 +1312,20 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
initializer=init_ops.zeros_initializer,
combiner=self.combiner,
trainable=trainable,
name=self.name + "_weights")
name=self.name)
if self.ckpt_to_load_from is not None:
weights_to_restore = embedding_weights
if len(embedding_weights) == 1:
weights_to_restore = embedding_weights[0]
checkpoint_utils.init_from_checkpoint(
self.ckpt_to_load_from,
{self.tensor_name_in_ckpt: weights_to_restore})
return output, embedding_weights
def crossed_column(columns, hash_bucket_size, combiner="sum"):
def crossed_column(columns, hash_bucket_size, combiner="sum",
ckpt_to_load_from=None,
tensor_name_in_ckpt=None):
"""Creates a _CrossedColumn.
Args:
@ -1243,6 +1333,12 @@ def crossed_column(columns, hash_bucket_size, combiner="sum"):
_SparseColumn, _CrossedColumn, or _BucketizedColumn.
hash_bucket_size: An int that is > 1. The number of buckets.
combiner: A combiner string, supports sum, mean, sqrtn.
ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
to restore the column weights. Required if `tensor_name_in_ckpt` is not
None.
tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
checkpoint from which to restore the column weights. Required if
`ckpt_to_load_from` is not None.
Returns:
A _CrossedColumn.
@ -1254,12 +1350,14 @@ def crossed_column(columns, hash_bucket_size, combiner="sum"):
ValueError: if hash_bucket_size is not > 1 or
len(columns) is not > 1.
"""
return _CrossedColumn(columns, hash_bucket_size, combiner=combiner)
return _CrossedColumn(columns, hash_bucket_size, combiner=combiner,
ckpt_to_load_from=ckpt_to_load_from,
tensor_name_in_ckpt=tensor_name_in_ckpt)
class DataFrameColumn(_FeatureColumn,
collections.namedtuple("DataFrameColumn",
["name", "series"])):
["column_name", "series"])):
"""Represents a feature column produced from a `DataFrame`.
Instances of this class are immutable. A `DataFrame` column may be dense or
@ -1267,13 +1365,17 @@ class DataFrameColumn(_FeatureColumn,
batch_size.
Args:
name: a name for this column
column_name: a name for this column
series: a `Series` to be wrapped, which has already had its base features
substituted with `PredefinedSeries`.
"""
def __new__(cls, name, series):
return super(DataFrameColumn, cls).__new__(cls, name, series)
def __new__(cls, column_name, series):
return super(DataFrameColumn, cls).__new__(cls, column_name, series)
@property
def name(self):
return self.column_name
@property
def config(self):
@ -1301,7 +1403,17 @@ class DataFrameColumn(_FeatureColumn,
input_tensor,
weight_collections=None,
trainable=True):
return input_tensor
# DataFrame typically provides Tensors of shape [batch_size],
# but Estimator requires shape [batch_size, 1]
dims = input_tensor.get_shape().ndims
if dims == 0:
raise ValueError(
"Can't build input layer from tensor of shape (): {}".format(
self.column_name))
elif dims == 1:
return array_ops.expand_dims(input_tensor, 1)
else:
return input_tensor
# TODO(soergel): This mirrors RealValuedColumn for now, but should become
# better abstracted with less code duplication when we add other kinds.
@ -1469,7 +1581,7 @@ def _create_embeddings(name, shape, dtype, initializer, trainable,
with just one variable.
Args:
name: A string specifying the name of the embedding variable.
name: A string. The name of the embedding variable will be name + _weights.
shape: shape of the embeddding. Note this is not the shape of partitioned
variables.
dtype: type of the embedding. Also the shape of each partitioned variable.
@ -1531,7 +1643,7 @@ def _create_embedding_lookup(input_tensor, weight_tensor, vocab_size, dimension,
A Tensor with shape [batch_size, dimension] and embedding Variable.
"""
embeddings = _create_embeddings(name=name,
embeddings = _create_embeddings(name=name + "_weights",
shape=[vocab_size, dimension],
dtype=dtypes.float32,
initializer=initializer,
@ -1543,4 +1655,4 @@ def _create_embedding_lookup(input_tensor, weight_tensor, vocab_size, dimension,
sparse_weights=weight_tensor,
default_id=0,
combiner=combiner,
name=name), embeddings
name=name + "_weights"), embeddings

View File

@ -393,6 +393,24 @@ class InputLayerTest(tf.test.TestCase):
tf.initialize_all_tables().run()
self.assertAllEqual(output.eval().shape, [2, 10])
def testEmbeddingColumnWitCrossedColumn(self):
a = tf.contrib.layers.sparse_column_with_hash_bucket("aaa",
hash_bucket_size=100)
b = tf.contrib.layers.sparse_column_with_hash_bucket("bbb",
hash_bucket_size=100)
crossed = tf.contrib.layers.crossed_column(
set([a, b]), hash_bucket_size=10000)
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],
indices=[[0, 0], [1, 0], [1, 1]],
shape=[2, 2])
features = {"aaa": wire_tensor, "bbb": wire_tensor}
embeded_sparse = tf.contrib.layers.embedding_column(crossed, 10)
output = tf.contrib.layers.input_from_feature_columns(features,
[embeded_sparse])
with self.test_session():
tf.initialize_all_variables().run()
self.assertAllEqual(output.eval().shape, [2, 10])
def testSparseColumn(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],

View File

@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
@ -58,14 +60,17 @@ class FeatureColumnTest(tf.test.TestCase):
self.assertEqual(b.dimension, 10)
self.assertTrue(b.default_value is None)
# dimension is an integer
with self.assertRaises(TypeError):
with self.assertRaisesRegexp(TypeError, "dimension must be an integer"):
tf.contrib.layers.real_valued_column("d3", dimension=1.0)
# dimension is a positive integer
with self.assertRaises(ValueError):
with self.assertRaisesRegexp(ValueError,
"dimension must be greater than 0"):
tf.contrib.layers.real_valued_column("d3", dimension=0)
with self.assertRaisesRegexp(ValueError,
"dtype must be convertible to float"):
tf.contrib.layers.real_valued_column("d3", dtype=tf.string)
# default_value is an integer.
c1 = tf.contrib.layers.real_valued_column("c1", default_value=2)
self.assertListEqual(list(c1.default_value), [2.])
@ -90,15 +95,18 @@ class FeatureColumnTest(tf.test.TestCase):
dimension=4,
default_value=2.)
self.assertListEqual(list(d2.default_value), [2., 2., 2., 2.])
with self.assertRaises(TypeError):
with self.assertRaisesRegexp(TypeError,
"default_value must be compatible with dtype"):
tf.contrib.layers.real_valued_column("d3",
default_value=2.,
dtype=tf.int32)
# default_value is neither interger nor float.
with self.assertRaises(TypeError):
# default_value is neither integer nor float.
with self.assertRaisesRegexp(
TypeError, "default_value must be compatible with dtype"):
tf.contrib.layers.real_valued_column("e1", default_value="string")
with self.assertRaises(TypeError):
with self.assertRaisesRegexp(
TypeError, "default_value must be compatible with dtype"):
tf.contrib.layers.real_valued_column("e1",
dimension=3,
default_value=[1, 3., "string"])
@ -123,11 +131,13 @@ class FeatureColumnTest(tf.test.TestCase):
dimension=3,
default_value=[2., 2, 2])
self.assertListEqual(list(g2.default_value), [2., 2., 2.])
with self.assertRaises(TypeError):
with self.assertRaisesRegexp(
TypeError, "default_value must be compatible with dtype"):
tf.contrib.layers.real_valued_column("g3",
default_value=[2.],
dtype=tf.int32)
with self.assertRaises(ValueError):
with self.assertRaisesRegexp(
ValueError, "The length of default_value must be equal to dimension"):
tf.contrib.layers.real_valued_column("g4",
dimension=3,
default_value=[2.])
@ -138,11 +148,19 @@ class FeatureColumnTest(tf.test.TestCase):
self.assertEqual(a.name, "aaa_BUCKETIZED")
def testBucketizedColumnRequiresRealValuedColumn(self):
with self.assertRaises(TypeError):
with self.assertRaisesRegexp(
TypeError, "source_column must be an instance of _RealValuedColumn"):
tf.contrib.layers.bucketized_column("bbb", [0])
with self.assertRaisesRegexp(
TypeError, "source_column must be an instance of _RealValuedColumn"):
tf.contrib.layers.bucketized_column(
tf.contrib.layers.sparse_column_with_integerized_feature(
column_name="bbb", bucket_size=10),
[0])
def testBucketizedColumnRequiresSortedBuckets(self):
with self.assertRaises(ValueError):
with self.assertRaisesRegexp(
ValueError, "boundaries must be a sorted list"):
tf.contrib.layers.bucketized_column(
tf.contrib.layers.real_valued_column("ccc"), [5, 0, 4])
@ -171,7 +189,10 @@ class FeatureColumnTest(tf.test.TestCase):
def testCrossedColumnNotSupportRealValuedColumn(self):
b = tf.contrib.layers.sparse_column_with_hash_bucket("bbb",
hash_bucket_size=100)
with self.assertRaises(TypeError):
with self.assertRaisesRegexp(
TypeError,
"columns must be a set of _SparseColumn, _CrossedColumn, "
"or _BucketizedColumn instances"):
tf.contrib.layers.crossed_column(
set([b, tf.contrib.layers.real_valued_column("real")]),
hash_bucket_size=10000)
@ -192,7 +213,8 @@ class FeatureColumnTest(tf.test.TestCase):
"weights": tf.VarLenFeature(tf.int32)},
weighted_ids.config)
with self.assertRaises(ValueError):
with self.assertRaisesRegexp(ValueError,
"dtype is not convertible to float"):
weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights",
dtype=tf.string)
@ -209,7 +231,8 @@ class FeatureColumnTest(tf.test.TestCase):
[1], dtype=tf.int32)},
rvc.config)
with self.assertRaises(ValueError):
with self.assertRaisesRegexp(ValueError,
"dtype must be convertible to float"):
tf.contrib.layers.real_valued_column("rvc", dtype=tf.string)
def testSparseColumnDtypes(self):
@ -220,7 +243,8 @@ class FeatureColumnTest(tf.test.TestCase):
"sc", 10, dtype=tf.int32)
self.assertDictEqual({"sc": tf.VarLenFeature(dtype=tf.int32)}, sc.config)
with self.assertRaises(ValueError):
with self.assertRaisesRegexp(ValueError,
"dtype must be an integer"):
tf.contrib.layers.sparse_column_with_integerized_feature("sc",
10,
dtype=tf.float32)
@ -323,6 +347,107 @@ class FeatureColumnTest(tf.test.TestCase):
self.assertEqual(tf.float32, placeholder.dtype)
self.assertEqual([None, 1], placeholder.get_shape().as_list())
def testInitEmbeddingColumnWeightsFromCkpt(self):
sparse_col = tf.contrib.layers.sparse_column_with_hash_bucket(
column_name="object_in_image",
hash_bucket_size=4)
# Create _EmbeddingColumn which randomly initializes embedding of size
# [4, 16].
embedding_col = tf.contrib.layers.embedding_column(sparse_col, dimension=16)
# Creating a SparseTensor which has all the ids possible for the given
# vocab.
input_tensor = tf.SparseTensor(indices=[[0, 0], [1, 1], [2, 2], [3, 3]],
values=[0, 1, 2, 3],
shape=[4, 4])
# Invoking 'embedding_column.to_dnn_input_layer' will create the embedding
# variable. Creating under scope 'run_1' so as to prevent name conflicts
# when creating embedding variable for 'embedding_column_pretrained'.
with tf.variable_scope("run_1"):
# This will return a [4, 16] tensor which is same as embedding variable.
embeddings = embedding_col.to_dnn_input_layer(input_tensor)
save = tf.train.Saver()
checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt")
with self.test_session() as sess:
sess.run(tf.initialize_all_variables())
saved_embedding = embeddings.eval()
save.save(sess, checkpoint_path)
embedding_col_initialized = tf.contrib.layers.embedding_column(
sparse_id_column=sparse_col,
dimension=16,
ckpt_to_load_from=checkpoint_path,
tensor_name_in_ckpt="run_1/object_in_image_embedding_weights")
with tf.variable_scope("run_2"):
# This will initialize the embedding from provided checkpoint and return a
# [4, 16] tensor which is same as embedding variable. Since we didn't
# modify embeddings, this should be same as 'saved_embedding'.
pretrained_embeddings = embedding_col_initialized.to_dnn_input_layer(
input_tensor)
with self.test_session() as sess:
sess.run(tf.initialize_all_variables())
loaded_embedding = pretrained_embeddings.eval()
self.assertAllClose(saved_embedding, loaded_embedding)
def testInitCrossedColumnWeightsFromCkpt(self):
sparse_col_1 = tf.contrib.layers.sparse_column_with_hash_bucket(
column_name="col_1", hash_bucket_size=4)
sparse_col_2 = tf.contrib.layers.sparse_column_with_hash_bucket(
column_name="col_2", hash_bucket_size=4)
crossed_col = tf.contrib.layers.crossed_column(
columns=[sparse_col_1, sparse_col_2],
hash_bucket_size=4)
input_tensor = tf.SparseTensor(indices=[[0, 0], [1, 1], [2, 2], [3, 3]],
values=[0, 1, 2, 3],
shape=[4, 4])
# Invoking 'crossed_col.to_weighted_sum' will create the crossed column
# weights variable.
with tf.variable_scope("run_1"):
# Returns looked up column weights which is same as crossed column weights
# as well as actual references to weights variables.
col_weights, weights = crossed_col.to_weighted_sum(input_tensor)
# Update the weights since default initializer initializes all weights to
# 0.0.
for weight in weights:
assign_op = tf.assign(weight, weight + 0.5)
save = tf.train.Saver()
checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt")
with self.test_session() as sess:
sess.run(tf.initialize_all_variables())
sess.run(assign_op)
saved_col_weights = col_weights.eval()
save.save(sess, checkpoint_path)
crossed_col_initialized = tf.contrib.layers.crossed_column(
columns=[sparse_col_1, sparse_col_2],
hash_bucket_size=4,
ckpt_to_load_from=checkpoint_path,
tensor_name_in_ckpt="run_1/col_1_X_col_2_weights")
with tf.variable_scope("run_2"):
# This will initialize the crossed column weights from provided checkpoint
# and return a [4, 1] tensor which is same as weights variable. Since we
# won't modify weights, this should be same as 'saved_col_weights'.
col_weights_from_ckpt, _ = crossed_col_initialized.to_weighted_sum(
input_tensor)
with self.test_session() as sess:
sess.run(tf.initialize_all_variables())
loaded_col_weights = col_weights_from_ckpt.eval()
self.assertAllClose(saved_col_weights, loaded_col_weights)
if __name__ == "__main__":
tf.test.main()

View File

@ -102,12 +102,13 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False,
TypeError: if `mode` is not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG'].
"""
if not dtype.is_floating:
raise TypeError('Cannot create initializer for non-floating point '
'type.')
raise TypeError('Cannot create initializer for non-floating point type.')
if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']:
raise TypeError('Unknow mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode)
def _initializer(shape, dtype=dtype):
"""Initializer function."""
if not dtype.is_floating:
raise TypeError('Cannot create initializer for non-floating point type.')
# Estimating fan_in and fan_out is not possible to do perfectly, but we try.
# This is the right thing for matrix multiply and convolutions.
fan_in = float(shape[-2])

View File

@ -64,6 +64,11 @@ class VarianceScalingInitializerTest(tf.test.TestCase):
TypeError,
'Cannot create initializer for non-floating point type.'):
tf.contrib.layers.variance_scaling_initializer(dtype=tf.int32)
initializer = tf.contrib.layers.variance_scaling_initializer()
with self.assertRaisesRegexp(
TypeError,
'Cannot create initializer for non-floating point type.'):
initializer([], dtype=tf.int32)
def _test_variance(self, initializer, shape, variance, factor, mode, uniform):
with tf.Graph().as_default() as g:

View File

@ -75,25 +75,24 @@ def avg_pool2d(inputs,
padding='VALID',
outputs_collections=None,
scope=None):
"""Adds a Avg Pooling op.
"""Adds a 2D average pooling op.
It is assumed by the wrapper that the pooling is only done per image and not
in depth or batch.
It is assumed that the pooling is done per image but not in batch or channels.
Args:
inputs: a tensor of size [batch_size, height, width, depth].
kernel_size: a list of length 2: [kernel_height, kernel_width] of the
inputs: A `Tensor` of size [batch_size, height, width, channels].
kernel_size: A list of length 2: [kernel_height, kernel_width] of the
pooling kernel over which the op is computed. Can be an int if both
values are the same.
stride: a list of length 2: [stride_height, stride_width].
Can be an int if both strides are the same. Note that presently
stride: A list of length 2: [stride_height, stride_width].
Can be an int if both strides are the same. Note that presently
both strides must have the same value.
padding: the padding method, either 'VALID' or 'SAME'.
outputs_collections: collection to add the outputs.
padding: The padding method, either 'VALID' or 'SAME'.
outputs_collections: The collections to which the outputs are added.
scope: Optional scope for op_scope.
Returns:
a tensor representing the results of the pooling operation.
A `Tensor` representing the results of the pooling operation.
"""
with ops.op_scope([inputs], scope, 'AvgPool2D') as sc:
inputs = ops.convert_to_tensor(inputs)
@ -843,27 +842,27 @@ def max_pool2d(inputs,
padding='VALID',
outputs_collections=None,
scope=None):
"""Adds a Max Pooling op.
"""Adds a 2D Max Pooling op.
It is assumed by the wrapper that the pooling is only done per image and not
in depth or batch.
It is assumed that the pooling is done per image but not in batch or channels.
Args:
inputs: a tensor of size [batch_size, height, width, depth].
kernel_size: a list of length 2: [kernel_height, kernel_width] of the
inputs: A `Tensor` of size [batch_size, height, width, channels].
kernel_size: A list of length 2: [kernel_height, kernel_width] of the
pooling kernel over which the op is computed. Can be an int if both
values are the same.
stride: a list of length 2: [stride_height, stride_width].
Can be an int if both strides are the same. Note that presently
stride: A list of length 2: [stride_height, stride_width].
Can be an int if both strides are the same. Note that presently
both strides must have the same value.
padding: the padding method, either 'VALID' or 'SAME'.
outputs_collections: collection to add the outputs.
padding: The padding method, either 'VALID' or 'SAME'.
outputs_collections: The collections to which the outputs are added.
scope: Optional scope for op_scope.
Returns:
a tensor representing the results of the pooling operation.
A `Tensor` representing the results of the pooling operation.
Raises:
ValueError: if 'kernel_size' is not a 2-D list
ValueError: If 'kernel_size' is not a 2-D list
"""
with ops.op_scope([inputs], scope, 'MaxPool2D') as sc:
inputs = ops.convert_to_tensor(inputs)
@ -1037,6 +1036,7 @@ def separable_convolution2d(
depthwise_weights = variables.model_variable(
'depthwise_weights',
shape=depthwise_shape,
dtype=dtype,
initializer=weights_initializer,
regularizer=weights_regularizer,
trainable=trainable,
@ -1049,6 +1049,7 @@ def separable_convolution2d(
pointwise_weights = variables.model_variable(
'pointwise_weights',
shape=pointwise_shape,
dtype=dtype,
initializer=weights_initializer,
regularizer=weights_regularizer,
trainable=trainable,

View File

@ -30,59 +30,52 @@ class AvgPool2DTest(tf.test.TestCase):
def testCreateAvgPool(self):
height, width = 3, 3
with self.test_session():
images = np.random.uniform(size=(5, height, width, 3))
output = tf.contrib.layers.avg_pool2d(images, [3, 3])
self.assertEquals(output.op.name, 'AvgPool2D/AvgPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
images = np.random.uniform(size=(5, height, width, 3))
output = tf.contrib.layers.avg_pool2d(images, [3, 3])
self.assertEquals(output.op.name, 'AvgPool2D/AvgPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
def testCollectOutputs(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, [3, 3],
outputs_collections='outputs')
c_output = tf.get_collection('outputs')[0]
self.assertEquals(c_output.name, 'AvgPool2D')
self.assertEquals(c_output.outputs, output)
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, [3, 3],
outputs_collections='outputs')
output_collection = tf.get_collection('outputs')[0]
self.assertEquals(output_collection.name, 'AvgPool2D')
self.assertEquals(output_collection.outputs, output)
def testCreateSquareAvgPool(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, 3)
self.assertEquals(output.op.name, 'AvgPool2D/AvgPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, 3)
self.assertEquals(output.op.name, 'AvgPool2D/AvgPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
def testCreateAvgPoolWithScope(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, [3, 3], scope='pool1')
self.assertEquals(output.op.name, 'pool1/AvgPool')
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, [3, 3], scope='pool1')
self.assertEquals(output.op.name, 'pool1/AvgPool')
def testCreateAvgPoolSAME(self):
def testCreateAvgPoolWithSamePadding(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, [3, 3], padding='SAME')
self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, [3, 3], padding='SAME')
self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
def testCreateAvgPoolStrideSAME(self):
def testCreateAvgPoolStrideWithSamePadding(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, [3, 3], stride=1,
padding='SAME')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, [3, 3], stride=1,
padding='SAME')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
def testGlobalAvgPool(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, images.get_shape()[1:3],
stride=1)
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.avg_pool2d(images, images.get_shape()[1:3],
stride=1)
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
class BiasAddTest(tf.test.TestCase):
@ -825,7 +818,7 @@ class DropoutTest(tf.test.TestCase):
with self.test_session():
images = np.random.uniform(size=(5, height, width, 3))
output = tf.contrib.layers.dropout(images)
self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
self.assertEquals(output.op.name, 'Dropout/dropout/mul')
output.get_shape().assert_is_compatible_with(
tf.convert_to_tensor(images).get_shape())
@ -835,7 +828,7 @@ class DropoutTest(tf.test.TestCase):
is_training = tf.constant(True)
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.dropout(images, is_training=is_training)
self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
self.assertEquals(output.op.name, 'Dropout/dropout/mul')
output.get_shape().assert_is_compatible_with(images.get_shape())
def testCreateDropoutWithConstantFalse(self):
@ -1502,59 +1495,52 @@ class MaxPool2DTest(tf.test.TestCase):
def testCreateMaxPool(self):
height, width = 3, 3
with self.test_session():
images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32)
output = tf.contrib.layers.max_pool2d(images, [3, 3])
self.assertEquals(output.op.name, 'MaxPool2D/MaxPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32)
output = tf.contrib.layers.max_pool2d(images, [3, 3])
self.assertEquals(output.op.name, 'MaxPool2D/MaxPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
def testCollectOutputs(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, [3, 3],
outputs_collections='outputs')
c_output = tf.get_collection('outputs')[0]
self.assertEquals(c_output.name, 'MaxPool2D')
self.assertEquals(c_output.outputs, output)
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, [3, 3],
outputs_collections='outputs')
outputs_collection = tf.get_collection('outputs')[0]
self.assertEquals(outputs_collection.name, 'MaxPool2D')
self.assertEquals(outputs_collection.outputs, output)
def testCreateSquareMaxPool(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, 3)
self.assertEquals(output.op.name, 'MaxPool2D/MaxPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, 3)
self.assertEquals(output.op.name, 'MaxPool2D/MaxPool')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
def testCreateMaxPoolWithScope(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, [3, 3], scope='pool1')
self.assertEquals(output.op.name, 'pool1/MaxPool')
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, [3, 3], scope='pool1')
self.assertEquals(output.op.name, 'pool1/MaxPool')
def testCreateMaxPoolSAME(self):
def testCreateMaxPoolWithSamePadding(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, [3, 3], padding='SAME')
self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, [3, 3], padding='SAME')
self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
def testCreateMaxPoolStrideSAME(self):
def testCreateMaxPoolStrideWithSamePadding(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, [3, 3], stride=1,
padding='SAME')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, [3, 3], stride=1,
padding='SAME')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
def testGlobalMaxPool(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, images.get_shape()[1:3],
stride=1)
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.max_pool2d(images, images.get_shape()[1:3],
stride=1)
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
class OneHotEncodingTest(tf.test.TestCase):
@ -1618,10 +1604,28 @@ class RepeatTests(tf.test.TestCase):
class SeparableConv2dTest(tf.test.TestCase):
def testCreateConv(self):
def testCreateConvInt32(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
images = tf.random_uniform(
(5, height, width, 3), seed=1, dtype=tf.int32, maxval=12345)
with self.assertRaisesRegexp(TypeError, 'non-floating point type'):
tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2)
def testCreateConvFloat32(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform(
(5, height, width, 3), seed=1, dtype=tf.float32)
output = tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2)
self.assertEquals(output.op.name, 'SeparableConv2d/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
def testCreateConvFloat64(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform(
(5, height, width, 3), seed=1, dtype=tf.float64)
output = tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2)
self.assertEquals(output.op.name, 'SeparableConv2d/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])

View File

@ -31,6 +31,7 @@ from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as vars_
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as optimizer_
from tensorflow.python.training import training as train
@ -43,6 +44,13 @@ OPTIMIZER_CLS_NAMES = {
"SGD": train.GradientDescentOptimizer,
}
OPTIMIZER_SUMMARIES = [
"learning_rate",
"loss",
"gradients",
"gradient_norm",
]
def optimize_loss(loss,
global_step,
@ -51,11 +59,12 @@ def optimize_loss(loss,
gradient_noise_scale=None,
gradient_multipliers=None,
clip_gradients=None,
moving_average_decay=0.9,
moving_average_decay=None,
learning_rate_decay_fn=None,
update_ops=None,
variables=None,
name=None):
name=None,
summaries=None):
"""Given loss and parameters for optimizer, returns a training op.
Args:
@ -75,8 +84,8 @@ def optimize_loss(loss,
If present, gradients for specified
variables will be multiplied by given constant.
clip_gradients: float or `None`, clips gradients by this value.
moving_average_decay: float or None, takes into account previous loss
to make learning smoother due to outliers.
moving_average_decay: Deprecated. float or None, takes into account previous
loss to make learning smoother due to outliers.
learning_rate_decay_fn: function, takes `learning_rate` and `global_step`
`Tensor`s, returns `Tensor`.
Can be used to implement any learning rate decay
@ -87,6 +96,9 @@ def optimize_loss(loss,
variables: list of variables to optimize or
`None` to use all trainable variables.
name: The name for this operation is used to scope operations and summaries.
summaries: List of internal quantities to visualize on tensorboard. If not
set only the loss and the learning rate will be reported. The
complete list is in OPTIMIZER_SUMMARIES.
Returns:
Training op.
@ -96,8 +108,8 @@ def optimize_loss(loss,
"""
with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"):
# Update ops take UPDATE_OPS collection if not provided.
update_ops = (set(update_ops or []) or
set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)))
if update_ops is None:
update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
# Make sure update ops are ran before computing loss.
if update_ops:
with ops.control_dependencies(update_ops):
@ -105,7 +117,10 @@ def optimize_loss(loss,
loss = control_flow_ops.with_dependencies([barrier], loss)
# Moving average of the loss with decay.
# TODO(b/30439864): moving_average_decay should be removed.
if moving_average_decay is not None:
logging.warn("'moving_average_decay' is deprecated. Please use "
"tensorboard's builtin averaging instead.")
# Generate moving averages of the loss.
loss_averages = train.ExponentialMovingAverage(moving_average_decay,
name="avg")
@ -125,9 +140,12 @@ def optimize_loss(loss,
raise ValueError("Learning rate should be 0d Tensor or float. "
"Got %s of type %s" % (
str(learning_rate), str(type(learning_rate))))
if summaries is None:
summaries = ["loss", "learning_rate"]
if learning_rate_decay_fn is not None:
lr = learning_rate_decay_fn(lr, global_step)
logging_ops.scalar_summary("learning_rate", lr)
if "learning_rate" in summaries:
logging_ops.scalar_summary("learning_rate", lr)
# Create optimizer, given specified parameters.
if isinstance(optimizer, six.string_types):
@ -167,7 +185,8 @@ def optimize_loss(loss,
gradients = _clip_gradients_by_norm(gradients, clip_gradients)
# Add scalar summary for loss.
logging_ops.scalar_summary("loss", loss)
if "loss" in summaries:
logging_ops.scalar_summary("loss", loss)
# Add histograms for variables, gradients and gradient norms.
for gradient, variable in gradients:
@ -177,10 +196,12 @@ def optimize_loss(loss,
grad_values = gradient
if grad_values is not None:
logging_ops.histogram_summary(variable.name, variable)
logging_ops.histogram_summary(variable.name + "/gradients", grad_values)
logging_ops.histogram_summary(variable.name + "/gradient_norm",
clip_ops.global_norm([grad_values]))
if "gradients" in summaries:
logging_ops.histogram_summary(variable.name + "/gradients",
grad_values)
if "gradient_norm" in summaries:
logging_ops.histogram_summary(variable.name + "/gradient_norm",
clip_ops.global_norm([grad_values]))
# Create gradient updates.
grad_updates = opt.apply_gradients(gradients,

View File

@ -75,7 +75,8 @@ class OptimizersTest(tf.test.TestCase):
tf.initialize_all_variables().run()
session.run(train, feed_dict={x: 5})
var_value, global_step_value = session.run([var, global_step])
self.assertAlmostEqual(var_value, 8.58150, 4)
# Due to randomness the following number may change if graph is different.
self.assertAlmostEqual(var_value, 8.5591021, 4)
self.assertEqual(global_step_value, 1)
def testGradientNoiseWithClipping(self):

View File

@ -22,6 +22,7 @@ import inspect
import six
from tensorflow.contrib import losses
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
def regression_target(label_name=None,
@ -70,7 +70,7 @@ def multi_class_target(n_classes, label_name=None, weight_column_name=None):
will be multiplied by the loss of the example.
Returns:
An instance of _TargetColumn
An instance of _MultiClassTargetColumn.
Raises:
ValueError: if n_classes is < 2
@ -297,8 +297,17 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn):
"""_TargetColumn for binary classification using SVMs."""
def __init__(self, label_name, weight_column_name):
def loss_fn(logits, target):
check_shape_op = logging_ops.Assert(
math_ops.less_equal(array_ops.rank(target), 2),
["target's shape should be either [batch_size, 1] or [batch_size]"])
with ops.control_dependencies([check_shape_op]):
target = array_ops.reshape(
target, shape=[array_ops.shape(target)[0], 1])
return losses.hinge_loss(logits, target)
super(_BinarySvmTargetColumn, self).__init__(
loss_fn=_binary_hinge_loss,
loss_fn=loss_fn,
n_classes=2,
label_name=label_name,
weight_column_name=weight_column_name)
@ -331,22 +340,6 @@ def _log_loss_with_two_classes(logits, target):
return loss_vec
# TODO(sibyl-vie3Poto): Move this to contrib/losses/python/losses/loss_ops.py.
def _binary_hinge_loss(logits, target):
"""Method that returns the loss vector for binary hinge loss."""
check_shape_op = logging_ops.Assert(
math_ops.less_equal(
array_ops.rank(target), 2),
["target's shape should be either [batch_size, 1] or [batch_size]"])
with ops.control_dependencies([check_shape_op]):
target = array_ops.reshape(target, shape=[array_ops.shape(target)[0], 1])
# First need to convert binary labels to -1/1 labels (as floats).
all_ones = array_ops.ones_like(logits)
labels = math_ops.sub(2 * math_ops.to_float(target), all_ones)
loss_vec = nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits)))
return loss_vec
def _softmax_cross_entropy_loss(logits, target):
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
# Check that we got int32/int64 for classification.

View File

@ -36,6 +36,18 @@ py_test(
],
)
py_test(
name = "load_csv_test",
size = "small",
srcs = ["python/learn/tests/load_csv_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
],
)
py_test(
name = "data_feeder_test",
size = "small",
@ -235,9 +247,9 @@ py_test(
)
py_test(
name = "compare_test",
name = "binary_transform_test",
size = "small",
srcs = ["python/learn/tests/dataframe/compare_test.py"],
srcs = ["python/learn/tests/dataframe/binary_transform_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
@ -625,19 +637,6 @@ py_test(
],
)
py_test(
name = "checkpoints_test",
size = "small",
srcs = ["python/learn/utils/checkpoints_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)
py_test(
name = "graph_io_test",
size = "small",

View File

@ -56,6 +56,7 @@ Below are few simple examples of the API. For more examples, please see [example
Simple linear classification:
```python
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
iris = datasets.load_iris()
@ -70,6 +71,7 @@ print("Accuracy: %f" % score)
Simple linear regression:
```python
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics, preprocessing
boston = datasets.load_boston()
@ -85,6 +87,7 @@ print ("MSE: %f" % score)
Example of 3 layer network with 10, 20 and 10 hidden units respectively:
```python
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
iris = datasets.load_iris()
@ -99,6 +102,7 @@ print("Accuracy: %f" % score)
Example of how to pass a custom model to the Estimator:
```python
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
iris = datasets.load_iris()

View File

@ -33,6 +33,7 @@ from tensorflow.contrib.learn.python.learn import preprocessing
from tensorflow.contrib.learn.python.learn import utils
from tensorflow.contrib.learn.python.learn.dataframe import *
from tensorflow.contrib.learn.python.learn.estimators import *
from tensorflow.contrib.learn.python.learn.evaluable import Evaluable
from tensorflow.contrib.learn.python.learn.experiment import Experiment
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
from tensorflow.contrib.learn.python.learn.graph_actions import infer
@ -41,4 +42,5 @@ from tensorflow.contrib.learn.python.learn.graph_actions import run_feeds
from tensorflow.contrib.learn.python.learn.graph_actions import run_n
from tensorflow.contrib.learn.python.learn.graph_actions import train
from tensorflow.contrib.learn.python.learn.learn_io import *
from tensorflow.contrib.learn.python.learn.trainable import Trainable
# pylint: enable=wildcard-import

View File

@ -29,11 +29,14 @@ from tensorflow.contrib.learn.python.learn.dataframe.transform import Transform
# Transforms
from tensorflow.contrib.learn.python.learn.dataframe.transforms.boolean_mask import BooleanMask
from tensorflow.contrib.learn.python.learn.dataframe.transforms.difference import Difference
from tensorflow.contrib.learn.python.learn.dataframe.transforms.hashes import HashFast
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import NumpySource
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import PandasSource
from tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source import ReaderSource
from tensorflow.contrib.learn.python.learn.dataframe.transforms.sum import Sum
# pylint: disable=g-import-not-at-top,g-bad-import-order
# Unary Transform registration
@ -42,9 +45,9 @@ for ut_def in _ut.UNARY_TRANSFORMS:
_ut.register_unary_op(*ut_def)
# Comparison Transform registration
from tensorflow.contrib.learn.python.learn.dataframe.transforms import compare as _cmp
for ct_def in _cmp.COMPARISON_TRANSFORMS:
_cmp.register_comparison_ops(*ct_def)
from tensorflow.contrib.learn.python.learn.dataframe.transforms import binary_transforms as _bt
for bt_def in _bt.BINARY_TRANSFORMS:
_bt.register_binary_op(*bt_def)
__all__ = ['DataFrame', 'Series', 'PredefinedSeries', 'TransformedSeries',
'TensorFlowDataFrame', 'parameter', 'Transform']

View File

@ -117,10 +117,11 @@ class DataFrame(object):
value = [value]
self.assign(**dict(zip(key, value)))
def build(self):
def build(self, **kwargs):
# We do not allow passing a cache here, because that would encourage
# working around the rule that DataFrames cannot be expected to be
# synced with each other (e.g., they shuffle independently).
cache = {}
tensors = {name: c.build(cache) for name, c in self._columns.items()}
tensors = {name: c.build(cache, **kwargs)
for name, c in self._columns.items()}
return tensors

Some files were not shown because too many files have changed in this diff Show More