commit
ee221cb625
@ -33,10 +33,10 @@ and discussion.**
|
|||||||
|
|
||||||
People who are a little more adventurous can also try our nightly binaries:
|
People who are a little more adventurous can also try our nightly binaries:
|
||||||
|
|
||||||
* Linux CPU-only: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-cp27-none-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
* Linux CPU-only: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp35-cp35m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||||
* Linux GPU: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-cp27-none-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/140/artifact/pip_test/whl/tensorflow-0.8.0-cp35-cp35m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
* Linux GPU: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/140/artifact/pip_test/whl/tensorflow-0.8.0-cp35-cp35m-linux_x86_64.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-linux-gpu/TF_BUILD_CONTAINER_TYPE=GPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||||
* Mac CPU-only: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-py2-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/)) / [Python 3](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-py3-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/))
|
* Mac CPU-only: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py2-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac1-slave/)) / [Python 3](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py3-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac1-slave/))
|
||||||
* Mac GPU: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-py2-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-py3-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
* Mac GPU: [Python 2](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py2-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-py3-none-any.whl) ([build history](http://ci.tensorflow.org/view/Nightly/job/nigntly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||||
* [Android](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
|
* [Android](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
|
||||||
|
|
||||||
#### *Try your first TensorFlow program*
|
#### *Try your first TensorFlow program*
|
||||||
|
50
RELEASE.md
50
RELEASE.md
@ -1,16 +1,40 @@
|
|||||||
# Changes Since Last Release
|
|
||||||
|
|
||||||
## Features and Improvements
|
# Release 0.10.0
|
||||||
* Connectionist Temporal Classification ops are now "official" (see, e.g.,
|
|
||||||
`tf.nn.ctc_loss`)
|
|
||||||
* Preliminary graph-construction C API, for use by language bindings.
|
|
||||||
* Major revision to the graph-construction C++ API. Scoping mechanism to make op
|
|
||||||
naming, specifying control dependencies etc. more consistent. C++ values can
|
|
||||||
be used directly as operands, making op construction more concise.
|
|
||||||
|
|
||||||
## Breaking Changes to the API
|
## Major Features and Improvements
|
||||||
* `env.h` replaces use of `New*File()` functions to use `std::unique_ptr`
|
|
||||||
return arguments, removing the old raw pointer returns.
|
* Added support for C++ shape inference
|
||||||
|
* Added graph-construction C API
|
||||||
|
* Major revision to the graph-construction C++ API
|
||||||
|
* Support makefile build for iOS
|
||||||
|
* Added Mac GPU support
|
||||||
|
* Full version of TF-Slim available as `tf.contrib.slim`
|
||||||
|
* Added k-Means clustering and WALS matrix factorization
|
||||||
|
|
||||||
|
## Big Fixes and Other Changes
|
||||||
|
|
||||||
|
* Allow gradient computation for scalar values.
|
||||||
|
* Performance improvements for gRPC
|
||||||
|
* Improved support for fp16
|
||||||
|
* New high-level ops in tf.contrib.{layers,metrics}
|
||||||
|
* New features for TensorBoard, such as shape display, exponential smoothing
|
||||||
|
* Faster and more stable Google Cloud Storage (GCS) filesystem support
|
||||||
|
* Support for zlib compression and decompression for TFRecordReader and TFRecordWriter
|
||||||
|
* Support for reading (animated) GIFs
|
||||||
|
* Improved support for SparseTensor
|
||||||
|
* Added support for more probability distributions (Dirichlet, Beta, Bernoulli, etc.)
|
||||||
|
* Added Python interfaces to reset resource containers.
|
||||||
|
* Many bugfixes and performance improvements
|
||||||
|
* Many documentation fixes
|
||||||
|
|
||||||
|
## Thanks to our Contributors
|
||||||
|
|
||||||
|
This release contains contributions from many people at Google, as well as:
|
||||||
|
|
||||||
|
Alex Rothberg, Andrew Royer, Austin Marshall, @BlackCoal, Bob Adolf, Brian Diesel, Charles-Emmanuel Dias, @chemelnucfin, Chris Lesniewski, Daeyun Shin, Daniel Rodriguez, Danijar Hafner, Darcy Liu, Kristinn R. Thórisson, Daniel Castro, Dmitry Savintsev, Kashif Rasul, Dylan Paiton, Emmanuel T. Odeke, Ernest Grzybowski, Gavin Sherry, Gideon Dresdner, Gregory King, Harold Cooper, @heinzbeinz, Henry Saputra, Huarong Huo, Huazuo Gao, Igor Babuschkin, Igor Macedo Quintanilha, Ivan Ukhov, James Fysh, Jan Wilken Dörrie, Jihun Choi, Johnny Lim, Jonathan Raiman, Justin Francis, @lilac, Li Yi, Marc Khoury, Marco Marchesi, Max Melnick, Micael Carvalho, @mikowals, Mostafa Gazar, Nico Galoppo, Nishant Agrawal, Petr Janda, Yuncheng Li, @raix852, Robert Rose, @Robin-des-Bois, Rohit Girdhar, Sam Abrahams, satok16, Sergey Kishchenko, Sharkd Tu, @shotat, Siddharth Agrawal, Simon Denel, @sono-bfio, SunYeop Lee, Thijs Vogels, @tobegit3hub, @Undo1, Wang Yang, Wenjian Huang, Yaroslav Bulatov, Yuan Tang, Yunfeng Wang, Ziming Dong
|
||||||
|
|
||||||
|
We are also grateful to all who filed issues or helped resolve them, asked and
|
||||||
|
answered questions, and were part of inspiring discussions.
|
||||||
|
|
||||||
# Release 0.9.0
|
# Release 0.9.0
|
||||||
|
|
||||||
@ -55,7 +79,7 @@
|
|||||||
|
|
||||||
This release contains contributions from many people at Google, as well as:
|
This release contains contributions from many people at Google, as well as:
|
||||||
|
|
||||||
Aaron Schumacher, Aidan Dang, Akihiko ITOH, Aki Sukegawa, Arbit Chen, Aziz Alto, Danijar Hafner, Erik Erwitt, Fabrizio Milo, Felix Maximilian Möller, Henry Saputra, Sung Kim, Igor Babuschkin, Jan Zikes, Jeremy Barnes, Jesper Steen Møller, Johannes Mayer, Justin Harris, Kashif Rasul, Kevin Robinson, Loo Rong Jie, Lucas Moura, Łukasz Bieniasz-Krzywiec, Mario Cho, Maxim Grechkin, Michael Heilman, Mostafa Rahmani, Mourad Mourafiq, @ninotoshi, Orion Reblitz-Richardson, Yuncheng Li, @raoqiyu, Robert DiPietro, Sam Abrahams, Sebastian Raschka, Siddharth Agrawal, @snakecharmer1024, Stephen Roller, Sung Kim, SunYeop Lee, Thijs Vogels, Till Hoffmann, Victor Melo, Ville Kallioniemi, Waleed Abdulla, Wenjian Huang, Yaroslav Bulatov, Yeison Rodriguez, Yuan (Terry) Tang, Yuxin Wu, @zhongzyd, Ziming Dong, Zohar Jackson
|
Aaron Schumacher, Aidan Dang, Akihiko ITOH, Aki Sukegawa, Arbit Chen, Aziz Alto, Danijar Hafner, Erik Erwitt, Fabrizio Milo, Felix Maximilian Möller, Henry Saputra, Sung Kim, Igor Babuschkin, Jan Zikes, Jeremy Barnes, Jesper Steen Møller, Johannes Mayer, Justin Harris, Kashif Rasul, Kevin Robinson, Loo Rong Jie, Lucas Moura, Łukasz Bieniasz-Krzywiec, Mario Cho, Maxim Grechkin, Michael Heilman, Mostafa Rahmani, Mourad Mourafiq, @ninotoshi, Orion Reblitz-Richardson, Yuncheng Li, @raoqiyu, Robert DiPietro, Sam Abrahams, Sebastian Raschka, Siddharth Agrawal, @snakecharmer1024, Stephen Roller, Sung Kim, SunYeop Lee, Thijs Vogels, Till Hoffmann, Victor Melo, Ville Kallioniemi, Waleed Abdulla, Wenjian Huang, Yaroslav Bulatov, Yeison Rodriguez, Yuan Tang, Yuxin Wu, @zhongzyd, Ziming Dong, Zohar Jackson
|
||||||
|
|
||||||
We are also grateful to all who filed issues or helped resolve them, asked and
|
We are also grateful to all who filed issues or helped resolve them, asked and
|
||||||
answered questions, and were part of inspiring discussions.
|
answered questions, and were part of inspiring discussions.
|
||||||
@ -97,7 +121,7 @@ answered questions, and were part of inspiring discussions.
|
|||||||
|
|
||||||
This release contains contributions from many people at Google, as well as:
|
This release contains contributions from many people at Google, as well as:
|
||||||
|
|
||||||
Abhinav Upadhyay, Aggelos Avgerinos, Alan Wu, Alexander G. de G. Matthews, Aleksandr Yahnev, @amchercashin, Andy Kitchen, Aurelien Geron, Awni Hannun, @BanditCat, Bas Veeling, Cameron Chen, @cg31, Cheng-Lung Sung, Christopher Bonnett, Dan Becker, Dan Van Boxel, Daniel Golden, Danijar Hafner, Danny Goodman, Dave Decker, David Dao, David Kretch, Dongjoon Hyun, Dustin Dorroh, @e-lin, Eurico Doirado, Erik Erwitt, Fabrizio Milo, @gaohuazuo, Iblis Lin, Igor Babuschkin, Isaac Hodes, Isaac Turner, Iván Vallés, J Yegerlehner, Jack Zhang, James Wexler, Jan Zikes, Jay Young, Jeff Hodges, @jmtatsch, Johnny Lim, Jonas Meinertz Hansen, Kanit Wongsuphasawat, Kashif Rasul, Ken Shirriff, Kenneth Mitchner, Kenta Yonekura, Konrad Magnusson, Konstantin Lopuhin, @lahwran, @lekaha, @liyongsea, Lucas Adams, @makseq, Mandeep Singh, @manipopopo, Mark Amery, Memo Akten, Michael Heilman, Michael Peteuil, Nathan Daly, Nicolas Fauchereau, @ninotoshi, Olav Nymoen, @panmari, @papelita1234, Pedro Lopes, Pranav Sailesh Mani, RJ Ryan, Rob Culliton, Robert DiPietro, @ronrest, Sam Abrahams, Sarath Shekkizhar, Scott Graham, Sebastian Raschka, Sung Kim, Surya Bhupatiraju, Syed Ahmed, Till Hoffmann, @timsl, @urimend, @vesnica, Vlad Frolov, Vlad Zagorodniy, Wei-Ting Kuo, Wenjian Huang, William Dmitri Breaden Madden, Wladimir Schmidt, Yuwen Yan, Yuxin Wu, Yuya Kusakabe, @zhongzyd, @znah.
|
Abhinav Upadhyay, Aggelos Avgerinos, Alan Wu, Alexander G. de G. Matthews, Aleksandr Yahnev, @amchercashin, Andy Kitchen, Aurelien Geron, Awni Hannun, @BanditCat, Bas Veeling, Cameron Chen, @cg31, Cheng-Lung Sung, Christopher Bonnett, Dan Becker, Dan Van Boxel, Daniel Golden, Danijar Hafner, Danny Goodman, Dave Decker, David Dao, David Kretch, Dongjoon Hyun, Dustin Dorroh, @e-lin, Eurico Doirado, Erik Erwitt, Fabrizio Milo, @gaohuazuo, Iblis Lin, Igor Babuschkin, Isaac Hodes, Isaac Turner, Iván Vallés, J Yegerlehner, Jack Zhang, James Wexler, Jan Zikes, Jay Young, Jeff Hodges, @jmtatsch, Johnny Lim, Jonas Meinertz Hansen, Kanit Wongsuphasawat, Kashif Rasul, Ken Shirriff, Kenneth Mitchner, Kenta Yonekura, Konrad Magnusson, Konstantin Lopuhin, @lahwran, @lekaha, @liyongsea, Lucas Adams, @makseq, Mandeep Singh, @manipopopo, Mark Amery, Memo Akten, Michael Heilman, Michael Peteuil, Nathan Daly, Nicolas Fauchereau, @ninotoshi, Olav Nymoen, @panmari, @papelita1234, Pedro Lopes, Pranav Sailesh Mani, RJ Ryan, Rob Culliton, Robert DiPietro, @ronrest, Sam Abrahams, Sarath Shekkizhar, Scott Graham, Sebastian Raschka, Sung Kim, Surya Bhupatiraju, Syed Ahmed, Till Hoffmann, @timsl, @urimend, @vesnica, Vlad Frolov, Vlad Zagorodniy, Wei-Ting Kuo, Wenjian Huang, William Dmitri Breaden Madden, Wladimir Schmidt, Yuan Tang, Yuwen Yan, Yuxin Wu, Yuya Kusakabe, @zhongzyd, @znah.
|
||||||
|
|
||||||
We are also grateful to all who filed issues or helped resolve them, asked and
|
We are also grateful to all who filed issues or helped resolve them, asked and
|
||||||
answered questions, and were part of inspiring discussions.
|
answered questions, and were part of inspiring discussions.
|
||||||
|
@ -37,7 +37,10 @@ config_setting(
|
|||||||
|
|
||||||
package_group(
|
package_group(
|
||||||
name = "internal",
|
name = "internal",
|
||||||
packages = ["//tensorflow/..."],
|
packages = [
|
||||||
|
"//learning/vis/...",
|
||||||
|
"//tensorflow/...",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
sh_binary(
|
sh_binary(
|
||||||
@ -71,6 +74,7 @@ filegroup(
|
|||||||
name = "all_opensource_files",
|
name = "all_opensource_files",
|
||||||
data = [
|
data = [
|
||||||
":all_files",
|
":all_files",
|
||||||
|
"//tensorflow/c:all_files",
|
||||||
"//tensorflow/cc:all_files",
|
"//tensorflow/cc:all_files",
|
||||||
"//tensorflow/contrib:all_files",
|
"//tensorflow/contrib:all_files",
|
||||||
"//tensorflow/contrib/copy_graph:all_files",
|
"//tensorflow/contrib/copy_graph:all_files",
|
||||||
@ -103,6 +107,7 @@ filegroup(
|
|||||||
"//tensorflow/contrib/testing:all_files",
|
"//tensorflow/contrib/testing:all_files",
|
||||||
"//tensorflow/contrib/util:all_files",
|
"//tensorflow/contrib/util:all_files",
|
||||||
"//tensorflow/core:all_files",
|
"//tensorflow/core:all_files",
|
||||||
|
"//tensorflow/core/debug:all_files",
|
||||||
"//tensorflow/core/distributed_runtime:all_files",
|
"//tensorflow/core/distributed_runtime:all_files",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:all_files",
|
"//tensorflow/core/distributed_runtime/rpc:all_files",
|
||||||
"//tensorflow/core/kernels:all_files",
|
"//tensorflow/core/kernels:all_files",
|
||||||
|
95
tensorflow/c/BUILD
Normal file
95
tensorflow/c/BUILD
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# Description:
|
||||||
|
# C API for TensorFlow, for use by client language bindings.
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_cc_test",
|
||||||
|
"tf_cuda_library",
|
||||||
|
)
|
||||||
|
|
||||||
|
# For platform specific build config
|
||||||
|
load(
|
||||||
|
"//tensorflow/core:platform/default/build_config.bzl",
|
||||||
|
"tf_kernel_tests_linkstatic",
|
||||||
|
)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Public targets
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "c_api",
|
||||||
|
srcs = ["c_api.cc"],
|
||||||
|
hdrs = ["c_api.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "tf_status_helper",
|
||||||
|
srcs = ["tf_status_helper.cc"],
|
||||||
|
hdrs = ["tf_status_helper.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":c_api",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "checkpoint_reader",
|
||||||
|
srcs = ["checkpoint_reader.cc"],
|
||||||
|
hdrs = ["checkpoint_reader.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":tf_status_helper",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "c_api_test",
|
||||||
|
size = "small",
|
||||||
|
linkopts = select({
|
||||||
|
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
|
deps = [
|
||||||
|
":c_api",
|
||||||
|
"//tensorflow/core:direct_session",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:proto_text",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/kernels:array",
|
||||||
|
"//tensorflow/core/kernels:math",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Google-internal targets.
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/public/tensor_c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -482,7 +482,6 @@ static void TF_Run_Helper(
|
|||||||
result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
|
result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
|
||||||
}
|
}
|
||||||
if (!result.ok()) {
|
if (!result.ok()) {
|
||||||
LOG(ERROR) << result.error_message();
|
|
||||||
status->status = result;
|
status->status = result;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// TODO(jeff,sanjay): Rename to tensorflow/public/c_api.h
|
#ifndef TENSORFLOW_C_C_API_H_
|
||||||
#ifndef TENSORFLOW_PUBLIC_TENSOR_C_API_H_
|
#define TENSORFLOW_C_C_API_H_
|
||||||
#define TENSORFLOW_PUBLIC_TENSOR_C_API_H_
|
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
@ -699,4 +698,4 @@ extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);
|
|||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif // TENSORFLOW_PUBLIC_TENSOR_C_API_H_
|
#endif // TENSORFLOW_C_C_API_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/public/tensor_c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
#include "tensorflow/core/framework/graph.pb_text.h"
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/util/checkpoint_reader.h"
|
#include "tensorflow/c/checkpoint_reader.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H
|
#ifndef TENSORFLOW_C_CHECKPOINT_READER_H
|
||||||
#define TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H
|
#define TENSORFLOW_C_CHECKPOINT_READER_H
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/tensor_slice_reader.h"
|
#include "tensorflow/core/util/tensor_slice_reader.h"
|
||||||
#include "tensorflow/core/util/tf_status_helper.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -60,4 +60,4 @@ class CheckpointReader {
|
|||||||
} // namespace checkpoint
|
} // namespace checkpoint
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H
|
#endif // TENSORFLOW_C_CHECKPOINT_READER_H
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/util/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H
|
#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H
|
||||||
#define TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H
|
#define TENSORFLOW_C_TF_STATUS_HELPER_H
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/public/tensor_c_api.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -26,4 +26,4 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status);
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H
|
#endif // TENSORFLOW_C_TF_STATUS_HELPER_H
|
@ -73,6 +73,46 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "grad_op_registry",
|
||||||
|
srcs = ["framework/grad_op_registry.cc"],
|
||||||
|
hdrs = ["framework/grad_op_registry.h"],
|
||||||
|
deps = [
|
||||||
|
":ops",
|
||||||
|
":scope",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "math_grad",
|
||||||
|
srcs = ["gradients/math_grad.cc"],
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":grad_op_registry",
|
||||||
|
":ops",
|
||||||
|
":scope",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "gradients/math_grad_test",
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":grad_op_registry",
|
||||||
|
":math_grad",
|
||||||
|
"//tensorflow/core:all_kernels",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_wrappers_cc(
|
tf_gen_op_wrappers_cc(
|
||||||
name = "cc_ops",
|
name = "cc_ops",
|
||||||
op_lib_names = [
|
op_lib_names = [
|
||||||
|
42
tensorflow/cc/framework/grad_op_registry.cc
Normal file
42
tensorflow/cc/framework/grad_op_registry.cc
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
// static
|
||||||
|
GradOpRegistry* GradOpRegistry::Global() {
|
||||||
|
static GradOpRegistry* grad_op_registry = new GradOpRegistry;
|
||||||
|
return grad_op_registry;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GradOpRegistry::Register(const string& op, GradFunc func) {
|
||||||
|
CHECK(registry_.insert({op, func}).second) << "Existing gradient for " << op;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GradOpRegistry::Lookup(const string& op, GradFunc* func) {
|
||||||
|
auto iter = registry_.find(op);
|
||||||
|
if (iter == registry_.end()) {
|
||||||
|
return errors::NotFound("No gradient defined for op: ", op);
|
||||||
|
}
|
||||||
|
*func = iter->second;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace ops
|
||||||
|
} // namespace tensorflow
|
75
tensorflow/cc/framework/grad_op_registry.h
Normal file
75
tensorflow/cc/framework/grad_op_registry.h
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
|
#include "tensorflow/cc/framework/scope.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
// GradFunc is the signature for all gradient functions in GradOpRegistry.
|
||||||
|
// Implementations should add operations to compute the gradient outputs of 'op'
|
||||||
|
// (returned in 'grad_outputs') using 'scope' and 'grad_inputs'.
|
||||||
|
typedef Status (*GradFunc)(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs);
|
||||||
|
|
||||||
|
// GradOpRegistry maintains a static registry of gradient functions.
|
||||||
|
// Gradient functions are indexed in the registry by the forward op name (i.e.
|
||||||
|
// "MatMul" -> MatMulGrad func).
|
||||||
|
class GradOpRegistry {
|
||||||
|
public:
|
||||||
|
// Registers 'func' as the the gradient function for 'op'.
|
||||||
|
// Returns true if registration was succesful, check fails otherwise.
|
||||||
|
bool Register(const string& op, GradFunc func);
|
||||||
|
|
||||||
|
// Sets 'func' to the gradient function for 'op' and returns Status OK if
|
||||||
|
// the gradient function for 'op' exists in the registry.
|
||||||
|
// Note that 'func' can be null for ops that have registered no-gradient with
|
||||||
|
// the registry.
|
||||||
|
// Returns error status otherwise.
|
||||||
|
Status Lookup(const string& op, GradFunc* func);
|
||||||
|
|
||||||
|
// Returns a pointer to the global gradient function registry.
|
||||||
|
static GradOpRegistry* Global();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<string, GradFunc> registry_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
|
||||||
|
// Macros used to define gradient functions for ops.
|
||||||
|
#define REGISTER_GRADIENT_OP(name, fn) \
|
||||||
|
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, fn)
|
||||||
|
|
||||||
|
#define REGISTER_NO_GRADIENT_OP(name) \
|
||||||
|
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, nullptr)
|
||||||
|
|
||||||
|
#define REGISTER_GRADIENT_OP_UNIQ_HELPER(ctr, name, fn) \
|
||||||
|
REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn)
|
||||||
|
|
||||||
|
#define REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn) \
|
||||||
|
static bool unused_ret_val_##ctr = \
|
||||||
|
::tensorflow::ops::GradOpRegistry::Global()->Register(name, fn)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
|
@ -18,6 +18,44 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {}
|
||||||
|
|
||||||
|
Output Operation::input(int i) const {
|
||||||
|
CHECK_NOTNULL(node_);
|
||||||
|
CHECK_GE(i, 0);
|
||||||
|
CHECK_LT(i, node_->num_inputs());
|
||||||
|
// Handle the case where the input was unknown at the time this
|
||||||
|
// Operation was constructed.
|
||||||
|
if (inputs_[i].first == nullptr && inputs_[i].second == -1) {
|
||||||
|
for (const Edge* e : node_->in_edges()) {
|
||||||
|
if (e->IsControlEdge()) continue;
|
||||||
|
if (e->dst_input() == i) {
|
||||||
|
return Output(e->src(), e->src_output());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Output(inputs_[i].first, inputs_[i].second);
|
||||||
|
}
|
||||||
|
|
||||||
|
Output Operation::output(int i) const {
|
||||||
|
CHECK_NOTNULL(node_);
|
||||||
|
CHECK_GE(i, 0);
|
||||||
|
CHECK_LT(i, node_->num_outputs());
|
||||||
|
return Output(node_, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation::Inputs Operation::GetInputs(Node* node) {
|
||||||
|
Operation::Inputs inputs;
|
||||||
|
if (node != nullptr) {
|
||||||
|
inputs.resize(node->num_inputs(), {nullptr, -1});
|
||||||
|
for (const Edge* e : node->in_edges()) {
|
||||||
|
if (e->IsControlEdge()) continue;
|
||||||
|
inputs[e->dst_input()] = std::make_pair(e->src(), e->src_output());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
|
||||||
Input::Initializer::Initializer(
|
Input::Initializer::Initializer(
|
||||||
const std::initializer_list<Input::Initializer>& v) {
|
const std::initializer_list<Input::Initializer>& v) {
|
||||||
if (v.size() < 1) {
|
if (v.size() < 1) {
|
||||||
|
@ -27,17 +27,29 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
class Output;
|
||||||
|
|
||||||
// Represents a node in the computation graph.
|
// Represents a node in the computation graph.
|
||||||
class Operation {
|
class Operation {
|
||||||
public:
|
public:
|
||||||
Operation() : node_(nullptr) {}
|
Operation() : node_(nullptr) {}
|
||||||
explicit Operation(Node* n) : node_(n) {}
|
explicit Operation(Node* n);
|
||||||
|
|
||||||
|
int num_inputs() const { return node_->num_inputs(); }
|
||||||
|
DataType input_type(int o) const { return node_->input_type(o); }
|
||||||
|
Output input(int i) const;
|
||||||
|
|
||||||
int num_outputs() const { return node_->num_outputs(); }
|
int num_outputs() const { return node_->num_outputs(); }
|
||||||
DataType output_type(int o) const { return node_->output_type(o); }
|
DataType output_type(int o) const { return node_->output_type(o); }
|
||||||
|
Output output(int i) const;
|
||||||
|
|
||||||
Node* node() const { return node_; }
|
Node* node() const { return node_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
typedef std::vector<std::pair<Node*, int64>> Inputs;
|
||||||
|
static Inputs GetInputs(Node* node);
|
||||||
|
|
||||||
|
Inputs inputs_;
|
||||||
Node* node_;
|
Node* node_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -81,7 +93,7 @@ class Input {
|
|||||||
tensor = t;
|
tensor = t;
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit Initializer(const Tensor& t) : tensor(t) {}
|
Initializer(const Tensor& t) : tensor(t) {} // NOLINT(runtime/explicit)
|
||||||
|
|
||||||
// Construct from a scalar value and an explicit shape
|
// Construct from a scalar value and an explicit shape
|
||||||
template <typename T, typename = typename std::enable_if<
|
template <typename T, typename = typename std::enable_if<
|
||||||
|
91
tensorflow/cc/gradients/math_grad.cc
Normal file
91
tensorflow/cc/gradients/math_grad.cc
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// TODO(andydavis) Move this to a more appropriate file.
|
||||||
|
REGISTER_NO_GRADIENT_OP("Const");
|
||||||
|
|
||||||
|
// MatMulGrad helper function used to compute two MatMul operations
|
||||||
|
// based on input matrix transposition combinations.
|
||||||
|
Status MatMulGradHelper(const Scope& scope, const Output& x0, const bool adj_x0,
|
||||||
|
const Output& x1, const bool adj_x1, const Output& y0,
|
||||||
|
const bool adj_y0, const Output& y1, const bool adj_y1,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
auto dx =
|
||||||
|
MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
|
||||||
|
grad_outputs->push_back(dx);
|
||||||
|
auto dy =
|
||||||
|
MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
|
||||||
|
grad_outputs->push_back(dy);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatMulGrad common used to read and check node attr state, and determine
|
||||||
|
// proper MatMul products for gradients based on input matrix transposition
|
||||||
|
// combinations.
|
||||||
|
// TODO(andydavis) Re-use this function for BatchMatMulGrad.
|
||||||
|
Status MatMulGradCommon(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
const string& attr_adj_x, const string& attr_adj_y,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
DataType dtype;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
|
||||||
|
if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
|
||||||
|
return errors::Unimplemented(
|
||||||
|
"MatMul gradient for complex data type is not supported yet.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ta;
|
||||||
|
bool tb;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
|
||||||
|
|
||||||
|
if (!ta && !tb) {
|
||||||
|
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), true,
|
||||||
|
op.input(0), true, grad_inputs[0], false,
|
||||||
|
grad_outputs);
|
||||||
|
} else if (!ta && tb) {
|
||||||
|
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), false,
|
||||||
|
grad_inputs[0], true, op.input(0), false,
|
||||||
|
grad_outputs);
|
||||||
|
} else if (ta && !tb) {
|
||||||
|
return MatMulGradHelper(scope, op.input(1), false, grad_inputs[0], true,
|
||||||
|
op.input(0), false, grad_inputs[0], false,
|
||||||
|
grad_outputs);
|
||||||
|
}
|
||||||
|
return MatMulGradHelper(scope, op.input(1), true, grad_inputs[0], true,
|
||||||
|
grad_inputs[0], true, op.input(0), true,
|
||||||
|
grad_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MatMulGrad(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
return MatMulGradCommon(scope, op, grad_inputs, "transpose_a", "transpose_b",
|
||||||
|
grad_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tensorflow
|
183
tensorflow/cc/gradients/math_grad_test.cc
Normal file
183
tensorflow/cc/gradients/math_grad_test.cc
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/graph/default_device.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
using namespace ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// TODO(andydavis) Test gradient function against numeric gradients output.
|
||||||
|
// TODO(andydavis) As more gradients are added move common test functions
|
||||||
|
// to a testutil library.
|
||||||
|
class MathGradTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
MathGradTest() : root_(Scope::NewRootScope()) {}
|
||||||
|
|
||||||
|
void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y,
|
||||||
|
const bool t_y, const Output& dz,
|
||||||
|
std::vector<Tensor>* out) {
|
||||||
|
// Compute forward MatMul: z = MatMul(x, y).
|
||||||
|
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
|
||||||
|
TF_EXPECT_OK(root_.status());
|
||||||
|
CHECK_NOTNULL(z.node());
|
||||||
|
std::vector<Output> grad_outputs;
|
||||||
|
// Call MatMulGrad which populates 'grad_outputs'.
|
||||||
|
CallGradFunction(Operation(z.node()), {dz}, &grad_outputs);
|
||||||
|
EXPECT_EQ(2, grad_outputs.size());
|
||||||
|
// Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
|
||||||
|
GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallGradFunction(const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
GradFunc grad_fn;
|
||||||
|
TF_EXPECT_OK(GradOpRegistry::Global()->Lookup(op.node()->name(), &grad_fn));
|
||||||
|
TF_EXPECT_OK(grad_fn(root_, op, grad_inputs, grad_outputs));
|
||||||
|
TF_EXPECT_OK(root_.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y,
|
||||||
|
const bool t_y) {
|
||||||
|
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
|
||||||
|
TF_EXPECT_OK(root_.status());
|
||||||
|
Tensor out;
|
||||||
|
GetTensor(root_, z, &out);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RandMatMulGradData(const bool tx, const bool ty,
|
||||||
|
std::vector<Tensor>* data) {
|
||||||
|
// z = MatMul(x, y)
|
||||||
|
const int m = Rand();
|
||||||
|
const int k = Rand();
|
||||||
|
const int n = Rand();
|
||||||
|
// x.shape = [m, k]
|
||||||
|
const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
|
||||||
|
data->emplace_back(DT_FLOAT, x_shape);
|
||||||
|
RandTensor(&data->back());
|
||||||
|
// y.shape = [k, n]
|
||||||
|
const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
|
||||||
|
data->emplace_back(DT_FLOAT, y_shape);
|
||||||
|
RandTensor(&data->back());
|
||||||
|
// z.shape = [m, n]
|
||||||
|
data->emplace_back(DT_FLOAT, TensorShape({m, n}));
|
||||||
|
RandTensor(&data->back());
|
||||||
|
}
|
||||||
|
|
||||||
|
void RandTensor(Tensor* t) {
|
||||||
|
test::FillFn<float>(
|
||||||
|
t, [this](const int i) { return static_cast<float>(Rand()); });
|
||||||
|
}
|
||||||
|
|
||||||
|
int Rand() { return 1 + (random::New64() % 10); }
|
||||||
|
|
||||||
|
// TODO(andydavis) Move 'GetTensors/GetTensor' to some testutil class.
|
||||||
|
// Note: they should be moved to a general/non-grad specific testutil class.
|
||||||
|
void GetTensors(const Scope& scope, OutputList tensors,
|
||||||
|
std::vector<Tensor>* out) {
|
||||||
|
SessionOptions options;
|
||||||
|
std::unique_ptr<Session> session(NewSession(options));
|
||||||
|
GraphDef def;
|
||||||
|
scope.graph()->ToGraphDef(&def);
|
||||||
|
|
||||||
|
graph::SetDefaultDevice("/cpu:0", &def);
|
||||||
|
|
||||||
|
TF_CHECK_OK(session->Create(def));
|
||||||
|
std::vector<string> names;
|
||||||
|
for (const auto& t : tensors) {
|
||||||
|
names.push_back(strings::StrCat(t.node()->name(), ":", t.index()));
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(session->Run({}, names, {}, out));
|
||||||
|
TF_CHECK_OK(session->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
GetTensors(scope, {tensor}, &outputs);
|
||||||
|
*out = outputs[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
Scope root_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
|
||||||
|
std::vector<Tensor> data;
|
||||||
|
RandMatMulGradData(false, false, &data);
|
||||||
|
auto x = Const(root_, data[0]);
|
||||||
|
auto y = Const(root_, data[1]);
|
||||||
|
auto dz = Const(root_, data[2]);
|
||||||
|
|
||||||
|
std::vector<Tensor> grad_outputs;
|
||||||
|
ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs);
|
||||||
|
|
||||||
|
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true));
|
||||||
|
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MathGradTest, MatMulGrad_TransposeX) {
|
||||||
|
std::vector<Tensor> data;
|
||||||
|
RandMatMulGradData(true, false, &data);
|
||||||
|
auto x = Const(root_, data[0]);
|
||||||
|
auto y = Const(root_, data[1]);
|
||||||
|
auto dz = Const(root_, data[2]);
|
||||||
|
|
||||||
|
std::vector<Tensor> grad_outputs;
|
||||||
|
ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs);
|
||||||
|
|
||||||
|
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true));
|
||||||
|
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MathGradTest, MatMulGrad_TransposeY) {
|
||||||
|
std::vector<Tensor> data;
|
||||||
|
RandMatMulGradData(false, true, &data);
|
||||||
|
auto x = Const(root_, data[0]);
|
||||||
|
auto y = Const(root_, data[1]);
|
||||||
|
auto dz = Const(root_, data[2]);
|
||||||
|
|
||||||
|
std::vector<Tensor> grad_outputs;
|
||||||
|
ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs);
|
||||||
|
|
||||||
|
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false));
|
||||||
|
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
|
||||||
|
std::vector<Tensor> data;
|
||||||
|
RandMatMulGradData(true, true, &data);
|
||||||
|
auto x = Const(root_, data[0]);
|
||||||
|
auto y = Const(root_, data[1]);
|
||||||
|
auto dz = Const(root_, data[2]);
|
||||||
|
|
||||||
|
std::vector<Tensor> grad_outputs;
|
||||||
|
ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs);
|
||||||
|
|
||||||
|
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true));
|
||||||
|
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -39,6 +39,7 @@ set (DOWNLOAD_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/downloads"
|
|||||||
mark_as_advanced(DOWNLOAD_LOCATION)
|
mark_as_advanced(DOWNLOAD_LOCATION)
|
||||||
|
|
||||||
# External dependencies
|
# External dependencies
|
||||||
|
include(gif)
|
||||||
include(png)
|
include(png)
|
||||||
include(jpeg)
|
include(jpeg)
|
||||||
include(re2)
|
include(re2)
|
||||||
|
38
tensorflow/contrib/cmake/external/gif.cmake
vendored
Normal file
38
tensorflow/contrib/cmake/external/gif.cmake
vendored
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
include (ExternalProject)
|
||||||
|
|
||||||
|
set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive)
|
||||||
|
set(gif_URL http://ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
|
||||||
|
set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1)
|
||||||
|
set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install)
|
||||||
|
set(gif_STATIC_LIBRARIES ${gif_INSTALL}/lib/libgif.a)
|
||||||
|
|
||||||
|
set(gif_HEADERS
|
||||||
|
"${gif_INSTALL}/include/gif_lib.h"
|
||||||
|
)
|
||||||
|
|
||||||
|
ExternalProject_Add(gif
|
||||||
|
PREFIX gif
|
||||||
|
URL ${gif_URL}
|
||||||
|
URL_HASH ${gif_HASH}
|
||||||
|
INSTALL_DIR ${gif_INSTALL}
|
||||||
|
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||||
|
BUILD_COMMAND $(MAKE)
|
||||||
|
INSTALL_COMMAND $(MAKE) install
|
||||||
|
CONFIGURE_COMMAND
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/gif/src/gif/configure
|
||||||
|
--prefix=${gif_INSTALL}
|
||||||
|
--enable-shared=yes
|
||||||
|
)
|
||||||
|
|
||||||
|
# put gif includes in the directory where they are expected
|
||||||
|
add_custom_target(gif_create_destination_dir
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E make_directory ${gif_INCLUDE_DIR}/giflib-5.1.4/lib
|
||||||
|
DEPENDS gif)
|
||||||
|
|
||||||
|
add_custom_target(gif_copy_headers_to_destination
|
||||||
|
DEPENDS gif_create_destination_dir)
|
||||||
|
|
||||||
|
foreach(header_file ${gif_HEADERS})
|
||||||
|
add_custom_command(TARGET gif_copy_headers_to_destination PRE_BUILD
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E copy ${header_file} ${gif_INCLUDE_DIR}/giflib-5.1.4/lib/)
|
||||||
|
endforeach()
|
@ -1,10 +1,39 @@
|
|||||||
|
########################################################
|
||||||
|
# tf_cc_framework library
|
||||||
|
########################################################
|
||||||
|
set(tf_cc_framework_srcs
|
||||||
|
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.h"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.cc"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.h"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.cc"
|
||||||
|
)
|
||||||
|
|
||||||
|
add_library(tf_cc_framework OBJECT ${tf_cc_framework_srcs})
|
||||||
|
|
||||||
|
add_dependencies(tf_cc_framework tf_core_framework)
|
||||||
|
|
||||||
|
target_include_directories(tf_cc_framework PRIVATE
|
||||||
|
${tensorflow_source_dir}
|
||||||
|
${eigen_INCLUDE_DIRS}
|
||||||
|
)
|
||||||
|
|
||||||
|
target_compile_options(tf_cc_framework PRIVATE
|
||||||
|
-fno-exceptions
|
||||||
|
-DEIGEN_AVOID_STL_ARRAY
|
||||||
|
)
|
||||||
|
|
||||||
|
# C++11
|
||||||
|
target_compile_features(tf_cc_framework PRIVATE
|
||||||
|
cxx_rvalue_references
|
||||||
|
)
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
# tf_cc_op_gen_main library
|
# tf_cc_op_gen_main library
|
||||||
########################################################
|
########################################################
|
||||||
set(tf_cc_op_gen_main_srcs
|
set(tf_cc_op_gen_main_srcs
|
||||||
"${tensorflow_source_dir}/tensorflow/cc/ops/cc_op_gen.cc"
|
"${tensorflow_source_dir}/tensorflow/cc/framework/cc_op_gen.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/cc/ops/cc_op_gen_main.cc"
|
"${tensorflow_source_dir}/tensorflow/cc/framework/cc_op_gen_main.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/cc/ops/cc_op_gen.h"
|
"${tensorflow_source_dir}/tensorflow/cc/framework/cc_op_gen.h"
|
||||||
)
|
)
|
||||||
|
|
||||||
add_library(tf_cc_op_gen_main OBJECT ${tf_cc_op_gen_main_srcs})
|
add_library(tf_cc_op_gen_main OBJECT ${tf_cc_op_gen_main_srcs})
|
||||||
@ -120,6 +149,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names})
|
|||||||
${PROTOBUF_LIBRARIES}
|
${PROTOBUF_LIBRARIES}
|
||||||
tf_protos_cc
|
tf_protos_cc
|
||||||
re2_lib
|
re2_lib
|
||||||
|
${gif_STATIC_LIBRARIES}
|
||||||
${jpeg_STATIC_LIBRARIES}
|
${jpeg_STATIC_LIBRARIES}
|
||||||
${png_STATIC_LIBRARIES}
|
${png_STATIC_LIBRARIES}
|
||||||
${ZLIB_LIBRARIES}
|
${ZLIB_LIBRARIES}
|
||||||
|
@ -4,8 +4,17 @@
|
|||||||
file(GLOB tf_core_direct_session_srcs
|
file(GLOB tf_core_direct_session_srcs
|
||||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.cc"
|
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.h"
|
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.h"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/core/debug/*.h"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/core/debug/*.cc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
file(GLOB_RECURSE tf_core_direct_session_test_srcs
|
||||||
|
"${tensorflow_source_dir}/tensorflow/core/debug/*test*.h"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/core/debug/*test*.cc"
|
||||||
|
)
|
||||||
|
|
||||||
|
list(REMOVE_ITEM tf_core_direct_session_srcs ${tf_core_direct_session_test_srcs})
|
||||||
|
|
||||||
add_library(tf_core_direct_session OBJECT ${tf_core_direct_session_srcs})
|
add_library(tf_core_direct_session OBJECT ${tf_core_direct_session_srcs})
|
||||||
|
|
||||||
add_dependencies(tf_core_direct_session tf_core_cpu)
|
add_dependencies(tf_core_direct_session tf_core_cpu)
|
||||||
|
@ -150,6 +150,7 @@ list(REMOVE_ITEM tf_core_lib_srcs ${tf_core_lib_test_srcs})
|
|||||||
add_library(tf_core_lib OBJECT ${tf_core_lib_srcs})
|
add_library(tf_core_lib OBJECT ${tf_core_lib_srcs})
|
||||||
target_include_directories(tf_core_lib PUBLIC
|
target_include_directories(tf_core_lib PUBLIC
|
||||||
${tensorflow_source_dir}
|
${tensorflow_source_dir}
|
||||||
|
${gif_INCLUDE_DIR}
|
||||||
${jpeg_INCLUDE_DIR}
|
${jpeg_INCLUDE_DIR}
|
||||||
${png_INCLUDE_DIR}
|
${png_INCLUDE_DIR}
|
||||||
${eigen_INCLUDE_DIRS}
|
${eigen_INCLUDE_DIRS}
|
||||||
@ -168,6 +169,7 @@ target_compile_features(tf_core_lib PRIVATE
|
|||||||
)
|
)
|
||||||
|
|
||||||
add_dependencies(tf_core_lib
|
add_dependencies(tf_core_lib
|
||||||
|
gif_copy_headers_to_destination
|
||||||
jpeg_copy_headers_to_destination
|
jpeg_copy_headers_to_destination
|
||||||
png_copy_headers_to_destination
|
png_copy_headers_to_destination
|
||||||
re2_copy_headers_to_destination
|
re2_copy_headers_to_destination
|
||||||
|
@ -71,7 +71,7 @@ target_include_directories(tf_models_word2vec_kernels PRIVATE
|
|||||||
${re2_INCLUDES}
|
${re2_INCLUDES}
|
||||||
)
|
)
|
||||||
|
|
||||||
add_dependencies(tf_models_word2vec_ops
|
add_dependencies(tf_models_word2vec_kernels
|
||||||
tf_core_cpu
|
tf_core_cpu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ target_link_libraries(${proto_text} PUBLIC
|
|||||||
${PROTOBUF_LIBRARIES}
|
${PROTOBUF_LIBRARIES}
|
||||||
# tf_protos_cc
|
# tf_protos_cc
|
||||||
# re2_lib
|
# re2_lib
|
||||||
|
${gif_STATIC_LIBRARIES}
|
||||||
${jpeg_STATIC_LIBRARIES}
|
${jpeg_STATIC_LIBRARIES}
|
||||||
${png_STATIC_LIBRARIES}
|
${png_STATIC_LIBRARIES}
|
||||||
${ZLIB_LIBRARIES}
|
${ZLIB_LIBRARIES}
|
||||||
|
@ -23,6 +23,7 @@ add_executable(tf_tutorials_example_trainer
|
|||||||
$<TARGET_OBJECTS:tf_core_cpu>
|
$<TARGET_OBJECTS:tf_core_cpu>
|
||||||
$<TARGET_OBJECTS:tf_core_framework>
|
$<TARGET_OBJECTS:tf_core_framework>
|
||||||
$<TARGET_OBJECTS:tf_core_kernels>
|
$<TARGET_OBJECTS:tf_core_kernels>
|
||||||
|
$<TARGET_OBJECTS:tf_cc_framework>
|
||||||
$<TARGET_OBJECTS:tf_cc_ops>
|
$<TARGET_OBJECTS:tf_cc_ops>
|
||||||
$<TARGET_OBJECTS:tf_core_ops>
|
$<TARGET_OBJECTS:tf_core_ops>
|
||||||
$<TARGET_OBJECTS:tf_core_direct_session>
|
$<TARGET_OBJECTS:tf_core_direct_session>
|
||||||
@ -40,6 +41,7 @@ target_link_libraries(tf_tutorials_example_trainer PUBLIC
|
|||||||
re2_lib
|
re2_lib
|
||||||
${boringssl_STATIC_LIBRARIES}
|
${boringssl_STATIC_LIBRARIES}
|
||||||
${farmhash_STATIC_LIBRARIES}
|
${farmhash_STATIC_LIBRARIES}
|
||||||
|
${gif_STATIC_LIBRARIES}
|
||||||
${jpeg_STATIC_LIBRARIES}
|
${jpeg_STATIC_LIBRARIES}
|
||||||
${jsoncpp_STATIC_LIBRARIES}
|
${jsoncpp_STATIC_LIBRARIES}
|
||||||
${png_STATIC_LIBRARIES}
|
${png_STATIC_LIBRARIES}
|
||||||
|
@ -54,6 +54,29 @@ cuda_py_tests(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_tests(
|
||||||
|
name = "operator_pd_identity_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/kernel_tests/operator_pd_identity_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":distributions_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_tests(
|
||||||
|
name = "operator_pd_vdvt_update_test",
|
||||||
|
size = "large",
|
||||||
|
srcs = ["python/kernel_tests/operator_pd_vdvt_update_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":distributions_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
tags = ["notap"], # http://b/30441813
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "distributions_py",
|
name = "distributions_py",
|
||||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||||
@ -76,7 +99,16 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/beta_test.py"],
|
srcs = ["python/kernel_tests/beta_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_tests(
|
||||||
|
name = "binomial_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/kernel_tests/binomial_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":distributions_py",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
tags = ["notsan"],
|
tags = ["notsan"],
|
||||||
@ -156,9 +188,8 @@ cuda_py_tests(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
name = "kullback_leibler_test",
|
name = "laplace_test",
|
||||||
size = "small",
|
srcs = ["python/kernel_tests/laplace_test.py"],
|
||||||
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
|
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
@ -167,13 +198,14 @@ cuda_py_tests(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
name = "laplace_test",
|
name = "multinomial_test",
|
||||||
srcs = ["python/kernel_tests/laplace_test.py"],
|
srcs = ["python/kernel_tests/multinomial_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
tags = ["notsan"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
@ -216,6 +248,15 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/uniform_test.py"],
|
srcs = ["python/kernel_tests/uniform_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_tests(
|
||||||
|
name = "kullback_leibler_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
|
||||||
|
additional_deps = [
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -240,6 +281,28 @@ cuda_py_tests(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_tests(
|
||||||
|
name = "shape_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/kernel_tests/shape_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":distributions_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_tests(
|
||||||
|
name = "bijector_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/kernel_tests/bijector_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":distributions_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "all_files",
|
name = "all_files",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
|
@ -25,6 +25,7 @@ initialized with parameters that define the distributions.
|
|||||||
|
|
||||||
### Univariate (scalar) distributions
|
### Univariate (scalar) distributions
|
||||||
|
|
||||||
|
@@Binomial
|
||||||
@@Bernoulli
|
@@Bernoulli
|
||||||
@@Beta
|
@@Beta
|
||||||
@@Categorical
|
@@Categorical
|
||||||
@ -50,6 +51,7 @@ initialized with parameters that define the distributions.
|
|||||||
|
|
||||||
@@Dirichlet
|
@@Dirichlet
|
||||||
@@DirichletMultinomial
|
@@DirichletMultinomial
|
||||||
|
@@Multinomial
|
||||||
|
|
||||||
### Transformed distributions
|
### Transformed distributions
|
||||||
|
|
||||||
@ -79,6 +81,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bernoulli import *
|
from tensorflow.contrib.distributions.python.ops.bernoulli import *
|
||||||
from tensorflow.contrib.distributions.python.ops.beta import *
|
from tensorflow.contrib.distributions.python.ops.beta import *
|
||||||
|
from tensorflow.contrib.distributions.python.ops.binomial import *
|
||||||
from tensorflow.contrib.distributions.python.ops.categorical import *
|
from tensorflow.contrib.distributions.python.ops.categorical import *
|
||||||
from tensorflow.contrib.distributions.python.ops.chi2 import *
|
from tensorflow.contrib.distributions.python.ops.chi2 import *
|
||||||
from tensorflow.contrib.distributions.python.ops.dirichlet import *
|
from tensorflow.contrib.distributions.python.ops.dirichlet import *
|
||||||
@ -89,6 +92,7 @@ from tensorflow.contrib.distributions.python.ops.gamma import *
|
|||||||
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
|
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
|
||||||
from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
|
from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
|
||||||
from tensorflow.contrib.distributions.python.ops.laplace import *
|
from tensorflow.contrib.distributions.python.ops.laplace import *
|
||||||
|
from tensorflow.contrib.distributions.python.ops.multinomial import *
|
||||||
from tensorflow.contrib.distributions.python.ops.mvn import *
|
from tensorflow.contrib.distributions.python.ops.mvn import *
|
||||||
from tensorflow.contrib.distributions.python.ops.normal import *
|
from tensorflow.contrib.distributions.python.ops.normal import *
|
||||||
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
|
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
|
||||||
|
@ -57,10 +57,17 @@ class BernoulliTest(tf.test.TestCase):
|
|||||||
self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
|
self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
|
||||||
|
|
||||||
def testInvalidP(self):
|
def testInvalidP(self):
|
||||||
invalid_ps = [1.01, -0.01, 2., -3.]
|
invalid_ps = [1.01, 2.]
|
||||||
for p in invalid_ps:
|
for p in invalid_ps:
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
with self.assertRaisesOpError("x <= y"):
|
with self.assertRaisesOpError("p has components greater than 1"):
|
||||||
|
dist = tf.contrib.distributions.Bernoulli(p=p)
|
||||||
|
dist.p.eval()
|
||||||
|
|
||||||
|
invalid_ps = [-0.01, -3.]
|
||||||
|
for p in invalid_ps:
|
||||||
|
with self.test_session():
|
||||||
|
with self.assertRaisesOpError("Condition x >= 0"):
|
||||||
dist = tf.contrib.distributions.Bernoulli(p=p)
|
dist = tf.contrib.distributions.Bernoulli(p=p)
|
||||||
dist.p.eval()
|
dist.p.eval()
|
||||||
|
|
||||||
|
@ -0,0 +1,67 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for Bijector."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops.bijector import _Exp # pylint: disable=line-too-long
|
||||||
|
from tensorflow.contrib.distributions.python.ops.bijector import _Identity # pylint: disable=line-too-long
|
||||||
|
from tensorflow.contrib.distributions.python.ops.shape import _ShapeUtil # pylint: disable=line-too-long
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityBijectorTest(tf.test.TestCase):
|
||||||
|
"""Tests the correctness of the Y = g(X) = X transformation."""
|
||||||
|
|
||||||
|
def testBijector(self):
|
||||||
|
with self.test_session():
|
||||||
|
bijector = _Identity(_ShapeUtil(batch_ndims=1, event_ndims=1))
|
||||||
|
self.assertEqual(bijector.name, 'Identity')
|
||||||
|
x = [[[0.], [1]]]
|
||||||
|
self.assertAllEqual(bijector.forward(x).eval(), x)
|
||||||
|
self.assertAllEqual(bijector.inverse(x).eval(), x)
|
||||||
|
self.assertAllEqual(bijector.inverse_log_det_jacobian(x).eval(),
|
||||||
|
[[0., 0]])
|
||||||
|
rev, jac = bijector.inverse_and_inverse_log_det_jacobian(x)
|
||||||
|
self.assertAllEqual(rev.eval(), x)
|
||||||
|
self.assertAllEqual(jac.eval(), [[0., 0]])
|
||||||
|
|
||||||
|
|
||||||
|
class ExpBijectorTest(tf.test.TestCase):
|
||||||
|
"""Tests the correctness of the Y = g(X) = exp(X) transformation."""
|
||||||
|
|
||||||
|
def testBijector(self):
|
||||||
|
with self.test_session():
|
||||||
|
bijector = _Exp(_ShapeUtil(batch_ndims=1, event_ndims=1))
|
||||||
|
self.assertEqual(bijector.name, 'Exp')
|
||||||
|
x = [[[1.], [2]]]
|
||||||
|
self.assertAllClose(bijector.forward(x).eval(),
|
||||||
|
[[[math.exp(1.)], [math.exp(2.)]]])
|
||||||
|
self.assertAllClose(bijector.inverse(x).eval(),
|
||||||
|
[[[math.log(1.)], [math.log(2.)]]])
|
||||||
|
self.assertAllClose(bijector.inverse_log_det_jacobian(x).eval(),
|
||||||
|
[[0., -math.log(2.)]])
|
||||||
|
rev, jac = bijector.inverse_and_inverse_log_det_jacobian(x)
|
||||||
|
self.assertAllClose(rev.eval(), [[[math.log(1.)], [math.log(2.)]]])
|
||||||
|
self.assertAllClose(jac.eval(), [[0., -math.log(2.)]])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -0,0 +1,173 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy import stats
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class BinomialTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testSimpleShapes(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = np.float32(np.random.beta(1, 1))
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=1., p=p)
|
||||||
|
self.assertAllEqual([], binom.event_shape().eval())
|
||||||
|
self.assertAllEqual([], binom.batch_shape().eval())
|
||||||
|
self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
|
||||||
|
self.assertEqual(tf.TensorShape([]), binom.get_batch_shape())
|
||||||
|
|
||||||
|
def testComplexShapes(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = np.random.beta(1, 1, size=(3, 2)).astype(np.float32)
|
||||||
|
n = [[3., 2], [4, 5], [6, 7]]
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
self.assertAllEqual([], binom.event_shape().eval())
|
||||||
|
self.assertAllEqual([3, 2], binom.batch_shape().eval())
|
||||||
|
self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
|
||||||
|
self.assertEqual(tf.TensorShape([3, 2]), binom.get_batch_shape())
|
||||||
|
|
||||||
|
def testNProperty(self):
|
||||||
|
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
|
||||||
|
n = [[3.], [4]]
|
||||||
|
with self.test_session():
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
self.assertEqual((2, 1), binom.n.get_shape())
|
||||||
|
self.assertAllClose(n, binom.n.eval())
|
||||||
|
|
||||||
|
def testPProperty(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
with self.test_session():
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=3., p=p)
|
||||||
|
self.assertEqual((1, 3), binom.p.get_shape())
|
||||||
|
self.assertEqual((1, 3), binom.logits.get_shape())
|
||||||
|
self.assertAllClose(p, binom.p.eval())
|
||||||
|
|
||||||
|
def testLogitsProperty(self):
|
||||||
|
logits = [[0., 9., -0.5]]
|
||||||
|
with self.test_session():
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=3., logits=logits)
|
||||||
|
self.assertEqual((1, 3), binom.p.get_shape())
|
||||||
|
self.assertEqual((1, 3), binom.logits.get_shape())
|
||||||
|
self.assertAllClose(logits, binom.logits.eval())
|
||||||
|
|
||||||
|
def testPmfNandCountsAgree(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
n = [[5.]]
|
||||||
|
with self.test_session():
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
binom.pmf([2., 3, 2]).eval()
|
||||||
|
binom.pmf([3., 1, 2]).eval()
|
||||||
|
with self.assertRaisesOpError('Condition x >= 0.*'):
|
||||||
|
binom.pmf([-1., 4, 2]).eval()
|
||||||
|
with self.assertRaisesOpError('Condition x <= y.*'):
|
||||||
|
binom.pmf([7., 3, 0]).eval()
|
||||||
|
|
||||||
|
def testPmf_non_integer_counts(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
n = [[5.]]
|
||||||
|
with self.test_session():
|
||||||
|
# No errors with integer n.
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
binom.pmf([2., 3, 2]).eval()
|
||||||
|
binom.pmf([3., 1, 2]).eval()
|
||||||
|
# Both equality and integer checking fail.
|
||||||
|
with self.assertRaisesOpError('Condition x == y.*'):
|
||||||
|
binom.pmf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p, validate_args=False)
|
||||||
|
binom.pmf([1., 2., 3.]).eval()
|
||||||
|
# Non-integer arguments work.
|
||||||
|
binom.pmf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
|
def testPmfBothZeroBatches(self):
|
||||||
|
with self.test_session():
|
||||||
|
# Both zero-batches. No broadcast
|
||||||
|
p = 0.5
|
||||||
|
counts = 1.
|
||||||
|
pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose(0.5, pmf.eval())
|
||||||
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfBothZeroBatchesNontrivialN(self):
|
||||||
|
with self.test_session():
|
||||||
|
# Both zero-batches. No broadcast
|
||||||
|
p = 0.1
|
||||||
|
counts = 3.
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=5., p=p)
|
||||||
|
pmf = binom.pmf(counts)
|
||||||
|
self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval())
|
||||||
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfPStretchedInBroadcastWhenSameRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [[0.1, 0.9]]
|
||||||
|
counts = [[1., 2.]]
|
||||||
|
pmf = tf.contrib.distributions.Binomial(n=3., p=p).pmf(counts)
|
||||||
|
self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval())
|
||||||
|
self.assertEqual((1, 2), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfPStretchedInBroadcastWhenLowerRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [0.1, 0.4]
|
||||||
|
counts = [[1.], [0.]]
|
||||||
|
pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval())
|
||||||
|
self.assertEqual((2, 2), pmf.get_shape())
|
||||||
|
|
||||||
|
def testBinomialMean(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 5.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
expected_means = stats.binom.mean(n, p)
|
||||||
|
self.assertEqual((3,), binom.mean().get_shape())
|
||||||
|
self.assertAllClose(expected_means, binom.mean().eval())
|
||||||
|
|
||||||
|
def testBinomialVariance(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 5.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
expected_variances = stats.binom.var(n, p)
|
||||||
|
self.assertEqual((3,), binom.variance().get_shape())
|
||||||
|
self.assertAllClose(expected_variances, binom.variance().eval())
|
||||||
|
|
||||||
|
def testBinomialMode(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 5.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
expected_modes = [0., 1, 4]
|
||||||
|
self.assertEqual((3,), binom.mode().get_shape())
|
||||||
|
self.assertAllClose(expected_modes, binom.mode().eval())
|
||||||
|
|
||||||
|
def testBinomialMultipleMode(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 9.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
# For the case where (n + 1) * p is an integer, the modes are:
|
||||||
|
# (n + 1) * p and (n + 1) * p - 1. In this case, we get back
|
||||||
|
# the larger of the two modes.
|
||||||
|
expected_modes = [1., 2, 7]
|
||||||
|
self.assertEqual((3,), binom.mode().get_shape())
|
||||||
|
self.assertAllClose(expected_modes, binom.mode().eval())
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -61,14 +61,14 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
n = [[5.]]
|
n = [[5.]]
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial(n, alpha)
|
dist = tf.contrib.distributions.DirichletMultinomial(n, alpha)
|
||||||
dist.pmf([2, 3, 0]).eval()
|
dist.pmf([2., 3, 0]).eval()
|
||||||
dist.pmf([3, 0, 2]).eval()
|
dist.pmf([3., 0, 2]).eval()
|
||||||
with self.assertRaisesOpError('Condition x >= 0.*'):
|
with self.assertRaisesOpError('Condition x >= 0.*'):
|
||||||
dist.pmf([-1, 4, 2]).eval()
|
dist.pmf([-1., 4, 2]).eval()
|
||||||
with self.assertRaisesOpError('Condition x == y.*'):
|
with self.assertRaisesOpError('counts do not sum to n'):
|
||||||
dist.pmf([3, 3, 0]).eval()
|
dist.pmf([3., 3, 0]).eval()
|
||||||
|
|
||||||
def testPmfArbitraryCounts(self):
|
def testPmf_non_integer_counts(self):
|
||||||
alpha = [[1., 2, 3]]
|
alpha = [[1., 2, 3]]
|
||||||
n = [[5.]]
|
n = [[5.]]
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -80,8 +80,10 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
with self.assertRaisesOpError('Condition x == y.*'):
|
with self.assertRaisesOpError('Condition x == y.*'):
|
||||||
dist.pmf([1.0, 2.5, 1.5]).eval()
|
dist.pmf([1.0, 2.5, 1.5]).eval()
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial(
|
dist = tf.contrib.distributions.DirichletMultinomial(
|
||||||
n, alpha, allow_arbitrary_counts=True)
|
n, alpha, validate_args=False)
|
||||||
dist.pmf(np.array([1.0, 2.5, 1.5])).eval()
|
dist.pmf([1., 2., 3.]).eval()
|
||||||
|
# Non-integer arguments work.
|
||||||
|
dist.pmf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
def testPmfBothZeroBatches(self):
|
def testPmfBothZeroBatches(self):
|
||||||
# The probabilities of one vote falling into class k is the mean for class
|
# The probabilities of one vote falling into class k is the mean for class
|
||||||
@ -90,7 +92,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
# Both zero-batches. No broadcast
|
# Both zero-batches. No broadcast
|
||||||
alpha = [1., 2]
|
alpha = [1., 2]
|
||||||
counts = [1., 0]
|
counts = [1., 0]
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial(1, alpha)
|
dist = tf.contrib.distributions.DirichletMultinomial(1., alpha)
|
||||||
pmf = dist.pmf(counts)
|
pmf = dist.pmf(counts)
|
||||||
self.assertAllClose(1 / 3., pmf.eval())
|
self.assertAllClose(1 / 3., pmf.eval())
|
||||||
self.assertEqual((), pmf.get_shape())
|
self.assertEqual((), pmf.get_shape())
|
||||||
@ -102,7 +104,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
# Both zero-batches. No broadcast
|
# Both zero-batches. No broadcast
|
||||||
alpha = [1., 2]
|
alpha = [1., 2]
|
||||||
counts = [3., 2]
|
counts = [3., 2]
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial(5, alpha)
|
dist = tf.contrib.distributions.DirichletMultinomial(5., alpha)
|
||||||
pmf = dist.pmf(counts)
|
pmf = dist.pmf(counts)
|
||||||
self.assertAllClose(1 / 7., pmf.eval())
|
self.assertAllClose(1 / 7., pmf.eval())
|
||||||
self.assertEqual((), pmf.get_shape())
|
self.assertEqual((), pmf.get_shape())
|
||||||
@ -113,7 +115,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
with self.test_session():
|
with self.test_session():
|
||||||
alpha = [1., 2]
|
alpha = [1., 2]
|
||||||
counts = [3., 2]
|
counts = [3., 2]
|
||||||
n = np.full([4, 3], 5.)
|
n = np.full([4, 3], 5., dtype=np.float32)
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial(n, alpha)
|
dist = tf.contrib.distributions.DirichletMultinomial(n, alpha)
|
||||||
pmf = dist.pmf(counts)
|
pmf = dist.pmf(counts)
|
||||||
self.assertAllClose([[1 / 7., 1 / 7., 1 / 7.]] * 4, pmf.eval())
|
self.assertAllClose([[1 / 7., 1 / 7., 1 / 7.]] * 4, pmf.eval())
|
||||||
@ -125,7 +127,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
with self.test_session():
|
with self.test_session():
|
||||||
alpha = [[1., 2]]
|
alpha = [[1., 2]]
|
||||||
counts = [[1., 0], [0., 1]]
|
counts = [[1., 0], [0., 1]]
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial([1], alpha)
|
dist = tf.contrib.distributions.DirichletMultinomial([1.], alpha)
|
||||||
pmf = dist.pmf(counts)
|
pmf = dist.pmf(counts)
|
||||||
self.assertAllClose([1 / 3., 2 / 3.], pmf.eval())
|
self.assertAllClose([1 / 3., 2 / 3.], pmf.eval())
|
||||||
self.assertEqual((2), pmf.get_shape())
|
self.assertEqual((2), pmf.get_shape())
|
||||||
@ -231,12 +233,12 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testVariance_n_alpha_broadcast(self):
|
def testVariance_n_alpha_broadcast(self):
|
||||||
alpha_v = [1., 2, 3]
|
alpha_v = [1., 2, 3]
|
||||||
alpha_0 = np.sum(alpha_v)
|
alpha_0 = 6.
|
||||||
|
|
||||||
# Shape [4, 3]
|
# Shape [4, 3]
|
||||||
alpha = np.array(4 * [alpha_v])
|
alpha = np.array(4 * [alpha_v], dtype=np.float32)
|
||||||
# Shape [4, 1]
|
# Shape [4, 1]
|
||||||
ns = np.array([[2.], [3.], [4.], [5.]])
|
ns = np.array([[2.], [3.], [4.], [5.]], dtype=np.float32)
|
||||||
|
|
||||||
variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
|
variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
|
||||||
covariance_entry = lambda a, b, a_sum: -a * b/ a_sum**2
|
covariance_entry = lambda a, b, a_sum: -a * b/ a_sum**2
|
||||||
@ -250,7 +252,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
covariance_entry(alpha_v[1], alpha_v[2], alpha_0)],
|
covariance_entry(alpha_v[1], alpha_v[2], alpha_0)],
|
||||||
[covariance_entry(alpha_v[2], alpha_v[0], alpha_0),
|
[covariance_entry(alpha_v[2], alpha_v[0], alpha_0),
|
||||||
covariance_entry(alpha_v[2], alpha_v[1], alpha_0),
|
covariance_entry(alpha_v[2], alpha_v[1], alpha_0),
|
||||||
variance_entry(alpha_v[2], alpha_0)]]])
|
variance_entry(alpha_v[2], alpha_0)]]], dtype=np.float32)
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
# ns is shape [4, 1], and alpha is shape [4, 3].
|
# ns is shape [4, 1], and alpha is shape [4, 3].
|
||||||
@ -263,11 +265,11 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
self.assertAllClose(expected_variance, variance.eval())
|
self.assertAllClose(expected_variance, variance.eval())
|
||||||
|
|
||||||
def testVariance_multidimensional(self):
|
def testVariance_multidimensional(self):
|
||||||
alpha = np.random.rand(3, 5, 4)
|
alpha = np.random.rand(3, 5, 4).astype(np.float32)
|
||||||
alpha2 = np.random.rand(6, 3, 3)
|
alpha2 = np.random.rand(6, 3, 3).astype(np.float32)
|
||||||
# Ensure n > 0.
|
|
||||||
ns = np.random.geometric(p=0.8, size=[3, 5, 1]) + 1
|
ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32)
|
||||||
ns2 = np.random.geometric(p=0.8, size=[6, 1, 1]) + 1
|
ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32)
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial(ns, alpha)
|
dist = tf.contrib.distributions.DirichletMultinomial(ns, alpha)
|
||||||
@ -297,7 +299,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# One (three sided) coin flip. Prob[coin 3] = 0.8.
|
# One (three sided) coin flip. Prob[coin 3] = 0.8.
|
||||||
# Note that since it was one flip, value of tau didn't matter.
|
# Note that since it was one flip, value of tau didn't matter.
|
||||||
counts = [0, 0, 1]
|
counts = [0., 0, 1]
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial(1., alpha)
|
dist = tf.contrib.distributions.DirichletMultinomial(1., alpha)
|
||||||
pmf = dist.pmf(counts)
|
pmf = dist.pmf(counts)
|
||||||
@ -305,9 +307,9 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
self.assertEqual((), pmf.get_shape())
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
|
||||||
# Two (three sided) coin flips. Prob[coin 3] = 0.8.
|
# Two (three sided) coin flips. Prob[coin 3] = 0.8.
|
||||||
counts = [0, 0, 2]
|
counts = [0., 0, 2]
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial(2, alpha)
|
dist = tf.contrib.distributions.DirichletMultinomial(2., alpha)
|
||||||
pmf = dist.pmf(counts)
|
pmf = dist.pmf(counts)
|
||||||
self.assertAllClose(0.8**2, pmf.eval(), atol=1e-2)
|
self.assertAllClose(0.8**2, pmf.eval(), atol=1e-2)
|
||||||
self.assertEqual((), pmf.get_shape())
|
self.assertEqual((), pmf.get_shape())
|
||||||
@ -315,7 +317,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
# Three (three sided) coin flips.
|
# Three (three sided) coin flips.
|
||||||
counts = [1., 0, 2]
|
counts = [1., 0, 2]
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial(3, alpha)
|
dist = tf.contrib.distributions.DirichletMultinomial(3., alpha)
|
||||||
pmf = dist.pmf(counts)
|
pmf = dist.pmf(counts)
|
||||||
self.assertAllClose(3 * 0.1 * 0.8 * 0.8, pmf.eval(), atol=1e-2)
|
self.assertAllClose(3 * 0.1 * 0.8 * 0.8, pmf.eval(), atol=1e-2)
|
||||||
self.assertEqual((), pmf.get_shape())
|
self.assertEqual((), pmf.get_shape())
|
||||||
@ -336,10 +338,10 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
self.assertEqual((), pmf.get_shape())
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
|
||||||
# If there are two draws, it is much more likely that they are the same.
|
# If there are two draws, it is much more likely that they are the same.
|
||||||
counts_same = [2, 0]
|
counts_same = [2., 0]
|
||||||
counts_different = [1, 1.]
|
counts_different = [1, 1.]
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
dist = tf.contrib.distributions.DirichletMultinomial(2, alpha)
|
dist = tf.contrib.distributions.DirichletMultinomial(2., alpha)
|
||||||
pmf_same = dist.pmf(counts_same)
|
pmf_same = dist.pmf(counts_same)
|
||||||
pmf_different = dist.pmf(counts_different)
|
pmf_different = dist.pmf(counts_different)
|
||||||
self.assertLess(5 * pmf_different.eval(), pmf_same.eval())
|
self.assertLess(5 * pmf_different.eval(), pmf_same.eval())
|
||||||
|
@ -0,0 +1,226 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class MultinomialTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testSimpleShapes(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [.1, .3, .6]
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=1., p=p)
|
||||||
|
self.assertEqual(3, dist.event_shape().eval())
|
||||||
|
self.assertAllEqual([], dist.batch_shape().eval())
|
||||||
|
self.assertEqual(tf.TensorShape([3]), dist.get_event_shape())
|
||||||
|
self.assertEqual(tf.TensorShape([]), dist.get_batch_shape())
|
||||||
|
|
||||||
|
def testComplexShapes(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
|
||||||
|
n = [[3., 2], [4, 5], [6, 7]]
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
self.assertEqual(2, dist.event_shape().eval())
|
||||||
|
self.assertAllEqual([3, 2], dist.batch_shape().eval())
|
||||||
|
self.assertEqual(tf.TensorShape([2]), dist.get_event_shape())
|
||||||
|
self.assertEqual(tf.TensorShape([3, 2]), dist.get_batch_shape())
|
||||||
|
|
||||||
|
def testNProperty(self):
|
||||||
|
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
|
||||||
|
n = [[3.], [4]]
|
||||||
|
with self.test_session():
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
self.assertEqual((2, 1), dist.n.get_shape())
|
||||||
|
self.assertAllClose(n, dist.n.eval())
|
||||||
|
|
||||||
|
def testPProperty(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
with self.test_session():
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=3., p=p)
|
||||||
|
self.assertEqual((1, 3), dist.p.get_shape())
|
||||||
|
self.assertEqual((1, 3), dist.logits.get_shape())
|
||||||
|
self.assertAllClose(p, dist.p.eval())
|
||||||
|
|
||||||
|
def testLogitsProperty(self):
|
||||||
|
logits = [[0., 9., -0.5]]
|
||||||
|
with self.test_session():
|
||||||
|
multinom = tf.contrib.distributions.Multinomial(n=3., logits=logits)
|
||||||
|
self.assertEqual((1, 3), multinom.p.get_shape())
|
||||||
|
self.assertEqual((1, 3), multinom.logits.get_shape())
|
||||||
|
self.assertAllClose(logits, multinom.logits.eval())
|
||||||
|
|
||||||
|
def testPmfNandCountsAgree(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
n = [[5.]]
|
||||||
|
with self.test_session():
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
dist.pmf([2., 3, 0]).eval()
|
||||||
|
dist.pmf([3., 0, 2]).eval()
|
||||||
|
with self.assertRaisesOpError('Condition x >= 0.*'):
|
||||||
|
dist.pmf([-1., 4, 2]).eval()
|
||||||
|
with self.assertRaisesOpError('counts do not sum to n'):
|
||||||
|
dist.pmf([3., 3, 0]).eval()
|
||||||
|
|
||||||
|
def testPmf_non_integer_counts(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
n = [[5.]]
|
||||||
|
with self.test_session():
|
||||||
|
# No errors with integer n.
|
||||||
|
multinom = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
multinom.pmf([2., 1, 2]).eval()
|
||||||
|
multinom.pmf([3., 0, 2]).eval()
|
||||||
|
# Counts don't sum to n.
|
||||||
|
with self.assertRaisesOpError('counts do not sum to n'):
|
||||||
|
multinom.pmf([2., 3, 2]).eval()
|
||||||
|
# Counts are non-integers.
|
||||||
|
with self.assertRaisesOpError('Condition x == y.*'):
|
||||||
|
multinom.pmf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
|
multinom = tf.contrib.distributions.Multinomial(
|
||||||
|
n=n, p=p, validate_args=False)
|
||||||
|
multinom.pmf([1., 2., 2.]).eval()
|
||||||
|
# Non-integer arguments work.
|
||||||
|
multinom.pmf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
|
def testPmfBothZeroBatches(self):
|
||||||
|
with self.test_session():
|
||||||
|
# Both zero-batches. No broadcast
|
||||||
|
p = [0.5, 0.5]
|
||||||
|
counts = [1., 0]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose(0.5, pmf.eval())
|
||||||
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfBothZeroBatchesNontrivialN(self):
|
||||||
|
with self.test_session():
|
||||||
|
# Both zero-batches. No broadcast
|
||||||
|
p = [0.1, 0.9]
|
||||||
|
counts = [3., 2]
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=5., p=p)
|
||||||
|
pmf = dist.pmf(counts)
|
||||||
|
# 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000.
|
||||||
|
self.assertAllClose(81./10000, pmf.eval())
|
||||||
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfPStretchedInBroadcastWhenSameRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [[0.1, 0.9]]
|
||||||
|
counts = [[1., 0], [0, 1]]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose([0.1, 0.9], pmf.eval())
|
||||||
|
self.assertEqual((2), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfPStretchedInBroadcastWhenLowerRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [0.1, 0.9]
|
||||||
|
counts = [[1., 0], [0, 1]]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose([0.1, 0.9], pmf.eval())
|
||||||
|
self.assertEqual((2), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfCountsStretchedInBroadcastWhenSameRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [[0.1, 0.9], [0.7, 0.3]]
|
||||||
|
counts = [[1., 0]]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose(pmf.eval(), [0.1, 0.7])
|
||||||
|
self.assertEqual((2), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [[0.1, 0.9], [0.7, 0.3]]
|
||||||
|
counts = [1., 0]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose(pmf.eval(), [0.1, 0.7])
|
||||||
|
self.assertEqual(pmf.get_shape(), (2))
|
||||||
|
|
||||||
|
def testPmfShapeCountsStretched_N(self):
|
||||||
|
with self.test_session():
|
||||||
|
# [2, 2, 2]
|
||||||
|
p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]]
|
||||||
|
# [2, 2]
|
||||||
|
n = [[3., 3], [3, 3]]
|
||||||
|
# [2]
|
||||||
|
counts = [2., 1]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
|
||||||
|
pmf.eval()
|
||||||
|
self.assertEqual(pmf.get_shape(), (2, 2))
|
||||||
|
|
||||||
|
def testPmfShapeCountsPStretched_N(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [0.1, 0.9]
|
||||||
|
counts = [3., 2]
|
||||||
|
n = np.full([4, 3], 5., dtype=np.float32)
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
|
||||||
|
pmf.eval()
|
||||||
|
self.assertEqual((4, 3), pmf.get_shape())
|
||||||
|
|
||||||
|
def testMultinomialMean(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 5.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
expected_means = 5 * np.array(p, dtype=np.float32)
|
||||||
|
self.assertEqual((3,), dist.mean().get_shape())
|
||||||
|
self.assertAllClose(expected_means, dist.mean().eval())
|
||||||
|
|
||||||
|
def testMultinomialVariance(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 5.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
expected_variances = [
|
||||||
|
[9./20, -1/10, -7/20], [-1/10, 4/5, -7/10], [-7/20, -7/10, 21/20]]
|
||||||
|
self.assertEqual((3, 3), dist.variance().get_shape())
|
||||||
|
self.assertAllClose(expected_variances, dist.variance().eval())
|
||||||
|
|
||||||
|
def testMultinomialVariance_batch(self):
|
||||||
|
with self.test_session():
|
||||||
|
# Shape [2]
|
||||||
|
n = [5.] * 2
|
||||||
|
# Shape [4, 1, 2]
|
||||||
|
p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
# Shape [2, 2]
|
||||||
|
inner_var = [[9./20, -9/20], [-9/20, 9/20]]
|
||||||
|
# Shape [4, 2, 2, 2]
|
||||||
|
expected_variances = [[inner_var, inner_var]] * 4
|
||||||
|
self.assertEqual((4, 2, 2, 2), dist.variance().get_shape())
|
||||||
|
self.assertAllClose(expected_variances, dist.variance().eval())
|
||||||
|
|
||||||
|
def testVariance_multidimensional(self):
|
||||||
|
# Shape [3, 5, 4]
|
||||||
|
p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
|
||||||
|
# Shape [6, 3, 3]
|
||||||
|
p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32)
|
||||||
|
|
||||||
|
ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32)
|
||||||
|
ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
dist = tf.contrib.distributions.Multinomial(ns, p)
|
||||||
|
dist2 = tf.contrib.distributions.Multinomial(ns2, p2)
|
||||||
|
|
||||||
|
variance = dist.variance()
|
||||||
|
variance2 = dist2.variance()
|
||||||
|
self.assertEqual((3, 5, 4, 4), variance.get_shape())
|
||||||
|
self.assertEqual((6, 3, 3, 3), variance2.get_shape())
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -117,6 +117,61 @@ class MultivariateNormalDiagTest(tf.test.TestCase):
|
|||||||
self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1)
|
self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
class MultivariateNormalDiagPlusVDVTTest(tf.test.TestCase):
|
||||||
|
"""Well tested because this is a simple override of the base class."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._rng = np.random.RandomState(42)
|
||||||
|
|
||||||
|
def testMean(self):
|
||||||
|
mu = [-1.0, 1.0]
|
||||||
|
diag_large = [1.0, 5.0]
|
||||||
|
v = [[2.0], [3.0]]
|
||||||
|
diag_small = [3.0]
|
||||||
|
with self.test_session():
|
||||||
|
dist = distributions.MultivariateNormalDiagPlusVDVT(
|
||||||
|
mu, diag_large, v, diag_small=diag_small)
|
||||||
|
self.assertAllEqual(mu, dist.mean().eval())
|
||||||
|
|
||||||
|
def testNonmatchingMuAndSigmaDimensionFailsStatic(self):
|
||||||
|
mu = self._rng.rand(2)
|
||||||
|
# With this diag_large and v, the covariance is 3 x 3
|
||||||
|
diag_large = self._rng.rand(3)
|
||||||
|
v = self._rng.rand(3, 2) # v works with diag_large.
|
||||||
|
with self.test_session():
|
||||||
|
with self.assertRaisesRegexp(ValueError, "shape.*should match"):
|
||||||
|
distributions.MultivariateNormalDiagPlusVDVT(
|
||||||
|
mu, diag_large, v)
|
||||||
|
|
||||||
|
def testNonmatchingMuDiagDimensionsFailsDynamic(self):
|
||||||
|
mu = self._rng.rand(2)
|
||||||
|
# With this diag_large and v, the covariance is 3 x 3
|
||||||
|
diag_large = self._rng.rand(3)
|
||||||
|
v = self._rng.rand(3, 2) # v works with diag_large.
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
mu_ph = tf.placeholder(tf.float32, name="mu_ph")
|
||||||
|
v_ph = tf.placeholder(tf.float32, name="v_ph")
|
||||||
|
diag_ph = tf.placeholder(tf.float32, name="diag_ph")
|
||||||
|
dist = distributions.MultivariateNormalDiagPlusVDVT(
|
||||||
|
mu_ph, diag_ph, v_ph)
|
||||||
|
with self.assertRaisesOpError("mu.*cov.*shape"):
|
||||||
|
dist.mean().eval(feed_dict={mu_ph: mu, diag_ph: diag_large, v_ph: v})
|
||||||
|
|
||||||
|
def testSample(self):
|
||||||
|
mu = [-1.0, 1.0]
|
||||||
|
diag_large = [1.0, 0.5]
|
||||||
|
v = [[0.2], [0.3]]
|
||||||
|
with self.test_session():
|
||||||
|
dist = distributions.MultivariateNormalDiagPlusVDVT(mu, diag_large, v)
|
||||||
|
|
||||||
|
samps = dist.sample_n(1000, seed=0).eval()
|
||||||
|
cov_mat = dist.sigma.eval()
|
||||||
|
|
||||||
|
self.assertAllClose(mu, samps.mean(axis=0), atol=0.1)
|
||||||
|
self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
class MultivariateNormalCholeskyTest(tf.test.TestCase):
|
class MultivariateNormalCholeskyTest(tf.test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -314,5 +369,87 @@ class MultivariateNormalCholeskyTest(tf.test.TestCase):
|
|||||||
self.assertEqual((3, 5), tuple(mvn.batch_shape().eval()))
|
self.assertEqual((3, 5), tuple(mvn.batch_shape().eval()))
|
||||||
|
|
||||||
|
|
||||||
|
class MultivariateNormalFullTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._rng = np.random.RandomState(42)
|
||||||
|
|
||||||
|
def _random_mu_and_sigma(self, batch_shape, event_shape):
|
||||||
|
# This ensures sigma is positive def.
|
||||||
|
mat_shape = batch_shape + event_shape + event_shape
|
||||||
|
mat = self._rng.randn(*mat_shape)
|
||||||
|
sigma = tf.batch_matmul(mat, mat, adj_y=True).eval()
|
||||||
|
|
||||||
|
mu_shape = batch_shape + event_shape
|
||||||
|
mu = self._rng.randn(*mu_shape)
|
||||||
|
|
||||||
|
return mu, sigma
|
||||||
|
|
||||||
|
def testKLNonBatch(self):
|
||||||
|
batch_shape = ()
|
||||||
|
event_shape = (2,)
|
||||||
|
with self.test_session():
|
||||||
|
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
|
||||||
|
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
|
||||||
|
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
|
||||||
|
mvn_b = distributions.MultivariateNormalFull(mu_b, sigma_b)
|
||||||
|
|
||||||
|
kl = distributions.kl(mvn_a, mvn_b)
|
||||||
|
self.assertEqual(batch_shape, kl.get_shape())
|
||||||
|
|
||||||
|
kl_v = kl.eval()
|
||||||
|
expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b)
|
||||||
|
self.assertAllClose(expected_kl, kl_v)
|
||||||
|
|
||||||
|
def testKLBatch(self):
|
||||||
|
batch_shape = (2,)
|
||||||
|
event_shape = (3,)
|
||||||
|
with self.test_session():
|
||||||
|
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
|
||||||
|
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
|
||||||
|
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
|
||||||
|
mvn_b = distributions.MultivariateNormalFull(mu_b, sigma_b)
|
||||||
|
|
||||||
|
kl = distributions.kl(mvn_a, mvn_b)
|
||||||
|
self.assertEqual(batch_shape, kl.get_shape())
|
||||||
|
|
||||||
|
kl_v = kl.eval()
|
||||||
|
expected_kl_0 = _compute_non_batch_kl(
|
||||||
|
mu_a[0, :], sigma_a[0, :, :], mu_b[0, :], sigma_b[0, :])
|
||||||
|
expected_kl_1 = _compute_non_batch_kl(
|
||||||
|
mu_a[1, :], sigma_a[1, :, :], mu_b[1, :], sigma_b[1, :])
|
||||||
|
self.assertAllClose(expected_kl_0, kl_v[0])
|
||||||
|
self.assertAllClose(expected_kl_1, kl_v[1])
|
||||||
|
|
||||||
|
def testKLTwoIdenticalDistributionsIsZero(self):
|
||||||
|
batch_shape = (2,)
|
||||||
|
event_shape = (3,)
|
||||||
|
with self.test_session():
|
||||||
|
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
|
||||||
|
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
|
||||||
|
|
||||||
|
# Should be zero since KL(p || p) = =.
|
||||||
|
kl = distributions.kl(mvn_a, mvn_a)
|
||||||
|
self.assertEqual(batch_shape, kl.get_shape())
|
||||||
|
|
||||||
|
kl_v = kl.eval()
|
||||||
|
self.assertAllClose(np.zeros(*batch_shape), kl_v)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b):
|
||||||
|
"""Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b)."""
|
||||||
|
# Check using numpy operations
|
||||||
|
# This mostly repeats the tensorflow code _kl_mvn_mvn(), but in numpy.
|
||||||
|
# So it is important to also check that KL(mvn, mvn) = 0.
|
||||||
|
sigma_b_inv = np.linalg.inv(sigma_b)
|
||||||
|
|
||||||
|
t = np.trace(sigma_b_inv.dot(sigma_a))
|
||||||
|
q = (mu_b - mu_a).dot(sigma_b_inv).dot(mu_b - mu_a)
|
||||||
|
k = mu_a.shape[0]
|
||||||
|
l = np.log(np.linalg.det(sigma_b) / np.linalg.det(sigma_a))
|
||||||
|
|
||||||
|
return 0.5 * (t + q - k + l)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -17,14 +17,17 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import abc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import six
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
|
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
|
||||||
from tensorflow.contrib.distributions.python.ops import operator_test_util
|
from tensorflow.contrib.distributions.python.ops import operator_test_util
|
||||||
|
|
||||||
|
|
||||||
class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
|
@six.add_metaclass(abc.ABCMeta)
|
||||||
|
class OperatorPDDiagBaseTest(object):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._rng = np.random.RandomState(42)
|
self._rng = np.random.RandomState(42)
|
||||||
@ -32,8 +35,14 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
|
|||||||
def _random_pd_diag(self, diag_shape):
|
def _random_pd_diag(self, diag_shape):
|
||||||
return self._rng.rand(*diag_shape) + 0.1
|
return self._rng.rand(*diag_shape) + 0.1
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
def _diag_to_matrix(self, diag):
|
def _diag_to_matrix(self, diag):
|
||||||
return tf.batch_matrix_diag(diag**2).eval()
|
pass
|
||||||
|
|
||||||
|
@abc.abstractproperty
|
||||||
|
def operator_class(self):
|
||||||
|
# Return the operator class that this tests.
|
||||||
|
pass
|
||||||
|
|
||||||
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
|
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
|
||||||
# Create a diagonal matrix explicitly.
|
# Create a diagonal matrix explicitly.
|
||||||
@ -46,7 +55,7 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
|
|||||||
# The diag is the square root.
|
# The diag is the square root.
|
||||||
diag = self._random_pd_diag(diag_shape).astype(dtype)
|
diag = self._random_pd_diag(diag_shape).astype(dtype)
|
||||||
mat = self._diag_to_matrix(diag).astype(dtype)
|
mat = self._diag_to_matrix(diag).astype(dtype)
|
||||||
operator = operator_pd_diag.OperatorPDSqrtDiag(diag)
|
operator = self.operator_class(diag)
|
||||||
|
|
||||||
return operator, mat
|
return operator, mat
|
||||||
|
|
||||||
@ -66,5 +75,29 @@ class OperatorPDSqrtDiagTest(operator_test_util.OperatorPDDerivedClassTest):
|
|||||||
operator.to_dense().eval() # Should not raise
|
operator.to_dense().eval() # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorPDDiagTest(
|
||||||
|
OperatorPDDiagBaseTest, operator_test_util.OperatorPDDerivedClassTest):
|
||||||
|
"""Most tests done in the base classes."""
|
||||||
|
|
||||||
|
def _diag_to_matrix(self, diag):
|
||||||
|
return tf.batch_matrix_diag(diag).eval()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def operator_class(self):
|
||||||
|
return operator_pd_diag.OperatorPDDiag
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorPDSqrtDiagTest(
|
||||||
|
OperatorPDDiagBaseTest, operator_test_util.OperatorPDDerivedClassTest):
|
||||||
|
"""Most tests done in the base classes."""
|
||||||
|
|
||||||
|
def _diag_to_matrix(self, diag):
|
||||||
|
return tf.batch_matrix_diag(diag**2).eval()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def operator_class(self):
|
||||||
|
return operator_pd_diag.OperatorPDSqrtDiag
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -0,0 +1,115 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops import operator_pd_identity
|
||||||
|
from tensorflow.contrib.distributions.python.ops import operator_test_util
|
||||||
|
|
||||||
|
distributions = tf.contrib.distributions
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorPDIdentityTest(operator_test_util.OperatorPDDerivedClassTest):
|
||||||
|
"""Most tests done in the base class."""
|
||||||
|
|
||||||
|
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
|
||||||
|
# Build an identity matrix with right shape and dtype.
|
||||||
|
# Build an operator that should act the same way.
|
||||||
|
batch_shape = list(batch_shape)
|
||||||
|
diag_shape = batch_shape + [k]
|
||||||
|
matrix_shape = batch_shape + [k, k]
|
||||||
|
diag = tf.ones(diag_shape, dtype=dtype)
|
||||||
|
identity_matrix = tf.batch_matrix_diag(diag)
|
||||||
|
operator = operator_pd_identity.OperatorPDIdentity(matrix_shape, dtype)
|
||||||
|
return operator, identity_matrix.eval()
|
||||||
|
|
||||||
|
def test_bad_dtype_args_raise(self):
|
||||||
|
dtype = np.float32
|
||||||
|
batch_shape = [2, 3]
|
||||||
|
k = 4
|
||||||
|
with self.test_session():
|
||||||
|
operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)
|
||||||
|
|
||||||
|
x_good_shape = batch_shape + [k, 5]
|
||||||
|
x_good = self._rng.randn(*x_good_shape).astype(dtype)
|
||||||
|
x_bad = x_good.astype(np.float64)
|
||||||
|
|
||||||
|
operator.matmul(x_good).eval() # Should not raise.
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, 'dtype'):
|
||||||
|
operator.matmul(x_bad)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, 'dtype'):
|
||||||
|
operator.solve(x_bad)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, 'dtype'):
|
||||||
|
operator.sqrt_solve(x_bad)
|
||||||
|
|
||||||
|
def test_bad_rank_args_raise(self):
|
||||||
|
# Prepend a singleton dimension, changing the rank of 'x', but not the size.
|
||||||
|
dtype = np.float32
|
||||||
|
batch_shape = [2, 3]
|
||||||
|
k = 4
|
||||||
|
with self.test_session():
|
||||||
|
operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)
|
||||||
|
|
||||||
|
x_good_shape = batch_shape + [k, 5]
|
||||||
|
x_good = self._rng.randn(*x_good_shape).astype(dtype)
|
||||||
|
x_bad = x_good.reshape(1, 2, 3, 4, 5)
|
||||||
|
|
||||||
|
operator.matmul(x_good).eval() # Should not raise.
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'tensor rank'):
|
||||||
|
operator.matmul(x_bad)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'tensor rank'):
|
||||||
|
operator.solve(x_bad)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'tensor rank'):
|
||||||
|
operator.sqrt_solve(x_bad)
|
||||||
|
|
||||||
|
def test_incompatible_shape_args_raise(self):
|
||||||
|
# Test shapes that are the same rank but incompatible for matrix
|
||||||
|
# multiplication.
|
||||||
|
dtype = np.float32
|
||||||
|
batch_shape = [2, 3]
|
||||||
|
k = 4
|
||||||
|
with self.test_session():
|
||||||
|
operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)
|
||||||
|
|
||||||
|
x_good_shape = batch_shape + [k, 5]
|
||||||
|
x_good = self._rng.randn(*x_good_shape).astype(dtype)
|
||||||
|
x_bad_shape = batch_shape + [5, k]
|
||||||
|
x_bad = x_good.reshape(*x_bad_shape)
|
||||||
|
|
||||||
|
operator.matmul(x_good).eval() # Should not raise.
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Incompatible'):
|
||||||
|
operator.matmul(x_bad)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Incompatible'):
|
||||||
|
operator.solve(x_bad)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Incompatible'):
|
||||||
|
operator.sqrt_solve(x_bad)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -0,0 +1,273 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops import operator_pd_full
|
||||||
|
from tensorflow.contrib.distributions.python.ops import operator_pd_vdvt_update
|
||||||
|
from tensorflow.contrib.distributions.python.ops import operator_test_util
|
||||||
|
|
||||||
|
distributions = tf.contrib.distributions
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorPDSqrtVDVTUpdateTest(
|
||||||
|
operator_test_util.OperatorPDDerivedClassTest):
|
||||||
|
"""Most tests done in the base class."""
|
||||||
|
_diag_is_none = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._rng = np.random.RandomState(42)
|
||||||
|
|
||||||
|
def _random_pd_matrix(self, shape):
|
||||||
|
# With probability 1 this is positive definite.
|
||||||
|
sqrt = self._rng.randn(*shape)
|
||||||
|
mat = tf.batch_matmul(sqrt, sqrt, adj_y=True)
|
||||||
|
return mat.eval()
|
||||||
|
|
||||||
|
def _random_v_and_diag(self, mat_shape, v_matrix_rank):
|
||||||
|
# Get the necessary elements to make the sqrt update.
|
||||||
|
mat_shape = list(mat_shape)
|
||||||
|
batch_shape = mat_shape[:-2]
|
||||||
|
diag_shape = mat_shape[:-2] + [v_matrix_rank]
|
||||||
|
k = mat_shape[-1]
|
||||||
|
assert k == mat_shape[-2], 'Must be a square matrix'
|
||||||
|
v_shape = batch_shape + [k, v_matrix_rank]
|
||||||
|
v = self._rng.randn(*v_shape) # anything goes with "v"!
|
||||||
|
|
||||||
|
if self._diag_is_none:
|
||||||
|
diag = None
|
||||||
|
else:
|
||||||
|
diag = self._rng.rand(*diag_shape) + 0.1 # Positive diag!
|
||||||
|
return v, diag
|
||||||
|
|
||||||
|
def _updated_mat(self, mat, v, diag):
|
||||||
|
# Get dense matrix defined by its square root, which is an update of `mat`:
|
||||||
|
# A = (mat + v D v^T) (mat + v D v^T)^T
|
||||||
|
# D is the diagonal matrix with `diag` on the diagonal.
|
||||||
|
|
||||||
|
# If diag is None, then it defaults to the identity matrix, so DV^T = V^T
|
||||||
|
if diag is None:
|
||||||
|
diag_vt = tf.batch_matrix_transpose(v)
|
||||||
|
else:
|
||||||
|
diag_mat = tf.batch_matrix_diag(diag)
|
||||||
|
diag_vt = tf.batch_matmul(diag_mat, v, adj_y=True)
|
||||||
|
|
||||||
|
v_diag_vt = tf.batch_matmul(v, diag_vt)
|
||||||
|
sqrt = mat + v_diag_vt
|
||||||
|
a = tf.batch_matmul(sqrt, sqrt, adj_y=True)
|
||||||
|
return a.eval()
|
||||||
|
|
||||||
|
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
|
||||||
|
"""This method is called by base class, enabling many standard tests."""
|
||||||
|
# Create a matrix then explicitly update it with v and diag.
|
||||||
|
# Create an OperatorPDSqrtVDVTUpdate from the matrix and v and diag
|
||||||
|
# The operator should have the same behavior.
|
||||||
|
#
|
||||||
|
# The low-rank matrix V will have rank 1/2 of k, unless k is 1, in which
|
||||||
|
# case it will be 1 as well.
|
||||||
|
if k == 1:
|
||||||
|
v_matrix_rank = k
|
||||||
|
else:
|
||||||
|
v_matrix_rank = k // 2
|
||||||
|
mat_shape = list(batch_shape) + [k, k]
|
||||||
|
mat = self._random_pd_matrix(mat_shape)
|
||||||
|
v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank)
|
||||||
|
|
||||||
|
# Set dtypes
|
||||||
|
mat = mat.astype(dtype)
|
||||||
|
v = v.astype(dtype)
|
||||||
|
if diag is not None:
|
||||||
|
diag = diag.astype(dtype)
|
||||||
|
|
||||||
|
# The matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T
|
||||||
|
# Our final updated operator should behave like this.
|
||||||
|
updated_mat = self._updated_mat(mat, v, diag)
|
||||||
|
|
||||||
|
# Represents the matrix: `mat`, before updating.
|
||||||
|
# This is the Operator that we will update.
|
||||||
|
o_made_with_mat = operator_pd_full.OperatorPDFull(mat)
|
||||||
|
|
||||||
|
# Represents the matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T,
|
||||||
|
# achieved by updating the operator "o_made_with_mat".
|
||||||
|
# This is the operator we're testing.
|
||||||
|
operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
|
||||||
|
o_made_with_mat, v, diag)
|
||||||
|
|
||||||
|
return operator, updated_mat
|
||||||
|
|
||||||
|
def test_to_dense_placeholder(self):
|
||||||
|
# Test simple functionality when the inputs are placeholders.
|
||||||
|
mat_shape = [3, 3]
|
||||||
|
v_matrix_rank = 2
|
||||||
|
with self.test_session():
|
||||||
|
# Make an OperatorPDFull with a matrix placeholder.
|
||||||
|
mat_ph = tf.placeholder(tf.float64, name='mat_ph')
|
||||||
|
mat = self._random_pd_matrix(mat_shape)
|
||||||
|
o_made_with_mat = operator_pd_full.OperatorPDFull(mat_ph)
|
||||||
|
|
||||||
|
# Make the placeholders and arrays for the updated operator.
|
||||||
|
v_ph = tf.placeholder(tf.float64, name='v_ph')
|
||||||
|
v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank)
|
||||||
|
if self._diag_is_none:
|
||||||
|
diag_ph = None
|
||||||
|
feed_dict = {v_ph: v, mat_ph: mat}
|
||||||
|
else:
|
||||||
|
diag_ph = tf.placeholder(tf.float64, name='diag_ph')
|
||||||
|
feed_dict = {v_ph: v, diag_ph: diag, mat_ph: mat}
|
||||||
|
|
||||||
|
# Make the OperatorPDSqrtVDVTUpdate with v and diag placeholders.
|
||||||
|
operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
|
||||||
|
o_made_with_mat, v_ph, diag=diag_ph)
|
||||||
|
|
||||||
|
# Should not fail
|
||||||
|
operator.to_dense().eval(feed_dict=feed_dict)
|
||||||
|
operator.log_det().eval(feed_dict=feed_dict)
|
||||||
|
|
||||||
|
def test_operator_not_subclass_of_operator_pd_raises(self):
|
||||||
|
# We enforce that `operator` is an `OperatorPDBase`.
|
||||||
|
with self.test_session():
|
||||||
|
v, diag = self._random_v_and_diag((3, 3), 2)
|
||||||
|
operator_m = 'I am not a subclass of OperatorPDBase'
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, 'not instance'):
|
||||||
|
operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
|
||||||
|
|
||||||
|
def test_non_pos_def_diag_raises(self):
|
||||||
|
if self._diag_is_none:
|
||||||
|
return
|
||||||
|
# We enforce that the diag is positive definite.
|
||||||
|
with self.test_session():
|
||||||
|
matrix_shape = (3, 3)
|
||||||
|
v_rank = 2
|
||||||
|
v, diag = self._random_v_and_diag(matrix_shape, v_rank)
|
||||||
|
mat = self._random_pd_matrix(matrix_shape)
|
||||||
|
diag[0] = 0.0
|
||||||
|
|
||||||
|
operator_m = operator_pd_full.OperatorPDFull(mat)
|
||||||
|
operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
|
||||||
|
operator_m, v, diag)
|
||||||
|
|
||||||
|
with self.assertRaisesOpError('positive'):
|
||||||
|
operator.to_dense().eval()
|
||||||
|
|
||||||
|
def test_non_pos_def_diag_doesnt_raise_if_verify_pd_false(self):
|
||||||
|
# We enforce that the diag is positive definite.
|
||||||
|
if self._diag_is_none:
|
||||||
|
return
|
||||||
|
with self.test_session():
|
||||||
|
matrix_shape = (3, 3)
|
||||||
|
v_rank = 2
|
||||||
|
v, diag = self._random_v_and_diag(matrix_shape, v_rank)
|
||||||
|
mat = self._random_pd_matrix(matrix_shape)
|
||||||
|
diag[0] = 0.0
|
||||||
|
|
||||||
|
operator_m = operator_pd_full.OperatorPDFull(mat)
|
||||||
|
operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
|
||||||
|
operator_m, v, diag, verify_pd=False)
|
||||||
|
|
||||||
|
operator.to_dense().eval() # Should not raise.
|
||||||
|
|
||||||
|
def test_event_shape_mismatch_v_and_diag_raises_static(self):
|
||||||
|
v = self._rng.rand(4, 3, 2)
|
||||||
|
diag = self._rng.rand(4, 1) # Should be shape (4, 2,) to match v.
|
||||||
|
with self.test_session():
|
||||||
|
|
||||||
|
mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
|
||||||
|
operator_m = operator_pd_full.OperatorPDFull(mat)
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'diag.*v.*last dimension'):
|
||||||
|
operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
|
||||||
|
|
||||||
|
def test_batch_shape_mismatch_v_and_diag_raises_static(self):
|
||||||
|
v = self._rng.rand(4, 3, 2)
|
||||||
|
diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v.
|
||||||
|
with self.test_session():
|
||||||
|
|
||||||
|
mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
|
||||||
|
operator_m = operator_pd_full.OperatorPDFull(mat)
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'diag.*batch shape'):
|
||||||
|
operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
|
||||||
|
|
||||||
|
def test_tensor_rank_shape_mismatch_v_and_diag_raises_static(self):
|
||||||
|
v = self._rng.rand(1, 2, 2, 2)
|
||||||
|
diag = self._rng.rand(5, 1) # Should have rank 1 less than v.
|
||||||
|
with self.test_session():
|
||||||
|
|
||||||
|
mat = self._random_pd_matrix((1, 2, 2, 2)) # mat and v match
|
||||||
|
operator_m = operator_pd_full.OperatorPDFull(mat)
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'diag.*rank'):
|
||||||
|
operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(operator_m, v, diag)
|
||||||
|
|
||||||
|
def test_event_shape_mismatch_v_and_diag_raises_dynamic(self):
|
||||||
|
with self.test_session():
|
||||||
|
|
||||||
|
v = self._rng.rand(4, 3, 2)
|
||||||
|
diag = self._rng.rand(4, 1) # Should be shape (4, 2,) to match v.
|
||||||
|
mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
|
||||||
|
|
||||||
|
v_ph = tf.placeholder(tf.float32, name='v_ph')
|
||||||
|
diag_ph = tf.placeholder(tf.float32, name='diag_ph')
|
||||||
|
mat_ph = tf.placeholder(tf.float32, name='mat_ph')
|
||||||
|
|
||||||
|
operator_m = operator_pd_full.OperatorPDFull(mat_ph)
|
||||||
|
updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
|
||||||
|
operator_m, v_ph, diag_ph)
|
||||||
|
with self.assertRaisesOpError('x == y'):
|
||||||
|
updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat})
|
||||||
|
|
||||||
|
def test_batch_shape_mismatch_v_and_diag_raises_dynamic(self):
|
||||||
|
with self.test_session():
|
||||||
|
v = self._rng.rand(4, 3, 2)
|
||||||
|
diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v.
|
||||||
|
mat = self._random_pd_matrix((4, 3, 3)) # mat and v match
|
||||||
|
|
||||||
|
v_ph = tf.placeholder(tf.float32, name='v_ph')
|
||||||
|
diag_ph = tf.placeholder(tf.float32, name='diag_ph')
|
||||||
|
mat_ph = tf.placeholder(tf.float32, name='mat_ph')
|
||||||
|
|
||||||
|
operator_m = operator_pd_full.OperatorPDFull(mat_ph)
|
||||||
|
updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
|
||||||
|
operator_m, v_ph, diag_ph)
|
||||||
|
with self.assertRaisesOpError('x == y'):
|
||||||
|
updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat})
|
||||||
|
|
||||||
|
def test_tensor_rank_shape_mismatch_v_and_diag_raises_dynamic(self):
|
||||||
|
with self.test_session():
|
||||||
|
|
||||||
|
v = self._rng.rand(2, 2, 2, 2)
|
||||||
|
diag = self._rng.rand(2, 2) # Should have rank 1 less than v.
|
||||||
|
mat = self._random_pd_matrix((2, 2, 2, 2)) # mat and v match
|
||||||
|
|
||||||
|
v_ph = tf.placeholder(tf.float32, name='v_ph')
|
||||||
|
diag_ph = tf.placeholder(tf.float32, name='diag_ph')
|
||||||
|
mat_ph = tf.placeholder(tf.float32, name='mat_ph')
|
||||||
|
|
||||||
|
operator_m = operator_pd_full.OperatorPDFull(mat_ph)
|
||||||
|
updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
|
||||||
|
operator_m, v_ph, diag_ph)
|
||||||
|
with self.assertRaisesOpError('rank'):
|
||||||
|
updated.to_dense().eval(feed_dict={v_ph: v, diag_ph: diag, mat_ph: mat})
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorPDSqrtVDVTUpdateNoneDiagTest(OperatorPDSqrtVDVTUpdateTest):
|
||||||
|
_diag_is_none = True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -0,0 +1,165 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for ShapeUtil."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.contrib.distributions.python.ops.shape import _ShapeUtil # pylint: disable=line-too-long
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeUtilTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testShapeUtilGetNdims(self):
|
||||||
|
with self.test_session():
|
||||||
|
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
|
||||||
|
x = 1
|
||||||
|
self.assertEqual(shaper.get_sample_ndims(x), 0)
|
||||||
|
self.assertEqual(shaper.batch_ndims, 0)
|
||||||
|
self.assertEqual(shaper.event_ndims, 0)
|
||||||
|
|
||||||
|
shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
|
||||||
|
x = [[[0., 1, 2], [3, 4, 5]]]
|
||||||
|
self.assertAllEqual(shaper.get_ndims(x), 3)
|
||||||
|
self.assertEqual(shaper.get_sample_ndims(x), 1)
|
||||||
|
self.assertEqual(shaper.batch_ndims, 1)
|
||||||
|
self.assertEqual(shaper.event_ndims, 1)
|
||||||
|
|
||||||
|
x += [[[6, 7, 8], [9, 10, 11]]]
|
||||||
|
self.assertAllEqual(shaper.get_ndims(x), 3)
|
||||||
|
self.assertEqual(shaper.get_sample_ndims(x), 1)
|
||||||
|
self.assertEqual(shaper.batch_ndims, 1)
|
||||||
|
self.assertEqual(shaper.event_ndims, 1)
|
||||||
|
|
||||||
|
# Test ndims functions work, even despite unfed Tensors.
|
||||||
|
y = tf.placeholder(tf.float32, shape=(1024, None, 1024))
|
||||||
|
self.assertAllEqual(shaper.get_ndims(y), 3)
|
||||||
|
self.assertEqual(shaper.get_sample_ndims(y), 1)
|
||||||
|
self.assertEqual(shaper.batch_ndims, 1)
|
||||||
|
self.assertEqual(shaper.event_ndims, 1)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
y = tf.placeholder(tf.float32)
|
||||||
|
shaper.get_ndims(y)
|
||||||
|
|
||||||
|
def testShapeUtilGetDims(self):
|
||||||
|
with self.test_session():
|
||||||
|
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
y = tf.placeholder(tf.float32)
|
||||||
|
shaper.get_sample_dims(y)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
y = tf.placeholder(tf.float32)
|
||||||
|
shaper.get_batch_dims(y)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
y = tf.placeholder(tf.float32)
|
||||||
|
shaper.get_event_dims(y)
|
||||||
|
|
||||||
|
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
|
||||||
|
x = 1
|
||||||
|
self.assertAllEqual(shaper.get_sample_dims(x), [])
|
||||||
|
self.assertAllEqual(shaper.get_batch_dims(x), [])
|
||||||
|
self.assertAllEqual(shaper.get_event_dims(x), [])
|
||||||
|
self.assertAllEqual(shaper.get_dims(x, sample=False), [])
|
||||||
|
|
||||||
|
shaper = _ShapeUtil(batch_ndims=1, event_ndims=2)
|
||||||
|
x = [[[[0., 1], [2, 4]]]]
|
||||||
|
self.assertAllEqual(shaper.get_sample_dims(x), [0])
|
||||||
|
self.assertAllEqual(shaper.get_batch_dims(x), [1])
|
||||||
|
self.assertAllEqual(shaper.get_event_dims(x), [2, 3])
|
||||||
|
self.assertAllEqual(shaper.get_dims(x, sample=False), [1, 2, 3])
|
||||||
|
|
||||||
|
x += x
|
||||||
|
self.assertAllEqual(shaper.get_sample_dims(x), [0])
|
||||||
|
self.assertAllEqual(shaper.get_batch_dims(x), [1])
|
||||||
|
self.assertAllEqual(shaper.get_event_dims(x), [2, 3])
|
||||||
|
self.assertAllEqual(shaper.get_dims(x, sample=False), [1, 2, 3])
|
||||||
|
|
||||||
|
# Test dims functions work, despite unfed Tensors.
|
||||||
|
y = tf.placeholder(tf.float32, shape=(1024, None, 5, 5))
|
||||||
|
self.assertAllEqual(shaper.get_sample_dims(y), [0])
|
||||||
|
self.assertAllEqual(shaper.get_batch_dims(y), [1])
|
||||||
|
self.assertAllEqual(shaper.get_event_dims(y), [2, 3])
|
||||||
|
|
||||||
|
def testShapeUtilGetShape(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
y = tf.placeholder(tf.float32)
|
||||||
|
shaper.get_sample_shape(y)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
y = tf.placeholder(tf.float32)
|
||||||
|
shaper.get_batch_shape(y)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
y = tf.placeholder(tf.float32)
|
||||||
|
shaper.get_event_shape(y)
|
||||||
|
|
||||||
|
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
|
||||||
|
x = 1
|
||||||
|
self.assertAllEqual(shaper.get_sample_shape(x), [])
|
||||||
|
self.assertAllEqual(shaper.get_batch_shape(x), [])
|
||||||
|
self.assertAllEqual(shaper.get_event_shape(x), [])
|
||||||
|
self.assertAllEqual(shaper.get_shape(x, batch=False), [])
|
||||||
|
|
||||||
|
shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
|
||||||
|
x = [[[0., 1, 2], [3, 4, 5]]]
|
||||||
|
self.assertAllEqual(shaper.get_sample_shape(x), [1])
|
||||||
|
self.assertAllEqual(shaper.get_batch_shape(x), [2])
|
||||||
|
self.assertAllEqual(shaper.get_event_shape(x), [3])
|
||||||
|
self.assertAllEqual(shaper.get_shape(x, batch=False), [1, 3])
|
||||||
|
|
||||||
|
x += [[[6, 7, 8], [9, 10, 11]]]
|
||||||
|
self.assertAllEqual(shaper.get_sample_shape(x), [2])
|
||||||
|
self.assertAllEqual(shaper.get_batch_shape(x), [2])
|
||||||
|
self.assertAllEqual(shaper.get_event_shape(x), [3])
|
||||||
|
self.assertAllEqual(shaper.get_shape(x, batch=False), [2, 3])
|
||||||
|
|
||||||
|
shaper = _ShapeUtil(batch_ndims=0, event_ndims=1)
|
||||||
|
x = tf.ones((3, 2))
|
||||||
|
self.assertAllEqual(shaper.get_shape(x, sample=False), (2,))
|
||||||
|
|
||||||
|
def feed_eval(fun, build_shape=(None, None, 2), graph_shape=(3, 4, 2)):
|
||||||
|
"""Helper to use a deferred-shape tensor eval'ed at graph runtime."""
|
||||||
|
y = tf.placeholder(tf.int32, shape=build_shape)
|
||||||
|
y_value = np.ones(graph_shape, dtype=y.dtype.as_numpy_dtype())
|
||||||
|
return sess.run(fun(y),
|
||||||
|
feed_dict={y: y_value})
|
||||||
|
|
||||||
|
shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
|
||||||
|
self.assertAllEqual(feed_eval(shaper.get_sample_shape), [3])
|
||||||
|
self.assertAllEqual(feed_eval(shaper.get_batch_shape), [4])
|
||||||
|
self.assertAllEqual(feed_eval(shaper.get_event_shape), [2])
|
||||||
|
self.assertAllEqual(
|
||||||
|
feed_eval(lambda y: shaper.get_shape(y, batch=False)),
|
||||||
|
[3, 2])
|
||||||
|
|
||||||
|
shaper = _ShapeUtil(batch_ndims=0, event_ndims=1)
|
||||||
|
self.assertAllEqual(
|
||||||
|
feed_eval(lambda y: shaper.get_shape(y, batch=False),
|
||||||
|
(None, None),
|
||||||
|
(3, 2)),
|
||||||
|
[3, 2])
|
||||||
|
self.assertAllEqual(
|
||||||
|
feed_eval(lambda y: shaper.get_shape(y, sample=False),
|
||||||
|
(None, None),
|
||||||
|
(3, 2)),
|
||||||
|
[2])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
@ -19,13 +19,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import distribution
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long
|
from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
@ -36,10 +36,6 @@ class Bernoulli(distribution.Distribution):
|
|||||||
|
|
||||||
The Bernoulli distribution is parameterized by p, the probability of a
|
The Bernoulli distribution is parameterized by p, the probability of a
|
||||||
positive event.
|
positive event.
|
||||||
|
|
||||||
Note, the following methods of the base class aren't implemented:
|
|
||||||
* cdf
|
|
||||||
* log_cdf
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -62,10 +58,10 @@ class Bernoulli(distribution.Distribution):
|
|||||||
dtype: dtype for samples.
|
dtype: dtype for samples.
|
||||||
validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args,
|
validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args,
|
||||||
`log_pmf` may return nans.
|
`log_pmf` may return nans.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: A name for this distribution.
|
name: A name for this distribution.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -75,25 +71,8 @@ class Bernoulli(distribution.Distribution):
|
|||||||
self._name = name
|
self._name = name
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
self._validate_args = validate_args
|
self._validate_args = validate_args
|
||||||
check_op = check_ops.assert_less_equal
|
self._logits, self._p = distribution_util.get_logits_and_prob(
|
||||||
if p is None and logits is None:
|
name=name, logits=logits, p=p, validate_args=validate_args)
|
||||||
raise ValueError("Must pass p or logits.")
|
|
||||||
elif p is not None and logits is not None:
|
|
||||||
raise ValueError("Must pass either p or logits, not both.")
|
|
||||||
elif p is None:
|
|
||||||
with ops.op_scope([logits], name):
|
|
||||||
self._logits = array_ops.identity(logits, name="logits")
|
|
||||||
with ops.name_scope(name):
|
|
||||||
with ops.name_scope("p"):
|
|
||||||
self._p = math_ops.sigmoid(self._logits)
|
|
||||||
elif logits is None:
|
|
||||||
with ops.name_scope(name):
|
|
||||||
with ops.name_scope("p"):
|
|
||||||
with ops.control_dependencies([check_op(p, 1.), check_op(0., p)] if
|
|
||||||
validate_args else []):
|
|
||||||
self._p = array_ops.identity(p)
|
|
||||||
with ops.name_scope("logits"):
|
|
||||||
self._logits = math_ops.log(self._p) - math_ops.log(1. - self._p)
|
|
||||||
with ops.name_scope(name):
|
with ops.name_scope(name):
|
||||||
with ops.name_scope("q"):
|
with ops.name_scope("q"):
|
||||||
self._q = 1. - self._p
|
self._q = 1. - self._p
|
||||||
@ -180,8 +159,12 @@ class Bernoulli(distribution.Distribution):
|
|||||||
event = ops.convert_to_tensor(event, name="event")
|
event = ops.convert_to_tensor(event, name="event")
|
||||||
event = math_ops.cast(event, self.logits.dtype)
|
event = math_ops.cast(event, self.logits.dtype)
|
||||||
logits = self.logits
|
logits = self.logits
|
||||||
if ((event.get_shape().ndims is not None) or
|
# sigmoid_cross_entropy_with_logits doesn't broadcast shape,
|
||||||
(logits.get_shape().ndims is not None) or
|
# so we do this here.
|
||||||
|
# TODO(b/30637701): Check dynamic shape, and don't broadcast if the
|
||||||
|
# dynamic shapes are the same.
|
||||||
|
if (not event.get_shape().is_fully_defined() or
|
||||||
|
not logits.get_shape().is_fully_defined() or
|
||||||
event.get_shape() != logits.get_shape()):
|
event.get_shape() != logits.get_shape()):
|
||||||
logits = array_ops.ones_like(event) * logits
|
logits = array_ops.ones_like(event) * logits
|
||||||
event = array_ops.ones_like(logits) * event
|
event = array_ops.ones_like(logits) * event
|
||||||
@ -202,8 +185,7 @@ class Bernoulli(distribution.Distribution):
|
|||||||
with ops.name_scope(self.name):
|
with ops.name_scope(self.name):
|
||||||
with ops.op_scope([self.p, n], name):
|
with ops.op_scope([self.p, n], name):
|
||||||
n = ops.convert_to_tensor(n, name="n")
|
n = ops.convert_to_tensor(n, name="n")
|
||||||
new_shape = array_ops.concat(
|
new_shape = array_ops.concat(0, ([n], self.batch_shape()))
|
||||||
0, [array_ops.expand_dims(n, 0), self.batch_shape()])
|
|
||||||
uniform = random_ops.random_uniform(
|
uniform = random_ops.random_uniform(
|
||||||
new_shape, seed=seed, dtype=dtypes.float32)
|
new_shape, seed=seed, dtype=dtypes.float32)
|
||||||
sample = math_ops.less(uniform, self.p)
|
sample = math_ops.less(uniform, self.p)
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The Beta distribution class."""
|
"""The Beta distribution class."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -95,6 +96,7 @@ class Beta(distribution.Distribution):
|
|||||||
x = [.2, .3, .9]
|
x = [.2, .3, .9]
|
||||||
dist.pdf(x) # Shape [2]
|
dist.pdf(x) # Shape [2]
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, a, b, validate_args=True, allow_nan_stats=False,
|
def __init__(self, a, b, validate_args=True, allow_nan_stats=False,
|
||||||
@ -102,20 +104,20 @@ class Beta(distribution.Distribution):
|
|||||||
"""Initialize a batch of Beta distributions.
|
"""Initialize a batch of Beta distributions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: Positive `float` or `double` tensor with shape broadcastable to
|
a: Positive floating point tensor with shape broadcastable to
|
||||||
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
||||||
different Beta distributions. This also defines the
|
different Beta distributions. This also defines the
|
||||||
dtype of the distribution.
|
dtype of the distribution.
|
||||||
b: Positive `float` or `double` tensor with shape broadcastable to
|
b: Positive floating point tensor with shape broadcastable to
|
||||||
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
||||||
different Beta distributions.
|
different Beta distributions.
|
||||||
validate_args: Whether to assert valid values for parameters `a` and `b`,
|
validate_args: Whether to assert valid values for parameters `a` and `b`,
|
||||||
and `x` in `prob` and `log_prob`. If False, correct behavior is not
|
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||||
guaranteed.
|
guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prefix Ops created by this distribution class.
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@ -127,6 +129,7 @@ class Beta(distribution.Distribution):
|
|||||||
# Define a 2-batch.
|
# Define a 2-batch.
|
||||||
dist = Beta([1.0, 2.0], [4.0, 5.0])
|
dist = Beta([1.0, 2.0], [4.0, 5.0])
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with ops.op_scope([a, b], name):
|
with ops.op_scope([a, b], name):
|
||||||
with ops.control_dependencies([
|
with ops.control_dependencies([
|
||||||
@ -276,8 +279,14 @@ class Beta(distribution.Distribution):
|
|||||||
array_ops.ones_like(a_b_sum, dtype=self.dtype)))
|
array_ops.ones_like(a_b_sum, dtype=self.dtype)))
|
||||||
else:
|
else:
|
||||||
return control_flow_ops.with_dependencies([
|
return control_flow_ops.with_dependencies([
|
||||||
check_ops.assert_less(one, a),
|
check_ops.assert_less(
|
||||||
check_ops.assert_less(one, b)], mode)
|
one, a,
|
||||||
|
message="mode not defined for components of a <= 1"
|
||||||
|
),
|
||||||
|
check_ops.assert_less(
|
||||||
|
one, b,
|
||||||
|
message="mode not defined for components of b <= 1"
|
||||||
|
)], mode)
|
||||||
|
|
||||||
def entropy(self, name="entropy"):
|
def entropy(self, name="entropy"):
|
||||||
"""Entropy of the distribution in nats."""
|
"""Entropy of the distribution in nats."""
|
||||||
@ -306,7 +315,7 @@ class Beta(distribution.Distribution):
|
|||||||
"""`Log(P[counts])`, computed for every batch member.
|
"""`Log(P[counts])`, computed for every batch member.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Non-negative `float` or `double`, tensor whose shape can
|
x: Non-negative floating point tensor whose shape can
|
||||||
be broadcast with `self.a` and `self.b`. For fixed leading
|
be broadcast with `self.a` and `self.b`. For fixed leading
|
||||||
dimensions, the last dimension represents counts for the corresponding
|
dimensions, the last dimension represents counts for the corresponding
|
||||||
Beta distribution in `self.a` and `self.b`. `x` is only legal if
|
Beta distribution in `self.a` and `self.b`. `x` is only legal if
|
||||||
@ -334,7 +343,7 @@ class Beta(distribution.Distribution):
|
|||||||
"""`P[x]`, computed for every batch member.
|
"""`P[x]`, computed for every batch member.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Non-negative `float`, `double` tensor whose shape can
|
x: Non-negative floating point tensor whose shape can
|
||||||
be broadcast with `self.a` and `self.b`. For fixed leading
|
be broadcast with `self.a` and `self.b`. For fixed leading
|
||||||
dimensions, the last dimension represents x for the corresponding Beta
|
dimensions, the last dimension represents x for the corresponding Beta
|
||||||
distribution in `self.a` and `self.b`. `x` is only legal if is
|
distribution in `self.a` and `self.b`. `x` is only legal if is
|
||||||
|
350
tensorflow/contrib/distributions/python/ops/bijector.py
Normal file
350
tensorflow/contrib/distributions/python/ops/bijector.py
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""An API for reversible (bijective) transformations of random variables."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
|
class _Bijector(object):
|
||||||
|
"""An interface for transforming random variable(s).
|
||||||
|
|
||||||
|
A bijector is characterized by three operations:
|
||||||
|
|
||||||
|
1) Forward Evaluation
|
||||||
|
Useful for turning one random outcome into another random outcome from a
|
||||||
|
different distribution.
|
||||||
|
|
||||||
|
2) Inverse Evaluation
|
||||||
|
Useful for "reversing" a transformation to compute one probability in terms
|
||||||
|
of another.
|
||||||
|
|
||||||
|
3) (log o det o Jacobian o inverse)(x)
|
||||||
|
"The log of the determinant of the matrix of all first-order partial
|
||||||
|
derivatives of the inverse function."
|
||||||
|
Useful for inverting a transformation to compute one probability in terms
|
||||||
|
of another. Geometrically, the det(Jacobian) is the volume of the
|
||||||
|
transformation and is used to scale the probability.
|
||||||
|
|
||||||
|
By convention, transformations of random variables are named in terms of the
|
||||||
|
forward transformation. The forward transformation creates samples, the
|
||||||
|
inverse is useful for computing probabilities.
|
||||||
|
|
||||||
|
Example transformations:
|
||||||
|
"Exponential"
|
||||||
|
|
||||||
|
```
|
||||||
|
Y = g(X) = exp(X)
|
||||||
|
X ~ Normal(0, 1) # Univariate.
|
||||||
|
```
|
||||||
|
|
||||||
|
Implies:
|
||||||
|
|
||||||
|
```
|
||||||
|
g^{-1}(Y) = log(Y)
|
||||||
|
|Jacobian(g^{-1})(y)| = 1 / y
|
||||||
|
Y ~ LogNormal(0, 1), i.e.,
|
||||||
|
prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
|
||||||
|
= (1 / y) Normal(log(y); 0, 1)
|
||||||
|
```
|
||||||
|
|
||||||
|
"ShiftAndScale"
|
||||||
|
|
||||||
|
```
|
||||||
|
Y = g(X) = sqrtSigma * X + mu
|
||||||
|
X ~ MultivariateNormal(0, I_d)
|
||||||
|
```
|
||||||
|
|
||||||
|
Implies:
|
||||||
|
|
||||||
|
```
|
||||||
|
g^{-1}(Y) = inv(sqrtSigma) * (Y - mu)
|
||||||
|
|Jacobian(g^{-1})(y)| = det(inv(sqrtSigma))
|
||||||
|
Y ~ MultivariateNormal(mu, sqrtSigma) , i.e.,
|
||||||
|
prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
|
||||||
|
= det(sqrtSigma)^(-d) *
|
||||||
|
MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example use:
|
||||||
|
Basic properties:
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = ... # A tensor.
|
||||||
|
# Evaluate forward transformation.
|
||||||
|
fwd_x = my_bijector.forward(x)
|
||||||
|
x != my_bijector.forward(fwd_x) # Not equal because g(x) != g(g(x)).
|
||||||
|
x == my_bijector.inverse(fwd_x)
|
||||||
|
```
|
||||||
|
|
||||||
|
Computing a log-likelihood:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def transformed_log_pdf(bijector, log_pdf, x):
|
||||||
|
return (bijector.inverse_log_det_jacobian(x) +
|
||||||
|
log_pdf(bijector.inverse(x)))
|
||||||
|
```
|
||||||
|
|
||||||
|
Transforming a random outcome:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def transformed_sample(bijector, x):
|
||||||
|
return bijector.forward(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(b/30476956): Try to remove constructor dependence on shape util.
|
||||||
|
def __init__(self, shaper=None, name=None):
|
||||||
|
"""Constructs Bijector.
|
||||||
|
|
||||||
|
A bijector transforms random variables into new random variables. Managing
|
||||||
|
shape is typically an important piece of this so a Bijector is usually
|
||||||
|
composed of ShapeUtil. The ShapeUtil object handles input shape checks as
|
||||||
|
well as reshaping/transposing for easier linear algebra operations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# Create the Y = g(X) = X transform which operates on 4-Tensors of vectors.
|
||||||
|
identity = Identity(ShapeUtil(batch_ndims=4, event_ndims=1))
|
||||||
|
|
||||||
|
# Create the Y = g(X) = exp(X) transform which operates on matrices.
|
||||||
|
exp = Exp(ShapeUtil(batch_ndims=0, event_ndims=2))
|
||||||
|
```
|
||||||
|
|
||||||
|
See Bijector subclass doc for more details and examples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shaper: object used for managing and manipulating shape, typically an
|
||||||
|
instance of ShapeUtil.
|
||||||
|
name: The name to give Ops created by the initializer.
|
||||||
|
"""
|
||||||
|
self._shaper = shaper
|
||||||
|
self._name = name or type(self).__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shaper(self):
|
||||||
|
"""Returns shape object used to manage shape constraints."""
|
||||||
|
return self._shaper
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
"""Returns the string name of this bijector."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def forward(self, x, name='forward'):
|
||||||
|
"""Returns the forward bijector evaluation, i.e., X = g(Y).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`. The input to the "forward" evaluation.
|
||||||
|
name: The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([x], name):
|
||||||
|
x = ops.convert_to_tensor(x)
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
|
def inverse(self, x, name='inverse'):
|
||||||
|
"""Returns the inverse bijector evaluation, i.e., X = g^{-1}(Y).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`. The input to the "inverse" evaluation.
|
||||||
|
name: The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([x], name):
|
||||||
|
x = ops.convert_to_tensor(x)
|
||||||
|
try:
|
||||||
|
return self._inverse(x)
|
||||||
|
except NotImplementedError:
|
||||||
|
return self._inverse_and_inverse_log_det_jacobian(x)[0]
|
||||||
|
|
||||||
|
def inverse_log_det_jacobian(self, x, name='inverse_log_det_jacobian'):
|
||||||
|
"""Returns the (log o det o Jacobian o inverse)(x).
|
||||||
|
|
||||||
|
Mathematically, returns: log(det(dY/dX g^{-1}))(Y).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`. The input to the "inverse" Jacobian evaluation.
|
||||||
|
name: The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([x], name):
|
||||||
|
x = ops.convert_to_tensor(x)
|
||||||
|
try:
|
||||||
|
return self._inverse_log_det_jacobian(x)
|
||||||
|
except NotImplementedError:
|
||||||
|
return self._inverse_and_inverse_log_det_jacobian(x)[1]
|
||||||
|
|
||||||
|
def inverse_and_inverse_log_det_jacobian(
|
||||||
|
self, x, name='inverse_and_inverse_log_det_jacobian'):
|
||||||
|
"""Returns both the inverse evaluation and inverse_log_det_jacobian.
|
||||||
|
|
||||||
|
Enables possibly more efficient calculation when both inverse and
|
||||||
|
corresponding Jacobian are needed.
|
||||||
|
|
||||||
|
See `inverse()`, `inverse_log_det_jacobian()` for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`. The input to the "inverse" Jacobian evaluation.
|
||||||
|
name: The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([x], name):
|
||||||
|
x = ops.convert_to_tensor(x)
|
||||||
|
try:
|
||||||
|
return self._inverse_and_inverse_log_det_jacobian(x)
|
||||||
|
except NotImplementedError:
|
||||||
|
return self._inverse(x), self._inverse_log_det_jacobian(x)
|
||||||
|
|
||||||
|
# Subclass interface.
|
||||||
|
def _forward(self, x):
|
||||||
|
"""Subclass implementation of forward().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`. The input to the "forward" evaluation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
`NotImplementedError`: if subclass implementation not provided
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('_forward not implemented')
|
||||||
|
|
||||||
|
def _inverse(self, x):
|
||||||
|
"""Subclass implementation of inverse().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`. The input to the "inverse" evaluation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
`NotImplementedError`: if subclass implementation not provided
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('_inverse not implemented')
|
||||||
|
|
||||||
|
def _inverse_log_det_jacobian(self, x):
|
||||||
|
"""Subclass implementation of inverse_log_det_jacobian().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`. The input to the "inverse" Jacobian evaluation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
`NotImplementedError`: if subclass implementation not provided
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('_inverse_log_det_jacobian not implemented')
|
||||||
|
|
||||||
|
def _inverse_and_inverse_log_det_jacobian(self, x):
|
||||||
|
"""Subclass implementation of inverse_and_inverse_log_det_jacobian().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`. The input to the "inverse" evaluation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of two `Tensor` items, inverse and inverse_log_det_jacobian.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
'_inverse_and_inverse_log_det_jacobian not implemented')
|
||||||
|
|
||||||
|
|
||||||
|
class _Identity(_Bijector):
|
||||||
|
"""Bijector which computes Y = g(X) = X.
|
||||||
|
|
||||||
|
Example Use:
|
||||||
|
```python
|
||||||
|
# Create the Y=g(X)=X transform which works only on Tensors with 1 batch
|
||||||
|
# ndims and 1 event ndim (i.e., vector of vectors).
|
||||||
|
identity = Identity(ShapeUtil(batch_ndims=1, event_ndims=1))
|
||||||
|
x = [[1., 2],
|
||||||
|
[3, 4]]
|
||||||
|
x == identity.forward(x) == identity.inverse(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(b/30476956): Try to remove constructor dependence on shape util.
|
||||||
|
def __init__(self, shaper=None, name='Identity'):
|
||||||
|
super(_Identity, self).__init__(shaper, name)
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _inverse(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _inverse_log_det_jacobian(self, x):
|
||||||
|
result_shape = self.shaper.get_shape(
|
||||||
|
x, sample=True, batch=True, event=False)
|
||||||
|
return array_ops.zeros(result_shape, dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class _Exp(_Bijector):
|
||||||
|
"""Bijector which computes Y = g(X) = exp(X).
|
||||||
|
|
||||||
|
Example Use:
|
||||||
|
```python
|
||||||
|
# Create the Y=g(X)=exp(X) transform which works only on Tensors with 1
|
||||||
|
# batch ndims and 2 event ndim (i.e., vector of matrices).
|
||||||
|
exp = Exp(ShapeUtil(batch_ndims=1, event_ndims=2))
|
||||||
|
x = [[[1., 2],
|
||||||
|
[3, 4]],
|
||||||
|
[[5, 6],
|
||||||
|
[7, 8]]]
|
||||||
|
exp(x) == exp.forward(x)
|
||||||
|
log(x) == exp.inverse(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(b/30476956): Try to remove constructor dependence on shape util.
|
||||||
|
def __init__(self, shaper=None, name='Exp'):
|
||||||
|
super(_Exp, self).__init__(shaper, name)
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
return math_ops.exp(x)
|
||||||
|
|
||||||
|
def _inverse(self, x):
|
||||||
|
return math_ops.log(x)
|
||||||
|
|
||||||
|
def _inverse_log_det_jacobian(self, x):
|
||||||
|
d = self.shaper.get_event_dims(x)
|
||||||
|
return -math_ops.reduce_sum(math_ops.log(x), d)
|
||||||
|
|
||||||
|
def _inverse_and_inverse_log_det_jacobian(self, x):
|
||||||
|
y = math_ops.log(x)
|
||||||
|
d = self.shaper.get_event_dims(x)
|
||||||
|
return y, -math_ops.reduce_sum(y, d)
|
340
tensorflow/contrib/distributions/python/ops/binomial.py
Normal file
340
tensorflow/contrib/distributions/python/ops/binomial.py
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""The Binomial distribution class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=line-too-long
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
|
|
||||||
|
class Binomial(distribution.Distribution):
|
||||||
|
"""Binomial distribution.
|
||||||
|
|
||||||
|
This distribution is parameterized by a vector `p` of probabilities and `n`,
|
||||||
|
the total counts.
|
||||||
|
|
||||||
|
#### Mathematical details
|
||||||
|
|
||||||
|
The Binomial is a distribution over the number of successes in `n` independent
|
||||||
|
trials, with each trial having the same probability of success `p`.
|
||||||
|
The probability mass function (pmf):
|
||||||
|
|
||||||
|
```pmf(k) = n! / (k! * (n - k)!) * (p)^k * (1 - p)^(n - k)```
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
Create a single distribution, corresponding to 5 coin flips.
|
||||||
|
|
||||||
|
```python
|
||||||
|
dist = Binomial(n=5., p=.5)
|
||||||
|
```
|
||||||
|
|
||||||
|
Create a single distribution (using logits), corresponding to 5 coin flips.
|
||||||
|
|
||||||
|
```python
|
||||||
|
dist = Binomial(n=5., logits=0.)
|
||||||
|
```
|
||||||
|
|
||||||
|
Creates 3 distributions with the third distribution most likely to have
|
||||||
|
successes.
|
||||||
|
|
||||||
|
```python
|
||||||
|
p = [.2, .3, .8]
|
||||||
|
# n will be broadcast to [4., 4., 4.], to match p.
|
||||||
|
dist = Binomial(n=4., p=p)
|
||||||
|
```
|
||||||
|
|
||||||
|
The distribution functions can be evaluated on counts.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# counts same shape as p.
|
||||||
|
counts = [1., 2, 3]
|
||||||
|
dist.prob(counts) # Shape [3]
|
||||||
|
|
||||||
|
# p will be broadcast to [[.2, .3, .8], [.2, .3, .8]] to match counts.
|
||||||
|
counts = [[1., 2, 1], [2, 2, 4]]
|
||||||
|
dist.prob(counts) # Shape [2, 3]
|
||||||
|
|
||||||
|
# p will be broadcast to shape [5, 7, 3] to match counts.
|
||||||
|
counts = [[...]] # Shape [5, 7, 3]
|
||||||
|
dist.prob(counts) # Shape [5, 7, 3]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n,
|
||||||
|
logits=None,
|
||||||
|
p=None,
|
||||||
|
validate_args=True,
|
||||||
|
allow_nan_stats=False,
|
||||||
|
name="Binomial"):
|
||||||
|
"""Initialize a batch of Binomial distributions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Non-negative floating point tensor with shape broadcastable to
|
||||||
|
`[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`.
|
||||||
|
Defines this as a batch of `N1 x ... x Nm` different Binomial
|
||||||
|
distributions. Its components should be equal to integer values.
|
||||||
|
logits: Floating point tensor representing the log-odds of a
|
||||||
|
positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
|
||||||
|
the same dtype as `n`. Each entry represents logits for the probability
|
||||||
|
of success for independent Binomial distributions.
|
||||||
|
p: Positive floating point tensor with shape broadcastable to
|
||||||
|
`[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the
|
||||||
|
probability of success for independent Binomial distributions.
|
||||||
|
validate_args: Whether to assert valid values for parameters `n` and `p`,
|
||||||
|
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||||
|
guaranteed.
|
||||||
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
|
undefined statistics will return NaN for this statistic.
|
||||||
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define 1-batch of a binomial distribution.
|
||||||
|
dist = Binomial(n=2., p=.9)
|
||||||
|
|
||||||
|
# Define a 2-batch.
|
||||||
|
dist = Binomial(n=[4., 5], p=[.1, .3])
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._logits, self._p = distribution_util.get_logits_and_prob(
|
||||||
|
name=name, logits=logits, p=p, validate_args=validate_args)
|
||||||
|
|
||||||
|
with ops.op_scope([n], name):
|
||||||
|
with ops.control_dependencies([
|
||||||
|
check_ops.assert_non_negative(
|
||||||
|
n, message="n has negative components."),
|
||||||
|
distribution_util.assert_integer_form(
|
||||||
|
n, message="n has non-integer components."
|
||||||
|
)] if validate_args else []):
|
||||||
|
self._n = array_ops.identity(n, name="convert_n")
|
||||||
|
|
||||||
|
self._name = name
|
||||||
|
self._validate_args = validate_args
|
||||||
|
self._allow_nan_stats = allow_nan_stats
|
||||||
|
|
||||||
|
self._mean = self._n * self._p
|
||||||
|
self._get_batch_shape = self._mean.get_shape()
|
||||||
|
self._get_event_shape = tensor_shape.TensorShape([])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
"""Name to prepend to all ops."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""dtype of samples from this distribution."""
|
||||||
|
return self._p.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def validate_args(self):
|
||||||
|
"""Boolean describing behavior on invalid input."""
|
||||||
|
return self._validate_args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_nan_stats(self):
|
||||||
|
"""Boolean describing behavior when a stat is undefined for batch member."""
|
||||||
|
return self._allow_nan_stats
|
||||||
|
|
||||||
|
def batch_shape(self, name="batch_shape"):
|
||||||
|
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
|
The product of the dimensions of the `batch_shape` is the number of
|
||||||
|
independent distributions of this kind the instance represents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name to give to the op
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` `batch_shape`
|
||||||
|
"""
|
||||||
|
return array_ops.shape(self._mean)
|
||||||
|
|
||||||
|
def get_batch_shape(self):
|
||||||
|
"""`TensorShape` available at graph construction time.
|
||||||
|
|
||||||
|
Same meaning as `batch_shape`. May be only partially defined.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch shape
|
||||||
|
"""
|
||||||
|
return self._get_batch_shape
|
||||||
|
|
||||||
|
def event_shape(self, name="event_shape"):
|
||||||
|
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name to give to the op
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` `event_shape`
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([], name):
|
||||||
|
return constant_op.constant([], name=name, dtype=dtypes.int32)
|
||||||
|
|
||||||
|
def get_event_shape(self):
|
||||||
|
"""`TensorShape` available at graph construction time.
|
||||||
|
|
||||||
|
Same meaning as `event_shape`. May be only partially defined.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
event shape
|
||||||
|
"""
|
||||||
|
return self._get_event_shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n(self):
|
||||||
|
"""Number of trials."""
|
||||||
|
return self._n
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits(self):
|
||||||
|
"""Log-odds."""
|
||||||
|
return self._logits
|
||||||
|
|
||||||
|
@property
|
||||||
|
def p(self):
|
||||||
|
"""Probability of success."""
|
||||||
|
return self._p
|
||||||
|
|
||||||
|
def mean(self, name="mean"):
|
||||||
|
"""Mean of the distribution."""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
return array_ops.identity(self._mean, name=name)
|
||||||
|
|
||||||
|
def variance(self, name="variance"):
|
||||||
|
"""Variance of the distribution."""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._n, self._p], name):
|
||||||
|
return self._n * self._p * (1 - self._p)
|
||||||
|
|
||||||
|
def std(self, name="std"):
|
||||||
|
"""Standard deviation of the distribution."""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._n, self._p], name):
|
||||||
|
return math_ops.sqrt(self.variance())
|
||||||
|
|
||||||
|
def mode(self, name="mode"):
|
||||||
|
"""Mode of the distribution.
|
||||||
|
|
||||||
|
Note that when `(n + 1) * p` is an integer, there are actually two modes.
|
||||||
|
Namely, `(n + 1) * p` and `(n + 1) * p - 1` are both modes. Here we return
|
||||||
|
only the larger of the two modes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name for this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The mode of the Binomial distribution.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._n, self._p], name):
|
||||||
|
return math_ops.floor((self._n + 1) * self._p)
|
||||||
|
|
||||||
|
def log_prob(self, counts, name="log_prob"):
|
||||||
|
"""`Log(P[counts])`, computed for every batch member.
|
||||||
|
|
||||||
|
For each batch member of counts `k`, `P[counts]` is the probability that
|
||||||
|
after sampling `n` draws from this Binomial distribution, the number of
|
||||||
|
successes is `k`. Note that different sequences of draws can result in the
|
||||||
|
same counts, thus the probability includes a combinatorial coefficient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||||
|
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
|
||||||
|
less than or equal to `n` and its components are equal to integer
|
||||||
|
values.
|
||||||
|
name: Name to give this Op, defaults to "log_prob".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Log probabilities for each record, shape `[N1,...,Nm]`.
|
||||||
|
"""
|
||||||
|
n = self._n
|
||||||
|
p = self._p
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._n, self._p, counts], name):
|
||||||
|
counts = self._check_counts(counts)
|
||||||
|
|
||||||
|
prob_prob = counts * math_ops.log(p) + (
|
||||||
|
n - counts) * math_ops.log(1 - p)
|
||||||
|
|
||||||
|
combinations = math_ops.lgamma(n + 1) - math_ops.lgamma(
|
||||||
|
counts + 1) - math_ops.lgamma(n - counts + 1)
|
||||||
|
log_prob = prob_prob + combinations
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
def prob(self, counts, name="prob"):
|
||||||
|
"""`P[counts]`, computed for every batch member.
|
||||||
|
|
||||||
|
|
||||||
|
For each batch member of counts `k`, `P[counts]` is the probability that
|
||||||
|
after sampling `n` draws from this Binomial distribution, the number of
|
||||||
|
successes is `k`. Note that different sequences of draws can result in the
|
||||||
|
same counts, thus the probability includes a combinatorial coefficient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||||
|
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
|
||||||
|
less than or equal to `n` and its components are equal to integer
|
||||||
|
values.
|
||||||
|
name: Name to give this Op, defaults to "prob".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Probabilities for each record, shape `[N1,...,Nm]`.
|
||||||
|
"""
|
||||||
|
return super(Binomial, self).prob(counts, name=name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_continuous(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_reparameterized(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_counts(self, counts):
|
||||||
|
"""Check counts for proper shape, values, then return tensor version."""
|
||||||
|
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
|
||||||
|
if not self.validate_args:
|
||||||
|
return counts
|
||||||
|
return control_flow_ops.with_dependencies([
|
||||||
|
check_ops.assert_non_negative(
|
||||||
|
counts, message="counts has negative components."),
|
||||||
|
check_ops.assert_less_equal(
|
||||||
|
counts, self._n, message="counts are not less than or equal to n."),
|
||||||
|
distribution_util.assert_integer_form(
|
||||||
|
counts, message="counts have non-integer components.")], counts)
|
@ -34,11 +34,6 @@ class Categorical(distribution.Distribution):
|
|||||||
|
|
||||||
The categorical distribution is parameterized by the log-probabilities
|
The categorical distribution is parameterized by the log-probabilities
|
||||||
of a set of classes.
|
of a set of classes.
|
||||||
|
|
||||||
Note, the following methods of the base class aren't implemented:
|
|
||||||
* mean
|
|
||||||
* cdf
|
|
||||||
* log_cdf
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -57,10 +52,10 @@ class Categorical(distribution.Distribution):
|
|||||||
indexes into the classes.
|
indexes into the classes.
|
||||||
dtype: The type of the event samples (default: int32).
|
dtype: The type of the event samples (default: int32).
|
||||||
validate_args: Unused in this distribution.
|
validate_args: Unused in this distribution.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: A name for this distribution (optional).
|
name: A name for this distribution (optional).
|
||||||
"""
|
"""
|
||||||
self._allow_nan_stats = allow_nan_stats
|
self._allow_nan_stats = allow_nan_stats
|
||||||
@ -177,8 +172,7 @@ class Categorical(distribution.Distribution):
|
|||||||
samples = math_ops.cast(samples, self._dtype)
|
samples = math_ops.cast(samples, self._dtype)
|
||||||
ret = array_ops.reshape(
|
ret = array_ops.reshape(
|
||||||
array_ops.transpose(samples),
|
array_ops.transpose(samples),
|
||||||
array_ops.concat(
|
array_ops.concat(0, ([n], self.batch_shape())))
|
||||||
0, [array_ops.expand_dims(n, 0), self.batch_shape()]))
|
|
||||||
ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
|
ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
|
||||||
.concatenate(self.get_batch_shape()))
|
.concatenate(self.get_batch_shape()))
|
||||||
return ret
|
return ret
|
||||||
|
@ -42,15 +42,15 @@ class Chi2(gamma.Gamma):
|
|||||||
"""Construct Chi2 distributions with parameter `df`.
|
"""Construct Chi2 distributions with parameter `df`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
df: `float` or `double` tensor, the degrees of freedom of the
|
df: Floating point tensor, the degrees of freedom of the
|
||||||
distribution(s). `df` must contain only positive values.
|
distribution(s). `df` must contain only positive values.
|
||||||
validate_args: Whether to assert that `df > 0`, and that `x > 0` in the
|
validate_args: Whether to assert that `df > 0`, and that `x > 0` in the
|
||||||
methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||||
and the inputs are invalid, correct behavior is not guaranteed.
|
and the inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prepend to all ops created by this distribution.
|
name: The name to prepend to all ops created by this distribution.
|
||||||
"""
|
"""
|
||||||
# Even though all stats of chi2 are defined for valid parameters, this is
|
# Even though all stats of chi2 are defined for valid parameters, this is
|
||||||
|
@ -19,9 +19,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import distribution
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -29,7 +28,6 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import logging_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import special_math_ops
|
from tensorflow.python.ops import special_math_ops
|
||||||
@ -37,24 +35,6 @@ from tensorflow.python.ops import special_math_ops
|
|||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
|
|
||||||
def _assert_close(x, y, data=None, summarize=None, name=None):
|
|
||||||
if x.dtype.is_integer:
|
|
||||||
return check_ops.assert_equal(
|
|
||||||
x, y, data=data, summarize=summarize, name=name)
|
|
||||||
|
|
||||||
with ops.op_scope([x, y, data], name, "assert_close"):
|
|
||||||
x = ops.convert_to_tensor(x, name="x")
|
|
||||||
y = ops.convert_to_tensor(y, name="y")
|
|
||||||
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
|
|
||||||
if data is None:
|
|
||||||
data = [
|
|
||||||
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
|
|
||||||
y.name, y
|
|
||||||
]
|
|
||||||
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
|
|
||||||
return logging_ops.Assert(condition, data, summarize=summarize)
|
|
||||||
|
|
||||||
|
|
||||||
class Dirichlet(distribution.Distribution):
|
class Dirichlet(distribution.Distribution):
|
||||||
"""Dirichlet distribution.
|
"""Dirichlet distribution.
|
||||||
|
|
||||||
@ -117,6 +97,7 @@ class Dirichlet(distribution.Distribution):
|
|||||||
x = [.2, .3, .5]
|
x = [.2, .3, .5]
|
||||||
dist.prob(x) # Shape [2]
|
dist.prob(x) # Shape [2]
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -127,16 +108,16 @@ class Dirichlet(distribution.Distribution):
|
|||||||
"""Initialize a batch of Dirichlet distributions.
|
"""Initialize a batch of Dirichlet distributions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
alpha: Positive `float` or `double` tensor with shape broadcastable to
|
alpha: Positive floating point tensor with shape broadcastable to
|
||||||
`[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
`[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
||||||
different `k` class Dirichlet distributions.
|
different `k` class Dirichlet distributions.
|
||||||
validate_args: Whether to assert valid values for parameters `alpha` and
|
validate_args: Whether to assert valid values for parameters `alpha` and
|
||||||
`x` in `prob` and `log_prob`. If False, correct behavior is not
|
`x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||||
guaranteed.
|
guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prefix Ops created by this distribution class.
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@ -149,6 +130,7 @@ class Dirichlet(distribution.Distribution):
|
|||||||
# Define a 2-batch of 3-class distributions.
|
# Define a 2-batch of 3-class distributions.
|
||||||
dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with ops.op_scope([alpha], name):
|
with ops.op_scope([alpha], name):
|
||||||
alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps")
|
alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps")
|
||||||
@ -302,7 +284,9 @@ class Dirichlet(distribution.Distribution):
|
|||||||
array_ops.ones_like(self._alpha, dtype=self.dtype)))
|
array_ops.ones_like(self._alpha, dtype=self.dtype)))
|
||||||
else:
|
else:
|
||||||
return control_flow_ops.with_dependencies([
|
return control_flow_ops.with_dependencies([
|
||||||
check_ops.assert_less(one, self._alpha)
|
check_ops.assert_less(
|
||||||
|
one, self._alpha,
|
||||||
|
message="mode not defined for components of alpha <= 1")
|
||||||
], mode)
|
], mode)
|
||||||
|
|
||||||
def entropy(self, name="entropy"):
|
def entropy(self, name="entropy"):
|
||||||
@ -334,7 +318,7 @@ class Dirichlet(distribution.Distribution):
|
|||||||
"""`Log(P[counts])`, computed for every batch member.
|
"""`Log(P[counts])`, computed for every batch member.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Non-negative `float` or `double`, tensor whose shape can
|
x: Non-negative tensor with dtype `dtype` and whose shape can
|
||||||
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||||
dimension represents counts for the corresponding Dirichlet distribution
|
dimension represents counts for the corresponding Dirichlet distribution
|
||||||
in `self.alpha`. `x` is only legal if it sums up to one.
|
in `self.alpha`. `x` is only legal if it sums up to one.
|
||||||
@ -359,7 +343,7 @@ class Dirichlet(distribution.Distribution):
|
|||||||
"""`P[x]`, computed for every batch member.
|
"""`P[x]`, computed for every batch member.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Non-negative `float`, `double` tensor whose shape can
|
x: Non-negative tensor with dtype `dtype` and whose shape can
|
||||||
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||||
dimension represents x for the corresponding Dirichlet distribution in
|
dimension represents x for the corresponding Dirichlet distribution in
|
||||||
`self.alpha` and `self.beta`. `x` is only legal if it sums up to one.
|
`self.alpha` and `self.beta`. `x` is only legal if it sums up to one.
|
||||||
@ -407,7 +391,8 @@ class Dirichlet(distribution.Distribution):
|
|||||||
x = ops.convert_to_tensor(x, name="x_before_deps")
|
x = ops.convert_to_tensor(x, name="x_before_deps")
|
||||||
candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1])
|
candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1])
|
||||||
one = constant_op.constant(1., self.dtype)
|
one = constant_op.constant(1., self.dtype)
|
||||||
dependencies = [check_ops.assert_positive(x), check_ops.assert_less(x, one),
|
dependencies = [check_ops.assert_positive(x), check_ops.assert_less(
|
||||||
_assert_close(one, candidate_one)
|
x, one, message="x has components greater than or equal to 1"),
|
||||||
|
distribution_util.assert_close(one, candidate_one)
|
||||||
] if self.validate_args else []
|
] if self.validate_args else []
|
||||||
return control_flow_ops.with_dependencies(dependencies, x)
|
return control_flow_ops.with_dependencies(dependencies, x)
|
||||||
|
@ -13,13 +13,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The Dirichlet Multinomial distribution class."""
|
"""The Dirichlet Multinomial distribution class."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
@ -30,34 +32,6 @@ from tensorflow.python.ops import special_math_ops
|
|||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
|
|
||||||
def _assert_integer_form(x):
|
|
||||||
"""Check x for integer components (or floats that are equal to integers)."""
|
|
||||||
x = ops.convert_to_tensor(x, name='x')
|
|
||||||
casted_x = math_ops.to_int64(x)
|
|
||||||
return check_ops.assert_equal(x, math_ops.cast(
|
|
||||||
math_ops.round(casted_x), x.dtype))
|
|
||||||
|
|
||||||
|
|
||||||
def _log_combinations(n, counts, name='log_combinations'):
|
|
||||||
"""Log number of ways counts could have come in."""
|
|
||||||
# First a bit about the number of ways counts could have come in:
|
|
||||||
# E.g. if counts = [1, 2], then this is 3 choose 2.
|
|
||||||
# In general, this is (sum counts)! / sum(counts!)
|
|
||||||
# The sum should be along the last dimension of counts. This is the
|
|
||||||
# "distribution" dimension. Here n a priori represents the sum of counts.
|
|
||||||
with ops.op_scope([counts], name):
|
|
||||||
# To compute factorials, use the fact that Gamma(n + 1) = n!
|
|
||||||
# Compute two terms, each a sum over counts. Compute each for each
|
|
||||||
# batch member.
|
|
||||||
# Log Gamma((sum counts) + 1) = Log((sum counts)!)
|
|
||||||
total_permutations = math_ops.lgamma(n + 1)
|
|
||||||
# sum(Log Gamma(counts + 1)) = Log sum(counts!)
|
|
||||||
counts_factorial = math_ops.lgamma(counts + 1)
|
|
||||||
redundant_permutations = math_ops.reduce_sum(counts_factorial,
|
|
||||||
reduction_indices=[-1])
|
|
||||||
return total_permutations - redundant_permutations
|
|
||||||
|
|
||||||
|
|
||||||
class DirichletMultinomial(distribution.Distribution):
|
class DirichletMultinomial(distribution.Distribution):
|
||||||
"""DirichletMultinomial mixture distribution.
|
"""DirichletMultinomial mixture distribution.
|
||||||
|
|
||||||
@ -126,38 +100,35 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
counts = [2, 1, 0]
|
counts = [2, 1, 0]
|
||||||
dist.pmf(counts) # Shape [2]
|
dist.pmf(counts) # Shape [2]
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO(b/27419586) Change docstring for dtype of alpha once int allowed.
|
# TODO(b/27419586) Change docstring for dtype of alpha once int allowed.
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
n,
|
n,
|
||||||
alpha,
|
alpha,
|
||||||
allow_arbitrary_counts=False,
|
|
||||||
validate_args=True,
|
validate_args=True,
|
||||||
allow_nan_stats=False,
|
allow_nan_stats=False,
|
||||||
name='DirichletMultinomial'):
|
name="DirichletMultinomial"):
|
||||||
"""Initialize a batch of DirichletMultinomial distributions.
|
"""Initialize a batch of DirichletMultinomial distributions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n: Non-negative `float` or `double` tensor with shape
|
n: Non-negative floating point tensor, whose dtype is the same as
|
||||||
broadcastable to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch
|
`alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`.
|
||||||
of `N1 x ... x Nm` different Dirichlet multinomial distributions. Its
|
Defines this as a batch of `N1 x ... x Nm` different Dirichlet
|
||||||
components should be equal to integral values.
|
multinomial distributions. Its components should be equal to integer
|
||||||
alpha: Positive `float` or `double` tensor with shape broadcastable to
|
values.
|
||||||
`[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
alpha: Positive floating point tensor, whose dtype is the same as
|
||||||
different `k` class Dirichlet multinomial distributions.
|
`n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines
|
||||||
allow_arbitrary_counts: Boolean. This represents whether the pmf/cdf
|
this as a batch of `N1 x ... x Nm` different `k` class Dirichlet
|
||||||
allows for the `counts` tensor to be non-integral values.
|
multinomial distributions.
|
||||||
The pmf/cdf are functions that can be evaluated at non-integral values,
|
|
||||||
but are only a distribution over non-negative integers. If
|
|
||||||
`validate_args` is `False`, this assertion is turned off.
|
|
||||||
validate_args: Whether to assert valid values for parameters `alpha` and
|
validate_args: Whether to assert valid values for parameters `alpha` and
|
||||||
`n`, and `x` in `prob` and `log_prob`. If False, correct behavior is
|
`n`, and `x` in `prob` and `log_prob`. If `False`, correct behavior is
|
||||||
not guaranteed.
|
not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prefix Ops created by this distribution class.
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@ -170,11 +141,11 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
# Define a 2-batch of 3-class distributions.
|
# Define a 2-batch of 3-class distributions.
|
||||||
dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self._allow_nan_stats = allow_nan_stats
|
self._allow_nan_stats = allow_nan_stats
|
||||||
self._validate_args = validate_args
|
self._validate_args = validate_args
|
||||||
self._name = name
|
self._name = name
|
||||||
self._allow_arbitrary_counts = allow_arbitrary_counts
|
|
||||||
with ops.op_scope([n, alpha], name):
|
with ops.op_scope([n, alpha], name):
|
||||||
# Broadcasting works because:
|
# Broadcasting works because:
|
||||||
# * The broadcasting convention is to prepend dimensions of size [1], and
|
# * The broadcasting convention is to prepend dimensions of size [1], and
|
||||||
@ -186,8 +157,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
# * All calls involving `counts` eventually require a broadcast between
|
# * All calls involving `counts` eventually require a broadcast between
|
||||||
# `counts` and alpha.
|
# `counts` and alpha.
|
||||||
self._alpha = self._check_alpha(alpha)
|
self._alpha = self._check_alpha(alpha)
|
||||||
n = self._check_n(n)
|
self._n = self._check_n(n)
|
||||||
self._n = math_ops.cast(n, self._alpha.dtype)
|
|
||||||
|
|
||||||
self._alpha_sum = math_ops.reduce_sum(
|
self._alpha_sum = math_ops.reduce_sum(
|
||||||
self._alpha, reduction_indices=[-1], keep_dims=False)
|
self._alpha, reduction_indices=[-1], keep_dims=False)
|
||||||
@ -227,7 +197,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
"""dtype of samples from this distribution."""
|
"""dtype of samples from this distribution."""
|
||||||
return self._alpha.dtype
|
return self._alpha.dtype
|
||||||
|
|
||||||
def mean(self, name='mean'):
|
def mean(self, name="mean"):
|
||||||
"""Class means for every batch member."""
|
"""Class means for every batch member."""
|
||||||
alpha = self._alpha
|
alpha = self._alpha
|
||||||
alpha_sum = self._alpha_sum
|
alpha_sum = self._alpha_sum
|
||||||
@ -237,7 +207,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1)
|
mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1)
|
||||||
return array_ops.expand_dims(n, -1) * mean_no_n
|
return array_ops.expand_dims(n, -1) * mean_no_n
|
||||||
|
|
||||||
def variance(self, name='mean'):
|
def variance(self, name="mean"):
|
||||||
"""Class variances for every batch member.
|
"""Class variances for every batch member.
|
||||||
|
|
||||||
The variance for each batch member is defined as the following:
|
The variance for each batch member is defined as the following:
|
||||||
@ -279,7 +249,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
variance *= array_ops.expand_dims(shared_factor, -1)
|
variance *= array_ops.expand_dims(shared_factor, -1)
|
||||||
return variance
|
return variance
|
||||||
|
|
||||||
def batch_shape(self, name='batch_shape'):
|
def batch_shape(self, name="batch_shape"):
|
||||||
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
The product of the dimensions of the `batch_shape` is the number of
|
The product of the dimensions of the `batch_shape` is the number of
|
||||||
@ -305,7 +275,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
"""
|
"""
|
||||||
return self._get_batch_shape
|
return self._get_batch_shape
|
||||||
|
|
||||||
def event_shape(self, name='event_shape'):
|
def event_shape(self, name="event_shape"):
|
||||||
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -328,15 +298,15 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
"""
|
"""
|
||||||
return self._get_event_shape
|
return self._get_event_shape
|
||||||
|
|
||||||
def cdf(self, x, name='cdf'):
|
def cdf(self, x, name="cdf"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'DirichletMultinomial does not have a well-defined cdf.')
|
"DirichletMultinomial does not have a well-defined cdf.")
|
||||||
|
|
||||||
def log_cdf(self, x, name='log_cdf'):
|
def log_cdf(self, x, name="log_cdf"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'DirichletMultinomial does not have a well-defined cdf.')
|
"DirichletMultinomial does not have a well-defined cdf.")
|
||||||
|
|
||||||
def log_prob(self, counts, name='log_prob'):
|
def log_prob(self, counts, name="log_prob"):
|
||||||
"""`Log(P[counts])`, computed for every batch member.
|
"""`Log(P[counts])`, computed for every batch member.
|
||||||
|
|
||||||
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
||||||
@ -346,12 +316,11 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
probability includes a combinatorial coefficient.
|
probability includes a combinatorial coefficient.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
counts: Non-negative `float` or `double` tensor whose shape can
|
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||||
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||||
dimension represents counts for the corresponding Dirichlet Multinomial
|
dimension represents counts for the corresponding Dirichlet Multinomial
|
||||||
distribution in `self.alpha`. `counts` is only legal if it sums up to
|
distribution in `self.alpha`. `counts` is only legal if it sums up to
|
||||||
`n` and its components are equal to integral values. The second
|
`n` and its components are equal to integer values.
|
||||||
condition is relaxed if `allow_arbitrary_counts` is set.
|
|
||||||
name: Name to give this Op, defaults to "log_prob".
|
name: Name to give this Op, defaults to "log_prob".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -362,25 +331,14 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
with ops.name_scope(self.name):
|
with ops.name_scope(self.name):
|
||||||
with ops.op_scope([n, alpha, counts], name):
|
with ops.op_scope([n, alpha, counts], name):
|
||||||
counts = self._check_counts(counts)
|
counts = self._check_counts(counts)
|
||||||
# Use the same dtype as alpha for computations.
|
|
||||||
counts = math_ops.cast(counts, self.dtype)
|
|
||||||
|
|
||||||
ordered_prob = (special_math_ops.lbeta(alpha + counts) -
|
ordered_prob = (special_math_ops.lbeta(alpha + counts) -
|
||||||
special_math_ops.lbeta(alpha))
|
special_math_ops.lbeta(alpha))
|
||||||
log_prob = ordered_prob + _log_combinations(n, counts)
|
log_prob = ordered_prob + distribution_util.log_combinations(
|
||||||
# If alpha = counts = [[]], ordered_prob carries the right shape, which
|
n, counts)
|
||||||
# is []. However, since reduce_sum([[]]) = [0], log_combinations = [0],
|
|
||||||
# which is not correct. Luckily, [] + [0] = [], so the sum is fine, but
|
|
||||||
# shape must be inferred from ordered_prob. We must also make this
|
|
||||||
# broadcastable with n, so this is multiplied by n to ensure the shape
|
|
||||||
# is correctly inferred.
|
|
||||||
# Note also that tf.constant([]).get_shape() =
|
|
||||||
# TensorShape([Dimension(0)])
|
|
||||||
broadcasted_tensor = ordered_prob * n
|
|
||||||
log_prob.set_shape(broadcasted_tensor.get_shape())
|
|
||||||
return log_prob
|
return log_prob
|
||||||
|
|
||||||
def prob(self, counts, name='prob'):
|
def prob(self, counts, name="prob"):
|
||||||
"""`P[counts]`, computed for every batch member.
|
"""`P[counts]`, computed for every batch member.
|
||||||
|
|
||||||
For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability
|
For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability
|
||||||
@ -390,12 +348,11 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
probability includes a combinatorial coefficient.
|
probability includes a combinatorial coefficient.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
counts: Non-negative `float`, `double` tensor whose shape can
|
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||||
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||||
dimension represents counts for the corresponding Dirichlet Multinomial
|
dimension represents counts for the corresponding Dirichlet Multinomial
|
||||||
distribution in `self.alpha`. `counts` is only legal if it sums up to
|
distribution in `self.alpha`. `counts` is only legal if it sums up to
|
||||||
`n` and its components are equal to integral values. The second
|
`n` and its components are equal to integer values.
|
||||||
condition is relaxed if `allow_arbitrary_counts` is set.
|
|
||||||
name: Name to give this Op, defaults to "prob".
|
name: Name to give this Op, defaults to "prob".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -405,21 +362,21 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
|
|
||||||
def _check_counts(self, counts):
|
def _check_counts(self, counts):
|
||||||
"""Check counts for proper shape, values, then return tensor version."""
|
"""Check counts for proper shape, values, then return tensor version."""
|
||||||
counts = ops.convert_to_tensor(counts, name='counts')
|
counts = ops.convert_to_tensor(counts, name="counts")
|
||||||
if not self.validate_args:
|
if not self.validate_args:
|
||||||
return counts
|
return counts
|
||||||
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
|
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
|
||||||
dependencies = [check_ops.assert_non_negative(counts),
|
|
||||||
check_ops.assert_equal(self._n,
|
|
||||||
math_ops.cast(candidate_n,
|
|
||||||
self._n.dtype))]
|
|
||||||
if not self._allow_arbitrary_counts:
|
|
||||||
dependencies += [_assert_integer_form(counts)]
|
|
||||||
|
|
||||||
return control_flow_ops.with_dependencies(dependencies, counts)
|
return control_flow_ops.with_dependencies([
|
||||||
|
check_ops.assert_non_negative(counts),
|
||||||
|
check_ops.assert_equal(
|
||||||
|
self._n, candidate_n,
|
||||||
|
message="counts do not sum to n"
|
||||||
|
),
|
||||||
|
distribution_util.assert_integer_form(counts)], counts)
|
||||||
|
|
||||||
def _check_alpha(self, alpha):
|
def _check_alpha(self, alpha):
|
||||||
alpha = ops.convert_to_tensor(alpha, name='alpha')
|
alpha = ops.convert_to_tensor(alpha, name="alpha")
|
||||||
if not self.validate_args:
|
if not self.validate_args:
|
||||||
return alpha
|
return alpha
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
@ -427,11 +384,12 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
check_ops.assert_positive(alpha)], alpha)
|
check_ops.assert_positive(alpha)], alpha)
|
||||||
|
|
||||||
def _check_n(self, n):
|
def _check_n(self, n):
|
||||||
n = ops.convert_to_tensor(n, name='n')
|
n = ops.convert_to_tensor(n, name="n")
|
||||||
if not self.validate_args:
|
if not self.validate_args:
|
||||||
return n
|
return n
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_non_negative(n), _assert_integer_form(n)], n)
|
[check_ops.assert_non_negative(n),
|
||||||
|
distribution_util.assert_integer_form(n)], n)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_continuous(self):
|
def is_continuous(self):
|
||||||
|
177
tensorflow/contrib/distributions/python/ops/distribution_util.py
Normal file
177
tensorflow/contrib/distributions/python/ops/distribution_util.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utilities for probability distributions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import logging_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
|
def assert_close(
|
||||||
|
x, y, data=None, summarize=None, message=None, name="assert_close"):
|
||||||
|
"""Assert that that x and y are within machine epsilon of each other.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Numeric `Tensor`
|
||||||
|
y: Numeric `Tensor`
|
||||||
|
data: The tensors to print out if the condition is `False`. Defaults to
|
||||||
|
error message and first few entries of `x` and `y`.
|
||||||
|
summarize: Print this many entries of each tensor.
|
||||||
|
message: A string to prefix to the default message.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
|
||||||
|
"""
|
||||||
|
message = message or ""
|
||||||
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
|
y = ops.convert_to_tensor(y, name="y")
|
||||||
|
|
||||||
|
if x.dtype.is_integer:
|
||||||
|
return check_ops.assert_equal(
|
||||||
|
x, y, data=data, summarize=summarize, message=message, name=name)
|
||||||
|
|
||||||
|
with ops.op_scope([x, y, data], name, "assert_close"):
|
||||||
|
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
|
||||||
|
if data is None:
|
||||||
|
data = [
|
||||||
|
message,
|
||||||
|
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
|
||||||
|
y.name, y
|
||||||
|
]
|
||||||
|
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
|
||||||
|
return logging_ops.Assert(
|
||||||
|
condition, data, summarize=summarize)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_integer_form(
|
||||||
|
x, data=None, summarize=None, message=None, name="assert_integer_form"):
|
||||||
|
"""Assert that x has integer components (or floats equal to integers).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Numeric `Tensor`
|
||||||
|
data: The tensors to print out if the condition is `False`. Defaults to
|
||||||
|
error message and first few entries of `x` and `y`.
|
||||||
|
summarize: Print this many entries of each tensor.
|
||||||
|
message: A string to prefix to the default message.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Op raising `InvalidArgumentError` if round(x) != x.
|
||||||
|
"""
|
||||||
|
|
||||||
|
message = message or "x has non-integer components"
|
||||||
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
|
casted_x = math_ops.to_int64(x)
|
||||||
|
return check_ops.assert_equal(
|
||||||
|
x, math_ops.cast(math_ops.round(casted_x), x.dtype),
|
||||||
|
data=data, summarize=summarize, message=message, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_and_prob(
|
||||||
|
logits=None, p=None, multidimensional=False, validate_args=True, name=None):
|
||||||
|
"""Converts logits to probabilities and vice-versa, and returns both.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Numeric `Tensor` representing log-odds.
|
||||||
|
p: Numeric `Tensor` representing probabilities.
|
||||||
|
multidimensional: Given `p` a [N1, N2, ... k] dimensional tensor,
|
||||||
|
whether the last dimension represents the probability between k classes.
|
||||||
|
This will additionally assert that the values in the last dimension
|
||||||
|
sum to one. If `False`, will instead assert that each value is in
|
||||||
|
`[0, 1]`.
|
||||||
|
validate_args: Whether to assert `0 <= p <= 1` if multidimensional is
|
||||||
|
`False`, otherwise that the last dimension of `p` sums to one.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple with `logits` and `p`. If `p` has an entry that is `0` or `1`, then
|
||||||
|
the corresponding entry in the returned logits will be `-Inf` and `Inf`
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if neither `p` nor `logits` were passed in, or both were.
|
||||||
|
"""
|
||||||
|
if p is None and logits is None:
|
||||||
|
raise ValueError("Must pass p or logits.")
|
||||||
|
elif p is not None and logits is not None:
|
||||||
|
raise ValueError("Must pass either p or logits, not both.")
|
||||||
|
elif p is None:
|
||||||
|
with ops.op_scope([logits], name):
|
||||||
|
logits = array_ops.identity(logits, name="logits")
|
||||||
|
with ops.name_scope(name):
|
||||||
|
with ops.name_scope("p"):
|
||||||
|
p = math_ops.sigmoid(logits)
|
||||||
|
elif logits is None:
|
||||||
|
with ops.name_scope(name):
|
||||||
|
with ops.name_scope("p"):
|
||||||
|
p = array_ops.identity(p)
|
||||||
|
if validate_args:
|
||||||
|
one = constant_op.constant(1., p.dtype)
|
||||||
|
dependencies = [check_ops.assert_non_negative(p)]
|
||||||
|
if multidimensional:
|
||||||
|
dependencies += [assert_close(
|
||||||
|
math_ops.reduce_sum(p, reduction_indices=[-1]),
|
||||||
|
one, message="p does not sum to 1.")]
|
||||||
|
else:
|
||||||
|
dependencies += [check_ops.assert_less_equal(
|
||||||
|
p, one, message="p has components greater than 1.")]
|
||||||
|
p = control_flow_ops.with_dependencies(dependencies, p)
|
||||||
|
with ops.name_scope("logits"):
|
||||||
|
logits = math_ops.log(p) - math_ops.log(1. - p)
|
||||||
|
return (logits, p)
|
||||||
|
|
||||||
|
|
||||||
|
def log_combinations(n, counts, name="log_combinations"):
|
||||||
|
"""Multinomial coefficient.
|
||||||
|
|
||||||
|
Given `n` and `counts`, where `counts` has last dimension `k`, we compute
|
||||||
|
the multinomial coefficient as:
|
||||||
|
|
||||||
|
```n! / sum_i n_i!```
|
||||||
|
|
||||||
|
where `i` runs over all `k` classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Numeric `Tensor` broadcastable with `counts`. This represents `n`
|
||||||
|
outcomes.
|
||||||
|
counts: Numeric `Tensor` broadcastable with `n`. This represents counts
|
||||||
|
in `k` classes, where `k` is the last dimension of the tensor.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` representing the multinomial coefficient between `n` and `counts`.
|
||||||
|
"""
|
||||||
|
# First a bit about the number of ways counts could have come in:
|
||||||
|
# E.g. if counts = [1, 2], then this is 3 choose 2.
|
||||||
|
# In general, this is (sum counts)! / sum(counts!)
|
||||||
|
# The sum should be along the last dimension of counts. This is the
|
||||||
|
# "distribution" dimension. Here n a priori represents the sum of counts.
|
||||||
|
with ops.op_scope([n, counts], name):
|
||||||
|
total_permutations = math_ops.lgamma(n + 1)
|
||||||
|
counts_factorial = math_ops.lgamma(counts + 1)
|
||||||
|
redundant_permutations = math_ops.reduce_sum(counts_factorial,
|
||||||
|
reduction_indices=[-1])
|
||||||
|
return total_permutations - redundant_permutations
|
@ -46,15 +46,15 @@ class Exponential(gamma.Gamma):
|
|||||||
"""Construct Exponential distribution with parameter `lam`.
|
"""Construct Exponential distribution with parameter `lam`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lam: `float` or `double` tensor, the rate of the distribution(s).
|
lam: Floating point tensor, the rate of the distribution(s).
|
||||||
`lam` must contain only positive values.
|
`lam` must contain only positive values.
|
||||||
validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the
|
validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the
|
||||||
methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||||
and the inputs are invalid, correct behavior is not guaranteed.
|
and the inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prepend to all ops created by this distribution.
|
name: The name to prepend to all ops created by this distribution.
|
||||||
"""
|
"""
|
||||||
# Even though all statistics of are defined for valid inputs, this is not
|
# Even though all statistics of are defined for valid inputs, this is not
|
||||||
@ -95,13 +95,13 @@ class Exponential(gamma.Gamma):
|
|||||||
broadcast_shape = self._lam.get_shape()
|
broadcast_shape = self._lam.get_shape()
|
||||||
with ops.op_scope([self.lam, n], name, "ExponentialSample"):
|
with ops.op_scope([self.lam, n], name, "ExponentialSample"):
|
||||||
n = ops.convert_to_tensor(n, name="n")
|
n = ops.convert_to_tensor(n, name="n")
|
||||||
shape = array_ops.concat(
|
shape = array_ops.concat(0, ([n], array_ops.shape(self._lam)))
|
||||||
0, [array_ops.pack([n]), array_ops.shape(self._lam)])
|
|
||||||
# Sample uniformly-at-random from the open-interval (0, 1).
|
# Sample uniformly-at-random from the open-interval (0, 1).
|
||||||
sampled = random_ops.random_uniform(
|
sampled = random_ops.random_uniform(
|
||||||
shape, minval=np.nextafter(
|
shape, minval=np.nextafter(
|
||||||
self.dtype.as_numpy_dtype(0.), self.dtype.as_numpy_dtype(1.)),
|
self.dtype.as_numpy_dtype(0.), self.dtype.as_numpy_dtype(1.)),
|
||||||
maxval=constant_op.constant(1.0, dtype=self.dtype),
|
maxval=constant_op.constant(1.0, dtype=self.dtype),
|
||||||
|
seed=seed,
|
||||||
dtype=self.dtype)
|
dtype=self.dtype)
|
||||||
|
|
||||||
n_val = tensor_util.constant_value(n)
|
n_val = tensor_util.constant_value(n)
|
||||||
|
@ -69,19 +69,19 @@ class Gamma(distribution.Distribution):
|
|||||||
broadcasting (e.g. `alpha + beta` is a valid operation).
|
broadcasting (e.g. `alpha + beta` is a valid operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
alpha: `float` or `double` tensor, the shape params of the
|
alpha: Floating point tensor, the shape params of the
|
||||||
distribution(s).
|
distribution(s).
|
||||||
alpha must contain only positive values.
|
alpha must contain only positive values.
|
||||||
beta: `float` or `double` tensor, the inverse scale params of the
|
beta: Floating point tensor, the inverse scale params of the
|
||||||
distribution(s).
|
distribution(s).
|
||||||
beta must contain only positive values.
|
beta must contain only positive values.
|
||||||
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
|
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
|
||||||
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||||
and the inputs are invalid, correct behavior is not guaranteed.
|
and the inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prepend to all ops created by this distribution.
|
name: The name to prepend to all ops created by this distribution.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -213,9 +213,12 @@ class Gamma(distribution.Distribution):
|
|||||||
nan = np.nan * self._ones()
|
nan = np.nan * self._ones()
|
||||||
return math_ops.select(alpha_ge_1, mode_if_defined, nan)
|
return math_ops.select(alpha_ge_1, mode_if_defined, nan)
|
||||||
else:
|
else:
|
||||||
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
|
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_less(one, alpha)], mode_if_defined)
|
[check_ops.assert_less(
|
||||||
|
one, alpha,
|
||||||
|
message="mode not defined for components of alpha <= 1"
|
||||||
|
)], mode_if_defined)
|
||||||
|
|
||||||
def variance(self, name="variance"):
|
def variance(self, name="variance"):
|
||||||
"""Variance of each batch member."""
|
"""Variance of each batch member."""
|
||||||
|
@ -69,18 +69,18 @@ class InverseGamma(distribution.Distribution):
|
|||||||
broadcasting (e.g. `alpha + beta` is a valid operation).
|
broadcasting (e.g. `alpha + beta` is a valid operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
alpha: `float` or `double` tensor, the shape params of the
|
alpha: Floating point tensor, the shape params of the
|
||||||
distribution(s).
|
distribution(s).
|
||||||
alpha must contain only positive values.
|
alpha must contain only positive values.
|
||||||
beta: `float` or `double` tensor, the scale params of the distribution(s).
|
beta: Floating point tensor, the scale params of the distribution(s).
|
||||||
beta must contain only positive values.
|
beta must contain only positive values.
|
||||||
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
|
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
|
||||||
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||||
and the inputs are invalid, correct behavior is not guaranteed.
|
and the inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prepend to all ops created by this distribution.
|
name: The name to prepend to all ops created by this distribution.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -206,9 +206,12 @@ class InverseGamma(distribution.Distribution):
|
|||||||
nan = np.nan * self._ones()
|
nan = np.nan * self._ones()
|
||||||
return math_ops.select(alpha_gt_1, mean_if_defined, nan)
|
return math_ops.select(alpha_gt_1, mean_if_defined, nan)
|
||||||
else:
|
else:
|
||||||
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
|
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_less(one, alpha)], mean_if_defined)
|
[check_ops.assert_less(
|
||||||
|
one, alpha,
|
||||||
|
message="mean not defined for components of alpha <= 1")],
|
||||||
|
mean_if_defined)
|
||||||
|
|
||||||
def mode(self, name="mode"):
|
def mode(self, name="mode"):
|
||||||
"""Mode of each batch member.
|
"""Mode of each batch member.
|
||||||
@ -250,9 +253,12 @@ class InverseGamma(distribution.Distribution):
|
|||||||
nan = np.nan * self._ones()
|
nan = np.nan * self._ones()
|
||||||
return math_ops.select(alpha_gt_2, var_if_defined, nan)
|
return math_ops.select(alpha_gt_2, var_if_defined, nan)
|
||||||
else:
|
else:
|
||||||
two = ops.convert_to_tensor(2.0, dtype=self.dtype)
|
two = constant_op.constant(2.0, dtype=self.dtype)
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_less(two, alpha)], var_if_defined)
|
[check_ops.assert_less(
|
||||||
|
two, alpha,
|
||||||
|
message="variance not defined for components of alpha <= 2")],
|
||||||
|
var_if_defined)
|
||||||
|
|
||||||
def log_prob(self, x, name="log_prob"):
|
def log_prob(self, x, name="log_prob"):
|
||||||
"""Log prob of observations in `x` under these InverseGamma distribution(s).
|
"""Log prob of observations in `x` under these InverseGamma distribution(s).
|
||||||
|
@ -34,9 +34,9 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
|
|||||||
Args:
|
Args:
|
||||||
dist_a: instance of distributions.Distribution.
|
dist_a: instance of distributions.Distribution.
|
||||||
dist_b: instance of distributions.Distribution.
|
dist_b: instance of distributions.Distribution.
|
||||||
allow_nan: If False (default), a runtime error is raised
|
allow_nan: If `False` (default), a runtime error is raised
|
||||||
if the KL returns NaN values for any batch entry of the given
|
if the KL returns NaN values for any batch entry of the given
|
||||||
distributions. If True, the KL may return a NaN for the given entry.
|
distributions. If `True`, the KL may return a NaN for the given entry.
|
||||||
name: (optional) Name scope to use for created operations.
|
name: (optional) Name scope to use for created operations.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -60,17 +60,17 @@ class Laplace(distribution.Distribution):
|
|||||||
broadcasting (e.g., `loc / scale` is a valid operation).
|
broadcasting (e.g., `loc / scale` is a valid operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loc: `float` or `double` tensor which characterizes the location (center)
|
loc: Floating point tensor which characterizes the location (center)
|
||||||
of the distribution.
|
of the distribution.
|
||||||
scale: `float` or `double`, positive-valued tensor which characterzes the
|
scale: Positive floating point tensor which characterizes the spread of
|
||||||
spread of the distribution.
|
the distribution.
|
||||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||||
is `False`, and the inputs are invalid, correct behavior is not
|
is `False`, and the inputs are invalid, correct behavior is not
|
||||||
guaranteed.
|
guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to give Ops created by the initializer.
|
name: The name to give Ops created by the initializer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -294,8 +294,7 @@ class Laplace(distribution.Distribution):
|
|||||||
with ops.op_scope([self._loc, self._scale, n], name):
|
with ops.op_scope([self._loc, self._scale, n], name):
|
||||||
n = ops.convert_to_tensor(n)
|
n = ops.convert_to_tensor(n)
|
||||||
n_val = tensor_util.constant_value(n)
|
n_val = tensor_util.constant_value(n)
|
||||||
shape = array_ops.concat(
|
shape = array_ops.concat(0, ([n], self.batch_shape()))
|
||||||
0, [array_ops.pack([n]), self.batch_shape()])
|
|
||||||
# Sample uniformly-at-random from the open-interval (-1, 1).
|
# Sample uniformly-at-random from the open-interval (-1, 1).
|
||||||
uniform_samples = random_ops.random_uniform(
|
uniform_samples = random_ops.random_uniform(
|
||||||
shape=shape,
|
shape=shape,
|
||||||
|
343
tensorflow/contrib/distributions/python/ops/multinomial.py
Normal file
343
tensorflow/contrib/distributions/python/ops/multinomial.py
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""The Multinomial distribution class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=line-too-long
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
|
|
||||||
|
class Multinomial(distribution.Distribution):
|
||||||
|
"""Multinomial distribution.
|
||||||
|
|
||||||
|
This distribution is parameterized by a vector `p` of probability
|
||||||
|
parameters for `k` classes and `n`, the counts per each class..
|
||||||
|
|
||||||
|
#### Mathematical details
|
||||||
|
|
||||||
|
The Multinomial is a distribution over k-class count data, meaning
|
||||||
|
for each k-tuple of non-negative integer `counts = [n_1,...,n_k]`, we have a
|
||||||
|
probability of these draws being made from the distribution. The distribution
|
||||||
|
has hyperparameters `p = (p_1,...,p_k)`, and probability mass
|
||||||
|
function (pmf):
|
||||||
|
|
||||||
|
```pmf(counts) = n! / (n_1!...n_k!) * (p_1)^n_1*(p_2)^n_2*...(p_k)^n_k```
|
||||||
|
|
||||||
|
where above `n = sum_j n_j`, `n!` is `n` factorial.
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
Create a 3-class distribution, with the 3rd class is most likely to be drawn,
|
||||||
|
using logits..
|
||||||
|
|
||||||
|
```python
|
||||||
|
logits = [-50., -43, 0]
|
||||||
|
dist = Multinomial(n=4., logits=logits)
|
||||||
|
```
|
||||||
|
|
||||||
|
Create a 3-class distribution, with the 3rd class is most likely to be drawn.
|
||||||
|
|
||||||
|
```python
|
||||||
|
p = [.2, .3, .5]
|
||||||
|
dist = Multinomial(n=4., p=p)
|
||||||
|
```
|
||||||
|
|
||||||
|
The distribution functions can be evaluated on counts.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# counts same shape as p.
|
||||||
|
counts = [1., 0, 3]
|
||||||
|
dist.prob(counts) # Shape []
|
||||||
|
|
||||||
|
# p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
|
||||||
|
counts = [[1., 2, 1], [2, 2, 0]]
|
||||||
|
dist.prob(counts) # Shape [2]
|
||||||
|
|
||||||
|
# p will be broadcast to shape [5, 7, 3] to match counts.
|
||||||
|
counts = [[...]] # Shape [5, 7, 3]
|
||||||
|
dist.prob(counts) # Shape [5, 7]
|
||||||
|
```
|
||||||
|
|
||||||
|
Create a 2-batch of 3-class distributions.
|
||||||
|
|
||||||
|
```python
|
||||||
|
p = [[.1, .2, .7], [.3, .3, .4]] # Shape [2, 3]
|
||||||
|
dist = Multinomial(n=[4., 5], p=p)
|
||||||
|
|
||||||
|
counts = [[2., 1, 1], [3, 1, 1]]
|
||||||
|
dist.prob(counts) # Shape [2]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n,
|
||||||
|
logits=None,
|
||||||
|
p=None,
|
||||||
|
validate_args=True,
|
||||||
|
allow_nan_stats=False,
|
||||||
|
name="Multinomial"):
|
||||||
|
"""Initialize a batch of Multinomial distributions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Non-negative floating point tensor with shape broadcastable to
|
||||||
|
`[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
|
||||||
|
`N1 x ... x Nm` different Multinomial distributions. Its components
|
||||||
|
should be equal to integer values.
|
||||||
|
logits: Floating point tensor representing the log-odds of a
|
||||||
|
positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
|
||||||
|
and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm`
|
||||||
|
different `k` class Multinomial distributions.
|
||||||
|
p: Positive floating point tensor with shape broadcastable to
|
||||||
|
`[N1,..., Nm, k]` `m >= 0` and same dtype as `n`. Defines this as
|
||||||
|
a batch of `N1 x ... x Nm` different `k` class Multinomial
|
||||||
|
distributions. `p`'s components in the last portion of its shape should
|
||||||
|
sum up to 1.
|
||||||
|
validate_args: Whether to assert valid values for parameters `n` and `p`,
|
||||||
|
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||||
|
guaranteed.
|
||||||
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
|
undefined statistics will return NaN for this statistic.
|
||||||
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define 1-batch of 2-class multinomial distribution,
|
||||||
|
# also known as a Binomial distribution.
|
||||||
|
dist = Multinomial(n=2., p=[.1, .9])
|
||||||
|
|
||||||
|
# Define a 2-batch of 3-class distributions.
|
||||||
|
dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]])
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._logits, self._p = distribution_util.get_logits_and_prob(
|
||||||
|
name=name, logits=logits, p=p, validate_args=validate_args,
|
||||||
|
multidimensional=True)
|
||||||
|
with ops.op_scope([n, self._p], name):
|
||||||
|
with ops.control_dependencies([
|
||||||
|
check_ops.assert_non_negative(
|
||||||
|
n, message="n has negative components."),
|
||||||
|
distribution_util.assert_integer_form(
|
||||||
|
n, message="n has non-integer components."
|
||||||
|
)] if validate_args else []):
|
||||||
|
self._n = array_ops.identity(n, name="convert_n")
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
self._validate_args = validate_args
|
||||||
|
self._allow_nan_stats = allow_nan_stats
|
||||||
|
|
||||||
|
self._mean = array_ops.expand_dims(n, -1) * self._p
|
||||||
|
# Only used for inferring shape.
|
||||||
|
self._broadcast_shape = math_ops.reduce_sum(self._mean,
|
||||||
|
reduction_indices=[-1],
|
||||||
|
keep_dims=False)
|
||||||
|
|
||||||
|
self._get_batch_shape = self._broadcast_shape.get_shape()
|
||||||
|
self._get_event_shape = (
|
||||||
|
self._mean.get_shape().with_rank_at_least(1)[-1:])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n(self):
|
||||||
|
"""Number of trials."""
|
||||||
|
return self._n
|
||||||
|
|
||||||
|
@property
|
||||||
|
def p(self):
|
||||||
|
"""Event probabilities."""
|
||||||
|
return self._p
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits(self):
|
||||||
|
"""Log-odds."""
|
||||||
|
return self._logits
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
"""Name to prepend to all ops."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""dtype of samples from this distribution."""
|
||||||
|
return self._p.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def validate_args(self):
|
||||||
|
"""Boolean describing behavior on invalid input."""
|
||||||
|
return self._validate_args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_nan_stats(self):
|
||||||
|
"""Boolean describing behavior when a stat is undefined for batch member."""
|
||||||
|
return self._allow_nan_stats
|
||||||
|
|
||||||
|
def batch_shape(self, name="batch_shape"):
|
||||||
|
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
|
The product of the dimensions of the `batch_shape` is the number of
|
||||||
|
independent distributions of this kind the instance represents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name to give to the op
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` `batch_shape`
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._broadcast_shape], name):
|
||||||
|
return array_ops.shape(self._broadcast_shape)
|
||||||
|
|
||||||
|
def get_batch_shape(self):
|
||||||
|
"""`TensorShape` available at graph construction time.
|
||||||
|
|
||||||
|
Same meaning as `batch_shape`. May be only partially defined.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch shape
|
||||||
|
"""
|
||||||
|
return self._get_batch_shape
|
||||||
|
|
||||||
|
def event_shape(self, name="event_shape"):
|
||||||
|
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name to give to the op
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` `event_shape`
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._mean], name):
|
||||||
|
return array_ops.gather(array_ops.shape(self._mean),
|
||||||
|
[array_ops.rank(self._mean) - 1])
|
||||||
|
|
||||||
|
def get_event_shape(self):
|
||||||
|
"""`TensorShape` available at graph construction time.
|
||||||
|
|
||||||
|
Same meaning as `event_shape`. May be only partially defined.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
event shape
|
||||||
|
"""
|
||||||
|
return self._get_event_shape
|
||||||
|
|
||||||
|
def mean(self, name="mean"):
|
||||||
|
"""Mean of the distribution."""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
return array_ops.identity(self._mean, name=name)
|
||||||
|
|
||||||
|
def variance(self, name="variance"):
|
||||||
|
"""Variance of the distribution."""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._n, self._p, self._mean], name):
|
||||||
|
p = array_ops.expand_dims(
|
||||||
|
self._p * array_ops.expand_dims(
|
||||||
|
array_ops.ones_like(self._n), -1), -1)
|
||||||
|
variance = -math_ops.batch_matmul(
|
||||||
|
array_ops.expand_dims(self._mean, -1), p, adj_y=True)
|
||||||
|
variance += array_ops.batch_matrix_diag(self._mean)
|
||||||
|
return variance
|
||||||
|
|
||||||
|
def log_prob(self, counts, name="log_prob"):
|
||||||
|
"""`Log(P[counts])`, computed for every batch member.
|
||||||
|
|
||||||
|
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
||||||
|
that after sampling `n` draws from this Multinomial distribution, the
|
||||||
|
number of draws falling in class `j` is `n_j`. Note that different
|
||||||
|
sequences of draws can result in the same counts, thus the probability
|
||||||
|
includes a combinatorial coefficient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
counts: Non-negative tensor with dtype `dtype` and whose shape can
|
||||||
|
be broadcast with `self.p` and `self.n`. For fixed leading dimensions,
|
||||||
|
the last dimension represents counts for the corresponding Multinomial
|
||||||
|
distribution in `self.p`. `counts` is only legal if it sums up to `n`
|
||||||
|
and its components are equal to integer values.
|
||||||
|
name: Name to give this Op, defaults to "log_prob".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Log probabilities for each record, shape `[N1,...,Nm]`.
|
||||||
|
"""
|
||||||
|
n = self._n
|
||||||
|
p = self._p
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([n, p, counts], name):
|
||||||
|
counts = self._check_counts(counts)
|
||||||
|
|
||||||
|
prob_prob = math_ops.reduce_sum(counts * math_ops.log(self._p),
|
||||||
|
reduction_indices=[-1])
|
||||||
|
log_prob = prob_prob + distribution_util.log_combinations(
|
||||||
|
n, counts)
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
def prob(self, counts, name="prob"):
|
||||||
|
"""`P[counts]`, computed for every batch member.
|
||||||
|
|
||||||
|
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
||||||
|
that after sampling `n` draws from this Multinomial distribution, the
|
||||||
|
number of draws falling in class `j` is `n_j`. Note that different
|
||||||
|
sequences of draws can result in the same counts, thus the probability
|
||||||
|
includes a combinatorial coefficient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
counts: Non-negative tensor with dtype `dtype` and whose shape can
|
||||||
|
be broadcast with `self.p` and `self.n`. For fixed leading dimensions,
|
||||||
|
the last dimension represents counts for the corresponding Multinomial
|
||||||
|
distribution in `self.p`. `counts` is only legal if it sums up to `n`
|
||||||
|
and its components are equal to integer values.
|
||||||
|
name: Name to give this Op, defaults to "prob".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Probabilities for each record, shape `[N1,...,Nm]`.
|
||||||
|
"""
|
||||||
|
return super(Multinomial, self).prob(counts, name=name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_continuous(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_reparameterized(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_counts(self, counts):
|
||||||
|
"""Check counts for proper shape, values, then return tensor version."""
|
||||||
|
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
|
||||||
|
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
|
||||||
|
if not self.validate_args:
|
||||||
|
return counts
|
||||||
|
|
||||||
|
return control_flow_ops.with_dependencies([
|
||||||
|
check_ops.assert_non_negative(
|
||||||
|
counts, message="counts has negative components."),
|
||||||
|
check_ops.assert_equal(
|
||||||
|
self._n, candidate_n, message="counts do not sum to n."),
|
||||||
|
distribution_util.assert_integer_form(
|
||||||
|
counts, message="counts have non-integer components.")], counts)
|
@ -21,9 +21,11 @@ from __future__ import print_function
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import distribution
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import kullback_leibler
|
||||||
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
|
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
|
||||||
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
|
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
|
||||||
from tensorflow.contrib.distributions.python.ops import operator_pd_full
|
from tensorflow.contrib.distributions.python.ops import operator_pd_full
|
||||||
|
from tensorflow.contrib.distributions.python.ops import operator_pd_vdvt_update
|
||||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
|
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -40,6 +42,7 @@ __all__ = [
|
|||||||
"MultivariateNormalDiag",
|
"MultivariateNormalDiag",
|
||||||
"MultivariateNormalCholesky",
|
"MultivariateNormalCholesky",
|
||||||
"MultivariateNormalFull",
|
"MultivariateNormalFull",
|
||||||
|
"MultivariateNormalDiagPlusVDVT",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -52,14 +55,13 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
|
|||||||
|
|
||||||
#### Mathematical details
|
#### Mathematical details
|
||||||
|
|
||||||
The PDF of this distribution is:
|
With `C` the covariance matrix represented by the operator, the PDF of this
|
||||||
|
distribution is:
|
||||||
|
|
||||||
```
|
```
|
||||||
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
|
f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
|
||||||
```
|
```
|
||||||
|
|
||||||
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
|
|
||||||
|
|
||||||
#### Examples
|
#### Examples
|
||||||
|
|
||||||
A single multi-variate Gaussian distribution is defined by a vector of means
|
A single multi-variate Gaussian distribution is defined by a vector of means
|
||||||
@ -103,16 +105,16 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
|
|||||||
which determines the covariance.
|
which determines the covariance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
|
mu: Floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
|
||||||
cov: `float` or `double` instance of `OperatorPDBase` with same `dtype`
|
cov: Instance of `OperatorPDBase` with same `dtype` as `mu` and shape
|
||||||
as `mu` and shape `[N1,...,Nb, k, k]`.
|
`[N1,...,Nb, k, k]`.
|
||||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||||
is `False`, and the inputs are invalid, correct behavior is not
|
is `False`, and the inputs are invalid, correct behavior is not
|
||||||
guaranteed.
|
guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to give Ops created by the initializer.
|
name: The name to give Ops created by the initializer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -148,7 +150,7 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
|
|||||||
else:
|
else:
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
# Static checks could not be run, so possibly do dyamic checks.
|
# Static checks could not be run, so possibly do dynamic checks.
|
||||||
if not self.validate_args:
|
if not self.validate_args:
|
||||||
return mu
|
return mu
|
||||||
else:
|
else:
|
||||||
@ -170,12 +172,12 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def validate_args(self):
|
def validate_args(self):
|
||||||
"""Boolean describing behavior on invalid input."""
|
"""`Boolean` describing behavior on invalid input."""
|
||||||
return self._validate_args
|
return self._validate_args
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def allow_nan_stats(self):
|
def allow_nan_stats(self):
|
||||||
"""Boolean describing behavior when a stat is undefined for batch member."""
|
"""`Boolean` describing behavior when stats are undefined."""
|
||||||
return self._allow_nan_stats
|
return self._allow_nan_stats
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -417,7 +419,7 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
|
|||||||
determined by `diag_stdev`: `C_{ii} = diag_stdev[i]**2`.
|
determined by `diag_stdev`: `C_{ii} = diag_stdev[i]**2`.
|
||||||
|
|
||||||
```
|
```
|
||||||
f(x) = (2*pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 * (x - mu)^T C^{-1} (x - mu))
|
f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Examples
|
#### Examples
|
||||||
@ -464,17 +466,17 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
|
|||||||
The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stdev[i]`.
|
The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stdev[i]`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: Rank `N + 1` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
mu: Rank `N + 1` floating point tensor with shape `[N1,...,Nb, k]`,
|
||||||
`b >= 0`.
|
`b >= 0`.
|
||||||
diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`,
|
diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`,
|
||||||
representing the standard deviations.
|
representing the standard deviations. Must be positive.
|
||||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||||
is `False`,
|
is `False`,
|
||||||
and the inputs are invalid, correct behavior is not guaranteed.
|
and the inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to give Ops created by the initializer.
|
name: The name to give Ops created by the initializer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -487,6 +489,125 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
|
|||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
|
class MultivariateNormalDiagPlusVDVT(MultivariateNormalOperatorPD):
|
||||||
|
"""The multivariate normal distribution on `R^k`.
|
||||||
|
|
||||||
|
Every batch member of this distribution is defined by a mean and a lightweight
|
||||||
|
covariance matrix `C`.
|
||||||
|
|
||||||
|
#### Mathematical details
|
||||||
|
|
||||||
|
The PDF of this distribution in terms of the mean `mu` and covariance `C` is:
|
||||||
|
|
||||||
|
```
|
||||||
|
f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
|
||||||
|
```
|
||||||
|
|
||||||
|
For every batch member, this distribution represents `k` random variables
|
||||||
|
`(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix
|
||||||
|
`C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]`
|
||||||
|
|
||||||
|
The user initializes this class by providing the mean `mu`, and a lightweight
|
||||||
|
definition of `C`:
|
||||||
|
|
||||||
|
```
|
||||||
|
C = SS^T = SS = (M + V D V^T) (M + V D V^T)
|
||||||
|
M is diagonal (k x k)
|
||||||
|
V = is shape (k x r), typically r << k
|
||||||
|
D = is diagonal (r x r), optional (defaults to identity).
|
||||||
|
```
|
||||||
|
|
||||||
|
This allows for `O(kr + r^3)` pdf evaluation and determinant, and `O(kr)`
|
||||||
|
sampling and storage (per batch member).
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
A single multi-variate Gaussian distribution is defined by a vector of means
|
||||||
|
of length `k`, and square root of the covariance `S = M + V D V^T`. Extra
|
||||||
|
leading dimensions, if provided, allow for batches.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Initialize a single 3-variate Gaussian with covariance square root
|
||||||
|
# S = M + V D V^T, where V D V^T is a matrix-rank 2 update.
|
||||||
|
mu = [1, 2, 3.]
|
||||||
|
diag_large = [1.1, 2.2, 3.3]
|
||||||
|
v = ... # shape 3 x 2
|
||||||
|
diag_small = [4., 5.]
|
||||||
|
dist = tf.contrib.distributions.MultivariateNormalDiagPlusVDVT(
|
||||||
|
mu, diag_large, v, diag_small=diag_small)
|
||||||
|
|
||||||
|
# Evaluate this on an observation in R^3, returning a scalar.
|
||||||
|
dist.pdf([-1, 0, 1])
|
||||||
|
|
||||||
|
# Initialize a batch of two 3-variate Gaussians. This time, don't provide
|
||||||
|
# diag_small. This means S = M + V V^T.
|
||||||
|
mu = [[1, 2, 3], [11, 22, 33]] # shape 2 x 3
|
||||||
|
diag_large = ... # shape 2 x 3
|
||||||
|
v = ... # shape 2 x 3 x 1, a matrix-rank 1 update.
|
||||||
|
dist = tf.contrib.distributions.MultivariateNormalDiagPlusVDVT(
|
||||||
|
mu, diag_large, v)
|
||||||
|
|
||||||
|
# Evaluate this on a two observations, each in R^3, returning a length two
|
||||||
|
# tensor.
|
||||||
|
x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3.
|
||||||
|
dist.pdf(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mu,
|
||||||
|
diag_large,
|
||||||
|
v,
|
||||||
|
diag_small=None,
|
||||||
|
validate_args=True,
|
||||||
|
allow_nan_stats=False,
|
||||||
|
name="MultivariateNormalDiagPlusVDVT"):
|
||||||
|
"""Multivariate Normal distributions on `R^k`.
|
||||||
|
|
||||||
|
For every batch member, this distribution represents `k` random variables
|
||||||
|
`(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix
|
||||||
|
`C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]`
|
||||||
|
|
||||||
|
The user initializes this class by providing the mean `mu`, and a
|
||||||
|
lightweight definition of `C`:
|
||||||
|
|
||||||
|
```
|
||||||
|
C = SS^T = SS = (M + V D V^T) (M + V D V^T)
|
||||||
|
M is diagonal (k x k)
|
||||||
|
V = is shape (k x r), typically r << k
|
||||||
|
D = is diagonal (r x r), optional (defaults to identity).
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mu: Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`,
|
||||||
|
`n >= 0`. The means.
|
||||||
|
diag_large: Optional rank `n + 1` floating point tensor, shape
|
||||||
|
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`.
|
||||||
|
v: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]`
|
||||||
|
`n >= 0`. Defines the matrix `V`.
|
||||||
|
diag_small: Rank `n + 1` floating point tensor, shape
|
||||||
|
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default
|
||||||
|
is `None`, which means `D` will be the identity matrix.
|
||||||
|
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||||
|
is `False`,
|
||||||
|
and the inputs are invalid, correct behavior is not guaranteed.
|
||||||
|
allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
|
||||||
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
|
batch member If `True`, batch members with valid parameters leading to
|
||||||
|
undefined statistics will return NaN for this statistic.
|
||||||
|
name: The name to give Ops created by the initializer.
|
||||||
|
"""
|
||||||
|
m = operator_pd_diag.OperatorPDDiag(diag_large, verify_pd=validate_args)
|
||||||
|
cov = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
|
||||||
|
m, v, diag=diag_small, verify_pd=validate_args,
|
||||||
|
verify_shapes=validate_args)
|
||||||
|
super(MultivariateNormalDiagPlusVDVT, self).__init__(
|
||||||
|
mu, cov, allow_nan_stats=allow_nan_stats, validate_args=validate_args,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
|
class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
|
||||||
"""The multivariate normal distribution on `R^k`.
|
"""The multivariate normal distribution on `R^k`.
|
||||||
|
|
||||||
@ -496,14 +617,14 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
|
|||||||
|
|
||||||
#### Mathematical details
|
#### Mathematical details
|
||||||
|
|
||||||
The PDF of this distribution is:
|
The Cholesky factor `chol` defines the covariance matrix: `C = chol chol^T`.
|
||||||
|
|
||||||
|
The PDF of this distribution is then:
|
||||||
|
|
||||||
```
|
```
|
||||||
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
|
f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
|
||||||
```
|
```
|
||||||
|
|
||||||
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
|
|
||||||
|
|
||||||
#### Examples
|
#### Examples
|
||||||
|
|
||||||
A single multi-variate Gaussian distribution is defined by a vector of means
|
A single multi-variate Gaussian distribution is defined by a vector of means
|
||||||
@ -546,20 +667,21 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
|
|||||||
"""Multivariate Normal distributions on `R^k`.
|
"""Multivariate Normal distributions on `R^k`.
|
||||||
|
|
||||||
User must provide means `mu` and `chol` which holds the (batch) Cholesky
|
User must provide means `mu` and `chol` which holds the (batch) Cholesky
|
||||||
factors `S`, such that the covariance of each batch member is `S S^*`.
|
factors, such that the covariance of each batch member is `chol chol^T`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
|
||||||
`b >= 0`.
|
`b >= 0`.
|
||||||
chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
||||||
`[N1,...,Nb, k, k]`.
|
`[N1,...,Nb, k, k]`. The upper triangular part is ignored (treated as
|
||||||
|
though it is zero), and the diagonal must be positive.
|
||||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||||
is `False`,
|
is `False`, and the inputs are invalid, correct behavior is not
|
||||||
and the inputs are invalid, correct behavior is not guaranteed.
|
guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to give Ops created by the initializer.
|
name: The name to give Ops created by the initializer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -582,14 +704,12 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
|
|||||||
|
|
||||||
#### Mathematical details
|
#### Mathematical details
|
||||||
|
|
||||||
The PDF of this distribution is:
|
With `C = sigma`, the PDF of this distribution is:
|
||||||
|
|
||||||
```
|
```
|
||||||
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
|
f(x) = (2 pi)^(-k/2) |det(C)|^(-1/2) exp(-1/2 (x - mu)^T C^{-1} (x - mu))
|
||||||
```
|
```
|
||||||
|
|
||||||
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
|
|
||||||
|
|
||||||
#### Examples
|
#### Examples
|
||||||
|
|
||||||
A single multi-variate Gaussian distribution is defined by a vector of means
|
A single multi-variate Gaussian distribution is defined by a vector of means
|
||||||
@ -630,17 +750,17 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
|
|||||||
User must provide means `mu` and `sigma`, the mean and covariance.
|
User must provide means `mu` and `sigma`, the mean and covariance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
|
||||||
`b >= 0`.
|
`b >= 0`.
|
||||||
sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
||||||
`[N1,...,Nb, k, k]`.
|
`[N1,...,Nb, k, k]`. Each batch member must be positive definite.
|
||||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||||
is `False`, and the inputs are invalid, correct behavior is not
|
is `False`, and the inputs are invalid, correct behavior is not
|
||||||
guaranteed.
|
guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: `Boolean`, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to give Ops created by the initializer.
|
name: The name to give Ops created by the initializer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -653,3 +773,72 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
|
|||||||
allow_nan_stats=allow_nan_stats,
|
allow_nan_stats=allow_nan_stats,
|
||||||
validate_args=validate_args,
|
validate_args=validate_args,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def _kl_mvn_mvn_brute_force(mvn_a, mvn_b, name=None):
|
||||||
|
"""Batched KL divergence `KL(mvn_a || mvn_b)` for multivariate normals.
|
||||||
|
|
||||||
|
With `X`, `Y` both multivariate normals in `R^k` with means `mu_x`, `mu_y` and
|
||||||
|
covariance `C_x`, `C_y` respectively,
|
||||||
|
|
||||||
|
```
|
||||||
|
KL(X || Y) = 0.5 * ( T + Q + - k + L ),
|
||||||
|
T := trace(C_b^{-1} C_a),
|
||||||
|
Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
|
||||||
|
L := Log[Det(C_b)] - Log[Det(C_a)]
|
||||||
|
```
|
||||||
|
|
||||||
|
This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
|
||||||
|
methods for solving systems with `C_b` may be available, a dense version of
|
||||||
|
(the square root of) `C_a` is used, so performance is `O(B s k^2)` where `B`
|
||||||
|
is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
|
||||||
|
and `y`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mvn_a: Instance of subclass of `MultivariateNormalOperatorPD`.
|
||||||
|
mvn_b: Instance of subclass of `MultivariateNormalOperatorPD`.
|
||||||
|
name: (optional) name to use for created ops. Default "kl_mvn_mvn".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Batchwise `KL(mvn_a || mvn_b)`.
|
||||||
|
"""
|
||||||
|
# Access the "private" OperatorPD that each mvn is built from.
|
||||||
|
cov_a = mvn_a._cov # pylint: disable=protected-access
|
||||||
|
cov_b = mvn_b._cov # pylint: disable=protected-access
|
||||||
|
mu_a = mvn_a.mu
|
||||||
|
mu_b = mvn_b.mu
|
||||||
|
inputs = [mu_a, mu_b] + cov_a.inputs + cov_b.inputs
|
||||||
|
|
||||||
|
with ops.op_scope(inputs, name, "kl_mvn_mvn"):
|
||||||
|
# If Ca = AA', Cb = BB', then
|
||||||
|
# tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
|
||||||
|
# = tr[inv(B) A A' inv(B)']
|
||||||
|
# = tr[(inv(B) A) (inv(B) A)']
|
||||||
|
# = sum_{ik} (inv(B) A)_{ik}^2
|
||||||
|
# The second equality follows from the cyclic permutation property.
|
||||||
|
b_inv_a = cov_b.sqrt_solve(cov_a.sqrt_to_dense())
|
||||||
|
t = math_ops.reduce_sum(
|
||||||
|
math_ops.square(b_inv_a),
|
||||||
|
reduction_indices=[-1, -2])
|
||||||
|
q = cov_b.inv_quadratic_form_on_vectors(mu_b - mu_a)
|
||||||
|
k = math_ops.cast(cov_a.vector_space_dimension(), mvn_a.dtype)
|
||||||
|
one_half_l = cov_b.sqrt_log_det() - cov_a.sqrt_log_det()
|
||||||
|
return 0.5 * (t + q - k) + one_half_l
|
||||||
|
|
||||||
|
|
||||||
|
# Register KL divergences.
|
||||||
|
kl_classes = [
|
||||||
|
MultivariateNormalFull,
|
||||||
|
MultivariateNormalCholesky,
|
||||||
|
MultivariateNormalDiag,
|
||||||
|
MultivariateNormalDiagPlusVDVT,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
for mvn_aa in kl_classes:
|
||||||
|
# Register when they are the same here, and do not register when they are the
|
||||||
|
# same below because that would result in a repeated registration.
|
||||||
|
kullback_leibler.RegisterKL(mvn_aa, mvn_aa)(_kl_mvn_mvn_brute_force)
|
||||||
|
for mvn_bb in kl_classes:
|
||||||
|
if mvn_bb != mvn_aa:
|
||||||
|
kullback_leibler.RegisterKL(mvn_aa, mvn_bb)(_kl_mvn_mvn_brute_force)
|
||||||
|
@ -92,15 +92,15 @@ class Normal(distribution.Distribution):
|
|||||||
broadcasting (e.g. `mu + sigma` is a valid operation).
|
broadcasting (e.g. `mu + sigma` is a valid operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: `float` or `double` tensor, the means of the distribution(s).
|
mu: Floating point tensor, the means of the distribution(s).
|
||||||
sigma: `float` or `double` tensor, the stddevs of the distribution(s).
|
sigma: Floating point tensor, the stddevs of the distribution(s).
|
||||||
sigma must contain only positive values.
|
sigma must contain only positive values.
|
||||||
validate_args: Whether to assert that `sigma > 0`. If `validate_args` is
|
validate_args: Whether to assert that `sigma > 0`. If `validate_args` is
|
||||||
False, correct output is not guaranteed when input is invalid.
|
`False`, correct output is not guaranteed when input is invalid.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to give Ops created by the initializer.
|
name: The name to give Ops created by the initializer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -321,8 +321,7 @@ class Normal(distribution.Distribution):
|
|||||||
with ops.op_scope([self._mu, self._sigma, n], name):
|
with ops.op_scope([self._mu, self._sigma, n], name):
|
||||||
broadcast_shape = (self._mu + self._sigma).get_shape()
|
broadcast_shape = (self._mu + self._sigma).get_shape()
|
||||||
n = ops.convert_to_tensor(n)
|
n = ops.convert_to_tensor(n)
|
||||||
shape = array_ops.concat(
|
shape = array_ops.concat(0, ([n], array_ops.shape(self.mean())))
|
||||||
0, [array_ops.pack([n]), array_ops.shape(self.mean())])
|
|
||||||
sampled = random_ops.random_normal(
|
sampled = random_ops.random_normal(
|
||||||
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
|
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
|
||||||
|
|
||||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import six
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import operator_pd
|
from tensorflow.contrib.distributions.python.ops import operator_pd
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -26,11 +29,190 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
class OperatorPDSqrtDiag(operator_pd.OperatorPDBase):
|
@six.add_metaclass(abc.ABCMeta)
|
||||||
|
class OperatorPDDiagBase(operator_pd.OperatorPDBase):
|
||||||
|
"""Base class for diagonal operators."""
|
||||||
|
|
||||||
|
def __init__(self, diag, verify_pd=True, name='OperatorPDDiagBase'):
|
||||||
|
self._verify_pd = verify_pd
|
||||||
|
self._name = name
|
||||||
|
with ops.name_scope(name):
|
||||||
|
with ops.op_scope([diag], 'init'):
|
||||||
|
self._diag = self._check_diag(diag)
|
||||||
|
|
||||||
|
def _check_diag(self, diag):
|
||||||
|
"""Verify that `diag` is positive."""
|
||||||
|
diag = ops.convert_to_tensor(diag, name='diag')
|
||||||
|
if not self.verify_pd:
|
||||||
|
return diag
|
||||||
|
deps = [check_ops.assert_positive(diag)]
|
||||||
|
return control_flow_ops.with_dependencies(deps, diag)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
"""String name identifying this `Operator`."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def verify_pd(self):
|
||||||
|
"""Whether to verify that this `Operator` is positive definite."""
|
||||||
|
return self._verify_pd
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""Data type of matrix elements of `A`."""
|
||||||
|
return self._diag.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self):
|
||||||
|
"""Initialization arguments."""
|
||||||
|
return [self._diag]
|
||||||
|
|
||||||
|
def get_shape(self):
|
||||||
|
"""`TensorShape` giving static shape."""
|
||||||
|
# If d_shape = [5, 3], we return [5, 3, 3].
|
||||||
|
d_shape = self._diag.get_shape()
|
||||||
|
return d_shape.concatenate(d_shape[-1:])
|
||||||
|
|
||||||
|
def _shape(self):
|
||||||
|
d_shape = array_ops.shape(self._diag)
|
||||||
|
k = array_ops.gather(d_shape, array_ops.size(d_shape) - 1)
|
||||||
|
return array_ops.concat(0, (d_shape, [k]))
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _batch_log_det(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _inv_quadratic_form_on_vectors(self, x):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _batch_matmul(self, x, transpose_x=False):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _batch_sqrt_matmul(self, x, transpose_x=False):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _batch_solve(self, rhs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _batch_sqrt_solve(self, rhs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _to_dense(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _sqrt_to_dense(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _add_to_tensor(self, mat):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorPDDiag(OperatorPDDiagBase):
|
||||||
"""Class representing a (batch) of positive definite matrices `A`.
|
"""Class representing a (batch) of positive definite matrices `A`.
|
||||||
|
|
||||||
This class provides access to functions of a batch of symmetric positive
|
This class provides access to functions of a batch of symmetric positive
|
||||||
definite (PD) matrices `A` in `R^{k x k}` defined by their their square root,
|
definite (PD) matrices `A` in `R^{k x k}`.
|
||||||
|
|
||||||
|
In this case, `A` is diagonal and is defined by a provided tensor `diag`,
|
||||||
|
`A_{ii} = diag[i]`.
|
||||||
|
|
||||||
|
Determinants, solves, and storage are `O(k)`.
|
||||||
|
|
||||||
|
In practice, this operator represents a (batch) matrix `A` with shape
|
||||||
|
`[N1,...,Nn, k, k]` for some `n >= 0`. The first `n` indices designate a
|
||||||
|
batch member. For every batch member `(i1,...,ib)`, `A[i1,...,ib, : :]` is
|
||||||
|
a `k x k` matrix.
|
||||||
|
|
||||||
|
For example,
|
||||||
|
|
||||||
|
```python
|
||||||
|
distributions = tf.contrib.distributions
|
||||||
|
diag = [1.0, 2.0]
|
||||||
|
operator = OperatorPDDiag(diag)
|
||||||
|
operator.det() # ==> (1 * 2)
|
||||||
|
|
||||||
|
# Compute the quadratic form x^T A^{-1} x for vector x.
|
||||||
|
x = [1.0, 2.0]
|
||||||
|
operator.inv_quadratic_form_on_vectors(x)
|
||||||
|
|
||||||
|
# Matrix multiplication by the square root, S w, with A = S S^T.
|
||||||
|
# Recall A is diagonal, and so then is S, with S_{ij} = sqrt(A_{ij}).
|
||||||
|
# If w is iid normal, S w has covariance A.
|
||||||
|
w = [[1.0],
|
||||||
|
[2.0]]
|
||||||
|
operator.sqrt_matmul(w)
|
||||||
|
```
|
||||||
|
|
||||||
|
The above three methods, `log_det`, `inv_quadratic_form_on_vectors`, and
|
||||||
|
`sqrt_matmul` provide "all" that is necessary to use a covariance matrix
|
||||||
|
in a multi-variate normal distribution. See the class
|
||||||
|
`MultivariateNormalDiag`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, diag, verify_pd=True, name='OperatorPDDiag'):
|
||||||
|
"""Initialize an OperatorPDDiag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
diag: Shape `[N1,...,Nn, k]` positive tensor with `n >= 0`, `k >= 1`.
|
||||||
|
verify_pd: Whether to check `diag` is positive.
|
||||||
|
name: A name to prepend to all ops created by this class.
|
||||||
|
"""
|
||||||
|
super(OperatorPDDiag, self).__init__(
|
||||||
|
diag, verify_pd=verify_pd, name=name)
|
||||||
|
|
||||||
|
def _batch_log_det(self):
|
||||||
|
return math_ops.reduce_sum(
|
||||||
|
math_ops.log(self._diag), reduction_indices=[-1])
|
||||||
|
|
||||||
|
def _inv_quadratic_form_on_vectors(self, x):
|
||||||
|
return self._iqfov_via_solve(x)
|
||||||
|
|
||||||
|
def _batch_matmul(self, x, transpose_x=False):
|
||||||
|
if transpose_x:
|
||||||
|
x = array_ops.batch_matrix_transpose(x)
|
||||||
|
diag_mat = array_ops.expand_dims(self._diag, -1)
|
||||||
|
return diag_mat * x
|
||||||
|
|
||||||
|
def _batch_sqrt_matmul(self, x, transpose_x=False):
|
||||||
|
if transpose_x:
|
||||||
|
x = array_ops.batch_matrix_transpose(x)
|
||||||
|
diag_mat = array_ops.expand_dims(self._diag, -1)
|
||||||
|
return math_ops.sqrt(diag_mat) * x
|
||||||
|
|
||||||
|
def _batch_solve(self, rhs):
|
||||||
|
diag_mat = array_ops.expand_dims(self._diag, -1)
|
||||||
|
return rhs / diag_mat
|
||||||
|
|
||||||
|
def _batch_sqrt_solve(self, rhs):
|
||||||
|
diag_mat = array_ops.expand_dims(self._diag, -1)
|
||||||
|
return rhs / math_ops.sqrt(diag_mat)
|
||||||
|
|
||||||
|
def _to_dense(self):
|
||||||
|
return array_ops.batch_matrix_diag(self._diag)
|
||||||
|
|
||||||
|
def _sqrt_to_dense(self):
|
||||||
|
return array_ops.batch_matrix_diag(math_ops.sqrt(self._diag))
|
||||||
|
|
||||||
|
def _add_to_tensor(self, mat):
|
||||||
|
mat_diag = array_ops.batch_matrix_diag_part(mat)
|
||||||
|
new_diag = self._diag + mat_diag
|
||||||
|
return array_ops.batch_matrix_set_diag(mat, new_diag)
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorPDSqrtDiag(OperatorPDDiagBase):
|
||||||
|
"""Class representing a (batch) of positive definite matrices `A`.
|
||||||
|
|
||||||
|
This class provides access to functions of a batch of symmetric positive
|
||||||
|
definite (PD) matrices `A` in `R^{k x k}` defined by their square root,
|
||||||
`S`, such that `A = SS^T`.
|
`S`, such that `A = SS^T`.
|
||||||
|
|
||||||
In this case, `S` is diagonal and is defined by a provided tensor `diag`,
|
In this case, `S` is diagonal and is defined by a provided tensor `diag`,
|
||||||
@ -75,58 +257,17 @@ class OperatorPDSqrtDiag(operator_pd.OperatorPDBase):
|
|||||||
verify_pd: Whether to check `diag` is positive.
|
verify_pd: Whether to check `diag` is positive.
|
||||||
name: A name to prepend to all ops created by this class.
|
name: A name to prepend to all ops created by this class.
|
||||||
"""
|
"""
|
||||||
self._verify_pd = verify_pd
|
super(OperatorPDSqrtDiag, self).__init__(
|
||||||
self._name = name
|
diag, verify_pd=verify_pd, name=name)
|
||||||
with ops.name_scope(name):
|
|
||||||
with ops.op_scope([diag], 'init'):
|
|
||||||
self._diag = self._check_diag(diag)
|
|
||||||
|
|
||||||
def _check_diag(self, diag):
|
|
||||||
"""Verify that `diag` is positive."""
|
|
||||||
diag = ops.convert_to_tensor(diag, name='diag')
|
|
||||||
if not self.verify_pd:
|
|
||||||
return diag
|
|
||||||
deps = [check_ops.assert_positive(diag)]
|
|
||||||
return control_flow_ops.with_dependencies(deps, diag)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
"""String name identifying this `Operator`."""
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def verify_pd(self):
|
|
||||||
"""Whether to verify that this `Operator` is positive definite."""
|
|
||||||
return self._verify_pd
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
"""Data type of matrix elements of `A`."""
|
|
||||||
return self._diag.dtype
|
|
||||||
|
|
||||||
def _batch_log_det(self):
|
def _batch_log_det(self):
|
||||||
return 2 * math_ops.reduce_sum(
|
return 2 * math_ops.reduce_sum(
|
||||||
math_ops.log(self._diag), reduction_indices=[-1])
|
math_ops.log(self._diag), reduction_indices=[-1])
|
||||||
|
|
||||||
@property
|
|
||||||
def inputs(self):
|
|
||||||
"""List of tensors that were provided as initialization inputs."""
|
|
||||||
return [self._diag]
|
|
||||||
|
|
||||||
def _inv_quadratic_form_on_vectors(self, x):
|
def _inv_quadratic_form_on_vectors(self, x):
|
||||||
# This Operator is defined in terms of diagonal entries of the sqrt.
|
# This Operator is defined in terms of diagonal entries of the sqrt.
|
||||||
return self._iqfov_via_sqrt_solve(x)
|
return self._iqfov_via_sqrt_solve(x)
|
||||||
|
|
||||||
def get_shape(self):
|
|
||||||
"""`TensorShape` giving static shape."""
|
|
||||||
d_shape = self._diag.get_shape()
|
|
||||||
return d_shape.concatenate(d_shape[-1:])
|
|
||||||
|
|
||||||
def _shape(self):
|
|
||||||
d_shape = array_ops.shape(self._diag)
|
|
||||||
k = array_ops.gather(d_shape, array_ops.size(d_shape) - 1)
|
|
||||||
return array_ops.concat(0, (d_shape, [k]))
|
|
||||||
|
|
||||||
def _batch_matmul(self, x, transpose_x=False):
|
def _batch_matmul(self, x, transpose_x=False):
|
||||||
if transpose_x:
|
if transpose_x:
|
||||||
x = array_ops.batch_matrix_transpose(x)
|
x = array_ops.batch_matrix_transpose(x)
|
||||||
|
@ -0,0 +1,207 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Identity operator in `R^k`."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops import operator_pd
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorPDIdentity(operator_pd.OperatorPDBase):
|
||||||
|
"""Identity operator in `R^k`: `Ax = x`.
|
||||||
|
|
||||||
|
This provides an efficient implementation of the identity as an `OperatorPD`.
|
||||||
|
Storage, solves, and matmul are all `O(1)`, independent of batch size.
|
||||||
|
|
||||||
|
In order to be a drop-in replacement for other operators, shape and dtype
|
||||||
|
of arguments (e.g. to `matmul`) are checked statically as though this operator
|
||||||
|
was an instantiated matrix.
|
||||||
|
|
||||||
|
Dynamic shape checks of arguments are not done since that could impede
|
||||||
|
performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shape, dtype, verify_pd=True, name='OperatorPDIdentity'):
|
||||||
|
"""Initialize an `OperatorPDIdentity`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: `int32` rank 1 `Tensor` of length at least 2, and with the last
|
||||||
|
two entries equal (since this is a square matrix).
|
||||||
|
dtype: Data type of the matrix that this operator represents.
|
||||||
|
verify_pd: `Boolean`, if `True`, asserts are added to the initialization
|
||||||
|
args to ensure they define this operator as a square (batch) matrix.
|
||||||
|
name: Name to prepend to `Ops`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Grab static shape if available now.
|
||||||
|
with ops.name_scope(name):
|
||||||
|
with ops.op_scope([shape], 'init'):
|
||||||
|
self._dtype = dtypes.as_dtype(dtype)
|
||||||
|
self._verify_pd = verify_pd
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
# Store the static shape (if possible) right now before adding the
|
||||||
|
# asserts, since the asserts prevent .constant_value from working.
|
||||||
|
shape = ops.convert_to_tensor(shape, name='shape')
|
||||||
|
self._get_shape = tensor_shape.TensorShape(
|
||||||
|
tensor_util.constant_value(shape))
|
||||||
|
self._shape_arg = self._check_shape(shape)
|
||||||
|
|
||||||
|
def _check_shape(self, shape):
|
||||||
|
"""Check that the init arg `shape` defines a valid operator."""
|
||||||
|
shape = ops.convert_to_tensor(shape, name='shape')
|
||||||
|
if not self._verify_pd:
|
||||||
|
return shape
|
||||||
|
|
||||||
|
# Further checks are equivalent to verification that this is positive
|
||||||
|
# definite. Why? Because the further checks simply check that this is a
|
||||||
|
# square matrix, and combining the fact that this is square (and thus maps
|
||||||
|
# a vector space R^k onto itself), with the behavior of .matmul(), this must
|
||||||
|
# be the identity operator.
|
||||||
|
rank = array_ops.size(shape)
|
||||||
|
assert_matrix = check_ops.assert_less_equal(2, rank)
|
||||||
|
with ops.control_dependencies([assert_matrix]):
|
||||||
|
last_dim = array_ops.gather(shape, rank - 1)
|
||||||
|
second_to_last_dim = array_ops.gather(shape, rank - 2)
|
||||||
|
assert_square = check_ops.assert_equal(last_dim, second_to_last_dim)
|
||||||
|
return control_flow_ops.with_dependencies([assert_matrix, assert_square],
|
||||||
|
shape)
|
||||||
|
|
||||||
|
def _check_x(self, x):
|
||||||
|
"""Static check that the argument `x` is proper `shape`, `dtype`."""
|
||||||
|
# x is a typical argument e.g. to matmul or solve. In both cases, x should
|
||||||
|
# have the same type/shape since this is a square matrix. These checks are
|
||||||
|
# ususally not needed since we ususally have some tensor backing this
|
||||||
|
# distribution, and the calls to tf.matmul do a shape/type check.
|
||||||
|
#
|
||||||
|
# Static checks only for efficiency, the identity should be fast.
|
||||||
|
#
|
||||||
|
# Why check at all? Because we want this operator to be swappable for a
|
||||||
|
# real Operator.
|
||||||
|
if self.dtype != x.dtype:
|
||||||
|
raise TypeError(
|
||||||
|
'Expected argument "x" to have same dtype as this operator (%s). '
|
||||||
|
'Found: %s' % (self.dtype, x.dtype))
|
||||||
|
|
||||||
|
x_shape = x.get_shape()
|
||||||
|
self_shape = self.get_shape()
|
||||||
|
found_msg = (
|
||||||
|
'Found: operator.shape = %s, x.shape = %s' % (self_shape, x_shape))
|
||||||
|
if x_shape.ndims is not None and self_shape.ndims is not None:
|
||||||
|
if x_shape.ndims != self_shape.ndims:
|
||||||
|
raise ValueError(
|
||||||
|
'Expected argument "x" to have same tensor rank as this operator. '
|
||||||
|
+ found_msg)
|
||||||
|
if x_shape.is_fully_defined() and self_shape.is_fully_defined():
|
||||||
|
if x_shape[-2] != self_shape[-1]:
|
||||||
|
raise ValueError(
|
||||||
|
'Incompatible shapes for matrix-matrix operation. ' + found_msg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
"""String name identifying this `Operator`."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def verify_pd(self):
|
||||||
|
"""Whether to verify that this `Operator` is positive definite."""
|
||||||
|
return self._verify_pd
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""Data type of matrix elements of `A`."""
|
||||||
|
return self._dtype
|
||||||
|
|
||||||
|
def _add_to_tensor(self, mat):
|
||||||
|
# Add to a tensor in O(k) time!
|
||||||
|
mat_diag = array_ops.batch_matrix_diag_part(mat)
|
||||||
|
new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag
|
||||||
|
return array_ops.batch_matrix_set_diag(mat, new_diag)
|
||||||
|
|
||||||
|
def _inv_quadratic_form_on_vectors(self, x):
|
||||||
|
self._check_x(x)
|
||||||
|
return self._iqfov_via_sqrt_solve(x)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self):
|
||||||
|
"""List of tensors that were provided as initialization inputs."""
|
||||||
|
return [self._shape]
|
||||||
|
|
||||||
|
def get_shape(self):
|
||||||
|
"""Static `TensorShape` of entire operator.
|
||||||
|
|
||||||
|
If this operator represents the batch matrix `A` with
|
||||||
|
`A.shape = [N1,...,Nn, k, k]`, then this returns
|
||||||
|
`TensorShape([N1,...,Nn, k, k])`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`TensorShape`, statically determined, may be undefined.
|
||||||
|
"""
|
||||||
|
return self._get_shape
|
||||||
|
|
||||||
|
def _shape(self):
|
||||||
|
return self._shape_arg
|
||||||
|
|
||||||
|
def _det(self):
|
||||||
|
det = array_ops.ones(self.batch_shape(), dtype=self.dtype)
|
||||||
|
det.set_shape(self.get_batch_shape())
|
||||||
|
return det
|
||||||
|
|
||||||
|
def _batch_log_det(self):
|
||||||
|
log_det = array_ops.zeros(self.batch_shape(), dtype=self.dtype)
|
||||||
|
log_det.set_shape(self.get_batch_shape())
|
||||||
|
return log_det
|
||||||
|
|
||||||
|
def _batch_sqrt_log_det(self):
|
||||||
|
s_log_det = array_ops.zeros(self.batch_shape(), dtype=self.dtype)
|
||||||
|
s_log_det.set_shape(self.get_batch_shape())
|
||||||
|
return s_log_det
|
||||||
|
|
||||||
|
def _batch_matmul(self, x, transpose_x=False):
|
||||||
|
if transpose_x:
|
||||||
|
x = array_ops.batch_matrix_transpose(x)
|
||||||
|
self._check_x(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _batch_sqrt_matmul(self, x, transpose_x=False):
|
||||||
|
return self._batch_matmul(x, transpose_x=transpose_x)
|
||||||
|
|
||||||
|
def _batch_solve(self, rhs):
|
||||||
|
self._check_x(rhs)
|
||||||
|
return rhs
|
||||||
|
|
||||||
|
def _batch_sqrt_solve(self, rhs):
|
||||||
|
self._check_x(rhs)
|
||||||
|
return rhs
|
||||||
|
|
||||||
|
def _to_dense(self):
|
||||||
|
diag = array_ops.ones(self.vector_shape(), dtype=self.dtype)
|
||||||
|
dense = array_ops.batch_matrix_diag(diag)
|
||||||
|
dense.set_shape(self.get_shape())
|
||||||
|
return dense
|
||||||
|
|
||||||
|
def _sqrt_to_dense(self):
|
||||||
|
return self.to_dense()
|
@ -0,0 +1,475 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Operator defined: `A = SS^T` where `S = M + VDV^T`, for `OperatorPD` `M`."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops import operator_pd
|
||||||
|
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
|
||||||
|
from tensorflow.contrib.distributions.python.ops import operator_pd_identity
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import linalg_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorPDSqrtVDVTUpdate(operator_pd.OperatorPDBase):
|
||||||
|
r"""Operator defined by `A=SS^T`, where `S = M + VDV^T` for `OperatorPD` `M`.
|
||||||
|
|
||||||
|
This provides efficient low-rank updates of arbitrary `OperatorPD`.
|
||||||
|
|
||||||
|
Some math:
|
||||||
|
|
||||||
|
Given positive definite operator representing positive definite (batch) matrix
|
||||||
|
`M` in `R^{k x k}`, diagonal matrix `D` in `R^{r x r}`, and low rank `V` in
|
||||||
|
`R^{k x r}` this class represents the batch matrix `A`, defined by its square
|
||||||
|
root `S` as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
A = SS^T, where
|
||||||
|
S := M + VDV^T
|
||||||
|
```
|
||||||
|
|
||||||
|
Defining an operator in terms of its square root means that
|
||||||
|
`A_{ij} = S_i S_j^T`, where `S_i` is the ith row of `S`. The update
|
||||||
|
`VDV^T` has `ij` coordinate equal to `sum_k V_{ik} D_{kk} V_{jk}`.
|
||||||
|
|
||||||
|
Computational efficiency:
|
||||||
|
|
||||||
|
Defining `A` via its square root eliminates the need to compute the square
|
||||||
|
root.
|
||||||
|
|
||||||
|
Performance depends on the operator representing `M`, the batch size `B`, and
|
||||||
|
the width of the matrix being multiplied, or systems being solved `L`.
|
||||||
|
|
||||||
|
Since `V` is rank `r`, the update adds
|
||||||
|
|
||||||
|
* `O(B L k r)` to matmul, which requires a call to `M.matmul`.
|
||||||
|
* `O(B L r^3)` to solves, which require a call to `M.solve` as well as the
|
||||||
|
solution to a batch of rank `r` systems.
|
||||||
|
* `O(B r^3)` to determinants, which require a call to `M.solve` as well as the
|
||||||
|
solution to a batch of rank `r` systems.
|
||||||
|
|
||||||
|
The rank `r` solve and determinant are both done through a Cholesky
|
||||||
|
factorization, thus some computation is shared.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://en.wikipedia.org/wiki/Woodbury_matrix_identity
|
||||||
|
https://en.wikipedia.org/wiki/Matrix_determinant_lemma
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Note that diag must be nonsingular to use Woodbury lemma, and must be
|
||||||
|
# positive def to use a Cholesky factorization, so we enforce that here.
|
||||||
|
def __init__(self,
|
||||||
|
operator,
|
||||||
|
v,
|
||||||
|
diag=None,
|
||||||
|
verify_pd=True,
|
||||||
|
verify_shapes=True,
|
||||||
|
name='OperatorPDSqrtVDVTUpdate'):
|
||||||
|
"""Initialize an `OperatorPDSqrtVDVTUpdate`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operator: Subclass of `OperatorPDBase`. Represents the (batch) positive
|
||||||
|
definite matrix `M` in `R^{k x k}`.
|
||||||
|
v: `Tensor` defining batch matrix of same `dtype` and `batch_shape` as
|
||||||
|
`operator`, and last two dimensions of shape `(k, r)`.
|
||||||
|
diag: Optional `Tensor` defining batch vector of same `dtype` and
|
||||||
|
`batch_shape` as `operator`, and last dimension of size `r`. If `None`,
|
||||||
|
the update becomes `VV^T` rather than `VDV^T`.
|
||||||
|
verify_pd: `Boolean`. If `True`, add asserts that `diag > 0`, which,
|
||||||
|
along with the positive definiteness of `operator`, is sufficient to
|
||||||
|
make the resulting operator positive definite.
|
||||||
|
verify_shapes: `Boolean`. If `True`, check that `operator`, `v`, and
|
||||||
|
`diag` have compatible shapes.
|
||||||
|
name: A name to prepend to `Op` names.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(operator, operator_pd.OperatorPDBase):
|
||||||
|
raise TypeError('operator was not instance of OperatorPDBase.')
|
||||||
|
|
||||||
|
with ops.name_scope(name):
|
||||||
|
with ops.op_scope(operator.inputs + [v, diag], 'init'):
|
||||||
|
self._operator = operator
|
||||||
|
self._v = ops.convert_to_tensor(v, name='v')
|
||||||
|
self._verify_pd = verify_pd
|
||||||
|
self._verify_shapes = verify_shapes
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
# This operator will be PD so long as the diag is PSD, but Woodbury
|
||||||
|
# and determinant lemmas require diag to be PD. So require diag PD
|
||||||
|
# whenever we ask to "verify_pd".
|
||||||
|
if diag is not None:
|
||||||
|
self._diag = ops.convert_to_tensor(diag, name='diag')
|
||||||
|
self._diag_operator = operator_pd_diag.OperatorPDDiag(
|
||||||
|
diag, verify_pd=self.verify_pd)
|
||||||
|
# No need to verify that the inverse of a PD is PD.
|
||||||
|
self._diag_inv_operator = operator_pd_diag.OperatorPDDiag(
|
||||||
|
1 / self._diag, verify_pd=False)
|
||||||
|
else:
|
||||||
|
self._diag = None
|
||||||
|
self._diag_operator = self._get_identity_operator(self._v)
|
||||||
|
self._diag_inv_operator = self._diag_operator
|
||||||
|
|
||||||
|
self._check_types(operator, self._v, self._diag)
|
||||||
|
# Always check static.
|
||||||
|
checked = self._check_shapes_static(operator, self._v, self._diag)
|
||||||
|
if not checked and self._verify_shapes:
|
||||||
|
self._v, self._diag = self._check_shapes_dynamic(
|
||||||
|
operator, self._v, self._diag)
|
||||||
|
|
||||||
|
def _get_identity_operator(self, v):
|
||||||
|
"""Get an `OperatorPDIdentity` to play the role of `D` in `VDV^T`."""
|
||||||
|
with ops.op_scope([v], 'get_identity_operator'):
|
||||||
|
if v.get_shape().is_fully_defined():
|
||||||
|
v_shape = v.get_shape().as_list()
|
||||||
|
v_batch_shape = v_shape[:-2]
|
||||||
|
r = v_shape[-1]
|
||||||
|
id_shape = v_batch_shape + [r, r]
|
||||||
|
else:
|
||||||
|
v_shape = array_ops.shape(v)
|
||||||
|
v_rank = array_ops.rank(v)
|
||||||
|
v_batch_shape = array_ops.slice(v_shape, [0], [v_rank - 2])
|
||||||
|
r = array_ops.gather(v_shape, v_rank - 1) # Last dim of v
|
||||||
|
id_shape = array_ops.concat(0, (v_batch_shape, [r, r]))
|
||||||
|
return operator_pd_identity.OperatorPDIdentity(
|
||||||
|
id_shape, v.dtype, verify_pd=self._verify_pd)
|
||||||
|
|
||||||
|
def _check_types(self, operator, v, diag):
|
||||||
|
def msg():
|
||||||
|
string = (
|
||||||
|
'dtypes must match: Found operator.dtype = %s, v.dtype = %s'
|
||||||
|
% (operator.dtype, v.dtype))
|
||||||
|
return string
|
||||||
|
|
||||||
|
if operator.dtype != v.dtype:
|
||||||
|
raise TypeError(msg())
|
||||||
|
if diag is not None:
|
||||||
|
if diag.dtype != v.dtype:
|
||||||
|
raise TypeError('%s, diag.dtype = %s' % (msg(), diag.dtype))
|
||||||
|
|
||||||
|
def _check_shapes_static(self, operator, v, diag):
|
||||||
|
"""True if they are compatible. Raise if not. False if could not check."""
|
||||||
|
def msg():
|
||||||
|
# Error message when shapes don't match.
|
||||||
|
string = ' Found: operator.shape = %s, v.shape = %s' % (s_op, s_v)
|
||||||
|
if diag is not None:
|
||||||
|
string += ', diag.shape = ' % s_d
|
||||||
|
return string
|
||||||
|
|
||||||
|
s_op = operator.get_shape()
|
||||||
|
s_v = v.get_shape()
|
||||||
|
|
||||||
|
# If everything is not fully defined, return False because we couldn't check
|
||||||
|
if not (s_op.is_fully_defined() and s_v.is_fully_defined()):
|
||||||
|
return False
|
||||||
|
if diag is not None:
|
||||||
|
s_d = diag.get_shape()
|
||||||
|
if not s_d.is_fully_defined():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Now perform the checks, raising ValueError if they fail.
|
||||||
|
|
||||||
|
# Check tensor rank.
|
||||||
|
if s_v.ndims != s_op.ndims:
|
||||||
|
raise ValueError('v should have same rank as operator' + msg())
|
||||||
|
if diag is not None:
|
||||||
|
if s_d.ndims != s_op.ndims - 1:
|
||||||
|
raise ValueError('diag should have rank 1 less than operator' + msg())
|
||||||
|
|
||||||
|
# Check batch shape
|
||||||
|
if s_v[:-2] != s_op[:-2]:
|
||||||
|
raise ValueError('v and operator should have same batch shape' + msg())
|
||||||
|
if diag is not None:
|
||||||
|
if s_d[:-1] != s_op[:-2]:
|
||||||
|
raise ValueError(
|
||||||
|
'diag and operator should have same batch shape' + msg())
|
||||||
|
|
||||||
|
# Check event shape
|
||||||
|
if s_v[-2] != s_op[-1]:
|
||||||
|
raise ValueError(
|
||||||
|
'v and operator should be compatible for matmul' + msg())
|
||||||
|
if diag is not None:
|
||||||
|
if s_d[-1] != s_v[-1]:
|
||||||
|
raise ValueError('diag and v should have same last dimension' + msg())
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _check_shapes_dynamic(self, operator, v, diag):
|
||||||
|
"""Return (v, diag) with Assert dependencies, which check shape."""
|
||||||
|
checks = []
|
||||||
|
with ops.op_scope([operator, v, diag], 'check_shapes'):
|
||||||
|
s_v = array_ops.shape(v)
|
||||||
|
r_op = operator.rank()
|
||||||
|
r_v = array_ops.rank(v)
|
||||||
|
if diag is not None:
|
||||||
|
s_d = array_ops.shape(diag)
|
||||||
|
r_d = array_ops.rank(diag)
|
||||||
|
|
||||||
|
# Check tensor rank.
|
||||||
|
checks.append(check_ops.assert_rank(v, r_op))
|
||||||
|
if diag is not None:
|
||||||
|
checks.append(check_ops.assert_rank(diag, r_op - 1))
|
||||||
|
|
||||||
|
# Check batch shape
|
||||||
|
checks.append(check_ops.assert_equal(
|
||||||
|
operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2])))
|
||||||
|
if diag is not None:
|
||||||
|
checks.append(check_ops.assert_equal(
|
||||||
|
operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1])))
|
||||||
|
|
||||||
|
# Check event shape
|
||||||
|
checks.append(check_ops.assert_equal(
|
||||||
|
operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2)))
|
||||||
|
if diag is not None:
|
||||||
|
checks.append(check_ops.assert_equal(
|
||||||
|
array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1)))
|
||||||
|
|
||||||
|
v = control_flow_ops.with_dependencies(checks, v)
|
||||||
|
if diag is not None:
|
||||||
|
diag = control_flow_ops.with_dependencies(checks, diag)
|
||||||
|
return v, diag
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
"""String name identifying this `Operator`."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def verify_pd(self):
|
||||||
|
"""Whether to verify that this `Operator` is positive definite."""
|
||||||
|
return self._verify_pd
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""Data type of matrix elements of `A`."""
|
||||||
|
return self._v.dtype
|
||||||
|
|
||||||
|
def _inv_quadratic_form_on_vectors(self, x):
|
||||||
|
return self._iqfov_via_sqrt_solve(x)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self):
|
||||||
|
"""List of tensors that were provided as initialization inputs."""
|
||||||
|
return self._operator.inputs + self._diag_operator.inputs + [self._v]
|
||||||
|
|
||||||
|
def get_shape(self):
|
||||||
|
"""Static `TensorShape` of entire operator.
|
||||||
|
|
||||||
|
If this operator represents the batch matrix `A` with
|
||||||
|
`A.shape = [N1,...,Nn, k, k]`, then this returns
|
||||||
|
`TensorShape([N1,...,Nn, k, k])`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`TensorShape`, statically determined, may be undefined.
|
||||||
|
"""
|
||||||
|
return self._operator.get_shape()
|
||||||
|
|
||||||
|
def _shape(self):
|
||||||
|
return self._operator.shape()
|
||||||
|
|
||||||
|
def _det(self):
|
||||||
|
return math_ops.exp(self.log_det())
|
||||||
|
|
||||||
|
def _batch_log_det(self):
|
||||||
|
return 2 * self._batch_sqrt_log_det()
|
||||||
|
|
||||||
|
def _log_det(self):
|
||||||
|
return 2 * self._sqrt_log_det()
|
||||||
|
|
||||||
|
def _sqrt_log_det(self):
|
||||||
|
# The matrix determinant lemma states:
|
||||||
|
# det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M)
|
||||||
|
# = det(C) * det(D) * det(M)
|
||||||
|
#
|
||||||
|
# Here we compute the Cholesky factor of "C", then pass the result on.
|
||||||
|
diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance(
|
||||||
|
batch_mode=False))
|
||||||
|
return self._sqrt_log_det_core(diag_chol_c)
|
||||||
|
|
||||||
|
def _batch_sqrt_log_det(self):
|
||||||
|
# Here we compute the Cholesky factor of "C", then pass the result on.
|
||||||
|
diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance(
|
||||||
|
batch_mode=True))
|
||||||
|
return self._sqrt_log_det_core(diag_chol_c)
|
||||||
|
|
||||||
|
def _chol_capacitance(self, batch_mode):
|
||||||
|
"""Cholesky factorization of the capacitance term."""
|
||||||
|
# Cholesky factor for (D^{-1} + V^T M^{-1} V), which is sometimes
|
||||||
|
# known as the "capacitance" matrix.
|
||||||
|
|
||||||
|
# self._operator will use batch if need be. Automatically. We cannot force
|
||||||
|
# that here.
|
||||||
|
# M^{-1} V
|
||||||
|
minv_v = self._operator.solve(self._v)
|
||||||
|
# V^T M^{-1} V
|
||||||
|
if batch_mode:
|
||||||
|
vt_minv_v = math_ops.batch_matmul(self._v, minv_v, adj_x=True)
|
||||||
|
else:
|
||||||
|
vt_minv_v = math_ops.matmul(self._v, minv_v, transpose_a=True)
|
||||||
|
|
||||||
|
# D^{-1} + V^T M^{-1} V
|
||||||
|
capacitance = self._diag_inv_operator.add_to_tensor(vt_minv_v)
|
||||||
|
# Cholesky[D^{-1} + V^T M^{-1} V]
|
||||||
|
if batch_mode:
|
||||||
|
return linalg_ops.batch_cholesky(capacitance)
|
||||||
|
else:
|
||||||
|
return linalg_ops.cholesky(capacitance)
|
||||||
|
|
||||||
|
def _sqrt_log_det_core(self, diag_chol_c):
|
||||||
|
"""Finish computation of Sqrt[Log[Det]]."""
|
||||||
|
# Complete computation of ._log_det and ._batch_log_det, after the initial
|
||||||
|
# Cholesky factor has been taken with the appropriate batch/non-batch method
|
||||||
|
|
||||||
|
# det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M)
|
||||||
|
# = det(C) * det(D) * det(M)
|
||||||
|
# Multiply by 2 here because this is the log-det of the Cholesky factor of C
|
||||||
|
log_det_c = 2 * math_ops.reduce_sum(
|
||||||
|
math_ops.log(diag_chol_c),
|
||||||
|
reduction_indices=[-1])
|
||||||
|
# Add together to get Log[det(M + VDV^T)], the Log-det of the updated square
|
||||||
|
# root.
|
||||||
|
log_det_updated_sqrt = (
|
||||||
|
log_det_c + self._diag_operator.log_det() + self._operator.log_det())
|
||||||
|
return log_det_updated_sqrt
|
||||||
|
|
||||||
|
def _batch_matmul(self, x, transpose_x=False):
|
||||||
|
# Since the square root is PD, it is symmetric, and so A = SS^T = SS.
|
||||||
|
s_x = self._batch_sqrt_matmul(x, transpose_x=transpose_x)
|
||||||
|
return self._batch_sqrt_matmul(s_x)
|
||||||
|
|
||||||
|
def _matmul(self, x, transpose_x=False):
|
||||||
|
# Since the square root is PD, it is symmetric, and so A = SS^T = SS.
|
||||||
|
s_x = self._sqrt_matmul(x, transpose_x=transpose_x)
|
||||||
|
return self._sqrt_matmul(s_x)
|
||||||
|
|
||||||
|
def _batch_sqrt_matmul(self, x, transpose_x=False):
|
||||||
|
v = self._v
|
||||||
|
m = self._operator
|
||||||
|
d = self._diag_operator
|
||||||
|
# The operators call the appropriate matmul/batch_matmul automatically. We
|
||||||
|
# cannot override.
|
||||||
|
# batch_matmul is defined as: x * y, so adj_x and adj_y are the ways to
|
||||||
|
# transpose the left and right.
|
||||||
|
mx = m.matmul(x, transpose_x=transpose_x)
|
||||||
|
vt_x = math_ops.batch_matmul(v, x, adj_x=True, adj_y=transpose_x)
|
||||||
|
d_vt_x = d.matmul(vt_x)
|
||||||
|
v_d_vt_x = math_ops.batch_matmul(v, d_vt_x)
|
||||||
|
|
||||||
|
return mx + v_d_vt_x
|
||||||
|
|
||||||
|
def _sqrt_matmul(self, x, transpose_x=False):
|
||||||
|
v = self._v
|
||||||
|
m = self._operator
|
||||||
|
d = self._diag_operator
|
||||||
|
# The operators call the appropriate matmul/batch_matmul automatically. We
|
||||||
|
# cannot override.
|
||||||
|
# matmul is defined as: a * b, so transpose_a, transpose_b are used.
|
||||||
|
# transpose the left and right.
|
||||||
|
mx = m.matmul(x, transpose_x=transpose_x)
|
||||||
|
vt_x = math_ops.matmul(v, x, transpose_a=True, transpose_b=transpose_x)
|
||||||
|
d_vt_x = d.matmul(vt_x)
|
||||||
|
v_d_vt_x = math_ops.matmul(v, d_vt_x)
|
||||||
|
|
||||||
|
return mx + v_d_vt_x
|
||||||
|
|
||||||
|
def _solve(self, rhs):
|
||||||
|
# This operator represents A = SS^T, but S is symmetric, so A = SS,
|
||||||
|
# which means A^{-1} = S^{-1}S^{-2}
|
||||||
|
# S^{-1} rhs
|
||||||
|
sqrtinv_rhs = self._sqrt_solve(rhs)
|
||||||
|
return self._sqrt_solve(sqrtinv_rhs)
|
||||||
|
|
||||||
|
def _batch_solve(self, rhs):
|
||||||
|
sqrtinv_rhs = self._batch_sqrt_solve(rhs)
|
||||||
|
return self._batch_sqrt_solve(sqrtinv_rhs)
|
||||||
|
|
||||||
|
def _sqrt_solve(self, rhs):
|
||||||
|
# Recall the square root of this operator is M + VDV^T.
|
||||||
|
# The Woodbury formula gives:
|
||||||
|
# (M + VDV^T)^{-1}
|
||||||
|
# = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1}
|
||||||
|
# = M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
|
||||||
|
# where C is the capacitance matrix.
|
||||||
|
# TODO(jvdillon) Determine if recursively applying rank-1 updates is more
|
||||||
|
# efficient. May not be possible because a general n x n matrix can be
|
||||||
|
# represeneted as n rank-1 updates, and solving with this matrix is always
|
||||||
|
# done in O(n^3) time.
|
||||||
|
m = self._operator
|
||||||
|
v = self._v
|
||||||
|
cchol = self._chol_capacitance(batch_mode=False)
|
||||||
|
|
||||||
|
# The operators will use batch/singleton mode automatically. We don't
|
||||||
|
# override.
|
||||||
|
# M^{-1} rhs
|
||||||
|
minv_rhs = m.solve(rhs)
|
||||||
|
# V^T M^{-1} rhs
|
||||||
|
vt_minv_rhs = math_ops.matmul(v, minv_rhs, transpose_a=True)
|
||||||
|
# C^{-1} V^T M^{-1} rhs
|
||||||
|
cinv_vt_minv_rhs = linalg_ops.cholesky_solve(cchol, vt_minv_rhs)
|
||||||
|
# V C^{-1} V^T M^{-1} rhs
|
||||||
|
v_cinv_vt_minv_rhs = math_ops.matmul(v, cinv_vt_minv_rhs)
|
||||||
|
# M^{-1} V C^{-1} V^T M^{-1} rhs
|
||||||
|
minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs)
|
||||||
|
|
||||||
|
# M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
|
||||||
|
return minv_rhs - minv_v_cinv_vt_minv_rhs
|
||||||
|
|
||||||
|
def _batch_sqrt_solve(self, rhs):
|
||||||
|
# Recall the square root of this operator is M + VDV^T.
|
||||||
|
# The Woodbury formula gives:
|
||||||
|
# (M + VDV^T)^{-1}
|
||||||
|
# = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1}
|
||||||
|
# = M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
|
||||||
|
# where C is the capacitance matrix.
|
||||||
|
m = self._operator
|
||||||
|
v = self._v
|
||||||
|
cchol = self._chol_capacitance(batch_mode=True)
|
||||||
|
|
||||||
|
# The operators will use batch/singleton mode automatically. We don't
|
||||||
|
# override.
|
||||||
|
# M^{-1} rhs
|
||||||
|
minv_rhs = m.solve(rhs)
|
||||||
|
# V^T M^{-1} rhs
|
||||||
|
vt_minv_rhs = math_ops.batch_matmul(v, minv_rhs, adj_x=True)
|
||||||
|
# C^{-1} V^T M^{-1} rhs
|
||||||
|
cinv_vt_minv_rhs = linalg_ops.batch_cholesky_solve(cchol, vt_minv_rhs)
|
||||||
|
# V C^{-1} V^T M^{-1} rhs
|
||||||
|
v_cinv_vt_minv_rhs = math_ops.batch_matmul(v, cinv_vt_minv_rhs)
|
||||||
|
# M^{-1} V C^{-1} V^T M^{-1} rhs
|
||||||
|
minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs)
|
||||||
|
|
||||||
|
# M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
|
||||||
|
return minv_rhs - minv_v_cinv_vt_minv_rhs
|
||||||
|
|
||||||
|
def _to_dense(self):
|
||||||
|
sqrt = self.sqrt_to_dense()
|
||||||
|
return math_ops.batch_matmul(sqrt, sqrt, adj_y=True)
|
||||||
|
|
||||||
|
def _sqrt_to_dense(self):
|
||||||
|
v = self._v
|
||||||
|
d = self._diag_operator
|
||||||
|
m = self._operator
|
||||||
|
|
||||||
|
d_vt = d.matmul(v, transpose_x=True)
|
||||||
|
# Batch op won't be efficient for singletons. Currently we don't break
|
||||||
|
# to_dense into batch/singleton methods.
|
||||||
|
v_d_vt = math_ops.batch_matmul(v, d_vt)
|
||||||
|
m_plus_v_d_vt = m.to_dense() + v_d_vt
|
||||||
|
return m_plus_v_d_vt
|
396
tensorflow/contrib/distributions/python/ops/shape.py
Normal file
396
tensorflow/contrib/distributions/python/ops/shape.py
Normal file
@ -0,0 +1,396 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""A helper class for inferring Distribution shape."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
|
||||||
|
|
||||||
|
class _ShapeUtil(object):
|
||||||
|
"""Class which helps infer/identify subsets of tensor dimensions.
|
||||||
|
|
||||||
|
Terminology:
|
||||||
|
Recall that a `Tensor` has:
|
||||||
|
shape: sizes of tensor dimensions,
|
||||||
|
ndims: size of shape; number of tensor dimensions,
|
||||||
|
dims: indexes into shape; useful for transpose, reduce.
|
||||||
|
|
||||||
|
Tensors sampled from a `Distribution` can be partitioned by:
|
||||||
|
sample dims: indexes independent, identically distributed (iid) draws,
|
||||||
|
batch dims: indexes non-identical draws,
|
||||||
|
event dims: indexes coordinates of a single draw.
|
||||||
|
|
||||||
|
The sample, batch, and event dimensions constitute the entirety of a
|
||||||
|
`Tensor` shape. The dimensions are always in sample, batch, event order.
|
||||||
|
|
||||||
|
Assumptions:
|
||||||
|
We assume that batch_ndims and event_ndims are statically known for both
|
||||||
|
creating this object and for inputs to its functions.
|
||||||
|
TODO(jvdillon): Relax this assumption and support fully unknown shape.
|
||||||
|
|
||||||
|
We also assume that the `Tensor` rank is static, i.e., `x.get_shape().ndims
|
||||||
|
is not None`.
|
||||||
|
|
||||||
|
Possible use-cases:
|
||||||
|
~ Sample dimensions:
|
||||||
|
Computing summary statistics, i.e., the average is a reduction over sample
|
||||||
|
dimensions.
|
||||||
|
|
||||||
|
~ Batch dimensions:
|
||||||
|
Log-likelihood under model predicted location:
|
||||||
|
```python
|
||||||
|
mu = ... # vector of predictions, one for each covariate.
|
||||||
|
neg_log_likelihood = -tf.reduce_mean(
|
||||||
|
Normal(loc=mu, scale=1).log_pdf(x),
|
||||||
|
reduce_dims=[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
Monte Carlo estimation of a marginal probability:
|
||||||
|
Average over batch dimensions where batch dimensions are associated with
|
||||||
|
random draws of a prior.
|
||||||
|
E.g., suppose we want to find the Monte Carlo estimate of the marginal
|
||||||
|
distribution of a Normal with a random Laplace location:
|
||||||
|
```
|
||||||
|
P(X=x) = integral P(X=x|y) P(Y=y) dy
|
||||||
|
~= 1/n sum_{i=1}^n P(X=x|y_i), y_i ~iid Laplace(0,1)
|
||||||
|
= tf.reduce_mean(Normal(loc=Laplace(0, 1).sample_n(n=1000),
|
||||||
|
scale=tf.ones([1000, 1])).pdf(x),
|
||||||
|
reduce_dims=[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
The `Laplace` distribution generates a tensor of shape [1000, 1]. When fed
|
||||||
|
to a `Normal`, this is interpreted as 1000 different locations, i.e.,
|
||||||
|
1000 non-identical Normals. Therefore a single call to pdf(x) yields 1000
|
||||||
|
probabilities, one for every location. The average over this batch yields
|
||||||
|
the marginal.
|
||||||
|
|
||||||
|
~ Event dimensions:
|
||||||
|
Computing the determinant of the Jacobian of a function of a random
|
||||||
|
variable involves a reduction over event dimensions.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Write S, B, E for sample shape, batch shape, and event shape (resp.).
|
||||||
|
|
||||||
|
```python
|
||||||
|
x.get_shape() == S + B + E # For statically known x shape.
|
||||||
|
|
||||||
|
# 100 iid samples from one multivariate Normal with two
|
||||||
|
# degrees of freedom (DF).
|
||||||
|
mu = [0., 0]
|
||||||
|
sigma = [[1., 0],
|
||||||
|
[0, 1]]
|
||||||
|
X = MultivariateNormal(loc=mu, scale=sigma).sample_n(n=100)
|
||||||
|
# S = [100]
|
||||||
|
# B = []
|
||||||
|
# E = [2]
|
||||||
|
|
||||||
|
# 100 iid samples from one Wishart with 2x2 DF.
|
||||||
|
sigma = [[1., 0],
|
||||||
|
[0, 1]]
|
||||||
|
X = Wishart(scale=sigma).sample_n(n=100)
|
||||||
|
# S = [100]
|
||||||
|
# B = []
|
||||||
|
# E = [2, 2]
|
||||||
|
|
||||||
|
# 100 iid samples (with shape [2, 50]) from two, non-identical bivariate
|
||||||
|
# Normal distributions.
|
||||||
|
mu = ... # shape(2, 2)
|
||||||
|
sigma = ... # shape(2, 2, 2)
|
||||||
|
X = MultivariateNormal(loc=mu, scale=sigma).sample(shape=[2, 50])
|
||||||
|
# S = [2, 50]
|
||||||
|
# B = [2]
|
||||||
|
# E = [2]
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, batch_ndims=None, event_ndims=None, name='ShapeUtil'):
|
||||||
|
"""Construct ShapeUtil with known sample, batch, and/or event ndims.
|
||||||
|
|
||||||
|
Typically, batch_ndims and event_ndims are fixed throughout the lifetime of
|
||||||
|
a Distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_ndims: number of dims (rank) of the batch portion of indexes of a
|
||||||
|
`Tensor`. A "batch" is a non-identical distribution, i.e, Normal with
|
||||||
|
different parameters.
|
||||||
|
event_ndims: number of dims (rank) of the event portion of indexes of a
|
||||||
|
`Tensor`. An "event" is what is sampled from a distribution, i.e., a
|
||||||
|
trivariate Normal has an event shape of [3] and a 4 dimensional Wishart
|
||||||
|
has an event shape of [4, 4].
|
||||||
|
name: `String`. The name to give Ops created by this class.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if batch_ndims or event_ndims are invalid.
|
||||||
|
"""
|
||||||
|
if batch_ndims < 0:
|
||||||
|
raise ValueError('must specify non-negative batch_ndims(%d)', batch_ndims)
|
||||||
|
if batch_ndims > 0 and event_ndims < 1:
|
||||||
|
raise ValueError('must specify positive event_ndims(%d) when '
|
||||||
|
'batch_ndims(%d) is positive', event_ndims, batch_ndims)
|
||||||
|
# TODO(jvdillon): Support batches of scalars.
|
||||||
|
self._name = name
|
||||||
|
self._batch_ndims = batch_ndims
|
||||||
|
self._event_ndims = event_ndims
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
"""Name given to ops created by this class."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_ndims(self):
|
||||||
|
"""Returns number of dimensions corresponding to non-identical draws."""
|
||||||
|
return self._batch_ndims
|
||||||
|
|
||||||
|
@property
|
||||||
|
def event_ndims(self):
|
||||||
|
"""Returns number of dimensions needed to index a sample's coordinates."""
|
||||||
|
return self._event_ndims
|
||||||
|
|
||||||
|
def get_ndims(self, x, name='get_ndims'):
|
||||||
|
"""Get tensor ndims (rank).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`.
|
||||||
|
name: `String`. The name to give this op.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if ndims is not statically known.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Scalar` number of dimensions associated with a `Tensor`.
|
||||||
|
"""
|
||||||
|
if x is None:
|
||||||
|
raise ValueError('Input was None which does not have known ndims.')
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([x], name):
|
||||||
|
ndims = ops.convert_to_tensor(x).get_shape().ndims
|
||||||
|
if ndims is None:
|
||||||
|
raise ValueError('ShapeUtil assumes static number of '
|
||||||
|
'dimensions(%d)', ndims)
|
||||||
|
return ndims
|
||||||
|
|
||||||
|
def get_sample_ndims(self, x):
|
||||||
|
"""Returns number of dimensions corresponding to iid draws.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if batch_ndims or event_ndims are not statically known.
|
||||||
|
ValueError: if static sample_ndims does not match inferred
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scalar number of dimensions associated with a sample.
|
||||||
|
"""
|
||||||
|
ndims = self.get_ndims(x)
|
||||||
|
sample_ndims = ndims - self.batch_ndims - self.event_ndims
|
||||||
|
if sample_ndims < 0:
|
||||||
|
raise ValueError('expected batch_ndims(%d) + event_ndims(%d) < ndims(%d)',
|
||||||
|
self.batch_ndims, self.event_ndims, ndims)
|
||||||
|
return sample_ndims
|
||||||
|
|
||||||
|
def get_dims(self, x, sample=True, batch=True, event=True):
|
||||||
|
"""Returns subset of tensor's dimension indexes (indexes into shape).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`.
|
||||||
|
sample: `Boolean`. Include sample dimensions or not.
|
||||||
|
batch: `Boolean`. Include batch dimensions or not.
|
||||||
|
event: `Boolean`. Include event dimensions or not.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `x.get_shape().ndims` is `None`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List enumerating requested dimensions.
|
||||||
|
"""
|
||||||
|
ndims = self.get_ndims(x)
|
||||||
|
|
||||||
|
if sample and batch and event:
|
||||||
|
return list(range(ndims))
|
||||||
|
|
||||||
|
sample_start = 0
|
||||||
|
batch_start = self.get_sample_ndims(x)
|
||||||
|
event_start = batch_start + self.batch_ndims
|
||||||
|
|
||||||
|
sample_shape = list(range(sample_start, batch_start)) if sample else []
|
||||||
|
batch_shape = list(range(batch_start, event_start)) if batch else []
|
||||||
|
event_shape = list(range(event_start, ndims)) if event else []
|
||||||
|
|
||||||
|
return sample_shape + batch_shape + event_shape
|
||||||
|
|
||||||
|
def get_shape(self, x, sample=True, batch=True, event=True, name='get_shape'):
|
||||||
|
"""Returns subset of tensor's shape (size of dimensions).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`.
|
||||||
|
sample: `Boolean`. Include sample shape or not.
|
||||||
|
batch: `Boolean`. Include batch shape or not.
|
||||||
|
event: `Boolean`. Include event shape or not.
|
||||||
|
name: `String`. The name to give this op.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `x.get_shape().ndims` is `None`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List describing event shape if known statically, `Tensor` otherwise.
|
||||||
|
"""
|
||||||
|
if not sample and not batch and not event:
|
||||||
|
return []
|
||||||
|
with ops.name_scope(self._name):
|
||||||
|
with ops.op_scope([x], name):
|
||||||
|
x = ops.convert_to_tensor(x)
|
||||||
|
shape = (x.get_shape().as_list()
|
||||||
|
if x.get_shape().is_fully_defined()
|
||||||
|
else array_ops.shape(x))
|
||||||
|
|
||||||
|
if sample and batch and event:
|
||||||
|
return shape
|
||||||
|
|
||||||
|
sample_start = 0
|
||||||
|
batch_start = self.get_sample_ndims(x)
|
||||||
|
event_start = batch_start + self.batch_ndims
|
||||||
|
|
||||||
|
sample_shape = shape[sample_start:batch_start] if sample else []
|
||||||
|
batch_shape = shape[batch_start:event_start] if batch else []
|
||||||
|
event_shape = shape[event_start:] if event else []
|
||||||
|
|
||||||
|
if not batch and not event:
|
||||||
|
return sample_shape
|
||||||
|
if not sample and not event:
|
||||||
|
return batch_shape
|
||||||
|
if not sample and not batch:
|
||||||
|
return event_shape
|
||||||
|
|
||||||
|
if x.get_shape().is_fully_defined():
|
||||||
|
return sample_shape + batch_shape + event_shape
|
||||||
|
else:
|
||||||
|
return array_ops.concat(0, [sample_shape, batch_shape, event_shape])
|
||||||
|
|
||||||
|
def get_sample_dims(self, x):
|
||||||
|
"""Returns dimension indexes corresponding to sample.
|
||||||
|
|
||||||
|
Convenience function; identical to:
|
||||||
|
|
||||||
|
```python
|
||||||
|
get_dims(x, sample=True, batch=False, event=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `x.get_shape().ndims` is `None`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List enumerating sample dimensions.
|
||||||
|
"""
|
||||||
|
return self.get_dims(x, sample=True, batch=False, event=False)
|
||||||
|
|
||||||
|
def get_batch_dims(self, x):
|
||||||
|
"""Returns dimension indexes corresponding to batch.
|
||||||
|
|
||||||
|
Convenience function; identical to:
|
||||||
|
|
||||||
|
```python
|
||||||
|
get_dims(x, sample=False, batch=True, event=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `x.get_shape().ndims` is `None`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List enumerating batch dimensions.
|
||||||
|
"""
|
||||||
|
return self.get_dims(x, sample=False, batch=True, event=False)
|
||||||
|
|
||||||
|
def get_event_dims(self, x):
|
||||||
|
"""Returns dimension indexes corresponding to event.
|
||||||
|
|
||||||
|
Convenience function; identical to:
|
||||||
|
|
||||||
|
```python
|
||||||
|
get_dims(x, sample=False, batch=False, event=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `x.get_shape().ndims` is `None`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List enumerating event dimensions.
|
||||||
|
"""
|
||||||
|
return self.get_dims(x, sample=False, batch=False, event=True)
|
||||||
|
|
||||||
|
def get_sample_shape(self, x):
|
||||||
|
"""Returns shape corresponding to sample.
|
||||||
|
|
||||||
|
Convenience function; identical to:
|
||||||
|
|
||||||
|
```python
|
||||||
|
get_shape(x, sample=True, batch=False, event=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List describing sample shape if known statically, `Tensor` otherwise.
|
||||||
|
"""
|
||||||
|
return self.get_shape(x, sample=True, batch=False, event=False)
|
||||||
|
|
||||||
|
def get_batch_shape(self, x):
|
||||||
|
"""Returns shape corresponding to batch.
|
||||||
|
|
||||||
|
Convenience function; identical to:
|
||||||
|
|
||||||
|
```python
|
||||||
|
get_shape(x, sample=False, batch=True, event=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List describing batch shape if known statically, `Tensor` otherwise.
|
||||||
|
"""
|
||||||
|
return self.get_shape(x, sample=False, batch=True, event=False)
|
||||||
|
|
||||||
|
def get_event_shape(self, x):
|
||||||
|
"""Returns shape corresponding to event.
|
||||||
|
|
||||||
|
Convenience function; identical to:
|
||||||
|
|
||||||
|
```python
|
||||||
|
get_shape(x, sample=False, batch=False, event=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List describing event shape if known statically, `Tensor` otherwise.
|
||||||
|
"""
|
||||||
|
return self.get_shape(x, sample=False, batch=False, event=True)
|
@ -82,6 +82,7 @@ class StudentT(distribution.Distribution):
|
|||||||
# returning a length 2 tensor.
|
# returning a length 2 tensor.
|
||||||
dist.pdf(3.0)
|
dist.pdf(3.0)
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -99,19 +100,19 @@ class StudentT(distribution.Distribution):
|
|||||||
broadcasting (e.g. `df + mu + sigma` is a valid operation).
|
broadcasting (e.g. `df + mu + sigma` is a valid operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
df: `float` or `double` tensor, the degrees of freedom of the
|
df: Floating point tensor, the degrees of freedom of the
|
||||||
distribution(s). `df` must contain only positive values.
|
distribution(s). `df` must contain only positive values.
|
||||||
mu: `float` or `double` tensor, the means of the distribution(s).
|
mu: Floating point tensor, the means of the distribution(s).
|
||||||
sigma: `float` or `double` tensor, the scaling factor for the
|
sigma: Floating point tensor, the scaling factor for the
|
||||||
distribution(s). `sigma` must contain only positive values.
|
distribution(s). `sigma` must contain only positive values.
|
||||||
Note that `sigma` is not the standard deviation of this distribution.
|
Note that `sigma` is not the standard deviation of this distribution.
|
||||||
validate_args: Whether to assert that `df > 0, sigma > 0`. If
|
validate_args: Whether to assert that `df > 0, sigma > 0`. If
|
||||||
`validate_args` is False and inputs are invalid, correct behavior is not
|
`validate_args` is `False` and inputs are invalid, correct behavior is
|
||||||
guaranteed.
|
not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to give Ops created by the initializer.
|
name: The name to give Ops created by the initializer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -185,9 +186,12 @@ class StudentT(distribution.Distribution):
|
|||||||
nan = np.nan + self._zeros()
|
nan = np.nan + self._zeros()
|
||||||
return math_ops.select(df_gt_1, result_if_defined, nan)
|
return math_ops.select(df_gt_1, result_if_defined, nan)
|
||||||
else:
|
else:
|
||||||
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
|
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_less(one, self._df)], result_if_defined)
|
[check_ops.assert_less(
|
||||||
|
one, self._df,
|
||||||
|
message="mean not defined for components of df <= 1"
|
||||||
|
)], result_if_defined)
|
||||||
|
|
||||||
def mode(self, name="mode"):
|
def mode(self, name="mode"):
|
||||||
with ops.name_scope(self.name):
|
with ops.name_scope(self.name):
|
||||||
@ -232,9 +236,12 @@ class StudentT(distribution.Distribution):
|
|||||||
result_where_defined,
|
result_where_defined,
|
||||||
self._zeros() + np.nan)
|
self._zeros() + np.nan)
|
||||||
else:
|
else:
|
||||||
one = ops.convert_to_tensor(1.0, self.dtype)
|
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_less(one, self._df)], result_where_defined)
|
[check_ops.assert_less(
|
||||||
|
one, self._df,
|
||||||
|
message="variance not defined for components of df <= 1"
|
||||||
|
)], result_where_defined)
|
||||||
|
|
||||||
def std(self, name="std"):
|
def std(self, name="std"):
|
||||||
with ops.name_scope(self.name):
|
with ops.name_scope(self.name):
|
||||||
@ -348,8 +355,7 @@ class StudentT(distribution.Distribution):
|
|||||||
# Let X = R*cos(theta), and let Y = R*sin(theta).
|
# Let X = R*cos(theta), and let Y = R*sin(theta).
|
||||||
# Then X ~ t_df and Y ~ t_df.
|
# Then X ~ t_df and Y ~ t_df.
|
||||||
# The variates X and Y are not independent.
|
# The variates X and Y are not independent.
|
||||||
shape = array_ops.concat(0, [array_ops.pack([2, n]),
|
shape = array_ops.concat(0, ([2, n], self.batch_shape()))
|
||||||
self.batch_shape()])
|
|
||||||
uniform = random_ops.random_uniform(shape=shape,
|
uniform = random_ops.random_uniform(shape=shape,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
seed=seed)
|
seed=seed)
|
||||||
|
@ -57,6 +57,7 @@ class TransformedDistribution(distribution.Distribution):
|
|||||||
name="LogitNormalTransformedDistribution"
|
name="LogitNormalTransformedDistribution"
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -67,14 +67,14 @@ class Uniform(distribution.Distribution):
|
|||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: `float` or `double` tensor, the minimum endpoint.
|
a: Floating point tensor, the minimum endpoint.
|
||||||
b: `float` or `double` tensor, the maximum endpoint. Must be > `a`.
|
b: Floating point tensor, the maximum endpoint. Must be > `a`.
|
||||||
validate_args: Whether to assert that `a > b`. If `validate_args` is False
|
validate_args: Whether to assert that `a > b`. If `validate_args` is
|
||||||
and inputs are invalid, correct behavior is not guaranteed.
|
`False` and inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prefix Ops created by this distribution class.
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -83,8 +83,9 @@ class Uniform(distribution.Distribution):
|
|||||||
self._allow_nan_stats = allow_nan_stats
|
self._allow_nan_stats = allow_nan_stats
|
||||||
self._validate_args = validate_args
|
self._validate_args = validate_args
|
||||||
with ops.op_scope([a, b], name):
|
with ops.op_scope([a, b], name):
|
||||||
with ops.control_dependencies([check_ops.assert_less(a, b)] if
|
with ops.control_dependencies([check_ops.assert_less(
|
||||||
validate_args else []):
|
a, b, message="uniform not defined when a > b.")] if validate_args
|
||||||
|
else []):
|
||||||
a = array_ops.identity(a, name="a")
|
a = array_ops.identity(a, name="a")
|
||||||
b = array_ops.identity(b, name="b")
|
b = array_ops.identity(b, name="b")
|
||||||
|
|
||||||
@ -228,7 +229,7 @@ class Uniform(distribution.Distribution):
|
|||||||
n = ops.convert_to_tensor(n, name="n")
|
n = ops.convert_to_tensor(n, name="n")
|
||||||
n_val = tensor_util.constant_value(n)
|
n_val = tensor_util.constant_value(n)
|
||||||
|
|
||||||
shape = array_ops.concat(0, [array_ops.pack([n]), self.batch_shape()])
|
shape = array_ops.concat(0, ([n], self.batch_shape()))
|
||||||
samples = random_ops.random_uniform(shape=shape,
|
samples = random_ops.random_uniform(shape=shape,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
seed=seed)
|
seed=seed)
|
||||||
|
@ -94,6 +94,30 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "gmm_test",
|
||||||
|
srcs = [
|
||||||
|
"python/ops/gmm_test.py",
|
||||||
|
],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "gmm_ops_test",
|
||||||
|
srcs = [
|
||||||
|
"python/ops/gmm_ops_test.py",
|
||||||
|
],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "factorization_ops_test",
|
name = "factorization_ops_test",
|
||||||
srcs = ["python/ops/factorization_ops_test.py"],
|
srcs = ["python/ops/factorization_ops_test.py"],
|
||||||
|
@ -304,7 +304,7 @@ class WalsModelTest(tf.test.TestCase):
|
|||||||
col_factors2 = [x.eval() for x in wals_model.col_factors]
|
col_factors2 = [x.eval() for x in wals_model.col_factors]
|
||||||
|
|
||||||
for c1, c2 in zip(col_factors1, col_factors2):
|
for c1, c2 in zip(col_factors1, col_factors2):
|
||||||
self.assertAllClose(c1, c2, atol=1e-3)
|
self.assertAllClose(c1, c2, rtol=5e-3, atol=1e-2)
|
||||||
|
|
||||||
def test_als_transposed(self):
|
def test_als_transposed(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -383,7 +383,7 @@ class WalsModelTest(tf.test.TestCase):
|
|||||||
regularization=1e-5,
|
regularization=1e-5,
|
||||||
row_weights=None,
|
row_weights=None,
|
||||||
col_weights=None)
|
col_weights=None)
|
||||||
self.simple_train(model, inp, 15)
|
self.simple_train(model, inp, 25)
|
||||||
row_factor = model.row_factors[0].eval()
|
row_factor = model.row_factors[0].eval()
|
||||||
col_factor = model.col_factors[0].eval()
|
col_factor = model.col_factors[0].eval()
|
||||||
self.assertAllClose(data,
|
self.assertAllClose(data,
|
||||||
@ -407,7 +407,7 @@ class WalsModelTest(tf.test.TestCase):
|
|||||||
regularization=1e-5,
|
regularization=1e-5,
|
||||||
row_weights=[0] * rows,
|
row_weights=[0] * rows,
|
||||||
col_weights=[0] * cols)
|
col_weights=[0] * cols)
|
||||||
self.simple_train(model, inp, 15)
|
self.simple_train(model, inp, 25)
|
||||||
row_factor = model.row_factors[0].eval()
|
row_factor = model.row_factors[0].eval()
|
||||||
col_factor = model.col_factors[0].eval()
|
col_factor = model.col_factors[0].eval()
|
||||||
self.assertAllClose(data,
|
self.assertAllClose(data,
|
||||||
@ -438,7 +438,7 @@ class WalsModelTest(tf.test.TestCase):
|
|||||||
regularization=0.001,
|
regularization=0.001,
|
||||||
row_weights=row_wts,
|
row_weights=row_wts,
|
||||||
col_weights=col_wts)
|
col_weights=col_wts)
|
||||||
self.simple_train(model, inp, 10)
|
self.simple_train(model, inp, 25)
|
||||||
row_factor = model.row_factors[0].eval()
|
row_factor = model.row_factors[0].eval()
|
||||||
col_factor = model.col_factors[0].eval()
|
col_factor = model.col_factors[0].eval()
|
||||||
out = np.dot(row_factor, np.transpose(col_factor))
|
out = np.dot(row_factor, np.transpose(col_factor))
|
||||||
@ -446,7 +446,7 @@ class WalsModelTest(tf.test.TestCase):
|
|||||||
for j in xrange(cols):
|
for j in xrange(cols):
|
||||||
if keep_index([i, j]):
|
if keep_index([i, j]):
|
||||||
self.assertNear(data[i][j], out[i][j],
|
self.assertNear(data[i][j], out[i][j],
|
||||||
err=0.2, msg="%d, %d" % (i, j))
|
err=0.4, msg="%d, %d" % (i, j))
|
||||||
else:
|
else:
|
||||||
self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j))
|
self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j))
|
||||||
|
|
||||||
|
211
tensorflow/contrib/factorization/python/ops/gmm.py
Normal file
211
tensorflow/contrib/factorization/python/ops/gmm.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Implementation of Gaussian mixture model (GMM) clustering.
|
||||||
|
|
||||||
|
This goes on top of skflow API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.factorization.python.ops import gmm_ops
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
||||||
|
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
||||||
|
from tensorflow.contrib.learn.python.learn.utils import checkpoints
|
||||||
|
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
||||||
|
|
||||||
|
|
||||||
|
class GMM(estimator.Estimator, TransformerMixin):
|
||||||
|
"""GMM clustering."""
|
||||||
|
SCORES = 'scores'
|
||||||
|
ASSIGNMENTS = 'assignments'
|
||||||
|
ALL_SCORES = 'all_scores'
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_clusters,
|
||||||
|
model_dir=None,
|
||||||
|
random_seed=0,
|
||||||
|
params='wmc',
|
||||||
|
initial_clusters='random',
|
||||||
|
covariance_type='full',
|
||||||
|
batch_size=128,
|
||||||
|
steps=10,
|
||||||
|
continue_training=False,
|
||||||
|
config=None,
|
||||||
|
verbose=1):
|
||||||
|
"""Creates a model for running GMM training and inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_clusters: number of clusters to train.
|
||||||
|
model_dir: the directory to save the model results and log files.
|
||||||
|
random_seed: Python integer. Seed for PRNG used to initialize centers.
|
||||||
|
params: Controls which parameters are updated in the training process.
|
||||||
|
Can contain any combination of "w" for weights, "m" for means,
|
||||||
|
and "c" for covars.
|
||||||
|
initial_clusters: specifies how to initialize the clusters for training.
|
||||||
|
See gmm_ops.gmm for the possible values.
|
||||||
|
covariance_type: one of "full", "diag".
|
||||||
|
batch_size: See TensorFlowEstimator
|
||||||
|
steps: See TensorFlowEstimator
|
||||||
|
continue_training: See TensorFlowEstimator
|
||||||
|
config: See TensorFlowEstimator
|
||||||
|
verbose: See TensorFlowEstimator
|
||||||
|
"""
|
||||||
|
super(GMM, self).__init__(
|
||||||
|
model_dir=model_dir,
|
||||||
|
config=config)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.steps = steps
|
||||||
|
self.continue_training = continue_training
|
||||||
|
self.verbose = verbose
|
||||||
|
self._num_clusters = num_clusters
|
||||||
|
self._params = params
|
||||||
|
self._training_initial_clusters = initial_clusters
|
||||||
|
self._covariance_type = covariance_type
|
||||||
|
self._training_graph = None
|
||||||
|
self._random_seed = random_seed
|
||||||
|
|
||||||
|
def fit(self, x, y=None, monitors=None, logdir=None, steps=None):
|
||||||
|
"""Trains a GMM clustering on x.
|
||||||
|
|
||||||
|
Note: See TensorFlowEstimator for logic for continuous training and graph
|
||||||
|
construction across multiple calls to fit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: training input matrix of shape [n_samples, n_features].
|
||||||
|
y: labels. Should be None.
|
||||||
|
monitors: List of `Monitor` objects to print training progress and
|
||||||
|
invoke early stopping.
|
||||||
|
logdir: the directory to save the log file that can be used for optional
|
||||||
|
visualization.
|
||||||
|
steps: number of training steps. If not None, overrides the value passed
|
||||||
|
in constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns self.
|
||||||
|
"""
|
||||||
|
if logdir is not None:
|
||||||
|
self._model_dir = logdir
|
||||||
|
self._data_feeder = data_feeder.setup_train_data_feeder(
|
||||||
|
x, None, self._num_clusters, self.batch_size)
|
||||||
|
self._train_model(input_fn=self._data_feeder.input_builder,
|
||||||
|
feed_fn=self._data_feeder.get_feed_dict_fn(),
|
||||||
|
steps=steps or self.steps,
|
||||||
|
monitors=monitors,
|
||||||
|
init_feed_fn=self._data_feeder.get_feed_dict_fn())
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, x, batch_size=None):
|
||||||
|
"""Predict cluster id for each element in x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 2-D matrix or iterator.
|
||||||
|
batch_size: size to use for batching up x for querying the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array with same number of rows as x, containing cluster ids.
|
||||||
|
"""
|
||||||
|
return super(GMM, self).predict(x=x, batch_size=batch_size)[GMM.ASSIGNMENTS]
|
||||||
|
|
||||||
|
def score(self, x, batch_size=None):
|
||||||
|
"""Predict total sum of distances to nearest clusters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 2-D matrix or iterator.
|
||||||
|
batch_size: size to use for batching up x for querying the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total score.
|
||||||
|
"""
|
||||||
|
return np.sum(self.evaluate(x=x, batch_size=batch_size)[GMM.SCORES])
|
||||||
|
|
||||||
|
def transform(self, x, batch_size=None):
|
||||||
|
"""Transforms each element in x to distances to cluster centers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 2-D matrix or iterator.
|
||||||
|
batch_size: size to use for batching up x for querying the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array with same number of rows as x, and num_clusters columns, containing
|
||||||
|
distances to the cluster centers.
|
||||||
|
"""
|
||||||
|
return super(GMM, self).predict(x=x, batch_size=batch_size)[GMM.ALL_SCORES]
|
||||||
|
|
||||||
|
def clusters(self):
|
||||||
|
"""Returns cluster centers."""
|
||||||
|
clusters = checkpoints.load_variable(self.model_dir,
|
||||||
|
gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
|
||||||
|
return np.squeeze(clusters, 1)
|
||||||
|
|
||||||
|
def covariances(self):
|
||||||
|
"""Returns the covariances."""
|
||||||
|
return checkpoints.load_variable(
|
||||||
|
self.model_dir,
|
||||||
|
gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
|
||||||
|
|
||||||
|
def _get_train_ops(self, features, _):
|
||||||
|
(_,
|
||||||
|
_,
|
||||||
|
losses,
|
||||||
|
training_op) = gmm_ops.gmm(
|
||||||
|
features,
|
||||||
|
self._training_initial_clusters,
|
||||||
|
self._num_clusters,
|
||||||
|
self._random_seed,
|
||||||
|
self._covariance_type,
|
||||||
|
self._params)
|
||||||
|
incr_step = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
|
||||||
|
loss = tf.reduce_sum(losses)
|
||||||
|
training_op = with_dependencies([training_op, incr_step], loss)
|
||||||
|
return training_op, loss
|
||||||
|
|
||||||
|
def _get_predict_ops(self, features):
|
||||||
|
(all_scores,
|
||||||
|
model_predictions,
|
||||||
|
_,
|
||||||
|
_) = gmm_ops.gmm(
|
||||||
|
features,
|
||||||
|
self._training_initial_clusters,
|
||||||
|
self._num_clusters,
|
||||||
|
self._random_seed,
|
||||||
|
self._covariance_type,
|
||||||
|
self._params)
|
||||||
|
return {
|
||||||
|
GMM.ALL_SCORES: all_scores[0],
|
||||||
|
GMM.ASSIGNMENTS: model_predictions[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_eval_ops(self, features, _, unused_metrics):
|
||||||
|
(_,
|
||||||
|
_,
|
||||||
|
losses,
|
||||||
|
_) = gmm_ops.gmm(
|
||||||
|
features,
|
||||||
|
self._training_initial_clusters,
|
||||||
|
self._num_clusters,
|
||||||
|
self._random_seed,
|
||||||
|
self._covariance_type,
|
||||||
|
self._params)
|
||||||
|
return {
|
||||||
|
GMM.SCORES: tf.reduce_sum(losses),
|
||||||
|
}
|
461
tensorflow/contrib/factorization/python/ops/gmm_ops.py
Normal file
461
tensorflow/contrib/factorization/python/ops/gmm_ops.py
Normal file
@ -0,0 +1,461 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Gaussian mixture models Operations."""
|
||||||
|
# TODO(xavigonzalvo): Factor out covariance matrix operations to make
|
||||||
|
# code reusable for different types (e.g. diag).
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.python.ops.embedding_ops import embedding_lookup
|
||||||
|
|
||||||
|
# Machine epsilon.
|
||||||
|
MEPS = np.finfo(float).eps
|
||||||
|
FULL_COVARIANCE = 'full'
|
||||||
|
DIAG_COVARIANCE = 'diag'
|
||||||
|
|
||||||
|
|
||||||
|
def _covariance(x, diag):
|
||||||
|
"""Defines the covariance operation of a matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: a matrix Tensor. Dimension 0 should contain the number of examples.
|
||||||
|
diag: if True, it computes the diagonal covariance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Tensor representing the covariance of x. In the case of
|
||||||
|
diagonal matrix just the diagonal is returned.
|
||||||
|
"""
|
||||||
|
num_points = tf.to_float(tf.shape(x)[0])
|
||||||
|
x -= tf.reduce_mean(x, 0, keep_dims=True)
|
||||||
|
if diag:
|
||||||
|
cov = tf.reduce_sum(
|
||||||
|
tf.square(x), 0, keep_dims=True) / (num_points - 1)
|
||||||
|
else:
|
||||||
|
cov = tf.matmul(x, x, transpose_a=True) / (num_points - 1)
|
||||||
|
return cov
|
||||||
|
|
||||||
|
|
||||||
|
def _init_clusters_random(data, num_clusters, random_seed):
|
||||||
|
"""Does random initialization of clusters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: a list of Tensors with a matrix of data, each row is an example.
|
||||||
|
num_clusters: an integer with the number of clusters.
|
||||||
|
random_seed: Seed for PRNG used to initialize seeds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Tensor with num_clusters random rows of data.
|
||||||
|
"""
|
||||||
|
assert isinstance(data, list)
|
||||||
|
num_data = tf.add_n([tf.shape(inp)[0] for inp in data])
|
||||||
|
with tf.control_dependencies([tf.assert_less_equal(num_clusters, num_data)]):
|
||||||
|
indices = tf.random_uniform([num_clusters],
|
||||||
|
minval=0,
|
||||||
|
maxval=tf.cast(num_data, tf.int64),
|
||||||
|
seed=random_seed,
|
||||||
|
dtype=tf.int64)
|
||||||
|
indices = tf.cast(indices, tf.int32) % num_data
|
||||||
|
clusters_init = embedding_lookup(data, indices, partition_strategy='div')
|
||||||
|
return clusters_init
|
||||||
|
|
||||||
|
|
||||||
|
class GmmAlgorithm(object):
|
||||||
|
"""Tensorflow Gaussian mixture model clustering class."""
|
||||||
|
CLUSTERS_VARIABLE = 'clusters'
|
||||||
|
CLUSTERS_COVS_VARIABLE = 'clusters_covs'
|
||||||
|
|
||||||
|
def __init__(self, data, num_classes, initial_means=None, params='wmc',
|
||||||
|
covariance_type=FULL_COVARIANCE, random_seed=0):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: a list of Tensors with data, each row is a new example.
|
||||||
|
num_classes: number of clusters.
|
||||||
|
initial_means: a Tensor with a matrix of means. If None, means are
|
||||||
|
computed by sampling randomly.
|
||||||
|
params: Controls which parameters are updated in the training
|
||||||
|
process. Can contain any combination of "w" for weights, "m" for
|
||||||
|
means, and "c" for covariances.
|
||||||
|
covariance_type: one of "full", "diag".
|
||||||
|
random_seed: Seed for PRNG used to initialize seeds.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception if covariance type is unknown.
|
||||||
|
"""
|
||||||
|
self._params = params
|
||||||
|
self._random_seed = random_seed
|
||||||
|
self._covariance_type = covariance_type
|
||||||
|
if self._covariance_type not in [DIAG_COVARIANCE, FULL_COVARIANCE]:
|
||||||
|
raise Exception( # pylint: disable=g-doc-exception
|
||||||
|
'programmer error: Invalid covariance type: %s' %
|
||||||
|
self._covariance_type)
|
||||||
|
# Create sharded variables for multiple shards. The following
|
||||||
|
# lists are indexed by shard.
|
||||||
|
# Probability per example in a class.
|
||||||
|
num_shards = len(data)
|
||||||
|
self._probs = [None] * num_shards
|
||||||
|
# Prior probability.
|
||||||
|
self._prior_probs = [None] * num_shards
|
||||||
|
# Membership weights w_{ik} where "i" is the i-th example and "k"
|
||||||
|
# is the k-th mixture.
|
||||||
|
self._w = [None] * num_shards
|
||||||
|
# Number of examples in a class.
|
||||||
|
self._points_in_k = [None] * num_shards
|
||||||
|
first_shard = data[0]
|
||||||
|
self._dimensions = tf.shape(first_shard)[1]
|
||||||
|
self._num_classes = num_classes
|
||||||
|
# Small value to guarantee that covariances are invertible.
|
||||||
|
self._min_var = tf.diag(tf.ones(tf.pack([self._dimensions]))) * 1e-3
|
||||||
|
self._create_variables(data, initial_means)
|
||||||
|
# Operations of partial statistics for the computation of the means.
|
||||||
|
self._w_mul_x = []
|
||||||
|
# Operations of partial statistics for the computation of the covariances.
|
||||||
|
self._w_mul_x2 = []
|
||||||
|
self._define_graph(data)
|
||||||
|
|
||||||
|
def _create_variables(self, data, initial_means=None):
|
||||||
|
"""Initializes GMM algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: a list of Tensors with data, each row is a new example.
|
||||||
|
initial_means: a Tensor with a matrix of means.
|
||||||
|
"""
|
||||||
|
first_shard = data[0]
|
||||||
|
# Initialize means: num_classes X 1 X dimensions.
|
||||||
|
if initial_means is not None:
|
||||||
|
self._means = tf.Variable(tf.expand_dims(initial_means, 1),
|
||||||
|
name=self.CLUSTERS_VARIABLE,
|
||||||
|
validate_shape=False, dtype=tf.float32)
|
||||||
|
else:
|
||||||
|
# Sample data randomly
|
||||||
|
self._means = tf.Variable(tf.expand_dims(
|
||||||
|
_init_clusters_random(data, self._num_classes, self._random_seed), 1),
|
||||||
|
name=self.CLUSTERS_VARIABLE,
|
||||||
|
validate_shape=False)
|
||||||
|
|
||||||
|
# Initialize covariances.
|
||||||
|
if self._covariance_type == FULL_COVARIANCE:
|
||||||
|
cov = _covariance(first_shard, False) + self._min_var
|
||||||
|
# A matrix per class, num_classes X dimensions X dimensions
|
||||||
|
covs = tf.tile(
|
||||||
|
tf.expand_dims(cov, 0), [self._num_classes, 1, 1])
|
||||||
|
elif self._covariance_type == DIAG_COVARIANCE:
|
||||||
|
cov = _covariance(first_shard, True) + self._min_var
|
||||||
|
# A diagonal per row, num_classes X dimensions.
|
||||||
|
covs = tf.tile(tf.expand_dims(tf.diag_part(cov), 0),
|
||||||
|
[self._num_classes, 1])
|
||||||
|
self._covs = tf.Variable(covs, name='clusters_covs', validate_shape=False)
|
||||||
|
# Mixture weights, representing the probability that a randomly
|
||||||
|
# selected unobservable data (in EM terms) was generated by component k.
|
||||||
|
self._alpha = tf.Variable(tf.tile([1.0 / self._num_classes],
|
||||||
|
[self._num_classes]))
|
||||||
|
|
||||||
|
def training_ops(self):
|
||||||
|
"""Returns the training operation."""
|
||||||
|
return self._train_ops
|
||||||
|
|
||||||
|
def alphas(self):
|
||||||
|
return self._alpha
|
||||||
|
|
||||||
|
def clusters(self):
|
||||||
|
"""Returns the clusters with dimensions num_classes X 1 X num_dimensions."""
|
||||||
|
return self._means
|
||||||
|
|
||||||
|
def covariances(self):
|
||||||
|
"""Returns the covariances matrices."""
|
||||||
|
return self._covs
|
||||||
|
|
||||||
|
def assignments(self):
|
||||||
|
"""Returns a list of Tensors with the matrix of assignments per shard."""
|
||||||
|
ret = []
|
||||||
|
for w in self._w:
|
||||||
|
ret.append(tf.argmax(w, 1))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def scores(self):
|
||||||
|
"""Returns the distances to each class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple with two Tensors. The first contains the distance to
|
||||||
|
each class. The second contains the distance to the assigned
|
||||||
|
class.
|
||||||
|
"""
|
||||||
|
return (self._all_scores, self._scores)
|
||||||
|
|
||||||
|
def _define_graph(self, data):
|
||||||
|
"""Define graph for a single iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: a list of Tensors defining the training data.
|
||||||
|
"""
|
||||||
|
for shard_id, shard in enumerate(data):
|
||||||
|
self._num_examples = tf.shape(shard)[0]
|
||||||
|
shard = tf.expand_dims(shard, 0)
|
||||||
|
self._define_log_prob_operation(shard_id, shard)
|
||||||
|
self._define_prior_log_prob_operation(shard_id)
|
||||||
|
self._define_expectation_operation(shard_id)
|
||||||
|
self._define_partial_maximization_operation(shard_id, shard)
|
||||||
|
self._define_maximization_operation(len(data))
|
||||||
|
self._define_distance_to_clusters(data)
|
||||||
|
|
||||||
|
def _define_full_covariance_probs(self, shard_id, shard):
|
||||||
|
"""Defines the full covariance probabilties per example in a class.
|
||||||
|
|
||||||
|
Updates a matrix with dimension num_examples X num_classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_id: id of the current shard.
|
||||||
|
shard: current data shard, 1 X num_examples X dimensions.
|
||||||
|
"""
|
||||||
|
diff = shard - self._means
|
||||||
|
cholesky = tf.batch_cholesky(self._covs + self._min_var)
|
||||||
|
log_det_covs = 2.0 * tf.reduce_sum(tf.log(
|
||||||
|
tf.batch_matrix_diag_part(cholesky)), 1)
|
||||||
|
x_mu_cov = tf.square(tf.batch_matrix_triangular_solve(
|
||||||
|
cholesky, tf.transpose(diff, perm=[0, 2, 1]),
|
||||||
|
lower=True))
|
||||||
|
diag_m = tf.transpose(tf.reduce_sum(x_mu_cov, 1))
|
||||||
|
self._probs[shard_id] = -0.5 * (
|
||||||
|
diag_m + tf.to_float(self._dimensions) * tf.log(2 * np.pi) +
|
||||||
|
log_det_covs)
|
||||||
|
|
||||||
|
def _define_diag_covariance_probs(self, shard_id, shard):
|
||||||
|
"""Defines the diagonal covariance probabilities per example in a class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_id: id of the current shard.
|
||||||
|
shard: current data shard, 1 X num_examples X dimensions.
|
||||||
|
|
||||||
|
Returns a matrix num_examples * num_classes.
|
||||||
|
"""
|
||||||
|
# num_classes X 1
|
||||||
|
# TODO(xavigonzalvo): look into alternatives to log for
|
||||||
|
# reparametrization of variance parameters.
|
||||||
|
det_expanded = tf.reduce_sum(tf.log(self._covs + 1e-3),
|
||||||
|
1, keep_dims=True)
|
||||||
|
diff = shard - self._means
|
||||||
|
x2 = tf.square(diff)
|
||||||
|
cov_expanded = tf.expand_dims(1.0 / (self._covs + 1e-3), 2)
|
||||||
|
# num_classes X num_examples
|
||||||
|
x2_cov = tf.batch_matmul(x2, cov_expanded)
|
||||||
|
x2_cov = tf.transpose(tf.squeeze(x2_cov, [2]))
|
||||||
|
self._probs[shard_id] = -0.5 * (
|
||||||
|
tf.to_float(self._dimensions) * tf.log(2.0 * np.pi) +
|
||||||
|
tf.transpose(det_expanded) + x2_cov)
|
||||||
|
|
||||||
|
def _define_log_prob_operation(self, shard_id, shard):
|
||||||
|
"""Probability per example in a class.
|
||||||
|
|
||||||
|
Updates a matrix with dimension num_examples X num_classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_id: id of the current shard.
|
||||||
|
shard: current data shard, 1 X num_examples X dimensions.
|
||||||
|
"""
|
||||||
|
# TODO(xavigonzalvo): Use the pdf defined in
|
||||||
|
# third_party/tensorflow/contrib/distributions/python/ops/gaussian.py
|
||||||
|
if self._covariance_type == FULL_COVARIANCE:
|
||||||
|
self._define_full_covariance_probs(shard_id, shard)
|
||||||
|
elif self._covariance_type == DIAG_COVARIANCE:
|
||||||
|
self._define_diag_covariance_probs(shard_id, shard)
|
||||||
|
self._probs[shard_id] += tf.log(self._alpha)
|
||||||
|
|
||||||
|
def _define_prior_log_prob_operation(self, shard_id):
|
||||||
|
"""Computes the prior probability of all samples.
|
||||||
|
|
||||||
|
Updates a vector where each item is the prior probabibility of an
|
||||||
|
input example.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_id: id of current shard_id.
|
||||||
|
"""
|
||||||
|
self._prior_probs[shard_id] = tf.log(
|
||||||
|
tf.reduce_sum(tf.exp(self._probs[shard_id]), 1, keep_dims=True))
|
||||||
|
|
||||||
|
def _define_expectation_operation(self, shard_id):
|
||||||
|
# Shape broadcasting.
|
||||||
|
probs = tf.expand_dims(self._probs[shard_id], 0)
|
||||||
|
# Membership weights are computed as:
|
||||||
|
# w_{ik} = \frac{\alpha_k f(\mathbf{y_i}|\mathbf{\theta}_k)}
|
||||||
|
# {\sum_{m=1}^{K}\alpha_mf(\mathbf{y_i}|\mathbf{\theta}_m)}
|
||||||
|
# where "i" is the i-th example, "k" is the k-th mixture, theta are
|
||||||
|
# the model parameters and y_i the observations.
|
||||||
|
# These are defined for each shard.
|
||||||
|
self._w[shard_id] = tf.reshape(
|
||||||
|
tf.exp(probs - self._prior_probs[shard_id]),
|
||||||
|
tf.pack([self._num_examples, self._num_classes]))
|
||||||
|
|
||||||
|
def _define_partial_maximization_operation(self, shard_id, shard):
|
||||||
|
"""Computes the partial statistics of the means and covariances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_id: current shard id.
|
||||||
|
shard: current data shard, 1 X num_examples X dimensions.
|
||||||
|
"""
|
||||||
|
# Soft assignment of each data point to each of the two clusters.
|
||||||
|
self._points_in_k[shard_id] = tf.reduce_sum(self._w[shard_id], 0,
|
||||||
|
keep_dims=True)
|
||||||
|
# Partial means.
|
||||||
|
w_mul_x = tf.expand_dims(
|
||||||
|
tf.matmul(self._w[shard_id],
|
||||||
|
tf.squeeze(shard, [0]), transpose_a=True), 1)
|
||||||
|
self._w_mul_x.append(w_mul_x)
|
||||||
|
# Partial covariances.
|
||||||
|
x = tf.concat(0, [shard for _ in range(self._num_classes)])
|
||||||
|
x_trans = tf.transpose(x, perm=[0, 2, 1])
|
||||||
|
x_mul_w = tf.concat(0, [
|
||||||
|
tf.expand_dims(x_trans[k, :, :] * self._w[shard_id][:, k], 0)
|
||||||
|
for k in range(self._num_classes)])
|
||||||
|
self._w_mul_x2.append(tf.batch_matmul(x_mul_w, x))
|
||||||
|
|
||||||
|
def _define_maximization_operation(self, num_batches):
|
||||||
|
"""Maximization operations."""
|
||||||
|
# TODO(xavigonzalvo): some of these operations could be moved to C++.
|
||||||
|
# Compute the effective number of data points assigned to component k.
|
||||||
|
with tf.control_dependencies(self._w):
|
||||||
|
points_in_k = tf.squeeze(tf.add_n(self._points_in_k), squeeze_dims=[0])
|
||||||
|
# Update alpha.
|
||||||
|
if 'w' in self._params:
|
||||||
|
final_points_in_k = points_in_k / num_batches
|
||||||
|
num_examples = tf.to_float(tf.reduce_sum(final_points_in_k))
|
||||||
|
self._alpha_op = self._alpha.assign(
|
||||||
|
final_points_in_k / (num_examples + MEPS))
|
||||||
|
else:
|
||||||
|
self._alpha_op = tf.no_op()
|
||||||
|
self._train_ops = [self._alpha_op]
|
||||||
|
|
||||||
|
# Update means.
|
||||||
|
points_in_k_expanded = tf.reshape(points_in_k,
|
||||||
|
[self._num_classes, 1, 1])
|
||||||
|
if 'm' in self._params:
|
||||||
|
self._means_op = self._means.assign(
|
||||||
|
tf.div(tf.add_n(self._w_mul_x), points_in_k_expanded + MEPS))
|
||||||
|
else:
|
||||||
|
self._means_op = tf.no_op()
|
||||||
|
# means are (num_classes x 1 x dims)
|
||||||
|
|
||||||
|
# Update covariances.
|
||||||
|
with tf.control_dependencies([self._means_op]):
|
||||||
|
b = tf.add_n(self._w_mul_x2) / (points_in_k_expanded + MEPS)
|
||||||
|
new_covs = []
|
||||||
|
for k in range(self._num_classes):
|
||||||
|
mean = self._means.ref()[k, :, :]
|
||||||
|
square_mean = tf.matmul(mean, mean, transpose_a=True)
|
||||||
|
new_cov = b[k, :, :] - square_mean + self._min_var
|
||||||
|
if self._covariance_type == FULL_COVARIANCE:
|
||||||
|
new_covs.append(tf.expand_dims(new_cov, 0))
|
||||||
|
elif self._covariance_type == DIAG_COVARIANCE:
|
||||||
|
new_covs.append(tf.expand_dims(tf.diag_part(new_cov), 0))
|
||||||
|
new_covs = tf.concat(0, new_covs)
|
||||||
|
if 'c' in self._params:
|
||||||
|
# Train operations don't need to take care of the means
|
||||||
|
# because covariances already depend on it.
|
||||||
|
with tf.control_dependencies([self._means_op, new_covs]):
|
||||||
|
self._train_ops.append(
|
||||||
|
tf.assign(self._covs, new_covs, validate_shape=False))
|
||||||
|
|
||||||
|
def _define_distance_to_clusters(self, data):
|
||||||
|
"""Defines the Mahalanobis distance to the assigned Gaussian."""
|
||||||
|
# TODO(xavigonzalvo): reuse (input - mean) * cov^-1 * (input -
|
||||||
|
# mean) from log probability function.
|
||||||
|
self._all_scores = []
|
||||||
|
for shard in data:
|
||||||
|
all_scores = []
|
||||||
|
shard = tf.expand_dims(shard, 0)
|
||||||
|
for c in xrange(self._num_classes):
|
||||||
|
if self._covariance_type == FULL_COVARIANCE:
|
||||||
|
cov = self._covs[c, :, :]
|
||||||
|
elif self._covariance_type == DIAG_COVARIANCE:
|
||||||
|
cov = tf.diag(self._covs[c, :])
|
||||||
|
inverse = tf.matrix_inverse(cov + self._min_var)
|
||||||
|
inv_cov = tf.tile(
|
||||||
|
tf.expand_dims(inverse, 0),
|
||||||
|
tf.pack([self._num_examples, 1, 1]))
|
||||||
|
diff = tf.transpose(shard - self._means[c, :, :], perm=[1, 0, 2])
|
||||||
|
m_left = tf.batch_matmul(diff, inv_cov)
|
||||||
|
all_scores.append(tf.sqrt(tf.batch_matmul(
|
||||||
|
m_left, tf.transpose(diff, perm=[0, 2, 1])
|
||||||
|
)))
|
||||||
|
self._all_scores.append(tf.reshape(
|
||||||
|
tf.concat(1, all_scores),
|
||||||
|
tf.pack([self._num_examples, self._num_classes])))
|
||||||
|
|
||||||
|
# Distance to the associated class.
|
||||||
|
self._all_scores = tf.concat(0, self._all_scores)
|
||||||
|
assignments = tf.concat(0, self.assignments())
|
||||||
|
rows = tf.to_int64(tf.range(0, self._num_examples))
|
||||||
|
indices = tf.concat(1, [tf.expand_dims(rows, 1),
|
||||||
|
tf.expand_dims(assignments, 1)])
|
||||||
|
self._scores = tf.gather_nd(self._all_scores, indices)
|
||||||
|
|
||||||
|
def _define_loglikelihood_operation(self):
|
||||||
|
"""Defines the total log-likelihood of current iteration."""
|
||||||
|
self._ll_op = []
|
||||||
|
for prior_probs in self._prior_probs:
|
||||||
|
self._ll_op.append(tf.reduce_sum(tf.log(prior_probs)))
|
||||||
|
tf.scalar_summary('ll', tf.reduce_sum(self._ll_op))
|
||||||
|
|
||||||
|
|
||||||
|
def gmm(inp, initial_clusters, num_clusters, random_seed,
|
||||||
|
covariance_type=FULL_COVARIANCE, params='wmc'):
|
||||||
|
"""Creates the graph for Gaussian mixture model (GMM) clustering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inp: An input tensor or list of input tensors
|
||||||
|
initial_clusters: Specifies the clusters used during
|
||||||
|
initialization. Can be a tensor or numpy array, or a function
|
||||||
|
that generates the clusters. Can also be "random" to specify
|
||||||
|
that clusters should be chosen randomly from input data. Note: type
|
||||||
|
is diverse to be consistent with skflow.
|
||||||
|
num_clusters: number of clusters.
|
||||||
|
random_seed: Python integer. Seed for PRNG used to initialize centers.
|
||||||
|
covariance_type: one of "diag", "full".
|
||||||
|
params: Controls which parameters are updated in the training
|
||||||
|
process. Can contain any combination of "w" for weights, "m" for
|
||||||
|
means, and "c" for covars.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Note: tuple of lists returned to be consistent with skflow
|
||||||
|
A tuple consisting of:
|
||||||
|
all_scores: A matrix (or list of matrices) of dimensions (num_input,
|
||||||
|
num_clusters) where the value is the distance of an input vector and a
|
||||||
|
cluster center.
|
||||||
|
assignments: A vector (or list of vectors). Each element in the vector
|
||||||
|
corresponds to an input row in 'inp' and specifies the cluster id
|
||||||
|
corresponding to the input.
|
||||||
|
scores: Similar to assignments but specifies the distance to the
|
||||||
|
assigned cluster instead.
|
||||||
|
training_op: an op that runs an iteration of training.
|
||||||
|
"""
|
||||||
|
initial_means = None
|
||||||
|
if initial_clusters != 'random' and not isinstance(
|
||||||
|
initial_clusters, tf.Tensor):
|
||||||
|
initial_means = tf.constant(initial_clusters, dtype=tf.float32)
|
||||||
|
|
||||||
|
# Implementation of GMM.
|
||||||
|
inp = inp if isinstance(inp, list) else [inp]
|
||||||
|
gmm_tool = GmmAlgorithm(inp, num_clusters, initial_means, params,
|
||||||
|
covariance_type, random_seed)
|
||||||
|
training_ops = gmm_tool.training_ops()
|
||||||
|
assignments = gmm_tool.assignments()
|
||||||
|
all_scores, scores = gmm_tool.scores()
|
||||||
|
return [all_scores], [assignments], [scores], tf.group(*training_ops)
|
198
tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
Normal file
198
tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for gmm_ops."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.factorization.python.ops import gmm_ops
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
|
class GmmOpsTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.num_examples = 1000
|
||||||
|
self.iterations = 40
|
||||||
|
self.seed = 4
|
||||||
|
tf.set_random_seed(self.seed)
|
||||||
|
np.random.seed(self.seed * 2)
|
||||||
|
self.data, self.true_assignments = self.make_data(self.num_examples)
|
||||||
|
# Generate more complicated data.
|
||||||
|
self.centers = [[1, 1], [-1, 0.5], [2, 1]]
|
||||||
|
self.more_data, self.more_true_assignments = self.make_data_from_centers(
|
||||||
|
self.num_examples, self.centers)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_data(num_vectors):
|
||||||
|
"""Generates 2-dimensional data centered on (2,2), (-1,-1).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_vectors: number of training examples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the data as a numpy array and the cluster ids.
|
||||||
|
"""
|
||||||
|
vectors = []
|
||||||
|
classes = []
|
||||||
|
for _ in xrange(num_vectors):
|
||||||
|
if np.random.random() > 0.5:
|
||||||
|
vectors.append([np.random.normal(2.0, 0.6),
|
||||||
|
np.random.normal(2.0, 0.9)])
|
||||||
|
classes.append(0)
|
||||||
|
else:
|
||||||
|
vectors.append([np.random.normal(-1.0, 0.4),
|
||||||
|
np.random.normal(-1.0, 0.5)])
|
||||||
|
classes.append(1)
|
||||||
|
return np.asarray(vectors), classes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_data_from_centers(num_vectors, centers):
|
||||||
|
"""Generates 2-dimensional data with random centers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_vectors: number of training examples.
|
||||||
|
centers: a list of random 2-dimensional centers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the data as a numpy array and the cluster ids.
|
||||||
|
"""
|
||||||
|
vectors = []
|
||||||
|
classes = []
|
||||||
|
for _ in xrange(num_vectors):
|
||||||
|
current_class = np.random.random_integers(0, len(centers) - 1)
|
||||||
|
vectors.append([np.random.normal(centers[current_class][0],
|
||||||
|
np.random.random_sample()),
|
||||||
|
np.random.normal(centers[current_class][1],
|
||||||
|
np.random.random_sample())])
|
||||||
|
classes.append(current_class)
|
||||||
|
return np.asarray(vectors), len(centers)
|
||||||
|
|
||||||
|
def test_covariance(self):
|
||||||
|
start_time = time.time()
|
||||||
|
data = self.data.T
|
||||||
|
np_cov = np.cov(data)
|
||||||
|
logging.info('Numpy took %f', time.time() - start_time)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with self.test_session() as sess:
|
||||||
|
op = gmm_ops._covariance(
|
||||||
|
tf.constant(data.T, dtype=tf.float32),
|
||||||
|
False)
|
||||||
|
op_diag = gmm_ops._covariance(
|
||||||
|
tf.constant(data.T, dtype=tf.float32),
|
||||||
|
True)
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
tf_cov = sess.run(op)
|
||||||
|
np.testing.assert_array_almost_equal(np_cov, tf_cov)
|
||||||
|
logging.info('Tensorflow took %f', time.time() - start_time)
|
||||||
|
tf_cov = sess.run(op_diag)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
np.diag(np_cov), np.ravel(tf_cov), decimal=5)
|
||||||
|
|
||||||
|
def test_simple_cluster(self):
|
||||||
|
"""Tests that the clusters are correct."""
|
||||||
|
num_classes = 2
|
||||||
|
graph = tf.Graph()
|
||||||
|
with graph.as_default() as g:
|
||||||
|
g.seed = 5
|
||||||
|
with self.test_session() as sess:
|
||||||
|
data = tf.constant(self.data, dtype=tf.float32)
|
||||||
|
_, assignments, _, training_op = gmm_ops.gmm(data, 'random',
|
||||||
|
num_classes,
|
||||||
|
random_seed=self.seed)
|
||||||
|
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for _ in xrange(self.iterations):
|
||||||
|
sess.run(training_op)
|
||||||
|
assignments = sess.run(assignments)
|
||||||
|
accuracy = np.mean(
|
||||||
|
np.asarray(self.true_assignments) == np.squeeze(assignments))
|
||||||
|
logging.info('Accuracy: %f', accuracy)
|
||||||
|
self.assertGreater(accuracy, 0.98)
|
||||||
|
|
||||||
|
def testParams(self):
|
||||||
|
"""Tests that the params work as intended."""
|
||||||
|
num_classes = 2
|
||||||
|
with self.test_session() as sess:
|
||||||
|
# Experiment 1. Update weights only.
|
||||||
|
data = tf.constant(self.data, dtype=tf.float32)
|
||||||
|
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
|
||||||
|
[[3.0, 3.0], [0.0, 0.0]], 'w')
|
||||||
|
training_ops = gmm_tool.training_ops()
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for _ in xrange(self.iterations):
|
||||||
|
sess.run(training_ops)
|
||||||
|
|
||||||
|
# Only the probability to each class is updated.
|
||||||
|
alphas = sess.run(gmm_tool.alphas())
|
||||||
|
self.assertGreater(alphas[1], 0.6)
|
||||||
|
means = sess.run(gmm_tool.clusters())
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
np.expand_dims([[3.0, 3.0], [0.0, 0.0]], 1), means)
|
||||||
|
covs = sess.run(gmm_tool.covariances())
|
||||||
|
np.testing.assert_almost_equal(covs[0], covs[1])
|
||||||
|
|
||||||
|
# Experiment 2. Update means and covariances.
|
||||||
|
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
|
||||||
|
[[3.0, 3.0], [0.0, 0.0]], 'mc')
|
||||||
|
training_ops = gmm_tool.training_ops()
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for _ in xrange(self.iterations):
|
||||||
|
sess.run(training_ops)
|
||||||
|
alphas = sess.run(gmm_tool.alphas())
|
||||||
|
self.assertAlmostEqual(alphas[0], alphas[1])
|
||||||
|
means = sess.run(gmm_tool.clusters())
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
np.expand_dims([[2.0, 2.0], [-1.0, -1.0]], 1), means, decimal=1)
|
||||||
|
covs = sess.run(gmm_tool.covariances())
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
[[0.371111, -0.0050774], [-0.0050774, 0.8651744]],
|
||||||
|
covs[0], decimal=4)
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
[[0.146976, 0.0259463], [0.0259463, 0.2543971]],
|
||||||
|
covs[1], decimal=4)
|
||||||
|
|
||||||
|
# Experiment 3. Update covariances only.
|
||||||
|
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
|
||||||
|
[[-1.0, -1.0], [1.0, 1.0]], 'c')
|
||||||
|
training_ops = gmm_tool.training_ops()
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for _ in xrange(self.iterations):
|
||||||
|
sess.run(training_ops)
|
||||||
|
alphas = sess.run(gmm_tool.alphas())
|
||||||
|
self.assertAlmostEqual(alphas[0], alphas[1])
|
||||||
|
means = sess.run(gmm_tool.clusters())
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
np.expand_dims([[-1.0, -1.0], [1.0, 1.0]], 1), means)
|
||||||
|
covs = sess.run(gmm_tool.covariances())
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
[[0.1299582, 0.0435872], [0.0435872, 0.2558578]],
|
||||||
|
covs[0], decimal=5)
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
[[3.195385, 2.6989155], [2.6989155, 3.3881593]],
|
||||||
|
covs[1], decimal=5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
172
tensorflow/contrib/factorization/python/ops/gmm_test.py
Normal file
172
tensorflow/contrib/factorization/python/ops/gmm_test.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for ops.gmm."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.factorization.python.ops.gmm import GMM
|
||||||
|
from tensorflow.contrib.factorization.python.ops.kmeans import KMeansClustering as KMeans
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||||
|
|
||||||
|
FLAGS = tf.app.flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
class GMMTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
np.random.seed(3)
|
||||||
|
tf.set_random_seed(2)
|
||||||
|
self.num_centers = 2
|
||||||
|
self.num_dims = 2
|
||||||
|
self.num_points = 4000
|
||||||
|
self.batch_size = 100
|
||||||
|
self.true_centers = self.make_random_centers(self.num_centers,
|
||||||
|
self.num_dims)
|
||||||
|
self.points, self.assignments, self.scores = self.make_random_points(
|
||||||
|
self.true_centers,
|
||||||
|
self.num_points)
|
||||||
|
self.true_score = np.add.reduce(self.scores)
|
||||||
|
|
||||||
|
# Use initial means from kmeans (just like scikit-learn does).
|
||||||
|
clusterer = KMeans(num_clusters=self.num_centers)
|
||||||
|
clusterer.fit(self.points, steps=30)
|
||||||
|
self.initial_means = clusterer.clusters()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_random_centers(num_centers, num_dims):
|
||||||
|
return np.round(np.random.rand(num_centers,
|
||||||
|
num_dims).astype(np.float32) * 500)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_random_points(centers, num_points):
|
||||||
|
num_centers, num_dims = centers.shape
|
||||||
|
assignments = np.random.choice(num_centers, num_points)
|
||||||
|
offsets = np.round(np.random.randn(num_points,
|
||||||
|
num_dims).astype(np.float32) * 20)
|
||||||
|
points = centers[assignments] + offsets
|
||||||
|
means = [np.mean(points[assignments == center], axis=0)
|
||||||
|
for center in xrange(num_centers)]
|
||||||
|
covs = [np.cov(points[assignments == center].T)
|
||||||
|
for center in xrange(num_centers)]
|
||||||
|
scores = []
|
||||||
|
for r in xrange(num_points):
|
||||||
|
scores.append(np.sqrt(np.dot(
|
||||||
|
np.dot(points[r, :] - means[assignments[r]],
|
||||||
|
np.linalg.inv(covs[assignments[r]])),
|
||||||
|
points[r, :] - means[assignments[r]])))
|
||||||
|
return (points, assignments, scores)
|
||||||
|
|
||||||
|
def test_clusters(self):
|
||||||
|
"""Tests the shape of the clusters."""
|
||||||
|
gmm = GMM(self.num_centers,
|
||||||
|
initial_clusters=self.initial_means,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
steps=40,
|
||||||
|
continue_training=True,
|
||||||
|
random_seed=4,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(x=self.points, steps=0)
|
||||||
|
clusters = gmm.clusters()
|
||||||
|
self.assertAllEqual(list(clusters.shape),
|
||||||
|
[self.num_centers, self.num_dims])
|
||||||
|
|
||||||
|
def test_fit(self):
|
||||||
|
gmm = GMM(self.num_centers,
|
||||||
|
initial_clusters='random',
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
random_seed=4,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(x=self.points, steps=1)
|
||||||
|
score1 = gmm.score(x=self.points)
|
||||||
|
gmm = GMM(self.num_centers,
|
||||||
|
initial_clusters='random',
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
random_seed=4,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(x=self.points, steps=10)
|
||||||
|
score2 = gmm.score(x=self.points)
|
||||||
|
self.assertGreater(score1, score2)
|
||||||
|
self.assertNear(self.true_score, score2, self.true_score * 0.15)
|
||||||
|
|
||||||
|
def test_infer(self):
|
||||||
|
gmm = GMM(self.num_centers,
|
||||||
|
initial_clusters=self.initial_means,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
steps=40,
|
||||||
|
continue_training=True,
|
||||||
|
random_seed=4,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(x=self.points, steps=60)
|
||||||
|
clusters = gmm.clusters()
|
||||||
|
|
||||||
|
# Make a small test set
|
||||||
|
points, true_assignments, true_offsets = (
|
||||||
|
self.make_random_points(clusters, 40))
|
||||||
|
|
||||||
|
assignments = np.ravel(gmm.predict(points))
|
||||||
|
self.assertAllEqual(true_assignments, assignments)
|
||||||
|
|
||||||
|
# Test score
|
||||||
|
score = gmm.score(points)
|
||||||
|
self.assertNear(score, np.sum(true_offsets), 4.05)
|
||||||
|
|
||||||
|
def _compare_with_sklearn(self, cov_type):
|
||||||
|
# sklearn version.
|
||||||
|
iterations = 40
|
||||||
|
np.random.seed(5)
|
||||||
|
sklearn_assignments = np.asarray([0, 0, 1, 0, 0, 0, 1, 0, 0, 1])
|
||||||
|
sklearn_means = np.asarray([[144.83417719, 254.20130341],
|
||||||
|
[274.38754816, 353.16074346]])
|
||||||
|
sklearn_covs = np.asarray([[[395.0081194, -4.50389512],
|
||||||
|
[-4.50389512, 408.27543989]],
|
||||||
|
[[385.17484203, -31.27834935],
|
||||||
|
[-31.27834935, 391.74249925]]])
|
||||||
|
|
||||||
|
# skflow version.
|
||||||
|
gmm = GMM(self.num_centers,
|
||||||
|
initial_clusters=self.initial_means,
|
||||||
|
covariance_type=cov_type,
|
||||||
|
batch_size=self.num_points,
|
||||||
|
steps=iterations,
|
||||||
|
continue_training=True,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(self.points)
|
||||||
|
skflow_assignments = gmm.predict(self.points[:10, :]).astype(int)
|
||||||
|
self.assertAllClose(sklearn_assignments,
|
||||||
|
np.ravel(skflow_assignments))
|
||||||
|
self.assertAllClose(sklearn_means, gmm.clusters())
|
||||||
|
if cov_type == 'full':
|
||||||
|
self.assertAllClose(sklearn_covs, gmm.covariances(), rtol=0.01)
|
||||||
|
else:
|
||||||
|
for d in [0, 1]:
|
||||||
|
self.assertAllClose(np.diag(sklearn_covs[d]),
|
||||||
|
gmm.covariances()[d, :], rtol=0.01)
|
||||||
|
|
||||||
|
def test_compare_full(self):
|
||||||
|
self._compare_with_sklearn('full')
|
||||||
|
|
||||||
|
def test_compare_diag(self):
|
||||||
|
self._compare_with_sklearn('diag')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -153,9 +153,11 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
def test_fit_with_cosine_distance(self):
|
def test_fit_with_cosine_distance(self):
|
||||||
# Create points on y=x and y=1.5x lines to check the cosine similarity.
|
# Create points on y=x and y=1.5x lines to check the cosine similarity.
|
||||||
# Note that euclidean distance will give different results in this case.
|
# Note that euclidean distance will give different results in this case.
|
||||||
points = np.array([[9, 9], [0.5, 0.5], [10, 15], [0.4, 0.6]])
|
points = np.array(
|
||||||
|
[[9, 9], [0.5, 0.5], [10, 15], [0.4, 0.6]], dtype=np.float32)
|
||||||
# true centers are the unit vectors on lines y=x and y=1.5x
|
# true centers are the unit vectors on lines y=x and y=1.5x
|
||||||
true_centers = np.array([[0.70710678, 0.70710678], [0.5547002, 0.83205029]])
|
true_centers = np.array(
|
||||||
|
[[0.70710678, 0.70710678], [0.5547002, 0.83205029]], dtype=np.float32)
|
||||||
kmeans = KMeans(2,
|
kmeans = KMeans(2,
|
||||||
initial_clusters=kmeans_ops.RANDOM_INIT,
|
initial_clusters=kmeans_ops.RANDOM_INIT,
|
||||||
distance_metric=kmeans_ops.COSINE_DISTANCE,
|
distance_metric=kmeans_ops.COSINE_DISTANCE,
|
||||||
@ -168,8 +170,9 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
np.sort(true_centers, axis=0))
|
np.sort(true_centers, axis=0))
|
||||||
|
|
||||||
def test_transform_with_cosine_distance(self):
|
def test_transform_with_cosine_distance(self):
|
||||||
points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18],
|
points = np.array(
|
||||||
[-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]])
|
[[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2],
|
||||||
|
[0.1, 2.5], [0.2, 2], [0.1, 3], [0.2, 4]], dtype=np.float32)
|
||||||
|
|
||||||
true_centers = [normalize(np.mean(normalize(points)[4:, :], axis=0,
|
true_centers = [normalize(np.mean(normalize(points)[4:, :], axis=0,
|
||||||
keepdims=True))[0],
|
keepdims=True))[0],
|
||||||
@ -180,8 +183,8 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
initial_clusters=kmeans_ops.RANDOM_INIT,
|
initial_clusters=kmeans_ops.RANDOM_INIT,
|
||||||
distance_metric=kmeans_ops.COSINE_DISTANCE,
|
distance_metric=kmeans_ops.COSINE_DISTANCE,
|
||||||
use_mini_batch=self.use_mini_batch,
|
use_mini_batch=self.use_mini_batch,
|
||||||
config=self.config(3))
|
config=self.config(5))
|
||||||
kmeans.fit(x=points, steps=30, batch_size=8)
|
kmeans.fit(x=points, steps=50, batch_size=8)
|
||||||
|
|
||||||
centers = normalize(kmeans.clusters())
|
centers = normalize(kmeans.clusters())
|
||||||
self.assertAllClose(np.sort(centers, axis=0),
|
self.assertAllClose(np.sort(centers, axis=0),
|
||||||
@ -193,16 +196,16 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
self.assertAllClose(transform, true_transform, atol=1e-3)
|
self.assertAllClose(transform, true_transform, atol=1e-3)
|
||||||
|
|
||||||
def test_predict_with_cosine_distance(self):
|
def test_predict_with_cosine_distance(self):
|
||||||
points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18],
|
points = np.array(
|
||||||
[-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]]).astype(
|
[[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2],
|
||||||
np.float32)
|
[0.1, 2.5], [0.2, 2], [0.1, 3], [0.2, 4]], dtype=np.float32)
|
||||||
true_centers = np.array(
|
true_centers = np.array(
|
||||||
[normalize(np.mean(normalize(points)[0:4, :],
|
[normalize(np.mean(normalize(points)[0:4, :],
|
||||||
axis=0,
|
axis=0,
|
||||||
keepdims=True))[0],
|
keepdims=True))[0],
|
||||||
normalize(np.mean(normalize(points)[4:, :],
|
normalize(np.mean(normalize(points)[4:, :],
|
||||||
axis=0,
|
axis=0,
|
||||||
keepdims=True))[0]])
|
keepdims=True))[0]], dtype=np.float32)
|
||||||
true_assignments = [0] * 4 + [1] * 4
|
true_assignments = [0] * 4 + [1] * 4
|
||||||
true_score = len(points) - np.tensordot(normalize(points),
|
true_score = len(points) - np.tensordot(normalize(points),
|
||||||
true_centers[true_assignments])
|
true_centers[true_assignments])
|
||||||
@ -230,14 +233,14 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
# the less populated centers.
|
# the less populated centers.
|
||||||
points = np.array([[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3],
|
points = np.array([[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3],
|
||||||
[-3.1, -3.2], [-2.8, -3.], [-2.9, -3.1], [-3., -3.1],
|
[-3.1, -3.2], [-2.8, -3.], [-2.9, -3.1], [-3., -3.1],
|
||||||
[-3., -3.1], [-3.2, -3.], [-3., -3.]]).astype(np.float32)
|
[-3., -3.1], [-3.2, -3.], [-3., -3.]], dtype=np.float32)
|
||||||
true_centers = np.array(
|
true_centers = np.array(
|
||||||
[normalize(np.mean(normalize(points)[0:2, :], axis=0,
|
[normalize(np.mean(normalize(points)[0:2, :], axis=0,
|
||||||
keepdims=True))[0],
|
keepdims=True))[0],
|
||||||
normalize(np.mean(normalize(points)[2:4, :], axis=0,
|
normalize(np.mean(normalize(points)[2:4, :], axis=0,
|
||||||
keepdims=True))[0],
|
keepdims=True))[0],
|
||||||
normalize(np.mean(normalize(points)[4:, :], axis=0,
|
normalize(np.mean(normalize(points)[4:, :], axis=0,
|
||||||
keepdims=True))[0]])
|
keepdims=True))[0]], dtype=np.float32)
|
||||||
true_assignments = [0] * 2 + [1] * 2 + [2] * 8
|
true_assignments = [0] * 2 + [1] * 2 + [2] * 8
|
||||||
true_score = len(points) - np.tensordot(normalize(points),
|
true_score = len(points) - np.tensordot(normalize(points),
|
||||||
true_centers[true_assignments])
|
true_centers[true_assignments])
|
||||||
@ -262,7 +265,7 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
self.assertAllClose(score, true_score, atol=1e-2)
|
self.assertAllClose(score, true_score, atol=1e-2)
|
||||||
|
|
||||||
def test_fit_raise_if_num_clusters_larger_than_num_points_random_init(self):
|
def test_fit_raise_if_num_clusters_larger_than_num_points_random_init(self):
|
||||||
points = np.array([[2.0, 3.0], [1.6, 8.2]])
|
points = np.array([[2.0, 3.0], [1.6, 8.2]], dtype=np.float32)
|
||||||
|
|
||||||
with self.assertRaisesOpError('less'):
|
with self.assertRaisesOpError('less'):
|
||||||
kmeans = KMeans(num_clusters=3, initial_clusters=kmeans_ops.RANDOM_INIT)
|
kmeans = KMeans(num_clusters=3, initial_clusters=kmeans_ops.RANDOM_INIT)
|
||||||
@ -270,7 +273,7 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus(
|
def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus(
|
||||||
self):
|
self):
|
||||||
points = np.array([[2.0, 3.0], [1.6, 8.2]])
|
points = np.array([[2.0, 3.0], [1.6, 8.2]], dtype=np.float32)
|
||||||
|
|
||||||
with self.assertRaisesOpError(AssertionError):
|
with self.assertRaisesOpError(AssertionError):
|
||||||
kmeans = KMeans(num_clusters=3,
|
kmeans = KMeans(num_clusters=3,
|
||||||
|
@ -21,10 +21,12 @@
|
|||||||
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
|
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace ffmpeg {
|
namespace ffmpeg {
|
||||||
@ -62,13 +64,11 @@ class FileDeleter {
|
|||||||
|
|
||||||
class DecodeAudioOp : public OpKernel {
|
class DecodeAudioOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit DecodeAudioOp(OpKernelConstruction* context)
|
explicit DecodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
: OpKernel(context) {
|
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
||||||
file_format_ = str_util::Lowercase(file_format_);
|
file_format_ = str_util::Lowercase(file_format_);
|
||||||
const std::set<string> valid_file_formats(
|
const std::set<string> valid_file_formats(
|
||||||
kValidFileFormats,
|
kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
|
||||||
kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
|
|
||||||
OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1,
|
OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"file_format arg must be in {",
|
"file_format arg must be in {",
|
||||||
@ -79,8 +79,7 @@ class DecodeAudioOp : public OpKernel {
|
|||||||
OP_REQUIRES(context, samples_per_second_ > 0,
|
OP_REQUIRES(context, samples_per_second_ > 0,
|
||||||
errors::InvalidArgument("samples_per_second must be > 0."));
|
errors::InvalidArgument("samples_per_second must be > 0."));
|
||||||
|
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context, context->GetAttr("channel_count", &channel_count_));
|
||||||
context, context->GetAttr("channel_count", &channel_count_));
|
|
||||||
OP_REQUIRES(context, channel_count_ > 0,
|
OP_REQUIRES(context, channel_count_ > 0,
|
||||||
errors::InvalidArgument("channel_count must be > 0."));
|
errors::InvalidArgument("channel_count must be > 0."));
|
||||||
}
|
}
|
||||||
@ -112,12 +111,18 @@ class DecodeAudioOp : public OpKernel {
|
|||||||
context, result.ok(),
|
context, result.ok(),
|
||||||
errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg "
|
errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg "
|
||||||
"can be found at http://www.ffmpeg.org."));
|
"can be found at http://www.ffmpeg.org."));
|
||||||
|
} else if (result.code() == error::UNKNOWN) {
|
||||||
|
LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message()
|
||||||
|
<< "'. Returning empty tensor.";
|
||||||
|
Tensor* output = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(0, TensorShape({0, 0}), &output));
|
||||||
|
return;
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(context, result);
|
OP_REQUIRES_OK(context, result);
|
||||||
}
|
}
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, !output_samples.empty(),
|
||||||
context, !output_samples.empty(),
|
errors::Unknown("No output created by FFmpeg."));
|
||||||
errors::Unknown("No output created by FFmpeg."));
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, output_samples.size() % channel_count_ == 0,
|
context, output_samples.size() % channel_count_ == 0,
|
||||||
errors::Unknown("FFmpeg created non-integer number of audio frames."));
|
errors::Unknown("FFmpeg created non-integer number of audio frames."));
|
||||||
@ -125,9 +130,9 @@ class DecodeAudioOp : public OpKernel {
|
|||||||
// Copy the output data to the output Tensor.
|
// Copy the output data to the output Tensor.
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
const int64 frame_count = output_samples.size() / channel_count_;
|
const int64 frame_count = output_samples.size() / channel_count_;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context,
|
||||||
context, context->allocate_output(
|
context->allocate_output(
|
||||||
0, TensorShape({frame_count, channel_count_}), &output));
|
0, TensorShape({frame_count, channel_count_}), &output));
|
||||||
auto matrix = output->tensor<float, 2>();
|
auto matrix = output->tensor<float, 2>();
|
||||||
for (int32 frame = 0; frame < frame_count; ++frame) {
|
for (int32 frame = 0; frame < frame_count; ++frame) {
|
||||||
for (int32 channel = 0; channel < channel_count_; ++channel) {
|
for (int32 channel = 0; channel < channel_count_; ++channel) {
|
||||||
@ -151,6 +156,15 @@ REGISTER_OP("DecodeAudio")
|
|||||||
.Attr("file_format: string")
|
.Attr("file_format: string")
|
||||||
.Attr("samples_per_second: int")
|
.Attr("samples_per_second: int")
|
||||||
.Attr("channel_count: int")
|
.Attr("channel_count: int")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
int64 channels;
|
||||||
|
if (c->GetAttr("channel_count", &channels).ok()) {
|
||||||
|
c->set_output(0, c->Matrix(c->UnknownDim(), channels));
|
||||||
|
} else {
|
||||||
|
c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Processes the contents of an audio file into a tensor using FFmpeg to decode
|
Processes the contents of an audio file into a tensor using FFmpeg to decode
|
||||||
the file.
|
the file.
|
||||||
@ -162,7 +176,8 @@ different from the contents of the file, channels will be merged or created.
|
|||||||
|
|
||||||
contents: The binary audio file contents.
|
contents: The binary audio file contents.
|
||||||
sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0
|
sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0
|
||||||
is time and dimension 1 is the channel.
|
is time and dimension 1 is the channel. If ffmpeg fails to decode the audio
|
||||||
|
then an empty tensor will be returned.
|
||||||
file_format: A string describing the audio file format. This can be "wav" or
|
file_format: A string describing the audio file format. This can be "wav" or
|
||||||
"mp3".
|
"mp3".
|
||||||
samples_per_second: The number of samples per second that the audio should have.
|
samples_per_second: The number of samples per second that the audio should have.
|
||||||
|
@ -72,6 +72,14 @@ class DecodeAudioOpTest(tf.test.TestCase):
|
|||||||
def testOgg(self):
|
def testOgg(self):
|
||||||
self._loadFileAndTest('mono_10khz.ogg', 'ogg', 0.57, 10000, 1)
|
self._loadFileAndTest('mono_10khz.ogg', 'ogg', 0.57, 10000, 1)
|
||||||
|
|
||||||
|
def testInvalidFile(self):
|
||||||
|
with self.test_session():
|
||||||
|
contents = 'invalid file'
|
||||||
|
audio_op = ffmpeg.decode_audio(contents, file_format='wav',
|
||||||
|
samples_per_second=10000, channel_count=2)
|
||||||
|
audio = audio_op.eval()
|
||||||
|
self.assertEqual(audio.shape, (0, 0))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -38,7 +38,6 @@ namespace {
|
|||||||
const char kFfmpegExecutable[] = "ffmpeg";
|
const char kFfmpegExecutable[] = "ffmpeg";
|
||||||
const int32 kDefaultProbeSize = 5000000; // 5MB
|
const int32 kDefaultProbeSize = 5000000; // 5MB
|
||||||
|
|
||||||
|
|
||||||
std::vector<string> FfmpegCommandLine(const string& input_filename,
|
std::vector<string> FfmpegCommandLine(const string& input_filename,
|
||||||
const string& output_filename,
|
const string& output_filename,
|
||||||
const string& input_format_id,
|
const string& input_format_id,
|
||||||
@ -63,6 +62,39 @@ std::vector<string> FfmpegCommandLine(const string& input_filename,
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Is a named binary installed and executable by the current process?
|
||||||
|
// Note that this is harder than it seems like it should be...
|
||||||
|
bool IsBinaryInstalled(const string& binary_name) {
|
||||||
|
string path = ::getenv("PATH");
|
||||||
|
for (const string& dir : str_util::Split(path, ':')) {
|
||||||
|
const string binary_path = io::JoinPath(dir, binary_name);
|
||||||
|
char absolute_path[PATH_MAX + 1];
|
||||||
|
::realpath(binary_path.c_str(), absolute_path);
|
||||||
|
struct stat statinfo;
|
||||||
|
int result = ::stat(absolute_path, &statinfo);
|
||||||
|
if (result < 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!S_ISREG(statinfo.st_mode)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is the current user able to execute the file?
|
||||||
|
if (statinfo.st_uid == ::geteuid() && statinfo.st_mode & S_IXUSR) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Is the current group able to execute the file?
|
||||||
|
if (statinfo.st_uid == ::getegid() && statinfo.st_mode & S_IXGRP) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Is anyone able to execute the file?
|
||||||
|
if (statinfo.st_mode & S_IXOTH) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
[[noreturn]] int ExecuteFfmpeg(const std::vector<string>& args) {
|
[[noreturn]] int ExecuteFfmpeg(const std::vector<string>& args) {
|
||||||
std::vector<char*> args_chars;
|
std::vector<char*> args_chars;
|
||||||
std::transform(args.begin(), args.end(), std::back_inserter(args_chars),
|
std::transform(args.begin(), args.end(), std::back_inserter(args_chars),
|
||||||
@ -191,6 +223,14 @@ Status ReadAudioFile(const string& filename,
|
|||||||
FfmpegCommandLine(filename, output_filename, audio_format_id,
|
FfmpegCommandLine(filename, output_filename, audio_format_id,
|
||||||
samples_per_second, channel_count);
|
samples_per_second, channel_count);
|
||||||
|
|
||||||
|
// Unfortunately, it's impossible to differentiate an exec failure due to the
|
||||||
|
// binary being missing and an error from the binary's execution. Therefore,
|
||||||
|
// check to see if the binary *should* be available. If not, return an error
|
||||||
|
// that will be converted into a helpful error message by the TensorFlow op.
|
||||||
|
if (!IsBinaryInstalled(kFfmpegExecutable)) {
|
||||||
|
return Status(error::Code::NOT_FOUND, StrCat("FFmpeg could not be found."));
|
||||||
|
}
|
||||||
|
|
||||||
// Execute ffmpeg and report errors.
|
// Execute ffmpeg and report errors.
|
||||||
pid_t child_pid = ::fork();
|
pid_t child_pid = ::fork();
|
||||||
if (child_pid < 0) {
|
if (child_pid < 0) {
|
||||||
@ -202,7 +242,7 @@ Status ReadAudioFile(const string& filename,
|
|||||||
int status_code;
|
int status_code;
|
||||||
::waitpid(child_pid, &status_code, 0);
|
::waitpid(child_pid, &status_code, 0);
|
||||||
if (status_code) {
|
if (status_code) {
|
||||||
return Status(error::Code::NOT_FOUND,
|
return Status(error::Code::UNKNOWN,
|
||||||
StrCat("FFmpeg execution failed: ", status_code));
|
StrCat("FFmpeg execution failed: ", status_code));
|
||||||
}
|
}
|
||||||
*output_samples = ReadPcmFile(output_filename);
|
*output_samples = ReadPcmFile(output_filename);
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
|
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
@ -24,8 +25,7 @@ namespace ffmpeg {
|
|||||||
|
|
||||||
class EncodeAudioOp : public OpKernel {
|
class EncodeAudioOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit EncodeAudioOp(OpKernelConstruction* context)
|
explicit EncodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
: OpKernel(context) {
|
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
||||||
file_format_ = str_util::Lowercase(file_format_);
|
file_format_ = str_util::Lowercase(file_format_);
|
||||||
OP_REQUIRES(context, file_format_ == "wav",
|
OP_REQUIRES(context, file_format_ == "wav",
|
||||||
@ -35,15 +35,15 @@ class EncodeAudioOp : public OpKernel {
|
|||||||
context, context->GetAttr("samples_per_second", &samples_per_second_));
|
context, context->GetAttr("samples_per_second", &samples_per_second_));
|
||||||
OP_REQUIRES(context, samples_per_second_ > 0,
|
OP_REQUIRES(context, samples_per_second_ > 0,
|
||||||
errors::InvalidArgument("samples_per_second must be > 0."));
|
errors::InvalidArgument("samples_per_second must be > 0."));
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context,
|
||||||
context, context->GetAttr("bits_per_second", &bits_per_second_));
|
context->GetAttr("bits_per_second", &bits_per_second_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
// Get and verify the input data.
|
// Get and verify the input data.
|
||||||
OP_REQUIRES(context, context->num_inputs() == 1,
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument(
|
context, context->num_inputs() == 1,
|
||||||
"EncodeAudio requires exactly one input."));
|
errors::InvalidArgument("EncodeAudio requires exactly one input."));
|
||||||
const Tensor& contents = context->input(0);
|
const Tensor& contents = context->input(0);
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()),
|
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -88,6 +88,7 @@ REGISTER_OP("EncodeAudio")
|
|||||||
.Attr("file_format: string")
|
.Attr("file_format: string")
|
||||||
.Attr("samples_per_second: int")
|
.Attr("samples_per_second: int")
|
||||||
.Attr("bits_per_second: int = 192000")
|
.Attr("bits_per_second: int = 192000")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Processes a `Tensor` containing sampled audio with the number of channels
|
Processes a `Tensor` containing sampled audio with the number of channels
|
||||||
and length of the audio specified by the dimensions of the `Tensor`. The
|
and length of the audio specified by the dimensions of the `Tensor`. The
|
||||||
|
@ -67,7 +67,8 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
|
|||||||
Returns:
|
Returns:
|
||||||
A rank 2 tensor that has time along dimension 0 and channels along
|
A rank 2 tensor that has time along dimension 0 and channels along
|
||||||
dimension 1. Dimension 0 will be `samples_per_second * length` wide, and
|
dimension 1. Dimension 0 will be `samples_per_second * length` wide, and
|
||||||
dimension 1 will be `channel_count` wide.
|
dimension 1 will be `channel_count` wide. If ffmpeg fails to decode the
|
||||||
|
audio then an empty tensor will be returned.
|
||||||
"""
|
"""
|
||||||
return gen_decode_audio_op_py.decode_audio(
|
return gen_decode_audio_op_py.decode_audio(
|
||||||
contents, file_format=file_format, samples_per_second=samples_per_second,
|
contents, file_format=file_format, samples_per_second=samples_per_second,
|
||||||
|
@ -14,6 +14,7 @@ py_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"__init__.py",
|
"__init__.py",
|
||||||
"python/framework/__init__.py",
|
"python/framework/__init__.py",
|
||||||
|
"python/framework/checkpoint_utils.py",
|
||||||
"python/framework/deprecation.py",
|
"python/framework/deprecation.py",
|
||||||
"python/framework/tensor_util.py",
|
"python/framework/tensor_util.py",
|
||||||
"python/ops/__init__.py",
|
"python/ops/__init__.py",
|
||||||
@ -35,10 +36,19 @@ py_test(
|
|||||||
deps = ["//tensorflow:tensorflow_py"],
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "checkpoint_utils_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/framework/checkpoint_utils_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["manual"], # http://b/30468735
|
||||||
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "ops_test",
|
name = "ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = glob(["python/ops/ops_test.py"]),
|
srcs = ["python/ops/ops_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = ["//tensorflow:tensorflow_py"],
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
)
|
)
|
||||||
@ -51,9 +61,16 @@ py_test(
|
|||||||
deps = ["//tensorflow:tensorflow_py"],
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "deprecation_test",
|
||||||
|
srcs = ["python/framework/deprecation_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "tensor_util_test",
|
name = "tensor_util_test",
|
||||||
srcs = glob(["python/framework/tensor_util_test.py"]),
|
srcs = ["python/framework/tensor_util_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = ["//tensorflow:tensorflow_py"],
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
)
|
)
|
||||||
@ -61,7 +78,7 @@ py_test(
|
|||||||
py_test(
|
py_test(
|
||||||
name = "variables_test",
|
name = "variables_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = glob(["python/ops/variables_test.py"]),
|
srcs = ["python/ops/variables_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = ["//tensorflow:tensorflow_py"],
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
)
|
)
|
||||||
@ -74,6 +91,15 @@ py_test(
|
|||||||
deps = ["//tensorflow:tensorflow_py"],
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "sampling_ops_threading_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/ops/sampling_ops_threading_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["notsan"],
|
||||||
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "all_files",
|
name = "all_files",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
|
@ -30,6 +30,7 @@
|
|||||||
|
|
||||||
## Deprecation
|
## Deprecation
|
||||||
@@deprecated
|
@@deprecated
|
||||||
|
@@deprecated_arg_values
|
||||||
|
|
||||||
## Arg_Scope
|
## Arg_Scope
|
||||||
@@arg_scope
|
@@arg_scope
|
||||||
|
@ -19,5 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
|
||||||
from tensorflow.contrib.framework.python.framework.deprecation import deprecated
|
from tensorflow.contrib.framework.python.framework.deprecation import deprecated
|
||||||
|
from tensorflow.contrib.framework.python.framework.deprecation import deprecated_arg_values
|
||||||
from tensorflow.contrib.framework.python.framework.tensor_util import *
|
from tensorflow.contrib.framework.python.framework.tensor_util import *
|
||||||
|
@ -0,0 +1,288 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tools to work with checkpoints."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.ops import gen_io_ops
|
||||||
|
from tensorflow.python.ops import state_ops
|
||||||
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import saver
|
||||||
|
from tensorflow.python.training import training as train
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_checkpoint",
|
||||||
|
"load_variable",
|
||||||
|
"list_variables",
|
||||||
|
"init_from_checkpoint"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_checkpoint_filename(filepattern):
|
||||||
|
"""Returns checkpoint filename given directory or specific filepattern."""
|
||||||
|
if gfile.IsDirectory(filepattern):
|
||||||
|
return saver.latest_checkpoint(filepattern)
|
||||||
|
return filepattern
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(filepattern):
|
||||||
|
"""Returns CheckpointReader for latest checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepattern: Directory with checkpoints file or path to checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`CheckpointReader` object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if checkpoint_dir doesn't have 'checkpoint' file or checkpoints.
|
||||||
|
"""
|
||||||
|
filename = _get_checkpoint_filename(filepattern)
|
||||||
|
if filename is None:
|
||||||
|
raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
|
||||||
|
"given directory %s" % filepattern)
|
||||||
|
return train.NewCheckpointReader(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def load_variable(checkpoint_dir, name):
|
||||||
|
"""Returns a Tensor with the contents of the given variable in the checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir: Directory with checkpoints file or path to checkpoint.
|
||||||
|
name: Name of the tensor to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` object.
|
||||||
|
"""
|
||||||
|
# TODO(b/29227106): Fix this in the right place and remove this.
|
||||||
|
if name.endswith(":0"):
|
||||||
|
name = name[:-2]
|
||||||
|
reader = load_checkpoint(checkpoint_dir)
|
||||||
|
return reader.get_tensor(name)
|
||||||
|
|
||||||
|
|
||||||
|
def list_variables(checkpoint_dir):
|
||||||
|
"""Returns list of all variables in the latest checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir: Directory with checkpoints file or path to checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tuples `(name, shape)`.
|
||||||
|
"""
|
||||||
|
reader = load_checkpoint(checkpoint_dir)
|
||||||
|
variable_map = reader.get_variable_to_shape_map()
|
||||||
|
names = sorted(variable_map.keys())
|
||||||
|
result = []
|
||||||
|
for name in names:
|
||||||
|
result.append((name, variable_map[name]))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
# Currently variable_scope doesn't provide very good APIs to access
|
||||||
|
# all variables under scope and retrieve and check existing scopes.
|
||||||
|
# TODO(ipolosukhin): Refactor variable_scope module to provide nicer APIs.
|
||||||
|
|
||||||
|
|
||||||
|
def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec,
|
||||||
|
name="checkpoint_initializer"):
|
||||||
|
"""Sets variable initializer to assign op form value in checkpoint's tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variable: `Variable` object.
|
||||||
|
file_pattern: string, where to load checkpoints from.
|
||||||
|
tensor_name: Name of the `Tensor` to load from checkpoint reader.
|
||||||
|
slice_spec: Slice specification for loading partitioned variables.
|
||||||
|
name: Name of the operation.
|
||||||
|
"""
|
||||||
|
base_type = variable.dtype.base_dtype
|
||||||
|
restore_op = gen_io_ops._restore_slice(
|
||||||
|
file_pattern,
|
||||||
|
tensor_name,
|
||||||
|
slice_spec,
|
||||||
|
base_type,
|
||||||
|
preferred_shard=-1,
|
||||||
|
name=name)
|
||||||
|
variable._initializer_op = state_ops.assign(variable, restore_op)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_variable_or_list_initializer(variable_or_list, file_pattern,
|
||||||
|
tensor_name):
|
||||||
|
if isinstance(variable_or_list, (list, tuple)):
|
||||||
|
# A set of slices.
|
||||||
|
slice_name = None
|
||||||
|
for v in variable_or_list:
|
||||||
|
if slice_name is None:
|
||||||
|
slice_name = v._save_slice_info.full_name
|
||||||
|
elif slice_name != v._save_slice_info.full_name:
|
||||||
|
raise ValueError("Slices must all be from the same tensor: %s != %s" %
|
||||||
|
(slice_name, v._save_slice_info.full_name))
|
||||||
|
_set_checkpoint_initializer(v, file_pattern, tensor_name,
|
||||||
|
v._save_slice_info.spec)
|
||||||
|
else:
|
||||||
|
_set_checkpoint_initializer(variable_or_list, file_pattern, tensor_name, "")
|
||||||
|
|
||||||
|
|
||||||
|
def init_from_checkpoint(checkpoint_dir, assignment_map):
|
||||||
|
"""Using assingment map initializes current variables with loaded tensors.
|
||||||
|
|
||||||
|
Note: This overrides default initialization ops of specified variables and
|
||||||
|
redefines dtype.
|
||||||
|
|
||||||
|
Assignment map supports following syntax:
|
||||||
|
`'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
|
||||||
|
current `scope_name` from `checkpoint_scope_name` with matching variable
|
||||||
|
names.
|
||||||
|
`'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
|
||||||
|
will initalize `scope_name/variable_name` variable
|
||||||
|
from `checkpoint_scope_name/some_other_variable`.
|
||||||
|
`'scope_variable_name': variable` - will initialize given `tf.Variable`
|
||||||
|
object with variable from the checkpoint.
|
||||||
|
`'scope_variable_name': list(variable)` - will initialize list of
|
||||||
|
partitioned variables with variable from the checkpoint.
|
||||||
|
`'scope_name/': '/'` - will load all variables in current `scope_name` from
|
||||||
|
checkpoint's root (e.g. no scope).
|
||||||
|
|
||||||
|
Supports loading into partitioned variables, which are represented as
|
||||||
|
'<variable>/part_<part #>'.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# Create variables.
|
||||||
|
with tf.variable_scope('test'):
|
||||||
|
m = tf.get_variable('my_var')
|
||||||
|
with tf.variable_scope('test2'):
|
||||||
|
var2 = tf.get_variable('my_var')
|
||||||
|
var3 = tf.get_variable(name="my1", shape=[100, 100],
|
||||||
|
partitioner=lambda shape, dtype: [5, 1])
|
||||||
|
...
|
||||||
|
# Specify which variables to intialize from checkpoint.
|
||||||
|
init_from_checkpoint(checkpoint_dir, {
|
||||||
|
'some_var': 'test/my_var',
|
||||||
|
'some_scope/': 'test2/'})
|
||||||
|
...
|
||||||
|
# Or use `Variable` objects to identify what to initialize.
|
||||||
|
init_from_checkpoint(checkpoint_dir, {
|
||||||
|
'some_scope/var2': var2,
|
||||||
|
})
|
||||||
|
# Initialize partitioned variables
|
||||||
|
init_from_checkpoint(checkpoint_dir, {
|
||||||
|
'some_var_from_ckpt': 'part_var',
|
||||||
|
})
|
||||||
|
# Or specifying the list of `Variable` objects.
|
||||||
|
init_from_checkpoint(checkpoint_dir, {
|
||||||
|
'some_var_from_ckpt': var3._get_variable_list(),
|
||||||
|
})
|
||||||
|
...
|
||||||
|
# Initialize variables as usual.
|
||||||
|
session.run(tf.get_all_variables())
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir: Directory with checkpoints file or path to checkpoint.
|
||||||
|
assignment_map: Dict, where keys are names of the variables in the
|
||||||
|
checkpoint and values are current variables or names of current variables
|
||||||
|
(in default graph).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
|
||||||
|
ValueError: If missing variables in current graph.
|
||||||
|
"""
|
||||||
|
filepattern = _get_checkpoint_filename(checkpoint_dir)
|
||||||
|
reader = load_checkpoint(checkpoint_dir)
|
||||||
|
variable_map = reader.get_variable_to_shape_map()
|
||||||
|
for tensor_name_in_ckpt, current_var_or_name in six.iteritems(assignment_map):
|
||||||
|
var = None
|
||||||
|
# Check if this is Variable object or list of Variable objects (in case of
|
||||||
|
# partitioned variables).
|
||||||
|
is_var = lambda x: isinstance(x, variables.Variable)
|
||||||
|
if is_var(current_var_or_name) or (
|
||||||
|
isinstance(current_var_or_name, list)
|
||||||
|
and all(is_var(v) for v in current_var_or_name)):
|
||||||
|
var = current_var_or_name
|
||||||
|
else:
|
||||||
|
var_scope = vs._get_default_variable_store()
|
||||||
|
# Check if this variable is in var_store.
|
||||||
|
var = var_scope._vars.get(current_var_or_name, None)
|
||||||
|
# Also check if variable is partitioned as list.
|
||||||
|
if var is None:
|
||||||
|
if current_var_or_name + "/part_0" in var_scope._vars:
|
||||||
|
var = []
|
||||||
|
i = 0
|
||||||
|
while current_var_or_name + "/part_%d" % i in var_scope._vars:
|
||||||
|
var.append(var_scope._vars[current_var_or_name + "/part_%d" % i])
|
||||||
|
i += 1
|
||||||
|
if var is not None:
|
||||||
|
# If 1 to 1 mapping was provided, find variable in the checkpoint.
|
||||||
|
if tensor_name_in_ckpt not in variable_map:
|
||||||
|
raise ValueError("Tensor %s is not found in %s checkpoint" % (
|
||||||
|
tensor_name_in_ckpt, checkpoint_dir
|
||||||
|
))
|
||||||
|
if is_var(var):
|
||||||
|
# Additional at-call-time checks.
|
||||||
|
if not var.get_shape().is_compatible_with(
|
||||||
|
variable_map[tensor_name_in_ckpt]):
|
||||||
|
raise ValueError(
|
||||||
|
"Shape of variable %s (%s) doesn't match with shape of "
|
||||||
|
"tensor %s (%s) from checkpoint reader." % (
|
||||||
|
var.name, str(var.get_shape()),
|
||||||
|
tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
|
||||||
|
))
|
||||||
|
var_name = var.name
|
||||||
|
else:
|
||||||
|
var_name = ",".join([v.name for v in var])
|
||||||
|
_set_variable_or_list_initializer(var, filepattern, tensor_name_in_ckpt)
|
||||||
|
logging.info("Initialize variable %s from checkpoint %s with %s" % (
|
||||||
|
var_name, checkpoint_dir, tensor_name_in_ckpt
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
scopes = ""
|
||||||
|
# TODO(vihanjain): Support list of 'current_var_or_name' here.
|
||||||
|
if "/" in current_var_or_name:
|
||||||
|
scopes = current_var_or_name[:current_var_or_name.rindex("/")]
|
||||||
|
if not tensor_name_in_ckpt.endswith("/"):
|
||||||
|
raise ValueError(
|
||||||
|
"Assignment map with scope only name {} should map to scope only "
|
||||||
|
"{}. Should be 'scope/': 'other_scope/'.".format(
|
||||||
|
scopes, tensor_name_in_ckpt))
|
||||||
|
# If scope to scope mapping was provided, find all variables in the scope.
|
||||||
|
for var_name in var_scope._vars:
|
||||||
|
if var_name.startswith(scopes):
|
||||||
|
# Lookup name with specified prefix and suffix from current variable.
|
||||||
|
# If tensor_name given is '/' (root), don't use it for full name.
|
||||||
|
if tensor_name_in_ckpt != "/":
|
||||||
|
full_tensor_name = tensor_name_in_ckpt + var_name[len(scopes) + 1:]
|
||||||
|
else:
|
||||||
|
full_tensor_name = var_name[len(scopes) + 1:]
|
||||||
|
if full_tensor_name not in variable_map:
|
||||||
|
raise ValueError(
|
||||||
|
"Tensor %s (%s in %s) is not found in %s checkpoint" % (
|
||||||
|
full_tensor_name, var_name[len(scopes) + 1:],
|
||||||
|
tensor_name_in_ckpt, checkpoint_dir
|
||||||
|
))
|
||||||
|
var = var_scope._vars[var_name]
|
||||||
|
_set_variable_or_list_initializer(var, filepattern, full_tensor_name)
|
||||||
|
logging.info("Initialize variable %s from checkpoint %s with %s" % (
|
||||||
|
var_name, checkpoint_dir, full_tensor_name
|
||||||
|
))
|
||||||
|
# pylint: enable=protected-access
|
@ -23,8 +23,6 @@ import os
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn.utils import checkpoints
|
|
||||||
|
|
||||||
|
|
||||||
def _create_checkpoints(sess, checkpoint_dir):
|
def _create_checkpoints(sess, checkpoint_dir):
|
||||||
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
|
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
|
||||||
@ -45,15 +43,14 @@ def _create_checkpoints(sess, checkpoint_dir):
|
|||||||
def _create_partition_checkpoints(sess, checkpoint_dir):
|
def _create_partition_checkpoints(sess, checkpoint_dir):
|
||||||
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
|
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
|
||||||
checkpoint_state_name = "checkpoint"
|
checkpoint_state_name = "checkpoint"
|
||||||
# TODO(ipolosukhin): Enable this when get_variable partitioning works.
|
v1 = tf.get_variable(
|
||||||
# v1 = tf.get_variable("var1", [100, 100],
|
name="var1",
|
||||||
# partitioner=tf.variable_axis_size_partitioner(axis=0,
|
shape=[100, 100],
|
||||||
# max_shard_bytes=512))
|
initializer=tf.truncated_normal_initializer(0.5),
|
||||||
v1 = tf.create_partitioned_variables(
|
partitioner=tf.min_max_variable_partitioner(max_partitions=5, axis=0,
|
||||||
shape=[100, 100], slicing=[5, 1], name="var1",
|
min_slice_size=8 << 10))
|
||||||
initializer=tf.truncated_normal_initializer(0.5))
|
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.initialize_all_variables())
|
||||||
v1_value = sess.run(v1)
|
v1_value = sess.run(v1._get_variable_list())
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
saver.save(sess, checkpoint_prefix, global_step=0,
|
saver.save(sess, checkpoint_prefix, global_step=0,
|
||||||
latest_filename=checkpoint_state_name)
|
latest_filename=checkpoint_state_name)
|
||||||
@ -65,30 +62,36 @@ class CheckpointsTest(tf.test.TestCase):
|
|||||||
def testNoCheckpoints(self):
|
def testNoCheckpoints(self):
|
||||||
checkpoint_dir = self.get_temp_dir() + "/no_checkpoints"
|
checkpoint_dir = self.get_temp_dir() + "/no_checkpoints"
|
||||||
with self.assertRaises(tf.errors.OpError):
|
with self.assertRaises(tf.errors.OpError):
|
||||||
self.assertAllEqual(checkpoints.load_variable(checkpoint_dir, "var1"), [])
|
self.assertAllEqual(tf.contrib.framework.load_variable(
|
||||||
|
checkpoint_dir, "var1"), [])
|
||||||
|
|
||||||
def testNoTensor(self):
|
def testNoTensor(self):
|
||||||
checkpoint_dir = self.get_temp_dir()
|
checkpoint_dir = self.get_temp_dir()
|
||||||
with self.test_session() as session:
|
with self.test_session() as session:
|
||||||
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
|
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
|
||||||
with self.assertRaises(tf.errors.OpError):
|
with self.assertRaises(tf.errors.OpError):
|
||||||
self.assertAllEqual(checkpoints.load_variable(checkpoint_dir, "var5"), [])
|
self.assertAllEqual(tf.contrib.framework.load_variable(
|
||||||
|
checkpoint_dir, "var5"), [])
|
||||||
|
|
||||||
def testGetTensor(self):
|
def testGetTensor(self):
|
||||||
checkpoint_dir = self.get_temp_dir()
|
checkpoint_dir = self.get_temp_dir()
|
||||||
with self.test_session() as session:
|
with self.test_session() as session:
|
||||||
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
|
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
|
||||||
self.assertAllEqual(checkpoints.load_variable(checkpoint_dir, "var1"), v1)
|
self.assertAllEqual(tf.contrib.framework.load_variable(
|
||||||
self.assertAllEqual(checkpoints.load_variable(checkpoint_dir, "var2"), v2)
|
checkpoint_dir, "var1"), v1)
|
||||||
self.assertAllEqual(checkpoints.load_variable(checkpoint_dir, "var3"), v3)
|
self.assertAllEqual(tf.contrib.framework.load_variable(
|
||||||
|
checkpoint_dir, "var2"), v2)
|
||||||
|
self.assertAllEqual(tf.contrib.framework.load_variable(
|
||||||
|
checkpoint_dir, "var3"), v3)
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
checkpoints.load_variable(checkpoint_dir, "useful_scope/var4"), v4)
|
tf.contrib.framework.load_variable(
|
||||||
|
checkpoint_dir, "useful_scope/var4"), v4)
|
||||||
|
|
||||||
def testGetAllVariables(self):
|
def testGetAllVariables(self):
|
||||||
checkpoint_dir = self.get_temp_dir()
|
checkpoint_dir = self.get_temp_dir()
|
||||||
with self.test_session() as session:
|
with self.test_session() as session:
|
||||||
_create_checkpoints(session, checkpoint_dir)
|
_create_checkpoints(session, checkpoint_dir)
|
||||||
self.assertEqual(checkpoints.list_variables(checkpoint_dir),
|
self.assertEqual(tf.contrib.framework.list_variables(checkpoint_dir),
|
||||||
[("useful_scope/var4", [9, 9]),
|
[("useful_scope/var4", [9, 9]),
|
||||||
("var1", [1, 10]),
|
("var1", [1, 10]),
|
||||||
("var2", [10, 10]),
|
("var2", [10, 10]),
|
||||||
@ -110,13 +113,13 @@ class CheckpointsTest(tf.test.TestCase):
|
|||||||
my4 = tf.get_variable("var4", [9, 9])
|
my4 = tf.get_variable("var4", [9, 9])
|
||||||
my3 = tf.get_variable("my3", [100, 100])
|
my3 = tf.get_variable("my3", [100, 100])
|
||||||
|
|
||||||
checkpoints.init_from_checkpoint(checkpoint_dir, {
|
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
|
||||||
"some_scope/my1": "var1",
|
"var1": "some_scope/my1",
|
||||||
"some_scope/some_other_scope/other_useful_scope/": "useful_scope/",
|
"useful_scope/": "some_scope/some_other_scope/other_useful_scope/",
|
||||||
})
|
})
|
||||||
checkpoints.init_from_checkpoint(checkpoint_dir, {
|
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
|
||||||
"some_scope/some_other_scope/my2": "var2",
|
"var2": "some_scope/some_other_scope/my2",
|
||||||
my3: "var3",
|
"var3": my3,
|
||||||
})
|
})
|
||||||
|
|
||||||
session.run(tf.initialize_all_variables())
|
session.run(tf.initialize_all_variables())
|
||||||
@ -143,8 +146,8 @@ class CheckpointsTest(tf.test.TestCase):
|
|||||||
with tf.variable_scope("useful_scope"):
|
with tf.variable_scope("useful_scope"):
|
||||||
my4 = tf.get_variable("var4", [9, 9])
|
my4 = tf.get_variable("var4", [9, 9])
|
||||||
|
|
||||||
checkpoints.init_from_checkpoint(checkpoint_dir, {
|
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
|
||||||
"some_scope/": "/",
|
"/": "some_scope/",
|
||||||
})
|
})
|
||||||
|
|
||||||
session.run(tf.initialize_all_variables())
|
session.run(tf.initialize_all_variables())
|
||||||
@ -162,23 +165,40 @@ class CheckpointsTest(tf.test.TestCase):
|
|||||||
with tf.Graph().as_default() as g:
|
with tf.Graph().as_default() as g:
|
||||||
with self.test_session(graph=g) as session:
|
with self.test_session(graph=g) as session:
|
||||||
with tf.variable_scope("some_scope"):
|
with tf.variable_scope("some_scope"):
|
||||||
# TODO(ipolosukhin): Enable this when get_variable partitioning works.
|
my1 = tf.get_variable(
|
||||||
# Currently get_variable with partitioner doesn't return Variable,
|
name="my1",
|
||||||
# but returns a concat op.
|
shape=[100, 100],
|
||||||
# my1 = tf.get_variable(
|
initializer=tf.truncated_normal_initializer(0.5),
|
||||||
# "my1", [100, 100],
|
partitioner=tf.min_max_variable_partitioner(
|
||||||
# partitioner=tf.variable_axis_size_partitioner(axis=0,
|
max_partitions=5, axis=0, min_slice_size=8 << 10))
|
||||||
# max_shard_bytes=100))
|
my1_var_list = my1._get_variable_list()
|
||||||
my1 = tf.create_partitioned_variables(
|
|
||||||
shape=[100, 100], slicing=[5, 1], name="my1",
|
|
||||||
initializer=tf.truncated_normal_initializer(0.5))
|
|
||||||
|
|
||||||
checkpoints.init_from_checkpoint(checkpoint_dir, {
|
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
|
||||||
"some_scope/my1": "var1",
|
"var1": "some_scope/my1",
|
||||||
})
|
})
|
||||||
|
|
||||||
session.run(tf.initialize_all_variables())
|
session.run(tf.initialize_all_variables())
|
||||||
my1_values = session.run(my1)
|
my1_values = session.run(my1_var_list)
|
||||||
|
self.assertAllEqual(my1_values, v1)
|
||||||
|
|
||||||
|
# New graph and session.
|
||||||
|
with tf.Graph().as_default() as g:
|
||||||
|
with self.test_session(graph=g) as session:
|
||||||
|
with tf.variable_scope("some_scope"):
|
||||||
|
my1 = tf.get_variable(
|
||||||
|
name="my1",
|
||||||
|
shape=[100, 100],
|
||||||
|
initializer=tf.truncated_normal_initializer(0.5),
|
||||||
|
partitioner=tf.min_max_variable_partitioner(
|
||||||
|
max_partitions=5, axis=0, min_slice_size=8 << 10))
|
||||||
|
my1_var_list = my1._get_variable_list()
|
||||||
|
|
||||||
|
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
|
||||||
|
"var1": my1_var_list,
|
||||||
|
})
|
||||||
|
|
||||||
|
session.run(tf.initialize_all_variables())
|
||||||
|
my1_values = session.run(my1_var_list)
|
||||||
self.assertAllEqual(my1_values, v1)
|
self.assertAllEqual(my1_values, v1)
|
||||||
|
|
||||||
def testInitFromCheckpointMissing(self):
|
def testInitFromCheckpointMissing(self):
|
||||||
@ -196,33 +216,33 @@ class CheckpointsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# No directory.
|
# No directory.
|
||||||
with self.assertRaises(tf.errors.OpError):
|
with self.assertRaises(tf.errors.OpError):
|
||||||
checkpoints.init_from_checkpoint("no_dir", {
|
tf.contrib.framework.init_from_checkpoint("no_dir", {
|
||||||
"some_scope/my1": "var1"})
|
"var1": "some_scope/my1"})
|
||||||
|
|
||||||
# No variable in checkpoint.
|
# No variable in checkpoint.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
checkpoints.init_from_checkpoint(checkpoint_dir, {
|
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
|
||||||
"some_scope/my1": "no_var"})
|
"no_var": "some_scope/my1"})
|
||||||
|
|
||||||
# No variable in the graph.
|
# No variable in the graph.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
checkpoints.init_from_checkpoint(checkpoint_dir, {
|
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
|
||||||
"some_scope/no_var": "var3"})
|
"var3": "some_scope/no_var"})
|
||||||
|
|
||||||
# Shape mismatch.
|
# Shape mismatch.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
checkpoints.init_from_checkpoint(checkpoint_dir, {
|
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
|
||||||
"some_scope/my1": "var1"})
|
"var1": "some_scope/my1"})
|
||||||
|
|
||||||
# Variable 'my1' and 'my2' are missing in given checkpoint scope.
|
# Variable 'my1' and 'my2' are missing in given checkpoint scope.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
checkpoints.init_from_checkpoint(checkpoint_dir, {
|
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
|
||||||
"some_scope/": "useful_scope/"})
|
"useful_scope/": "some_scope/"})
|
||||||
|
|
||||||
# Mapping is not to scope name.
|
# Mapping is not to scope name.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
checkpoints.init_from_checkpoint(checkpoint_dir, {
|
tf.contrib.framework.init_from_checkpoint(checkpoint_dir, {
|
||||||
"some_scope/": "useful_scope"})
|
"useful_scope": "some_scope/"})
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -34,43 +36,77 @@ def _get_qualified_name(function):
|
|||||||
return function.__name__
|
return function.__name__
|
||||||
|
|
||||||
|
|
||||||
def _add_deprecation_to_docstring(doc, date, instructions):
|
def _add_deprecation_to_docstring(
|
||||||
|
doc, instructions, no_doc_str, suffix_str, notice):
|
||||||
"""Adds a deprecation notice to a docstring."""
|
"""Adds a deprecation notice to a docstring."""
|
||||||
lines = doc.splitlines()
|
if not doc:
|
||||||
|
lines = [no_doc_str]
|
||||||
|
else:
|
||||||
|
lines = doc.splitlines()
|
||||||
|
lines[0] += ' ' + suffix_str
|
||||||
|
|
||||||
lines[0] += ' (deprecated)'
|
notice = [''] + notice + [instructions]
|
||||||
|
|
||||||
notice = [
|
|
||||||
'',
|
|
||||||
'THIS FUNCTION IS DEPRECATED. It will be removed after %s.' % date,
|
|
||||||
'Instructions for updating:',
|
|
||||||
'%s' % instructions,
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(lines) > 1:
|
if len(lines) > 1:
|
||||||
# Make sure that we keep our distance from the main body
|
# Make sure that we keep our distance from the main body
|
||||||
if lines[1].strip():
|
if lines[1].strip():
|
||||||
notice += ['']
|
notice.append('')
|
||||||
|
|
||||||
lines = [lines[0]] + notice + lines[1:]
|
lines[1:1] = notice
|
||||||
else:
|
else:
|
||||||
lines += notice
|
lines += notice
|
||||||
|
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_deprecated_function_notice_to_docstring(doc, date, instructions):
|
||||||
|
"""Adds a deprecation notice to a docstring for deprecated functions."""
|
||||||
|
return _add_deprecation_to_docstring(
|
||||||
|
doc, instructions,
|
||||||
|
'DEPRECATED FUNCTION',
|
||||||
|
'(deprecated)', [
|
||||||
|
'THIS FUNCTION IS DEPRECATED. It will be removed after %s.' % date,
|
||||||
|
'Instructions for updating:'])
|
||||||
|
|
||||||
|
|
||||||
|
def _add_deprecated_arg_notice_to_docstring(doc, date, instructions):
|
||||||
|
"""Adds a deprecation notice to a docstring for deprecated arguments."""
|
||||||
|
return _add_deprecation_to_docstring(
|
||||||
|
doc, instructions,
|
||||||
|
'DEPRECATED FUNCTION ARGUMENTS',
|
||||||
|
'(deprecated arguments)', [
|
||||||
|
'SOME ARGUMENTS ARE DEPRECATED. '
|
||||||
|
'They will be removed after %s.' % date,
|
||||||
|
'Instructions for updating:'])
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_deprecation_args(date, instructions):
|
||||||
|
if not date:
|
||||||
|
raise ValueError('Tell us what date this will be deprecated!')
|
||||||
|
if not re.match(r'20\d\d-[01]\d-[0123]\d', date):
|
||||||
|
raise ValueError('Date must be YYYY-MM-DD.')
|
||||||
|
if not instructions:
|
||||||
|
raise ValueError('Don\'t deprecate things without conversion instructions!')
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_callable(func, decorator_name):
|
||||||
|
if not hasattr(func, '__call__'):
|
||||||
|
raise ValueError(
|
||||||
|
'%s is not a function. If this is a property, '
|
||||||
|
'apply @%s after @property.' % (func, decorator_name))
|
||||||
|
|
||||||
|
|
||||||
def deprecated(date, instructions):
|
def deprecated(date, instructions):
|
||||||
"""Decorator for marking functions or methods deprecated.
|
"""Decorator for marking functions or methods deprecated.
|
||||||
|
|
||||||
This decorator adds a deprecation warning to a function's docstring. It has
|
This decorator logs a deprecation warning whenever the decorated function is
|
||||||
the following format:
|
called. It has the following format:
|
||||||
|
|
||||||
<function> (from <module>) is deprecated and will be removed after <date>.
|
<function> (from <module>) is deprecated and will be removed after <date>.
|
||||||
Instructions for updating:
|
Instructions for updating:
|
||||||
<instructions>
|
<instructions>
|
||||||
|
|
||||||
whenever the decorated function is called. <function> will include the class
|
<function> will include the class name if it is a method.
|
||||||
name if it is a method.
|
|
||||||
|
|
||||||
It also edits the docstring of the function: ' (deprecated)' is appended
|
It also edits the docstring of the function: ' (deprecated)' is appended
|
||||||
to the first line of the docstring and a deprecation notice is prepended
|
to the first line of the docstring and a deprecation notice is prepended
|
||||||
@ -88,24 +124,73 @@ def deprecated(date, instructions):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If date is not in ISO 8601 format, or instructions are empty.
|
ValueError: If date is not in ISO 8601 format, or instructions are empty.
|
||||||
"""
|
"""
|
||||||
if not date:
|
_validate_deprecation_args(date, instructions)
|
||||||
raise ValueError('Tell us what date this will be deprecated!')
|
|
||||||
if not re.match(r'20\d\d-[01]\d-[0123]\d', date):
|
|
||||||
raise ValueError('Date must be YYYY-MM-DD.')
|
|
||||||
if not instructions:
|
|
||||||
raise ValueError('Don\'t deprecate things without conversion instructions!')
|
|
||||||
|
|
||||||
def deprecated_wrapper(func):
|
def deprecated_wrapper(func):
|
||||||
"""Deprecation wrapper."""
|
"""Deprecation wrapper."""
|
||||||
|
_validate_callable(func, 'deprecated')
|
||||||
|
@functools.wraps(func)
|
||||||
def new_func(*args, **kwargs):
|
def new_func(*args, **kwargs):
|
||||||
logging.warn('%s (from %s) is deprecated and will be removed after %s.\n'
|
logging.warning(
|
||||||
'Instructions for updating:\n%s',
|
'%s (from %s) is deprecated and will be removed after %s.\n'
|
||||||
_get_qualified_name(func), func.__module__,
|
'Instructions for updating:\n%s',
|
||||||
date, instructions)
|
_get_qualified_name(func), func.__module__, date, instructions)
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
new_func.__name__ = func.__name__
|
new_func.__doc__ = _add_deprecated_function_notice_to_docstring(
|
||||||
new_func.__doc__ = _add_deprecation_to_docstring(func.__doc__, date,
|
func.__doc__, date, instructions)
|
||||||
instructions)
|
return new_func
|
||||||
new_func.__dict__.update(func.__dict__)
|
return deprecated_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def deprecated_arg_values(date, instructions, **deprecated_kwargs):
|
||||||
|
"""Decorator for marking specific function argument values as deprecated.
|
||||||
|
|
||||||
|
This decorator logs a deprecation warning whenever the decorated function is
|
||||||
|
called with the deprecated argument values. It has the following format:
|
||||||
|
|
||||||
|
Calling <function> (from <module>) with <arg>=<value> is deprecated and
|
||||||
|
will be removed after <date>. Instructions for updating:
|
||||||
|
<instructions>
|
||||||
|
|
||||||
|
<function> will include the class name if it is a method.
|
||||||
|
|
||||||
|
It also edits the docstring of the function: ' (deprecated arguments)' is
|
||||||
|
appended to the first line of the docstring and a deprecation notice is
|
||||||
|
prepended to the rest of the docstring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date: String. The date the function is scheduled to be removed. Must be
|
||||||
|
ISO 8601 (YYYY-MM-DD).
|
||||||
|
instructions: String. Instructions on how to update code using the
|
||||||
|
deprecated function.
|
||||||
|
**deprecated_kwargs: The deprecated argument values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorated function or method.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If date is not in ISO 8601 format, or instructions are empty.
|
||||||
|
"""
|
||||||
|
_validate_deprecation_args(date, instructions)
|
||||||
|
if not deprecated_kwargs:
|
||||||
|
raise ValueError('Specify which argument values are deprecated.')
|
||||||
|
|
||||||
|
def deprecated_wrapper(func):
|
||||||
|
"""Deprecation decorator."""
|
||||||
|
_validate_callable(func, 'deprecated_arg_values')
|
||||||
|
@functools.wraps(func)
|
||||||
|
def new_func(*args, **kwargs):
|
||||||
|
"""Deprecation wrapper."""
|
||||||
|
named_args = inspect.getcallargs(func, *args, **kwargs)
|
||||||
|
for arg_name, arg_value in deprecated_kwargs.items():
|
||||||
|
if arg_name in named_args and named_args[arg_name] == arg_value:
|
||||||
|
logging.warning(
|
||||||
|
'Calling %s (from %s) with %s=%s is deprecated and will be '
|
||||||
|
'removed after %s.\nInstructions for updating:\n%s',
|
||||||
|
_get_qualified_name(func), func.__module__,
|
||||||
|
arg_name, arg_value, date, instructions)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
|
||||||
|
func.__doc__, date, instructions)
|
||||||
return new_func
|
return new_func
|
||||||
return deprecated_wrapper
|
return deprecated_wrapper
|
||||||
|
@ -0,0 +1,488 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""tensor_util tests."""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.contrib.framework.python.framework import deprecation
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecationTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _assert_subset(self, expected_subset, actual_set):
|
||||||
|
self.assertTrue(
|
||||||
|
actual_set.issuperset(expected_subset),
|
||||||
|
msg="%s is not a superset of %s." % (actual_set, expected_subset))
|
||||||
|
|
||||||
|
def test_deprecated_illegal_args(self):
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
with self.assertRaisesRegexp(ValueError, "date"):
|
||||||
|
deprecation.deprecated(None, instructions)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "date"):
|
||||||
|
deprecation.deprecated("", instructions)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "YYYY-MM-DD"):
|
||||||
|
deprecation.deprecated("07-04-2016", instructions)
|
||||||
|
date = "2016-07-04"
|
||||||
|
with self.assertRaisesRegexp(ValueError, "instructions"):
|
||||||
|
deprecation.deprecated(date, None)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "instructions"):
|
||||||
|
deprecation.deprecated(date, "")
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_static_fn_with_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
@deprecation.deprecated(date, instructions)
|
||||||
|
def _fn(arg0, arg1):
|
||||||
|
"""fn doc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arg0: Arg 0.
|
||||||
|
arg1: Arg 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sum of args.
|
||||||
|
"""
|
||||||
|
return arg0 + arg1
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual("_fn", _fn.__name__)
|
||||||
|
self.assertEqual(
|
||||||
|
"fn doc. (deprecated)"
|
||||||
|
"\n"
|
||||||
|
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
|
||||||
|
"\nInstructions for updating:\n%s"
|
||||||
|
"\n"
|
||||||
|
"\n Args:"
|
||||||
|
"\n arg0: Arg 0."
|
||||||
|
"\n arg1: Arg 1."
|
||||||
|
"\n"
|
||||||
|
"\n Returns:"
|
||||||
|
"\n Sum of args."
|
||||||
|
"\n " % (date, instructions),
|
||||||
|
_fn.__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_static_fn_with_one_line_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
@deprecation.deprecated(date, instructions)
|
||||||
|
def _fn(arg0, arg1):
|
||||||
|
"""fn doc."""
|
||||||
|
return arg0 + arg1
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual("_fn", _fn.__name__)
|
||||||
|
self.assertEqual(
|
||||||
|
"fn doc. (deprecated)"
|
||||||
|
"\n"
|
||||||
|
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
|
||||||
|
"\nInstructions for updating:\n%s" % (date, instructions),
|
||||||
|
_fn.__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_static_fn_no_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
@deprecation.deprecated(date, instructions)
|
||||||
|
def _fn(arg0, arg1):
|
||||||
|
return arg0 + arg1
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual("_fn", _fn.__name__)
|
||||||
|
self.assertEqual(
|
||||||
|
"DEPRECATED FUNCTION"
|
||||||
|
"\n"
|
||||||
|
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
|
||||||
|
"\nInstructions for updating:"
|
||||||
|
"\n%s" % (date, instructions),
|
||||||
|
_fn.__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_instance_fn_with_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
class _Object(object):
|
||||||
|
|
||||||
|
def __init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@deprecation.deprecated(date, instructions)
|
||||||
|
def _fn(self, arg0, arg1):
|
||||||
|
"""fn doc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arg0: Arg 0.
|
||||||
|
arg1: Arg 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sum of args.
|
||||||
|
"""
|
||||||
|
return arg0 + arg1
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual(
|
||||||
|
"fn doc. (deprecated)"
|
||||||
|
"\n"
|
||||||
|
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
|
||||||
|
"\nInstructions for updating:\n%s"
|
||||||
|
"\n"
|
||||||
|
"\n Args:"
|
||||||
|
"\n arg0: Arg 0."
|
||||||
|
"\n arg1: Arg 1."
|
||||||
|
"\n"
|
||||||
|
"\n Returns:"
|
||||||
|
"\n Sum of args."
|
||||||
|
"\n " % (date, instructions),
|
||||||
|
getattr(_Object, "_fn").__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual(3, _Object()._fn(1, 2))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_instance_fn_with_one_line_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
class _Object(object):
|
||||||
|
|
||||||
|
def __init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@deprecation.deprecated(date, instructions)
|
||||||
|
def _fn(self, arg0, arg1):
|
||||||
|
"""fn doc."""
|
||||||
|
return arg0 + arg1
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual(
|
||||||
|
"fn doc. (deprecated)"
|
||||||
|
"\n"
|
||||||
|
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
|
||||||
|
"\nInstructions for updating:\n%s" % (date, instructions),
|
||||||
|
getattr(_Object, "_fn").__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual(3, _Object()._fn(1, 2))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_instance_fn_no_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
class _Object(object):
|
||||||
|
|
||||||
|
def __init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@deprecation.deprecated(date, instructions)
|
||||||
|
def _fn(self, arg0, arg1):
|
||||||
|
return arg0 + arg1
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual(
|
||||||
|
"DEPRECATED FUNCTION"
|
||||||
|
"\n"
|
||||||
|
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
|
||||||
|
"\nInstructions for updating:"
|
||||||
|
"\n%s" % (date, instructions),
|
||||||
|
getattr(_Object, "_fn").__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual(3, _Object()._fn(1, 2))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_prop_wrong_order(self, mock_warning):
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "apply @deprecated after @property"):
|
||||||
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
|
class _Object(object):
|
||||||
|
|
||||||
|
def __init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@deprecation.deprecated("2016-07-04", "Instructions.")
|
||||||
|
@property
|
||||||
|
def _prop(self):
|
||||||
|
return "prop_wrong_order"
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_prop_with_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
class _Object(object):
|
||||||
|
|
||||||
|
def __init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@deprecation.deprecated(date, instructions)
|
||||||
|
def _prop(self):
|
||||||
|
"""prop doc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String.
|
||||||
|
"""
|
||||||
|
return "prop_with_doc"
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual(
|
||||||
|
"prop doc. (deprecated)"
|
||||||
|
"\n"
|
||||||
|
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
|
||||||
|
"\nInstructions for updating:"
|
||||||
|
"\n%s"
|
||||||
|
"\n"
|
||||||
|
"\n Returns:"
|
||||||
|
"\n String."
|
||||||
|
"\n " % (date, instructions),
|
||||||
|
getattr(_Object, "_prop").__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual("prop_with_doc", _Object()._prop)
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_prop_no_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
class _Object(object):
|
||||||
|
|
||||||
|
def __init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@deprecation.deprecated(date, instructions)
|
||||||
|
def _prop(self):
|
||||||
|
return "prop_no_doc"
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual(
|
||||||
|
"DEPRECATED FUNCTION"
|
||||||
|
"\n"
|
||||||
|
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
|
||||||
|
"\nInstructions for updating:"
|
||||||
|
"\n%s" % (date, instructions),
|
||||||
|
getattr(_Object, "_prop").__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual("prop_no_doc", _Object()._prop)
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecatedArgsTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _assert_subset(self, expected_subset, actual_set):
|
||||||
|
self.assertTrue(
|
||||||
|
actual_set.issuperset(expected_subset),
|
||||||
|
msg="%s is not a superset of %s." % (actual_set, expected_subset))
|
||||||
|
|
||||||
|
def test_deprecated_illegal_args(self):
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
with self.assertRaisesRegexp(ValueError, "date"):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
None, instructions, deprecated=True)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "date"):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
"", instructions, deprecated=True)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "YYYY-MM-DD"):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
"07-04-2016", instructions, deprecated=True)
|
||||||
|
date = "2016-07-04"
|
||||||
|
with self.assertRaisesRegexp(ValueError, "instructions"):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
date, None, deprecated=True)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "instructions"):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
date, "", deprecated=True)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "argument", deprecated=True):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
date, instructions)
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_static_fn_with_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
|
||||||
|
def _fn(arg0, arg1, deprecated=True):
|
||||||
|
"""fn doc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arg0: Arg 0.
|
||||||
|
arg1: Arg 1.
|
||||||
|
deprecated: Deprecated!
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sum of args.
|
||||||
|
"""
|
||||||
|
return arg0 + arg1 if deprecated else arg1 + arg0
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual("_fn", _fn.__name__)
|
||||||
|
self.assertEqual(
|
||||||
|
"fn doc. (deprecated arguments)"
|
||||||
|
"\n"
|
||||||
|
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
|
||||||
|
"\nInstructions for updating:\n%s"
|
||||||
|
"\n"
|
||||||
|
"\n Args:"
|
||||||
|
"\n arg0: Arg 0."
|
||||||
|
"\n arg1: Arg 1."
|
||||||
|
"\n deprecated: Deprecated!"
|
||||||
|
"\n"
|
||||||
|
"\n Returns:"
|
||||||
|
"\n Sum of args."
|
||||||
|
"\n " % (date, instructions),
|
||||||
|
_fn.__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn with non-deprecated value logs nothing.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=False))
|
||||||
|
self.assertEqual(0, mock_warning.call_count)
|
||||||
|
|
||||||
|
# Assert calling new fn with deprecated value issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=True))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
# Assert calling new fn with default deprecated value issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2))
|
||||||
|
self.assertEqual(2, mock_warning.call_count)
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_static_fn_with_one_line_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
|
||||||
|
def _fn(arg0, arg1, deprecated=True):
|
||||||
|
"""fn doc."""
|
||||||
|
return arg0 + arg1 if deprecated else arg1 + arg0
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual("_fn", _fn.__name__)
|
||||||
|
self.assertEqual(
|
||||||
|
"fn doc. (deprecated arguments)"
|
||||||
|
"\n"
|
||||||
|
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
|
||||||
|
"\nInstructions for updating:\n%s" % (date, instructions),
|
||||||
|
_fn.__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn with non-deprecated value logs nothing.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=False))
|
||||||
|
self.assertEqual(0, mock_warning.call_count)
|
||||||
|
|
||||||
|
# Assert calling new fn with deprecated value issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=True))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
# Assert calling new fn with default deprecated value issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2))
|
||||||
|
self.assertEqual(2, mock_warning.call_count)
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_static_fn_no_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
|
||||||
|
def _fn(arg0, arg1, deprecated=True):
|
||||||
|
return arg0 + arg1 if deprecated else arg1 + arg0
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual("_fn", _fn.__name__)
|
||||||
|
self.assertEqual(
|
||||||
|
"DEPRECATED FUNCTION ARGUMENTS"
|
||||||
|
"\n"
|
||||||
|
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
|
||||||
|
"\nInstructions for updating:"
|
||||||
|
"\n%s" % (date, instructions),
|
||||||
|
_fn.__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn with non-deprecated value logs nothing.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=False))
|
||||||
|
self.assertEqual(0, mock_warning.call_count)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=True))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
# Assert calling new fn with default deprecated value issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2))
|
||||||
|
self.assertEqual(2, mock_warning.call_count)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
@ -27,6 +27,7 @@ from tensorflow.python.ops import data_flow_ops
|
|||||||
from tensorflow.python.ops import logging_ops
|
from tensorflow.python.ops import logging_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import input as input_ops
|
from tensorflow.python.training import input as input_ops
|
||||||
from tensorflow.python.training import queue_runner
|
from tensorflow.python.training import queue_runner
|
||||||
|
|
||||||
@ -34,10 +35,8 @@ __all__ = ['stratified_sample',
|
|||||||
'stratified_sample_unknown_dist',]
|
'stratified_sample_unknown_dist',]
|
||||||
|
|
||||||
|
|
||||||
# TODO(joelshor): Use an exponential-moving-average to estimate the initial
|
def stratified_sample(tensors, labels, target_probs, batch_size,
|
||||||
# class distribution and remove the requirement that it be provided.
|
init_probs=None, enqueue_many=False, queue_capacity=16,
|
||||||
def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
|
|
||||||
enqueue_many=False, queue_capacity=16,
|
|
||||||
threads_per_queue=1, name=None):
|
threads_per_queue=1, name=None):
|
||||||
"""Stochastically creates batches based on per-class probabilities.
|
"""Stochastically creates batches based on per-class probabilities.
|
||||||
|
|
||||||
@ -52,11 +51,12 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
|
|||||||
batch, according to enqueue_many.
|
batch, according to enqueue_many.
|
||||||
labels: Tensor for label of data. Label is a single integer or a batch,
|
labels: Tensor for label of data. Label is a single integer or a batch,
|
||||||
depending on enqueue_many. It is not a one-hot vector.
|
depending on enqueue_many. It is not a one-hot vector.
|
||||||
init_probs: Class proportions in the data. An object whose type has a
|
|
||||||
registered Tensor conversion function.
|
|
||||||
target_probs: Target class proportions in batch. An object whose type has a
|
target_probs: Target class proportions in batch. An object whose type has a
|
||||||
registered Tensor conversion function.
|
registered Tensor conversion function.
|
||||||
batch_size: Size of batch to be returned.
|
batch_size: Size of batch to be returned.
|
||||||
|
init_probs: Class proportions in the data. An object whose type has a
|
||||||
|
registered Tensor conversion function, or `None` for estimating the
|
||||||
|
initial distribution.
|
||||||
enqueue_many: Bool. If true, interpret input tensors as having a batch
|
enqueue_many: Bool. If true, interpret input tensors as having a batch
|
||||||
dimension.
|
dimension.
|
||||||
queue_capacity: Capacity of the large queue that holds input examples.
|
queue_capacity: Capacity of the large queue that holds input examples.
|
||||||
@ -81,10 +81,9 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
|
|||||||
data, label = data_provider.Get(['data', 'label'])
|
data, label = data_provider.Get(['data', 'label'])
|
||||||
|
|
||||||
# Get stratified batch according to per-class probabilities.
|
# Get stratified batch according to per-class probabilities.
|
||||||
init_probs = [1.0/NUM_CLASSES for _ in range(NUM_CLASSES)]
|
|
||||||
target_probs = [...distribution you want...]
|
target_probs = [...distribution you want...]
|
||||||
[data_batch], labels = tf.contrib.framework.sampling_ops.stratified_sample(
|
[data_batch], labels = tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
[data], label, init_probs, target_probs)
|
[data], label, target_probs)
|
||||||
|
|
||||||
# Run batch through network.
|
# Run batch through network.
|
||||||
...
|
...
|
||||||
@ -92,22 +91,34 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
|
|||||||
with ops.op_scope(tensors + [labels], name, 'stratified_sample'):
|
with ops.op_scope(tensors + [labels], name, 'stratified_sample'):
|
||||||
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
|
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
|
||||||
labels = ops.convert_to_tensor(labels)
|
labels = ops.convert_to_tensor(labels)
|
||||||
init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)
|
|
||||||
target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
|
target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
|
||||||
# Reduce the case of a single example to that of a batch of size 1.
|
# Reduce the case of a single example to that of a batch of size 1.
|
||||||
if not enqueue_many:
|
if not enqueue_many:
|
||||||
tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
|
tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
|
||||||
labels = array_ops.expand_dims(labels, 0)
|
labels = array_ops.expand_dims(labels, 0)
|
||||||
|
|
||||||
|
# If `init_probs` is `None`, set up online estimation of data distribution.
|
||||||
|
if init_probs is None:
|
||||||
|
# We use `target_probs` to get the number of classes, so its shape must be
|
||||||
|
# fully defined at graph construction time.
|
||||||
|
target_probs.get_shape().assert_is_fully_defined()
|
||||||
|
init_probs = _estimate_data_distribution(
|
||||||
|
labels, target_probs.get_shape().num_elements())
|
||||||
|
else:
|
||||||
|
init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)
|
||||||
|
|
||||||
# Validate that input is consistent.
|
# Validate that input is consistent.
|
||||||
tensor_list, labels, [init_probs, target_probs] = _verify_input(
|
tensor_list, labels, [init_probs, target_probs] = _verify_input(
|
||||||
tensor_list, labels, [init_probs, target_probs])
|
tensor_list, labels, [init_probs, target_probs])
|
||||||
|
|
||||||
# Check that all zero initial probabilities also have zero target
|
# Check that all zero initial probabilities also have zero target
|
||||||
# probabilities.
|
# probabilities.
|
||||||
assert_op = logging_ops.Assert(math_ops.reduce_all(math_ops.logical_or(
|
assert_op = logging_ops.Assert(
|
||||||
math_ops.not_equal(init_probs, 0),
|
math_ops.reduce_all(math_ops.logical_or(
|
||||||
math_ops.equal(target_probs, 0))), [init_probs, target_probs])
|
math_ops.not_equal(init_probs, 0),
|
||||||
|
math_ops.equal(target_probs, 0))),
|
||||||
|
['All classes with zero initial probability must also have zero target '
|
||||||
|
'probability: ', init_probs, target_probs])
|
||||||
init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)
|
init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)
|
||||||
|
|
||||||
# Calculate acceptance sampling probabilities.
|
# Calculate acceptance sampling probabilities.
|
||||||
@ -212,6 +223,40 @@ def stratified_sample_unknown_dist(tensors, labels, probs, batch_size,
|
|||||||
per_class_queues, probs, batch_size)
|
per_class_queues, probs, batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_data_distribution(labels, num_classes):
|
||||||
|
"""Estimate data distribution as labels are seen."""
|
||||||
|
# Variable to track running count of classes. Add 1 to avoid division-by-zero,
|
||||||
|
# and to guarantee that calculation of acceptance probabilities is (mostly)
|
||||||
|
# correct.
|
||||||
|
num_examples_per_class_seen = variables.Variable(
|
||||||
|
initial_value=[1] * num_classes, trainable=False, name='class_count',
|
||||||
|
dtype=dtypes.int64)
|
||||||
|
|
||||||
|
# Update the class-count based on what labels are seen in batch.
|
||||||
|
num_examples_per_class_seen = num_examples_per_class_seen.assign_add(
|
||||||
|
math_ops.reduce_sum(array_ops.one_hot(labels, num_classes,
|
||||||
|
dtype=dtypes.int64), 0))
|
||||||
|
|
||||||
|
# Normalize count into a probability.
|
||||||
|
# NOTE: Without the `+= 0` line below, the test
|
||||||
|
# `testMultiThreadedEstimateDataDistribution` fails. The reason is that
|
||||||
|
# before this line, `num_examples_per_class_seen` is a Tensor that shares a
|
||||||
|
# buffer with an underlying `ref` object. When the `ref` is changed by another
|
||||||
|
# thread, `num_examples_per_class_seen` changes as well. Since this can happen
|
||||||
|
# in the middle of the normalization computation, we get probabilities that
|
||||||
|
# are very far from summing to one. Adding `+= 0` copies the contents of the
|
||||||
|
# tensor to a new buffer, which will be consistent from the start to the end
|
||||||
|
# of the normalization computation.
|
||||||
|
num_examples_per_class_seen += 0
|
||||||
|
init_prob_estimate = math_ops.truediv(
|
||||||
|
num_examples_per_class_seen,
|
||||||
|
math_ops.reduce_sum(num_examples_per_class_seen))
|
||||||
|
|
||||||
|
# Must return float32 (not float64) to agree with downstream `_verify_input`
|
||||||
|
# checks.
|
||||||
|
return math_ops.cast(init_prob_estimate, dtypes.float32)
|
||||||
|
|
||||||
|
|
||||||
def _verify_input(tensor_list, labels, probs_list):
|
def _verify_input(tensor_list, labels, probs_list):
|
||||||
"""Verify that batched inputs are well-formed."""
|
"""Verify that batched inputs are well-formed."""
|
||||||
checked_probs_list = []
|
checked_probs_list = []
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
class SamplingOpsTest(tf.test.TestCase):
|
class SamplingOpsTest(tf.test.TestCase):
|
||||||
@ -33,15 +34,22 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# Curry the rejection sampler so we can easily run the same tests on both
|
# Curry the rejection sampler so we can easily run the same tests on both
|
||||||
# stratified_sample and stratified_sample_unknown_dist.
|
# stratified_sample and stratified_sample_unknown_dist.
|
||||||
def curried_sampler(val, lbls, probs, batch, enqueue_many=True):
|
def curried_sampler(tensors, labels, probs, batch_size, enqueue_many=True):
|
||||||
return tf.contrib.framework.sampling_ops.stratified_sample(
|
return tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
|
tensors=tensors,
|
||||||
|
labels=labels,
|
||||||
|
target_probs=probs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
init_probs=initial_p,
|
||||||
|
enqueue_many=enqueue_many)
|
||||||
|
|
||||||
samplers = [
|
samplers = [
|
||||||
tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist,
|
tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist,
|
||||||
curried_sampler,
|
curried_sampler,
|
||||||
]
|
]
|
||||||
|
|
||||||
for sampler in samplers:
|
for sampler in samplers:
|
||||||
|
logging.info('Now testing `%s`', sampler.__class__.__name__)
|
||||||
# Label must have only batch dimension if enqueue_many is True.
|
# Label must have only batch dimension if enqueue_many is True.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
sampler(val, tf.zeros([]), probs, batch_size, enqueue_many=True)
|
sampler(val, tf.zeros([]), probs, batch_size, enqueue_many=True)
|
||||||
@ -70,20 +78,21 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# Probabilities shape must be fully defined.
|
# Probabilities shape must be fully defined.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
sampler(val, label, tf.placeholder(tf.float32, shape=[None]),
|
sampler(
|
||||||
batch_size)
|
val, label, tf.placeholder(
|
||||||
|
tf.float32, shape=[None]), batch_size)
|
||||||
|
|
||||||
# In the rejection sampling case, make sure that probability lengths are
|
# In the rejection sampling case, make sure that probability lengths are
|
||||||
# the same.
|
# the same.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
tf.contrib.framework.sampling_ops.stratified_sample(
|
tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, label, [.2] * 5, [.1] * 10, batch_size)
|
val, label, [.1] * 10, batch_size, init_probs=[.2] * 5)
|
||||||
|
|
||||||
# In the rejection sampling case, make sure that zero initial probability
|
# In the rejection sampling case, make sure that zero initial probability
|
||||||
# classes also have zero target probability.
|
# classes also have zero target probability.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
tf.contrib.framework.sampling_ops.stratified_sample(
|
tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, label, [0, .5, .5], [.2, .4, .4], batch_size)
|
val, label, [.2, .4, .4], batch_size, init_probs=[0, .5, .5])
|
||||||
|
|
||||||
# Probabilities must be 1D.
|
# Probabilities must be 1D.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -116,15 +125,17 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
# Run session that should fail.
|
# Run session that should fail.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
sess.run([val_tf, lbl_tf], feed_dict={label_ph: illegal_label,
|
sess.run([val_tf, lbl_tf],
|
||||||
probs_ph: valid_probs})
|
feed_dict={label_ph: illegal_label,
|
||||||
|
probs_ph: valid_probs})
|
||||||
|
|
||||||
for illegal_prob in illegal_probs:
|
for illegal_prob in illegal_probs:
|
||||||
# Run session that should fail.
|
# Run session that should fail.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
sess.run([prob_tf], feed_dict={label_ph: valid_labels,
|
sess.run([prob_tf],
|
||||||
probs_ph: illegal_prob})
|
feed_dict={label_ph: valid_labels,
|
||||||
|
probs_ph: illegal_prob})
|
||||||
|
|
||||||
def batchingBehaviorHelper(self, sampler):
|
def batchingBehaviorHelper(self, sampler):
|
||||||
batch_size = 20
|
batch_size = 20
|
||||||
@ -152,15 +163,14 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
lbl_input_batch = tf.ones([], dtype=tf.int32)
|
lbl_input_batch = tf.ones([], dtype=tf.int32)
|
||||||
probs = np.array([0, 1, 0, 0, 0])
|
probs = np.array([0, 1, 0, 0, 0])
|
||||||
batches = tf.contrib.framework.sampling_ops.stratified_sample(
|
batches = tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val_input_batch, lbl_input_batch, probs, probs, batch_size)
|
val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
|
||||||
batches += tf.contrib.framework.sampling_ops.stratified_sample(
|
batches += tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val_input_batch, lbl_input_batch, probs, probs, batch_size)
|
val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
|
||||||
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
|
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
|
||||||
val_input_batch, lbl_input_batch, probs, batch_size)
|
val_input_batch, lbl_input_batch, probs, batch_size)
|
||||||
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
|
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
|
||||||
val_input_batch, lbl_input_batch, probs, batch_size)
|
val_input_batch, lbl_input_batch, probs, batch_size)
|
||||||
summary_op = tf.merge_summary(tf.get_collection(
|
summary_op = tf.merge_summary(tf.get_collection(tf.GraphKeys.SUMMARIES))
|
||||||
tf.GraphKeys.SUMMARIES))
|
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
coord = tf.train.Coordinator()
|
coord = tf.train.Coordinator()
|
||||||
@ -177,9 +187,15 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testRejectionBatchingBehavior(self):
|
def testRejectionBatchingBehavior(self):
|
||||||
initial_p = [0, .3, 0, .7, 0]
|
initial_p = [0, .3, 0, .7, 0]
|
||||||
|
|
||||||
def curried_sampler(val, lbls, probs, batch, enqueue_many=True):
|
def curried_sampler(val, lbls, probs, batch, enqueue_many=True):
|
||||||
return tf.contrib.framework.sampling_ops.stratified_sample(
|
return tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
|
val,
|
||||||
|
lbls,
|
||||||
|
probs,
|
||||||
|
batch,
|
||||||
|
init_probs=initial_p,
|
||||||
|
enqueue_many=enqueue_many)
|
||||||
|
|
||||||
self.batchingBehaviorHelper(curried_sampler)
|
self.batchingBehaviorHelper(curried_sampler)
|
||||||
|
|
||||||
@ -190,8 +206,7 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
lbl2 = 3
|
lbl2 = 3
|
||||||
# This cond allows the necessary class queues to be populated.
|
# This cond allows the necessary class queues to be populated.
|
||||||
label = tf.cond(
|
label = tf.cond(
|
||||||
tf.greater(.5, tf.random_uniform([])),
|
tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(lbl1),
|
||||||
lambda: tf.constant(lbl1),
|
|
||||||
lambda: tf.constant(lbl2))
|
lambda: tf.constant(lbl2))
|
||||||
val = [np.array([1, 4]) * label]
|
val = [np.array([1, 4]) * label]
|
||||||
probs = tf.placeholder(tf.float32, shape=[5])
|
probs = tf.placeholder(tf.float32, shape=[5])
|
||||||
@ -225,7 +240,7 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
def testBatchDimensionNotRequired(self):
|
def testBatchDimensionNotRequired(self):
|
||||||
classes = 5
|
classes = 5
|
||||||
# Probs must be a tensor, since we pass it directly to _verify_input.
|
# Probs must be a tensor, since we pass it directly to _verify_input.
|
||||||
probs = tf.constant([1.0/classes] * classes)
|
probs = tf.constant([1.0 / classes] * classes)
|
||||||
|
|
||||||
# Make sure that these vals/labels pairs don't throw any runtime exceptions.
|
# Make sure that these vals/labels pairs don't throw any runtime exceptions.
|
||||||
legal_input_pairs = [
|
legal_input_pairs = [
|
||||||
@ -243,16 +258,17 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
# Run graph to make sure there are no shape-related runtime errors.
|
# Run graph to make sure there are no shape-related runtime errors.
|
||||||
for vals, labels in legal_input_pairs:
|
for vals, labels in legal_input_pairs:
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run([val_tf, labels_tf], feed_dict={vals_ph: vals,
|
sess.run([val_tf, labels_tf],
|
||||||
labels_ph: labels})
|
feed_dict={vals_ph: vals,
|
||||||
|
labels_ph: labels})
|
||||||
|
|
||||||
def dataListHelper(self, sampler):
|
def dataListHelper(self, sampler):
|
||||||
batch_size = 20
|
batch_size = 20
|
||||||
val_input_batch = [tf.zeros([2, 3, 4]), tf.ones([2, 4]), tf.ones(2) * 3]
|
val_input_batch = [tf.zeros([2, 3, 4]), tf.ones([2, 4]), tf.ones(2) * 3]
|
||||||
lbl_input_batch = tf.ones([], dtype=tf.int32)
|
lbl_input_batch = tf.ones([], dtype=tf.int32)
|
||||||
probs = np.array([0, 1, 0, 0, 0])
|
probs = np.array([0, 1, 0, 0, 0])
|
||||||
val_list, lbls = sampler(
|
val_list, lbls = sampler(val_input_batch, lbl_input_batch, probs,
|
||||||
val_input_batch, lbl_input_batch, probs, batch_size)
|
batch_size)
|
||||||
|
|
||||||
# Check output shapes.
|
# Check output shapes.
|
||||||
self.assertTrue(isinstance(val_list, list))
|
self.assertTrue(isinstance(val_list, list))
|
||||||
@ -277,9 +293,16 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testRejectionDataListInput(self):
|
def testRejectionDataListInput(self):
|
||||||
initial_p = [0, 1, 0, 0, 0]
|
initial_p = [0, 1, 0, 0, 0]
|
||||||
|
|
||||||
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
|
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
|
||||||
return tf.contrib.framework.sampling_ops.stratified_sample(
|
return tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
|
val,
|
||||||
|
lbls,
|
||||||
|
probs,
|
||||||
|
batch,
|
||||||
|
init_probs=initial_p,
|
||||||
|
enqueue_many=enqueue_many)
|
||||||
|
|
||||||
self.dataListHelper(curried_sampler)
|
self.dataListHelper(curried_sampler)
|
||||||
|
|
||||||
def normalBehaviorHelper(self, sampler):
|
def normalBehaviorHelper(self, sampler):
|
||||||
@ -289,8 +312,7 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
lbl2 = 3
|
lbl2 = 3
|
||||||
# This cond allows the necessary class queues to be populated.
|
# This cond allows the necessary class queues to be populated.
|
||||||
label = tf.cond(
|
label = tf.cond(
|
||||||
tf.greater(.5, tf.random_uniform([])),
|
tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(lbl1),
|
||||||
lambda: tf.constant(lbl1),
|
|
||||||
lambda: tf.constant(lbl2))
|
lambda: tf.constant(lbl2))
|
||||||
val = [np.array([1, 4]) * label]
|
val = [np.array([1, 4]) * label]
|
||||||
probs = np.array([.8, 0, 0, .2, 0])
|
probs = np.array([.8, 0, 0, .2, 0])
|
||||||
@ -302,6 +324,9 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
data_l = []
|
data_l = []
|
||||||
label_l = []
|
label_l = []
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
|
# Need to initialize variables that keep running total of classes seen.
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
coord = tf.train.Coordinator()
|
coord = tf.train.Coordinator()
|
||||||
threads = tf.train.start_queue_runners(coord=coord)
|
threads = tf.train.start_queue_runners(coord=coord)
|
||||||
|
|
||||||
@ -329,7 +354,7 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
# is fixed, for a given implementation, this test will pass or fail 100% of
|
# is fixed, for a given implementation, this test will pass or fail 100% of
|
||||||
# the time. This use of assertNear is to cover cases where someone changes
|
# the time. This use of assertNear is to cover cases where someone changes
|
||||||
# an implementation detail, which would cause the random behavior to differ.
|
# an implementation detail, which would cause the random behavior to differ.
|
||||||
self.assertNear(actual_lbl, expected_label, 3*lbl_std_dev_of_mean)
|
self.assertNear(actual_lbl, expected_label, 3 * lbl_std_dev_of_mean)
|
||||||
|
|
||||||
def testNormalBehavior(self):
|
def testNormalBehavior(self):
|
||||||
self.normalBehaviorHelper(
|
self.normalBehaviorHelper(
|
||||||
@ -337,10 +362,26 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testRejectionNormalBehavior(self):
|
def testRejectionNormalBehavior(self):
|
||||||
initial_p = [.7, 0, 0, .3, 0]
|
initial_p = [.7, 0, 0, .3, 0]
|
||||||
|
|
||||||
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
|
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
|
||||||
return tf.contrib.framework.sampling_ops.stratified_sample(
|
return tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
|
val,
|
||||||
|
lbls,
|
||||||
|
probs,
|
||||||
|
batch,
|
||||||
|
init_probs=initial_p,
|
||||||
|
enqueue_many=enqueue_many)
|
||||||
|
|
||||||
self.normalBehaviorHelper(curried_sampler)
|
self.normalBehaviorHelper(curried_sampler)
|
||||||
|
|
||||||
|
def testRejectionNormalBehaviorWithOnlineInitPEstimate(self):
|
||||||
|
|
||||||
|
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
|
||||||
|
return tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
|
val, lbls, probs, batch, init_probs=None, enqueue_many=enqueue_many)
|
||||||
|
|
||||||
|
self.normalBehaviorHelper(curried_sampler)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingOpsThreadingTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testMultiThreadedEstimateDataDistribution(self):
|
||||||
|
num_classes = 10
|
||||||
|
|
||||||
|
# Set up graph.
|
||||||
|
tf.set_random_seed(1234)
|
||||||
|
label = tf.cast(tf.round(tf.random_uniform([1]) * num_classes), tf.int32)
|
||||||
|
|
||||||
|
prob_estimate = tf.contrib.framework.sampling_ops._estimate_data_distribution( # pylint: disable=line-too-long
|
||||||
|
label, num_classes)
|
||||||
|
# Check that prob_estimate is well-behaved in a multithreaded context.
|
||||||
|
_, _, [prob_estimate] = tf.contrib.framework.sampling_ops._verify_input(
|
||||||
|
[], label, [prob_estimate])
|
||||||
|
|
||||||
|
# Use queues to run multiple threads over the graph, each of which
|
||||||
|
# fetches `prob_estimate`.
|
||||||
|
queue = tf.FIFOQueue(
|
||||||
|
capacity=25,
|
||||||
|
dtypes=[prob_estimate.dtype],
|
||||||
|
shapes=[prob_estimate.get_shape()])
|
||||||
|
enqueue_op = queue.enqueue([prob_estimate])
|
||||||
|
tf.train.add_queue_runner(tf.train.QueueRunner(queue, [enqueue_op] * 25))
|
||||||
|
out_tensor = queue.dequeue()
|
||||||
|
|
||||||
|
# Run the multi-threaded session.
|
||||||
|
with self.test_session() as sess:
|
||||||
|
# Need to initialize variables that keep running total of classes seen.
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
|
coord = tf.train.Coordinator()
|
||||||
|
threads = tf.train.start_queue_runners(coord=coord)
|
||||||
|
|
||||||
|
for _ in range(25):
|
||||||
|
sess.run([out_tensor])
|
||||||
|
|
||||||
|
coord.request_stop()
|
||||||
|
coord.join(threads)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
REGISTER_OP("SparseFeatureCross")
|
REGISTER_OP("SparseFeatureCross")
|
||||||
@ -31,6 +32,12 @@ REGISTER_OP("SparseFeatureCross")
|
|||||||
.Attr("dense_types: list({int64, string}) >= 0")
|
.Attr("dense_types: list({int64, string}) >= 0")
|
||||||
.Attr("out_type: {int64, string}")
|
.Attr("out_type: {int64, string}")
|
||||||
.Attr("internal_type: {int64, string}")
|
.Attr("internal_type: {int64, string}")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->Matrix(c->UnknownDim(), 2));
|
||||||
|
c->set_output(1, c->Vector(c->UnknownDim()));
|
||||||
|
c->set_output(2, c->Vector(2));
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Generates sparse cross form a list of sparse tensors.
|
Generates sparse cross form a list of sparse tensors.
|
||||||
|
|
||||||
|
@ -75,6 +75,7 @@ import abc
|
|||||||
import collections
|
import collections
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
||||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||||
from tensorflow.contrib.layers.python.layers import embedding_ops
|
from tensorflow.contrib.layers.python.layers import embedding_ops
|
||||||
from tensorflow.contrib.layers.python.ops import bucketization_op
|
from tensorflow.contrib.layers.python.ops import bucketization_op
|
||||||
@ -149,6 +150,7 @@ class _FeatureColumn(object):
|
|||||||
raise ValueError("Calling an abstract method.")
|
raise ValueError("Calling an abstract method.")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(b/30410315): Support warm starting in all feature columns.
|
||||||
class _SparseColumn(_FeatureColumn,
|
class _SparseColumn(_FeatureColumn,
|
||||||
collections.namedtuple("_SparseColumn",
|
collections.namedtuple("_SparseColumn",
|
||||||
["column_name", "is_integerized",
|
["column_name", "is_integerized",
|
||||||
@ -191,35 +193,36 @@ class _SparseColumn(_FeatureColumn,
|
|||||||
combiner="sum",
|
combiner="sum",
|
||||||
dtype=dtypes.string):
|
dtype=dtypes.string):
|
||||||
if is_integerized and bucket_size is None:
|
if is_integerized and bucket_size is None:
|
||||||
raise ValueError("bucket_size should be set if is_integerized=True. "
|
raise ValueError("bucket_size must be set if is_integerized is True. "
|
||||||
"column_name: {}".format(column_name))
|
"column_name: {}".format(column_name))
|
||||||
|
|
||||||
if is_integerized and not dtype.is_integer:
|
if is_integerized and not dtype.is_integer:
|
||||||
raise ValueError("dtype should be an integer if is_integerized is True. "
|
raise ValueError("dtype must be an integer if is_integerized is True. "
|
||||||
"Column {}.".format(column_name))
|
"dtype: {}, column_name: {}.".format(dtype, column_name))
|
||||||
|
|
||||||
if bucket_size is None and lookup_config is None:
|
if bucket_size is None and lookup_config is None:
|
||||||
raise ValueError("one of bucket_size or lookup_config should be "
|
raise ValueError("one of bucket_size or lookup_config must be set. "
|
||||||
"set. column_name: {}".format(column_name))
|
"column_name: {}".format(column_name))
|
||||||
|
|
||||||
if bucket_size is not None and lookup_config:
|
if bucket_size is not None and lookup_config:
|
||||||
raise ValueError("one and only one of bucket_size or lookup_config "
|
raise ValueError("one and only one of bucket_size or lookup_config "
|
||||||
"should be set. column_name: {}".format(column_name))
|
"must be set. column_name: {}".format(column_name))
|
||||||
|
|
||||||
if bucket_size is not None and bucket_size < 2:
|
if bucket_size is not None and bucket_size < 2:
|
||||||
raise ValueError("bucket_size should be at least 2. "
|
raise ValueError("bucket_size must be at least 2. "
|
||||||
"column_name: {}".format(column_name))
|
"bucket_size: {}, column_name: {}".format(bucket_size,
|
||||||
|
column_name))
|
||||||
|
|
||||||
if ((lookup_config) and
|
if ((lookup_config) and
|
||||||
(not isinstance(lookup_config, _SparseIdLookupConfig))):
|
(not isinstance(lookup_config, _SparseIdLookupConfig))):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"lookup_config should be an instance of _SparseIdLookupConfig. "
|
"lookup_config must be an instance of _SparseIdLookupConfig. "
|
||||||
"Given one is in type {} for column_name {}".format(
|
"Given one is in type {} for column_name {}".format(
|
||||||
type(lookup_config), column_name))
|
type(lookup_config), column_name))
|
||||||
|
|
||||||
if (lookup_config and lookup_config.vocabulary_file and
|
if (lookup_config and lookup_config.vocabulary_file and
|
||||||
lookup_config.vocab_size is None):
|
lookup_config.vocab_size is None):
|
||||||
raise ValueError("vocab_size should be defined. "
|
raise ValueError("vocab_size must be defined. "
|
||||||
"column_name: {}".format(column_name))
|
"column_name: {}".format(column_name))
|
||||||
|
|
||||||
return super(_SparseColumn, cls).__new__(cls, column_name, is_integerized,
|
return super(_SparseColumn, cls).__new__(cls, column_name, is_integerized,
|
||||||
@ -260,8 +263,8 @@ class _SparseColumn(_FeatureColumn,
|
|||||||
input_tensor,
|
input_tensor,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
raise ValueError("Column {} is not supported in DNN. "
|
raise ValueError("SparseColumn is not supported in DNN. "
|
||||||
"Please use embedding_column.".format(self))
|
"Please use embedding_column. column: {}".format(self))
|
||||||
|
|
||||||
def to_weighted_sum(self,
|
def to_weighted_sum(self,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
@ -277,7 +280,7 @@ class _SparseColumn(_FeatureColumn,
|
|||||||
initializer=init_ops.zeros_initializer,
|
initializer=init_ops.zeros_initializer,
|
||||||
combiner=self.combiner,
|
combiner=self.combiner,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=self.name + "_weights")
|
name=self.name)
|
||||||
|
|
||||||
|
|
||||||
class _SparseColumnIntegerized(_SparseColumn):
|
class _SparseColumnIntegerized(_SparseColumn):
|
||||||
@ -289,8 +292,8 @@ class _SparseColumnIntegerized(_SparseColumn):
|
|||||||
combiner="sum",
|
combiner="sum",
|
||||||
dtype=dtypes.int64):
|
dtype=dtypes.int64):
|
||||||
if not dtype.is_integer:
|
if not dtype.is_integer:
|
||||||
raise ValueError("dtype should be an integer. Given {}".format(
|
raise ValueError("dtype must be an integer. "
|
||||||
column_name))
|
"dtype: {}, column_name: {}".format(dtype, column_name))
|
||||||
|
|
||||||
return super(_SparseColumnIntegerized, cls).__new__(cls,
|
return super(_SparseColumnIntegerized, cls).__new__(cls,
|
||||||
column_name,
|
column_name,
|
||||||
@ -505,8 +508,8 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
input_tensor,
|
input_tensor,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
raise ValueError("Column {} is not supported in DNN. "
|
raise ValueError("WeightedSparseColumn is not supported in DNN. "
|
||||||
"Please use embedding_column.".format(self))
|
"Please use embedding_column. column: {}".format(self))
|
||||||
|
|
||||||
def to_weighted_sum(self,
|
def to_weighted_sum(self,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
@ -522,7 +525,7 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
initializer=init_ops.zeros_initializer,
|
initializer=init_ops.zeros_initializer,
|
||||||
combiner=self.sparse_id_column.combiner,
|
combiner=self.sparse_id_column.combiner,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=self.name + "_weights")
|
name=self.name)
|
||||||
|
|
||||||
|
|
||||||
def weighted_sparse_column(sparse_id_column,
|
def weighted_sparse_column(sparse_id_column,
|
||||||
@ -568,7 +571,8 @@ def weighted_sparse_column(sparse_id_column,
|
|||||||
|
|
||||||
class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
||||||
"_EmbeddingColumn",
|
"_EmbeddingColumn",
|
||||||
["sparse_id_column", "dimension", "combiner", "initializer"])):
|
["sparse_id_column", "dimension", "combiner", "initializer",
|
||||||
|
"ckpt_to_load_from", "tensor_name_in_ckpt"])):
|
||||||
"""Represents an embedding column.
|
"""Represents an embedding column.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -586,15 +590,33 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
variable initialization. If not specified, defaults to
|
variable initialization. If not specified, defaults to
|
||||||
`tf.truncated_normal_initializer` with mean 0.0 and standard deviation
|
`tf.truncated_normal_initializer` with mean 0.0 and standard deviation
|
||||||
1/sqrt(sparse_id_column.length).
|
1/sqrt(sparse_id_column.length).
|
||||||
|
ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
|
||||||
|
to restore the column weights. Required if `tensor_name_in_ckpt` is not
|
||||||
|
None.
|
||||||
|
tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
|
||||||
|
checkpoint from which to restore the column weights. Required if
|
||||||
|
`ckpt_to_load_from` is not None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `initializer` is specified and is not callable. Also,
|
||||||
|
if only one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls,
|
def __new__(cls,
|
||||||
sparse_id_column,
|
sparse_id_column,
|
||||||
dimension,
|
dimension,
|
||||||
combiner="mean",
|
combiner="mean",
|
||||||
initializer=None):
|
initializer=None,
|
||||||
|
ckpt_to_load_from=None,
|
||||||
|
tensor_name_in_ckpt=None):
|
||||||
if initializer is not None and not callable(initializer):
|
if initializer is not None and not callable(initializer):
|
||||||
raise ValueError("initializer must be callable if specified.")
|
raise ValueError("initializer must be callable if specified. "
|
||||||
|
"Embedding of column_name: {}".format(
|
||||||
|
sparse_id_column.name))
|
||||||
|
|
||||||
|
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
|
||||||
|
raise ValueError("Must specify both `ckpt_to_load_from` and "
|
||||||
|
"`tensor_name_in_ckpt` or none of them.")
|
||||||
if initializer is None:
|
if initializer is None:
|
||||||
stddev = 1 / math.sqrt(sparse_id_column.length)
|
stddev = 1 / math.sqrt(sparse_id_column.length)
|
||||||
# TODO(b/25671353): Better initial value?
|
# TODO(b/25671353): Better initial value?
|
||||||
@ -602,7 +624,8 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
stddev=stddev)
|
stddev=stddev)
|
||||||
return super(_EmbeddingColumn, cls).__new__(cls, sparse_id_column,
|
return super(_EmbeddingColumn, cls).__new__(cls, sparse_id_column,
|
||||||
dimension, combiner,
|
dimension, combiner,
|
||||||
initializer)
|
initializer, ckpt_to_load_from,
|
||||||
|
tensor_name_in_ckpt)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -645,7 +668,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
input_tensor,
|
input_tensor,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
output, _ = _create_embedding_lookup(
|
output, embedding_weights = _create_embedding_lookup(
|
||||||
input_tensor=self.sparse_id_column.id_tensor(input_tensor),
|
input_tensor=self.sparse_id_column.id_tensor(input_tensor),
|
||||||
weight_tensor=self.sparse_id_column.weight_tensor(input_tensor),
|
weight_tensor=self.sparse_id_column.weight_tensor(input_tensor),
|
||||||
vocab_size=self.length,
|
vocab_size=self.length,
|
||||||
@ -654,7 +677,14 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
initializer=self.initializer,
|
initializer=self.initializer,
|
||||||
combiner=self.combiner,
|
combiner=self.combiner,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=self.name + "_weights")
|
name=self.name)
|
||||||
|
if self.ckpt_to_load_from is not None:
|
||||||
|
weights_to_restore = embedding_weights
|
||||||
|
if len(embedding_weights) == 1:
|
||||||
|
weights_to_restore = embedding_weights[0]
|
||||||
|
checkpoint_utils.init_from_checkpoint(
|
||||||
|
self.ckpt_to_load_from,
|
||||||
|
{self.tensor_name_in_ckpt: weights_to_restore})
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
@ -663,19 +693,22 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
num_outputs=1,
|
num_outputs=1,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
raise ValueError("Column {} is not supported in linear models. "
|
raise ValueError("EmbeddingColumn is not supported in linear models. "
|
||||||
"Please use sparse_column.".format(self))
|
"Please use sparse_column. column: {}".format(self))
|
||||||
|
|
||||||
|
|
||||||
def embedding_column(sparse_id_column,
|
def embedding_column(sparse_id_column,
|
||||||
dimension,
|
dimension,
|
||||||
combiner="mean",
|
combiner="mean",
|
||||||
initializer=None):
|
initializer=None,
|
||||||
|
ckpt_to_load_from=None,
|
||||||
|
tensor_name_in_ckpt=None):
|
||||||
"""Creates an _EmbeddingColumn.
|
"""Creates an _EmbeddingColumn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_id_column: A _SparseColumn which is created by `sparse_column_with_*`
|
sparse_id_column: A _SparseColumn which is created by `sparse_column_with_*`
|
||||||
functions. Note that `combiner` defined in `sparse_id_column` is ignored.
|
or crossed_column functions. Note that `combiner` defined in
|
||||||
|
`sparse_id_column` is ignored.
|
||||||
dimension: An integer specifying dimension of the embedding.
|
dimension: An integer specifying dimension of the embedding.
|
||||||
combiner: A string specifying how to reduce if there are multiple entries
|
combiner: A string specifying how to reduce if there are multiple entries
|
||||||
in a single row. Currently "mean", "sqrtn" and "sum" are supported. Each
|
in a single row. Currently "mean", "sqrtn" and "sum" are supported. Each
|
||||||
@ -688,11 +721,18 @@ def embedding_column(sparse_id_column,
|
|||||||
variable initialization. If not specified, defaults to
|
variable initialization. If not specified, defaults to
|
||||||
`tf.truncated_normal_initializer` with mean 0.0 and standard deviation
|
`tf.truncated_normal_initializer` with mean 0.0 and standard deviation
|
||||||
1/sqrt(sparse_id_column.length).
|
1/sqrt(sparse_id_column.length).
|
||||||
|
ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
|
||||||
|
to restore the column weights. Required if `tensor_name_in_ckpt` is not
|
||||||
|
None.
|
||||||
|
tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
|
||||||
|
checkpoint from which to restore the column weights. Required if
|
||||||
|
`ckpt_to_load_from` is not None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An _EmbeddingColumn.
|
An _EmbeddingColumn.
|
||||||
"""
|
"""
|
||||||
return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer)
|
return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer,
|
||||||
|
ckpt_to_load_from, tensor_name_in_ckpt)
|
||||||
|
|
||||||
|
|
||||||
class _HashedEmbeddingColumn(collections.namedtuple(
|
class _HashedEmbeddingColumn(collections.namedtuple(
|
||||||
@ -707,7 +747,8 @@ class _HashedEmbeddingColumn(collections.namedtuple(
|
|||||||
combiner="mean",
|
combiner="mean",
|
||||||
initializer=None):
|
initializer=None):
|
||||||
if initializer is not None and not callable(initializer):
|
if initializer is not None and not callable(initializer):
|
||||||
raise ValueError("initializer must be callable if specified.")
|
raise ValueError("initializer must be callable if specified. "
|
||||||
|
"column_name: {}".format(column_name))
|
||||||
if initializer is None:
|
if initializer is None:
|
||||||
stddev = 0.1
|
stddev = 0.1
|
||||||
# TODO(b/25671353): Better initial value?
|
# TODO(b/25671353): Better initial value?
|
||||||
@ -733,7 +774,7 @@ class _HashedEmbeddingColumn(collections.namedtuple(
|
|||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
embeddings = _create_embeddings(
|
embeddings = _create_embeddings(
|
||||||
name=self.name + "_weights",
|
name=self.name,
|
||||||
shape=[self.size],
|
shape=[self.size],
|
||||||
initializer=self.initializer,
|
initializer=self.initializer,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
@ -778,10 +819,14 @@ def hashed_embedding_column(column_name,
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if (dimension < 1) or (size < 1):
|
if (dimension < 1) or (size < 1):
|
||||||
raise ValueError("Dimension and size must be greater than 0.")
|
raise ValueError("Dimension and size must be greater than 0. "
|
||||||
|
"dimension: {}, size: {}, column_name: {}".format(
|
||||||
|
dimension, size, column_name))
|
||||||
|
|
||||||
if combiner not in ("mean", "sqrtn", "sum"):
|
if combiner not in ("mean", "sqrtn", "sum"):
|
||||||
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
|
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'. "
|
||||||
|
"combiner: {}, column_name: {}".format(
|
||||||
|
combiner, column_name))
|
||||||
|
|
||||||
return _HashedEmbeddingColumn(column_name, size, dimension, combiner,
|
return _HashedEmbeddingColumn(column_name, size, dimension, combiner,
|
||||||
initializer)
|
initializer)
|
||||||
@ -892,14 +937,18 @@ def real_valued_column(column_name,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(dimension, int):
|
if not isinstance(dimension, int):
|
||||||
raise TypeError("dimension must be an integer")
|
raise TypeError("dimension must be an integer. "
|
||||||
|
"dimension: {}, column_name: {}".format(dimension,
|
||||||
|
column_name))
|
||||||
|
|
||||||
if dimension < 1:
|
if dimension < 1:
|
||||||
raise ValueError("dimension must be greater than 0")
|
raise ValueError("dimension must be greater than 0. "
|
||||||
|
"dimension: {}, column_name: {}".format(dimension,
|
||||||
|
column_name))
|
||||||
|
|
||||||
if not (dtype.is_integer or dtype.is_floating):
|
if not (dtype.is_integer or dtype.is_floating):
|
||||||
raise ValueError("dtype is not convertible to tf.float32. Given {}".format(
|
raise ValueError("dtype must be convertible to float. "
|
||||||
dtype))
|
"dtype: {}, column_name: {}".format(dtype, column_name))
|
||||||
|
|
||||||
if default_value is None:
|
if default_value is None:
|
||||||
return _RealValuedColumn(column_name, dimension, default_value, dtype)
|
return _RealValuedColumn(column_name, dimension, default_value, dtype)
|
||||||
@ -920,9 +969,10 @@ def real_valued_column(column_name,
|
|||||||
|
|
||||||
if isinstance(default_value, list):
|
if isinstance(default_value, list):
|
||||||
if len(default_value) != dimension:
|
if len(default_value) != dimension:
|
||||||
raise ValueError("The length of default_value is not equal to the "
|
raise ValueError(
|
||||||
"value of dimension. default_value is {}.".format(
|
"The length of default_value must be equal to dimension. "
|
||||||
default_value))
|
"default_value: {}, dimension: {}, column_name: {}".format(
|
||||||
|
default_value, dimension, column_name))
|
||||||
# Check if the values in the list are all integers or are convertible to
|
# Check if the values in the list are all integers or are convertible to
|
||||||
# floats.
|
# floats.
|
||||||
is_list_all_int = True
|
is_list_all_int = True
|
||||||
@ -943,8 +993,9 @@ def real_valued_column(column_name,
|
|||||||
default_value = [float(v) for v in default_value]
|
default_value = [float(v) for v in default_value]
|
||||||
return _RealValuedColumn(column_name, dimension, default_value, dtype)
|
return _RealValuedColumn(column_name, dimension, default_value, dtype)
|
||||||
|
|
||||||
raise TypeError("default_value is not compatible with dtype. "
|
raise TypeError("default_value must be compatible with dtype. "
|
||||||
"default_value is {}.".format(default_value))
|
"default_value: {}, dtype: {}, column_name: {}".format(
|
||||||
|
default_value, dtype, column_name))
|
||||||
|
|
||||||
|
|
||||||
class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
||||||
@ -971,10 +1022,12 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
def __new__(cls, source_column, boundaries):
|
def __new__(cls, source_column, boundaries):
|
||||||
if not isinstance(source_column, _RealValuedColumn):
|
if not isinstance(source_column, _RealValuedColumn):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"source_column should be an instance of _RealValuedColumn.")
|
"source_column must be an instance of _RealValuedColumn. "
|
||||||
|
"source_column: {}".format(source_column))
|
||||||
|
|
||||||
if not isinstance(boundaries, list) or not boundaries:
|
if not isinstance(boundaries, list) or not boundaries:
|
||||||
raise ValueError("boundaries must be a list and it should not be empty.")
|
raise ValueError("boundaries must be a non-empty list. "
|
||||||
|
"boundaries: {}".format(boundaries))
|
||||||
|
|
||||||
# We allow bucket boundaries to be monotonically increasing
|
# We allow bucket boundaries to be monotonically increasing
|
||||||
# (ie a[i+1] >= a[i]). When two bucket boundaries are the same, we
|
# (ie a[i+1] >= a[i]). When two bucket boundaries are the same, we
|
||||||
@ -986,7 +1039,8 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
elif boundaries[i] < boundaries[i + 1]:
|
elif boundaries[i] < boundaries[i + 1]:
|
||||||
sanitized_boundaries.append(boundaries[i])
|
sanitized_boundaries.append(boundaries[i])
|
||||||
else:
|
else:
|
||||||
raise ValueError("boundaries must be a sorted list")
|
raise ValueError("boundaries must be a sorted list. "
|
||||||
|
"boundaries: {}".format(boundaries))
|
||||||
sanitized_boundaries.append(boundaries[len(boundaries) - 1])
|
sanitized_boundaries.append(boundaries[len(boundaries) - 1])
|
||||||
|
|
||||||
return super(_BucketizedColumn, cls).__new__(cls, source_column,
|
return super(_BucketizedColumn, cls).__new__(cls, source_column,
|
||||||
@ -1067,7 +1121,7 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
initializer=init_ops.zeros_initializer,
|
initializer=init_ops.zeros_initializer,
|
||||||
combiner="sum",
|
combiner="sum",
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=self.name + "_weights")
|
name=self.name)
|
||||||
|
|
||||||
|
|
||||||
def bucketized_column(source_column, boundaries):
|
def bucketized_column(source_column, boundaries):
|
||||||
@ -1087,7 +1141,8 @@ def bucketized_column(source_column, boundaries):
|
|||||||
|
|
||||||
|
|
||||||
class _CrossedColumn(_FeatureColumn, collections.namedtuple(
|
class _CrossedColumn(_FeatureColumn, collections.namedtuple(
|
||||||
"_CrossedColumn", ["columns", "hash_bucket_size", "combiner"])):
|
"_CrossedColumn", ["columns", "hash_bucket_size", "combiner",
|
||||||
|
"ckpt_to_load_from", "tensor_name_in_ckpt"])):
|
||||||
"""Represents a cross transformation also known as composition or union.
|
"""Represents a cross transformation also known as composition or union.
|
||||||
|
|
||||||
Instances of this class are immutable. It crosses given `columns`. Crossed
|
Instances of this class are immutable. It crosses given `columns`. Crossed
|
||||||
@ -1124,13 +1179,19 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
* "mean": do l1 normalization
|
* "mean": do l1 normalization
|
||||||
* "sqrtn": do l2 normalization
|
* "sqrtn": do l2 normalization
|
||||||
For more information: `tf.embedding_lookup_sparse`.
|
For more information: `tf.embedding_lookup_sparse`.
|
||||||
|
ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
|
||||||
|
to restore the column weights. Required if `tensor_name_in_ckpt` is not
|
||||||
|
None.
|
||||||
|
tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
|
||||||
|
checkpoint from which to restore the column weights. Required if
|
||||||
|
`ckpt_to_load_from` is not None.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if all items in columns are not an instance of _SparseColumn,
|
TypeError: if all items in columns are not an instance of _SparseColumn,
|
||||||
_CrossedColumn, or _BucketizedColumn or
|
_CrossedColumn, or _BucketizedColumn or
|
||||||
hash_bucket_size is not an int.
|
hash_bucket_size is not an int.
|
||||||
ValueError: if hash_bucket_size is not > 1 or
|
ValueError: if hash_bucket_size is not > 1 or len(columns) is not > 1. Also,
|
||||||
len(columns) is not > 1.
|
if only one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1138,26 +1199,36 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
return isinstance(column,
|
return isinstance(column,
|
||||||
(_SparseColumn, _CrossedColumn, _BucketizedColumn))
|
(_SparseColumn, _CrossedColumn, _BucketizedColumn))
|
||||||
|
|
||||||
def __new__(cls, columns, hash_bucket_size, combiner="sum"):
|
def __new__(cls, columns, hash_bucket_size, combiner="sum",
|
||||||
|
ckpt_to_load_from=None, tensor_name_in_ckpt=None):
|
||||||
for column in columns:
|
for column in columns:
|
||||||
if not _CrossedColumn._is_crossable(column):
|
if not _CrossedColumn._is_crossable(column):
|
||||||
raise TypeError("columns should be a set of "
|
raise TypeError("columns must be a set of _SparseColumn, "
|
||||||
"_SparseColumn, _CrossedColumn, or _BucketizedColumn. "
|
"_CrossedColumn, or _BucketizedColumn instances. "
|
||||||
"Column is {}".format(column))
|
"column: {}".format(column))
|
||||||
|
|
||||||
if len(columns) < 2:
|
if len(columns) < 2:
|
||||||
raise ValueError("columns should contain at least 2 elements.")
|
raise ValueError("columns must contain at least 2 elements. "
|
||||||
|
"columns: {}".format(columns))
|
||||||
|
|
||||||
if not isinstance(hash_bucket_size, int):
|
if not isinstance(hash_bucket_size, int):
|
||||||
raise TypeError("hash_bucket_size should be an int.")
|
raise TypeError("hash_bucket_size must be an int. "
|
||||||
|
"hash_bucket_size: {}".format(hash_bucket_size))
|
||||||
|
|
||||||
if hash_bucket_size < 2:
|
if hash_bucket_size < 2:
|
||||||
raise ValueError("hash_bucket_size should be at least 2.")
|
raise ValueError("hash_bucket_size must be at least 2. "
|
||||||
|
"hash_bucket_size: {}".format(hash_bucket_size))
|
||||||
|
|
||||||
|
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
|
||||||
|
raise ValueError("Must specify both `ckpt_to_load_from` and "
|
||||||
|
"`tensor_name_in_ckpt` or none of them.")
|
||||||
|
|
||||||
sorted_columns = sorted([column for column in columns],
|
sorted_columns = sorted([column for column in columns],
|
||||||
key=lambda column: column.name)
|
key=lambda column: column.name)
|
||||||
return super(_CrossedColumn, cls).__new__(cls, tuple(sorted_columns),
|
return super(_CrossedColumn, cls).__new__(cls, tuple(sorted_columns),
|
||||||
hash_bucket_size, combiner)
|
hash_bucket_size, combiner,
|
||||||
|
ckpt_to_load_from,
|
||||||
|
tensor_name_in_ckpt)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -1181,6 +1252,15 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
"""Returns a string which will be used as a key when we do sorting."""
|
"""Returns a string which will be used as a key when we do sorting."""
|
||||||
return "{}".format(self)
|
return "{}".format(self)
|
||||||
|
|
||||||
|
def id_tensor(self, input_tensor):
|
||||||
|
"""Returns the id tensor from the given transformed input_tensor."""
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def weight_tensor(self, input_tensor):
|
||||||
|
"""Returns the weight tensor from the given transformed input_tensor."""
|
||||||
|
return None
|
||||||
|
|
||||||
def insert_transformed_feature(self, columns_to_tensors):
|
def insert_transformed_feature(self, columns_to_tensors):
|
||||||
"""Handles cross transformation."""
|
"""Handles cross transformation."""
|
||||||
|
|
||||||
@ -1215,15 +1295,15 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
input_tensor,
|
input_tensor,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
raise ValueError("Column {} is not supported in DNN. "
|
raise ValueError("CrossedColumn is not supported in DNN. "
|
||||||
"Please use embedding_column.".format(self))
|
"Please use embedding_column. column: {}".format(self))
|
||||||
|
|
||||||
def to_weighted_sum(self,
|
def to_weighted_sum(self,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
num_outputs=1,
|
num_outputs=1,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
return _create_embedding_lookup(
|
output, embedding_weights = _create_embedding_lookup(
|
||||||
input_tensor=input_tensor,
|
input_tensor=input_tensor,
|
||||||
weight_tensor=None,
|
weight_tensor=None,
|
||||||
vocab_size=self.length,
|
vocab_size=self.length,
|
||||||
@ -1232,10 +1312,20 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
initializer=init_ops.zeros_initializer,
|
initializer=init_ops.zeros_initializer,
|
||||||
combiner=self.combiner,
|
combiner=self.combiner,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=self.name + "_weights")
|
name=self.name)
|
||||||
|
if self.ckpt_to_load_from is not None:
|
||||||
|
weights_to_restore = embedding_weights
|
||||||
|
if len(embedding_weights) == 1:
|
||||||
|
weights_to_restore = embedding_weights[0]
|
||||||
|
checkpoint_utils.init_from_checkpoint(
|
||||||
|
self.ckpt_to_load_from,
|
||||||
|
{self.tensor_name_in_ckpt: weights_to_restore})
|
||||||
|
return output, embedding_weights
|
||||||
|
|
||||||
|
|
||||||
def crossed_column(columns, hash_bucket_size, combiner="sum"):
|
def crossed_column(columns, hash_bucket_size, combiner="sum",
|
||||||
|
ckpt_to_load_from=None,
|
||||||
|
tensor_name_in_ckpt=None):
|
||||||
"""Creates a _CrossedColumn.
|
"""Creates a _CrossedColumn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1243,6 +1333,12 @@ def crossed_column(columns, hash_bucket_size, combiner="sum"):
|
|||||||
_SparseColumn, _CrossedColumn, or _BucketizedColumn.
|
_SparseColumn, _CrossedColumn, or _BucketizedColumn.
|
||||||
hash_bucket_size: An int that is > 1. The number of buckets.
|
hash_bucket_size: An int that is > 1. The number of buckets.
|
||||||
combiner: A combiner string, supports sum, mean, sqrtn.
|
combiner: A combiner string, supports sum, mean, sqrtn.
|
||||||
|
ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
|
||||||
|
to restore the column weights. Required if `tensor_name_in_ckpt` is not
|
||||||
|
None.
|
||||||
|
tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
|
||||||
|
checkpoint from which to restore the column weights. Required if
|
||||||
|
`ckpt_to_load_from` is not None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A _CrossedColumn.
|
A _CrossedColumn.
|
||||||
@ -1254,12 +1350,14 @@ def crossed_column(columns, hash_bucket_size, combiner="sum"):
|
|||||||
ValueError: if hash_bucket_size is not > 1 or
|
ValueError: if hash_bucket_size is not > 1 or
|
||||||
len(columns) is not > 1.
|
len(columns) is not > 1.
|
||||||
"""
|
"""
|
||||||
return _CrossedColumn(columns, hash_bucket_size, combiner=combiner)
|
return _CrossedColumn(columns, hash_bucket_size, combiner=combiner,
|
||||||
|
ckpt_to_load_from=ckpt_to_load_from,
|
||||||
|
tensor_name_in_ckpt=tensor_name_in_ckpt)
|
||||||
|
|
||||||
|
|
||||||
class DataFrameColumn(_FeatureColumn,
|
class DataFrameColumn(_FeatureColumn,
|
||||||
collections.namedtuple("DataFrameColumn",
|
collections.namedtuple("DataFrameColumn",
|
||||||
["name", "series"])):
|
["column_name", "series"])):
|
||||||
"""Represents a feature column produced from a `DataFrame`.
|
"""Represents a feature column produced from a `DataFrame`.
|
||||||
|
|
||||||
Instances of this class are immutable. A `DataFrame` column may be dense or
|
Instances of this class are immutable. A `DataFrame` column may be dense or
|
||||||
@ -1267,13 +1365,17 @@ class DataFrameColumn(_FeatureColumn,
|
|||||||
batch_size.
|
batch_size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: a name for this column
|
column_name: a name for this column
|
||||||
series: a `Series` to be wrapped, which has already had its base features
|
series: a `Series` to be wrapped, which has already had its base features
|
||||||
substituted with `PredefinedSeries`.
|
substituted with `PredefinedSeries`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, name, series):
|
def __new__(cls, column_name, series):
|
||||||
return super(DataFrameColumn, cls).__new__(cls, name, series)
|
return super(DataFrameColumn, cls).__new__(cls, column_name, series)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.column_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self):
|
def config(self):
|
||||||
@ -1301,7 +1403,17 @@ class DataFrameColumn(_FeatureColumn,
|
|||||||
input_tensor,
|
input_tensor,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
return input_tensor
|
# DataFrame typically provides Tensors of shape [batch_size],
|
||||||
|
# but Estimator requires shape [batch_size, 1]
|
||||||
|
dims = input_tensor.get_shape().ndims
|
||||||
|
if dims == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Can't build input layer from tensor of shape (): {}".format(
|
||||||
|
self.column_name))
|
||||||
|
elif dims == 1:
|
||||||
|
return array_ops.expand_dims(input_tensor, 1)
|
||||||
|
else:
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
# TODO(soergel): This mirrors RealValuedColumn for now, but should become
|
# TODO(soergel): This mirrors RealValuedColumn for now, but should become
|
||||||
# better abstracted with less code duplication when we add other kinds.
|
# better abstracted with less code duplication when we add other kinds.
|
||||||
@ -1469,7 +1581,7 @@ def _create_embeddings(name, shape, dtype, initializer, trainable,
|
|||||||
with just one variable.
|
with just one variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: A string specifying the name of the embedding variable.
|
name: A string. The name of the embedding variable will be name + _weights.
|
||||||
shape: shape of the embeddding. Note this is not the shape of partitioned
|
shape: shape of the embeddding. Note this is not the shape of partitioned
|
||||||
variables.
|
variables.
|
||||||
dtype: type of the embedding. Also the shape of each partitioned variable.
|
dtype: type of the embedding. Also the shape of each partitioned variable.
|
||||||
@ -1531,7 +1643,7 @@ def _create_embedding_lookup(input_tensor, weight_tensor, vocab_size, dimension,
|
|||||||
A Tensor with shape [batch_size, dimension] and embedding Variable.
|
A Tensor with shape [batch_size, dimension] and embedding Variable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embeddings = _create_embeddings(name=name,
|
embeddings = _create_embeddings(name=name + "_weights",
|
||||||
shape=[vocab_size, dimension],
|
shape=[vocab_size, dimension],
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
@ -1543,4 +1655,4 @@ def _create_embedding_lookup(input_tensor, weight_tensor, vocab_size, dimension,
|
|||||||
sparse_weights=weight_tensor,
|
sparse_weights=weight_tensor,
|
||||||
default_id=0,
|
default_id=0,
|
||||||
combiner=combiner,
|
combiner=combiner,
|
||||||
name=name), embeddings
|
name=name + "_weights"), embeddings
|
||||||
|
@ -393,6 +393,24 @@ class InputLayerTest(tf.test.TestCase):
|
|||||||
tf.initialize_all_tables().run()
|
tf.initialize_all_tables().run()
|
||||||
self.assertAllEqual(output.eval().shape, [2, 10])
|
self.assertAllEqual(output.eval().shape, [2, 10])
|
||||||
|
|
||||||
|
def testEmbeddingColumnWitCrossedColumn(self):
|
||||||
|
a = tf.contrib.layers.sparse_column_with_hash_bucket("aaa",
|
||||||
|
hash_bucket_size=100)
|
||||||
|
b = tf.contrib.layers.sparse_column_with_hash_bucket("bbb",
|
||||||
|
hash_bucket_size=100)
|
||||||
|
crossed = tf.contrib.layers.crossed_column(
|
||||||
|
set([a, b]), hash_bucket_size=10000)
|
||||||
|
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],
|
||||||
|
indices=[[0, 0], [1, 0], [1, 1]],
|
||||||
|
shape=[2, 2])
|
||||||
|
features = {"aaa": wire_tensor, "bbb": wire_tensor}
|
||||||
|
embeded_sparse = tf.contrib.layers.embedding_column(crossed, 10)
|
||||||
|
output = tf.contrib.layers.input_from_feature_columns(features,
|
||||||
|
[embeded_sparse])
|
||||||
|
with self.test_session():
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
self.assertAllEqual(output.eval().shape, [2, 10])
|
||||||
|
|
||||||
def testSparseColumn(self):
|
def testSparseColumn(self):
|
||||||
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
|
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
|
||||||
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],
|
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],
|
||||||
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
@ -58,14 +60,17 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
self.assertEqual(b.dimension, 10)
|
self.assertEqual(b.dimension, 10)
|
||||||
self.assertTrue(b.default_value is None)
|
self.assertTrue(b.default_value is None)
|
||||||
|
|
||||||
# dimension is an integer
|
with self.assertRaisesRegexp(TypeError, "dimension must be an integer"):
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
tf.contrib.layers.real_valued_column("d3", dimension=1.0)
|
tf.contrib.layers.real_valued_column("d3", dimension=1.0)
|
||||||
|
|
||||||
# dimension is a positive integer
|
with self.assertRaisesRegexp(ValueError,
|
||||||
with self.assertRaises(ValueError):
|
"dimension must be greater than 0"):
|
||||||
tf.contrib.layers.real_valued_column("d3", dimension=0)
|
tf.contrib.layers.real_valued_column("d3", dimension=0)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"dtype must be convertible to float"):
|
||||||
|
tf.contrib.layers.real_valued_column("d3", dtype=tf.string)
|
||||||
|
|
||||||
# default_value is an integer.
|
# default_value is an integer.
|
||||||
c1 = tf.contrib.layers.real_valued_column("c1", default_value=2)
|
c1 = tf.contrib.layers.real_valued_column("c1", default_value=2)
|
||||||
self.assertListEqual(list(c1.default_value), [2.])
|
self.assertListEqual(list(c1.default_value), [2.])
|
||||||
@ -90,15 +95,18 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
dimension=4,
|
dimension=4,
|
||||||
default_value=2.)
|
default_value=2.)
|
||||||
self.assertListEqual(list(d2.default_value), [2., 2., 2., 2.])
|
self.assertListEqual(list(d2.default_value), [2., 2., 2., 2.])
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(TypeError,
|
||||||
|
"default_value must be compatible with dtype"):
|
||||||
tf.contrib.layers.real_valued_column("d3",
|
tf.contrib.layers.real_valued_column("d3",
|
||||||
default_value=2.,
|
default_value=2.,
|
||||||
dtype=tf.int32)
|
dtype=tf.int32)
|
||||||
|
|
||||||
# default_value is neither interger nor float.
|
# default_value is neither integer nor float.
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "default_value must be compatible with dtype"):
|
||||||
tf.contrib.layers.real_valued_column("e1", default_value="string")
|
tf.contrib.layers.real_valued_column("e1", default_value="string")
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "default_value must be compatible with dtype"):
|
||||||
tf.contrib.layers.real_valued_column("e1",
|
tf.contrib.layers.real_valued_column("e1",
|
||||||
dimension=3,
|
dimension=3,
|
||||||
default_value=[1, 3., "string"])
|
default_value=[1, 3., "string"])
|
||||||
@ -123,11 +131,13 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
dimension=3,
|
dimension=3,
|
||||||
default_value=[2., 2, 2])
|
default_value=[2., 2, 2])
|
||||||
self.assertListEqual(list(g2.default_value), [2., 2., 2.])
|
self.assertListEqual(list(g2.default_value), [2., 2., 2.])
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "default_value must be compatible with dtype"):
|
||||||
tf.contrib.layers.real_valued_column("g3",
|
tf.contrib.layers.real_valued_column("g3",
|
||||||
default_value=[2.],
|
default_value=[2.],
|
||||||
dtype=tf.int32)
|
dtype=tf.int32)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "The length of default_value must be equal to dimension"):
|
||||||
tf.contrib.layers.real_valued_column("g4",
|
tf.contrib.layers.real_valued_column("g4",
|
||||||
dimension=3,
|
dimension=3,
|
||||||
default_value=[2.])
|
default_value=[2.])
|
||||||
@ -138,11 +148,19 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
self.assertEqual(a.name, "aaa_BUCKETIZED")
|
self.assertEqual(a.name, "aaa_BUCKETIZED")
|
||||||
|
|
||||||
def testBucketizedColumnRequiresRealValuedColumn(self):
|
def testBucketizedColumnRequiresRealValuedColumn(self):
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "source_column must be an instance of _RealValuedColumn"):
|
||||||
tf.contrib.layers.bucketized_column("bbb", [0])
|
tf.contrib.layers.bucketized_column("bbb", [0])
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "source_column must be an instance of _RealValuedColumn"):
|
||||||
|
tf.contrib.layers.bucketized_column(
|
||||||
|
tf.contrib.layers.sparse_column_with_integerized_feature(
|
||||||
|
column_name="bbb", bucket_size=10),
|
||||||
|
[0])
|
||||||
|
|
||||||
def testBucketizedColumnRequiresSortedBuckets(self):
|
def testBucketizedColumnRequiresSortedBuckets(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "boundaries must be a sorted list"):
|
||||||
tf.contrib.layers.bucketized_column(
|
tf.contrib.layers.bucketized_column(
|
||||||
tf.contrib.layers.real_valued_column("ccc"), [5, 0, 4])
|
tf.contrib.layers.real_valued_column("ccc"), [5, 0, 4])
|
||||||
|
|
||||||
@ -171,7 +189,10 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
def testCrossedColumnNotSupportRealValuedColumn(self):
|
def testCrossedColumnNotSupportRealValuedColumn(self):
|
||||||
b = tf.contrib.layers.sparse_column_with_hash_bucket("bbb",
|
b = tf.contrib.layers.sparse_column_with_hash_bucket("bbb",
|
||||||
hash_bucket_size=100)
|
hash_bucket_size=100)
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError,
|
||||||
|
"columns must be a set of _SparseColumn, _CrossedColumn, "
|
||||||
|
"or _BucketizedColumn instances"):
|
||||||
tf.contrib.layers.crossed_column(
|
tf.contrib.layers.crossed_column(
|
||||||
set([b, tf.contrib.layers.real_valued_column("real")]),
|
set([b, tf.contrib.layers.real_valued_column("real")]),
|
||||||
hash_bucket_size=10000)
|
hash_bucket_size=10000)
|
||||||
@ -192,7 +213,8 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
"weights": tf.VarLenFeature(tf.int32)},
|
"weights": tf.VarLenFeature(tf.int32)},
|
||||||
weighted_ids.config)
|
weighted_ids.config)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"dtype is not convertible to float"):
|
||||||
weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights",
|
weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights",
|
||||||
dtype=tf.string)
|
dtype=tf.string)
|
||||||
|
|
||||||
@ -209,7 +231,8 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
[1], dtype=tf.int32)},
|
[1], dtype=tf.int32)},
|
||||||
rvc.config)
|
rvc.config)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"dtype must be convertible to float"):
|
||||||
tf.contrib.layers.real_valued_column("rvc", dtype=tf.string)
|
tf.contrib.layers.real_valued_column("rvc", dtype=tf.string)
|
||||||
|
|
||||||
def testSparseColumnDtypes(self):
|
def testSparseColumnDtypes(self):
|
||||||
@ -220,7 +243,8 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
"sc", 10, dtype=tf.int32)
|
"sc", 10, dtype=tf.int32)
|
||||||
self.assertDictEqual({"sc": tf.VarLenFeature(dtype=tf.int32)}, sc.config)
|
self.assertDictEqual({"sc": tf.VarLenFeature(dtype=tf.int32)}, sc.config)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"dtype must be an integer"):
|
||||||
tf.contrib.layers.sparse_column_with_integerized_feature("sc",
|
tf.contrib.layers.sparse_column_with_integerized_feature("sc",
|
||||||
10,
|
10,
|
||||||
dtype=tf.float32)
|
dtype=tf.float32)
|
||||||
@ -323,6 +347,107 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
self.assertEqual(tf.float32, placeholder.dtype)
|
self.assertEqual(tf.float32, placeholder.dtype)
|
||||||
self.assertEqual([None, 1], placeholder.get_shape().as_list())
|
self.assertEqual([None, 1], placeholder.get_shape().as_list())
|
||||||
|
|
||||||
|
def testInitEmbeddingColumnWeightsFromCkpt(self):
|
||||||
|
sparse_col = tf.contrib.layers.sparse_column_with_hash_bucket(
|
||||||
|
column_name="object_in_image",
|
||||||
|
hash_bucket_size=4)
|
||||||
|
# Create _EmbeddingColumn which randomly initializes embedding of size
|
||||||
|
# [4, 16].
|
||||||
|
embedding_col = tf.contrib.layers.embedding_column(sparse_col, dimension=16)
|
||||||
|
|
||||||
|
# Creating a SparseTensor which has all the ids possible for the given
|
||||||
|
# vocab.
|
||||||
|
input_tensor = tf.SparseTensor(indices=[[0, 0], [1, 1], [2, 2], [3, 3]],
|
||||||
|
values=[0, 1, 2, 3],
|
||||||
|
shape=[4, 4])
|
||||||
|
|
||||||
|
# Invoking 'embedding_column.to_dnn_input_layer' will create the embedding
|
||||||
|
# variable. Creating under scope 'run_1' so as to prevent name conflicts
|
||||||
|
# when creating embedding variable for 'embedding_column_pretrained'.
|
||||||
|
with tf.variable_scope("run_1"):
|
||||||
|
# This will return a [4, 16] tensor which is same as embedding variable.
|
||||||
|
embeddings = embedding_col.to_dnn_input_layer(input_tensor)
|
||||||
|
|
||||||
|
save = tf.train.Saver()
|
||||||
|
checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt")
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
saved_embedding = embeddings.eval()
|
||||||
|
save.save(sess, checkpoint_path)
|
||||||
|
|
||||||
|
embedding_col_initialized = tf.contrib.layers.embedding_column(
|
||||||
|
sparse_id_column=sparse_col,
|
||||||
|
dimension=16,
|
||||||
|
ckpt_to_load_from=checkpoint_path,
|
||||||
|
tensor_name_in_ckpt="run_1/object_in_image_embedding_weights")
|
||||||
|
|
||||||
|
with tf.variable_scope("run_2"):
|
||||||
|
# This will initialize the embedding from provided checkpoint and return a
|
||||||
|
# [4, 16] tensor which is same as embedding variable. Since we didn't
|
||||||
|
# modify embeddings, this should be same as 'saved_embedding'.
|
||||||
|
pretrained_embeddings = embedding_col_initialized.to_dnn_input_layer(
|
||||||
|
input_tensor)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
loaded_embedding = pretrained_embeddings.eval()
|
||||||
|
|
||||||
|
self.assertAllClose(saved_embedding, loaded_embedding)
|
||||||
|
|
||||||
|
def testInitCrossedColumnWeightsFromCkpt(self):
|
||||||
|
sparse_col_1 = tf.contrib.layers.sparse_column_with_hash_bucket(
|
||||||
|
column_name="col_1", hash_bucket_size=4)
|
||||||
|
sparse_col_2 = tf.contrib.layers.sparse_column_with_hash_bucket(
|
||||||
|
column_name="col_2", hash_bucket_size=4)
|
||||||
|
|
||||||
|
crossed_col = tf.contrib.layers.crossed_column(
|
||||||
|
columns=[sparse_col_1, sparse_col_2],
|
||||||
|
hash_bucket_size=4)
|
||||||
|
|
||||||
|
input_tensor = tf.SparseTensor(indices=[[0, 0], [1, 1], [2, 2], [3, 3]],
|
||||||
|
values=[0, 1, 2, 3],
|
||||||
|
shape=[4, 4])
|
||||||
|
|
||||||
|
# Invoking 'crossed_col.to_weighted_sum' will create the crossed column
|
||||||
|
# weights variable.
|
||||||
|
with tf.variable_scope("run_1"):
|
||||||
|
# Returns looked up column weights which is same as crossed column weights
|
||||||
|
# as well as actual references to weights variables.
|
||||||
|
col_weights, weights = crossed_col.to_weighted_sum(input_tensor)
|
||||||
|
# Update the weights since default initializer initializes all weights to
|
||||||
|
# 0.0.
|
||||||
|
for weight in weights:
|
||||||
|
assign_op = tf.assign(weight, weight + 0.5)
|
||||||
|
|
||||||
|
save = tf.train.Saver()
|
||||||
|
checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt")
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
sess.run(assign_op)
|
||||||
|
saved_col_weights = col_weights.eval()
|
||||||
|
save.save(sess, checkpoint_path)
|
||||||
|
|
||||||
|
crossed_col_initialized = tf.contrib.layers.crossed_column(
|
||||||
|
columns=[sparse_col_1, sparse_col_2],
|
||||||
|
hash_bucket_size=4,
|
||||||
|
ckpt_to_load_from=checkpoint_path,
|
||||||
|
tensor_name_in_ckpt="run_1/col_1_X_col_2_weights")
|
||||||
|
|
||||||
|
with tf.variable_scope("run_2"):
|
||||||
|
# This will initialize the crossed column weights from provided checkpoint
|
||||||
|
# and return a [4, 1] tensor which is same as weights variable. Since we
|
||||||
|
# won't modify weights, this should be same as 'saved_col_weights'.
|
||||||
|
col_weights_from_ckpt, _ = crossed_col_initialized.to_weighted_sum(
|
||||||
|
input_tensor)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
loaded_col_weights = col_weights_from_ckpt.eval()
|
||||||
|
|
||||||
|
self.assertAllClose(saved_col_weights, loaded_col_weights)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -102,12 +102,13 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False,
|
|||||||
TypeError: if `mode` is not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG'].
|
TypeError: if `mode` is not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG'].
|
||||||
"""
|
"""
|
||||||
if not dtype.is_floating:
|
if not dtype.is_floating:
|
||||||
raise TypeError('Cannot create initializer for non-floating point '
|
raise TypeError('Cannot create initializer for non-floating point type.')
|
||||||
'type.')
|
|
||||||
if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']:
|
if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']:
|
||||||
raise TypeError('Unknow mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode)
|
raise TypeError('Unknow mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode)
|
||||||
def _initializer(shape, dtype=dtype):
|
def _initializer(shape, dtype=dtype):
|
||||||
"""Initializer function."""
|
"""Initializer function."""
|
||||||
|
if not dtype.is_floating:
|
||||||
|
raise TypeError('Cannot create initializer for non-floating point type.')
|
||||||
# Estimating fan_in and fan_out is not possible to do perfectly, but we try.
|
# Estimating fan_in and fan_out is not possible to do perfectly, but we try.
|
||||||
# This is the right thing for matrix multiply and convolutions.
|
# This is the right thing for matrix multiply and convolutions.
|
||||||
fan_in = float(shape[-2])
|
fan_in = float(shape[-2])
|
||||||
|
@ -64,6 +64,11 @@ class VarianceScalingInitializerTest(tf.test.TestCase):
|
|||||||
TypeError,
|
TypeError,
|
||||||
'Cannot create initializer for non-floating point type.'):
|
'Cannot create initializer for non-floating point type.'):
|
||||||
tf.contrib.layers.variance_scaling_initializer(dtype=tf.int32)
|
tf.contrib.layers.variance_scaling_initializer(dtype=tf.int32)
|
||||||
|
initializer = tf.contrib.layers.variance_scaling_initializer()
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError,
|
||||||
|
'Cannot create initializer for non-floating point type.'):
|
||||||
|
initializer([], dtype=tf.int32)
|
||||||
|
|
||||||
def _test_variance(self, initializer, shape, variance, factor, mode, uniform):
|
def _test_variance(self, initializer, shape, variance, factor, mode, uniform):
|
||||||
with tf.Graph().as_default() as g:
|
with tf.Graph().as_default() as g:
|
||||||
|
@ -75,25 +75,24 @@ def avg_pool2d(inputs,
|
|||||||
padding='VALID',
|
padding='VALID',
|
||||||
outputs_collections=None,
|
outputs_collections=None,
|
||||||
scope=None):
|
scope=None):
|
||||||
"""Adds a Avg Pooling op.
|
"""Adds a 2D average pooling op.
|
||||||
|
|
||||||
It is assumed by the wrapper that the pooling is only done per image and not
|
It is assumed that the pooling is done per image but not in batch or channels.
|
||||||
in depth or batch.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: a tensor of size [batch_size, height, width, depth].
|
inputs: A `Tensor` of size [batch_size, height, width, channels].
|
||||||
kernel_size: a list of length 2: [kernel_height, kernel_width] of the
|
kernel_size: A list of length 2: [kernel_height, kernel_width] of the
|
||||||
pooling kernel over which the op is computed. Can be an int if both
|
pooling kernel over which the op is computed. Can be an int if both
|
||||||
values are the same.
|
values are the same.
|
||||||
stride: a list of length 2: [stride_height, stride_width].
|
stride: A list of length 2: [stride_height, stride_width].
|
||||||
Can be an int if both strides are the same. Note that presently
|
Can be an int if both strides are the same. Note that presently
|
||||||
both strides must have the same value.
|
both strides must have the same value.
|
||||||
padding: the padding method, either 'VALID' or 'SAME'.
|
padding: The padding method, either 'VALID' or 'SAME'.
|
||||||
outputs_collections: collection to add the outputs.
|
outputs_collections: The collections to which the outputs are added.
|
||||||
scope: Optional scope for op_scope.
|
scope: Optional scope for op_scope.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a tensor representing the results of the pooling operation.
|
A `Tensor` representing the results of the pooling operation.
|
||||||
"""
|
"""
|
||||||
with ops.op_scope([inputs], scope, 'AvgPool2D') as sc:
|
with ops.op_scope([inputs], scope, 'AvgPool2D') as sc:
|
||||||
inputs = ops.convert_to_tensor(inputs)
|
inputs = ops.convert_to_tensor(inputs)
|
||||||
@ -843,27 +842,27 @@ def max_pool2d(inputs,
|
|||||||
padding='VALID',
|
padding='VALID',
|
||||||
outputs_collections=None,
|
outputs_collections=None,
|
||||||
scope=None):
|
scope=None):
|
||||||
"""Adds a Max Pooling op.
|
"""Adds a 2D Max Pooling op.
|
||||||
|
|
||||||
It is assumed by the wrapper that the pooling is only done per image and not
|
It is assumed that the pooling is done per image but not in batch or channels.
|
||||||
in depth or batch.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: a tensor of size [batch_size, height, width, depth].
|
inputs: A `Tensor` of size [batch_size, height, width, channels].
|
||||||
kernel_size: a list of length 2: [kernel_height, kernel_width] of the
|
kernel_size: A list of length 2: [kernel_height, kernel_width] of the
|
||||||
pooling kernel over which the op is computed. Can be an int if both
|
pooling kernel over which the op is computed. Can be an int if both
|
||||||
values are the same.
|
values are the same.
|
||||||
stride: a list of length 2: [stride_height, stride_width].
|
stride: A list of length 2: [stride_height, stride_width].
|
||||||
Can be an int if both strides are the same. Note that presently
|
Can be an int if both strides are the same. Note that presently
|
||||||
both strides must have the same value.
|
both strides must have the same value.
|
||||||
padding: the padding method, either 'VALID' or 'SAME'.
|
padding: The padding method, either 'VALID' or 'SAME'.
|
||||||
outputs_collections: collection to add the outputs.
|
outputs_collections: The collections to which the outputs are added.
|
||||||
scope: Optional scope for op_scope.
|
scope: Optional scope for op_scope.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a tensor representing the results of the pooling operation.
|
A `Tensor` representing the results of the pooling operation.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if 'kernel_size' is not a 2-D list
|
ValueError: If 'kernel_size' is not a 2-D list
|
||||||
"""
|
"""
|
||||||
with ops.op_scope([inputs], scope, 'MaxPool2D') as sc:
|
with ops.op_scope([inputs], scope, 'MaxPool2D') as sc:
|
||||||
inputs = ops.convert_to_tensor(inputs)
|
inputs = ops.convert_to_tensor(inputs)
|
||||||
@ -1037,6 +1036,7 @@ def separable_convolution2d(
|
|||||||
depthwise_weights = variables.model_variable(
|
depthwise_weights = variables.model_variable(
|
||||||
'depthwise_weights',
|
'depthwise_weights',
|
||||||
shape=depthwise_shape,
|
shape=depthwise_shape,
|
||||||
|
dtype=dtype,
|
||||||
initializer=weights_initializer,
|
initializer=weights_initializer,
|
||||||
regularizer=weights_regularizer,
|
regularizer=weights_regularizer,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
@ -1049,6 +1049,7 @@ def separable_convolution2d(
|
|||||||
pointwise_weights = variables.model_variable(
|
pointwise_weights = variables.model_variable(
|
||||||
'pointwise_weights',
|
'pointwise_weights',
|
||||||
shape=pointwise_shape,
|
shape=pointwise_shape,
|
||||||
|
dtype=dtype,
|
||||||
initializer=weights_initializer,
|
initializer=weights_initializer,
|
||||||
regularizer=weights_regularizer,
|
regularizer=weights_regularizer,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
|
@ -30,59 +30,52 @@ class AvgPool2DTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testCreateAvgPool(self):
|
def testCreateAvgPool(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = np.random.uniform(size=(5, height, width, 3))
|
||||||
images = np.random.uniform(size=(5, height, width, 3))
|
output = tf.contrib.layers.avg_pool2d(images, [3, 3])
|
||||||
output = tf.contrib.layers.avg_pool2d(images, [3, 3])
|
self.assertEquals(output.op.name, 'AvgPool2D/AvgPool')
|
||||||
self.assertEquals(output.op.name, 'AvgPool2D/AvgPool')
|
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
|
||||||
|
|
||||||
def testCollectOutputs(self):
|
def testCollectOutputs(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.avg_pool2d(images, [3, 3],
|
||||||
output = tf.contrib.layers.avg_pool2d(images, [3, 3],
|
outputs_collections='outputs')
|
||||||
outputs_collections='outputs')
|
output_collection = tf.get_collection('outputs')[0]
|
||||||
c_output = tf.get_collection('outputs')[0]
|
self.assertEquals(output_collection.name, 'AvgPool2D')
|
||||||
self.assertEquals(c_output.name, 'AvgPool2D')
|
self.assertEquals(output_collection.outputs, output)
|
||||||
self.assertEquals(c_output.outputs, output)
|
|
||||||
|
|
||||||
def testCreateSquareAvgPool(self):
|
def testCreateSquareAvgPool(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.avg_pool2d(images, 3)
|
||||||
output = tf.contrib.layers.avg_pool2d(images, 3)
|
self.assertEquals(output.op.name, 'AvgPool2D/AvgPool')
|
||||||
self.assertEquals(output.op.name, 'AvgPool2D/AvgPool')
|
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
|
||||||
|
|
||||||
def testCreateAvgPoolWithScope(self):
|
def testCreateAvgPoolWithScope(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.avg_pool2d(images, [3, 3], scope='pool1')
|
||||||
output = tf.contrib.layers.avg_pool2d(images, [3, 3], scope='pool1')
|
self.assertEquals(output.op.name, 'pool1/AvgPool')
|
||||||
self.assertEquals(output.op.name, 'pool1/AvgPool')
|
|
||||||
|
|
||||||
def testCreateAvgPoolSAME(self):
|
def testCreateAvgPoolWithSamePadding(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.avg_pool2d(images, [3, 3], padding='SAME')
|
||||||
output = tf.contrib.layers.avg_pool2d(images, [3, 3], padding='SAME')
|
self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
|
|
||||||
|
|
||||||
def testCreateAvgPoolStrideSAME(self):
|
def testCreateAvgPoolStrideWithSamePadding(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.avg_pool2d(images, [3, 3], stride=1,
|
||||||
output = tf.contrib.layers.avg_pool2d(images, [3, 3], stride=1,
|
padding='SAME')
|
||||||
padding='SAME')
|
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
|
|
||||||
|
|
||||||
def testGlobalAvgPool(self):
|
def testGlobalAvgPool(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.avg_pool2d(images, images.get_shape()[1:3],
|
||||||
output = tf.contrib.layers.avg_pool2d(images, images.get_shape()[1:3],
|
stride=1)
|
||||||
stride=1)
|
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
|
||||||
|
|
||||||
|
|
||||||
class BiasAddTest(tf.test.TestCase):
|
class BiasAddTest(tf.test.TestCase):
|
||||||
@ -825,7 +818,7 @@ class DropoutTest(tf.test.TestCase):
|
|||||||
with self.test_session():
|
with self.test_session():
|
||||||
images = np.random.uniform(size=(5, height, width, 3))
|
images = np.random.uniform(size=(5, height, width, 3))
|
||||||
output = tf.contrib.layers.dropout(images)
|
output = tf.contrib.layers.dropout(images)
|
||||||
self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
|
self.assertEquals(output.op.name, 'Dropout/dropout/mul')
|
||||||
output.get_shape().assert_is_compatible_with(
|
output.get_shape().assert_is_compatible_with(
|
||||||
tf.convert_to_tensor(images).get_shape())
|
tf.convert_to_tensor(images).get_shape())
|
||||||
|
|
||||||
@ -835,7 +828,7 @@ class DropoutTest(tf.test.TestCase):
|
|||||||
is_training = tf.constant(True)
|
is_training = tf.constant(True)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
output = tf.contrib.layers.dropout(images, is_training=is_training)
|
output = tf.contrib.layers.dropout(images, is_training=is_training)
|
||||||
self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
|
self.assertEquals(output.op.name, 'Dropout/dropout/mul')
|
||||||
output.get_shape().assert_is_compatible_with(images.get_shape())
|
output.get_shape().assert_is_compatible_with(images.get_shape())
|
||||||
|
|
||||||
def testCreateDropoutWithConstantFalse(self):
|
def testCreateDropoutWithConstantFalse(self):
|
||||||
@ -1502,59 +1495,52 @@ class MaxPool2DTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testCreateMaxPool(self):
|
def testCreateMaxPool(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32)
|
||||||
images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32)
|
output = tf.contrib.layers.max_pool2d(images, [3, 3])
|
||||||
output = tf.contrib.layers.max_pool2d(images, [3, 3])
|
self.assertEquals(output.op.name, 'MaxPool2D/MaxPool')
|
||||||
self.assertEquals(output.op.name, 'MaxPool2D/MaxPool')
|
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
|
||||||
|
|
||||||
def testCollectOutputs(self):
|
def testCollectOutputs(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.max_pool2d(images, [3, 3],
|
||||||
output = tf.contrib.layers.max_pool2d(images, [3, 3],
|
outputs_collections='outputs')
|
||||||
outputs_collections='outputs')
|
outputs_collection = tf.get_collection('outputs')[0]
|
||||||
c_output = tf.get_collection('outputs')[0]
|
self.assertEquals(outputs_collection.name, 'MaxPool2D')
|
||||||
self.assertEquals(c_output.name, 'MaxPool2D')
|
self.assertEquals(outputs_collection.outputs, output)
|
||||||
self.assertEquals(c_output.outputs, output)
|
|
||||||
|
|
||||||
def testCreateSquareMaxPool(self):
|
def testCreateSquareMaxPool(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.max_pool2d(images, 3)
|
||||||
output = tf.contrib.layers.max_pool2d(images, 3)
|
self.assertEquals(output.op.name, 'MaxPool2D/MaxPool')
|
||||||
self.assertEquals(output.op.name, 'MaxPool2D/MaxPool')
|
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
|
||||||
|
|
||||||
def testCreateMaxPoolWithScope(self):
|
def testCreateMaxPoolWithScope(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.max_pool2d(images, [3, 3], scope='pool1')
|
||||||
output = tf.contrib.layers.max_pool2d(images, [3, 3], scope='pool1')
|
self.assertEquals(output.op.name, 'pool1/MaxPool')
|
||||||
self.assertEquals(output.op.name, 'pool1/MaxPool')
|
|
||||||
|
|
||||||
def testCreateMaxPoolSAME(self):
|
def testCreateMaxPoolWithSamePadding(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.max_pool2d(images, [3, 3], padding='SAME')
|
||||||
output = tf.contrib.layers.max_pool2d(images, [3, 3], padding='SAME')
|
self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
|
|
||||||
|
|
||||||
def testCreateMaxPoolStrideSAME(self):
|
def testCreateMaxPoolStrideWithSamePadding(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.max_pool2d(images, [3, 3], stride=1,
|
||||||
output = tf.contrib.layers.max_pool2d(images, [3, 3], stride=1,
|
padding='SAME')
|
||||||
padding='SAME')
|
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
|
|
||||||
|
|
||||||
def testGlobalMaxPool(self):
|
def testGlobalMaxPool(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
output = tf.contrib.layers.max_pool2d(images, images.get_shape()[1:3],
|
||||||
output = tf.contrib.layers.max_pool2d(images, images.get_shape()[1:3],
|
stride=1)
|
||||||
stride=1)
|
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
|
|
||||||
|
|
||||||
|
|
||||||
class OneHotEncodingTest(tf.test.TestCase):
|
class OneHotEncodingTest(tf.test.TestCase):
|
||||||
@ -1618,10 +1604,28 @@ class RepeatTests(tf.test.TestCase):
|
|||||||
|
|
||||||
class SeparableConv2dTest(tf.test.TestCase):
|
class SeparableConv2dTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testCreateConv(self):
|
def testCreateConvInt32(self):
|
||||||
height, width = 3, 3
|
height, width = 3, 3
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
images = tf.random_uniform(
|
||||||
|
(5, height, width, 3), seed=1, dtype=tf.int32, maxval=12345)
|
||||||
|
with self.assertRaisesRegexp(TypeError, 'non-floating point type'):
|
||||||
|
tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2)
|
||||||
|
|
||||||
|
def testCreateConvFloat32(self):
|
||||||
|
height, width = 3, 3
|
||||||
|
with self.test_session():
|
||||||
|
images = tf.random_uniform(
|
||||||
|
(5, height, width, 3), seed=1, dtype=tf.float32)
|
||||||
|
output = tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2)
|
||||||
|
self.assertEquals(output.op.name, 'SeparableConv2d/Relu')
|
||||||
|
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
|
||||||
|
|
||||||
|
def testCreateConvFloat64(self):
|
||||||
|
height, width = 3, 3
|
||||||
|
with self.test_session():
|
||||||
|
images = tf.random_uniform(
|
||||||
|
(5, height, width, 3), seed=1, dtype=tf.float64)
|
||||||
output = tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2)
|
output = tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2)
|
||||||
self.assertEquals(output.op.name, 'SeparableConv2d/Relu')
|
self.assertEquals(output.op.name, 'SeparableConv2d/Relu')
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
|
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.ops import logging_ops
|
|||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.ops import variables as vars_
|
from tensorflow.python.ops import variables as vars_
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import optimizer as optimizer_
|
from tensorflow.python.training import optimizer as optimizer_
|
||||||
from tensorflow.python.training import training as train
|
from tensorflow.python.training import training as train
|
||||||
|
|
||||||
@ -43,6 +44,13 @@ OPTIMIZER_CLS_NAMES = {
|
|||||||
"SGD": train.GradientDescentOptimizer,
|
"SGD": train.GradientDescentOptimizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OPTIMIZER_SUMMARIES = [
|
||||||
|
"learning_rate",
|
||||||
|
"loss",
|
||||||
|
"gradients",
|
||||||
|
"gradient_norm",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def optimize_loss(loss,
|
def optimize_loss(loss,
|
||||||
global_step,
|
global_step,
|
||||||
@ -51,11 +59,12 @@ def optimize_loss(loss,
|
|||||||
gradient_noise_scale=None,
|
gradient_noise_scale=None,
|
||||||
gradient_multipliers=None,
|
gradient_multipliers=None,
|
||||||
clip_gradients=None,
|
clip_gradients=None,
|
||||||
moving_average_decay=0.9,
|
moving_average_decay=None,
|
||||||
learning_rate_decay_fn=None,
|
learning_rate_decay_fn=None,
|
||||||
update_ops=None,
|
update_ops=None,
|
||||||
variables=None,
|
variables=None,
|
||||||
name=None):
|
name=None,
|
||||||
|
summaries=None):
|
||||||
"""Given loss and parameters for optimizer, returns a training op.
|
"""Given loss and parameters for optimizer, returns a training op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -75,8 +84,8 @@ def optimize_loss(loss,
|
|||||||
If present, gradients for specified
|
If present, gradients for specified
|
||||||
variables will be multiplied by given constant.
|
variables will be multiplied by given constant.
|
||||||
clip_gradients: float or `None`, clips gradients by this value.
|
clip_gradients: float or `None`, clips gradients by this value.
|
||||||
moving_average_decay: float or None, takes into account previous loss
|
moving_average_decay: Deprecated. float or None, takes into account previous
|
||||||
to make learning smoother due to outliers.
|
loss to make learning smoother due to outliers.
|
||||||
learning_rate_decay_fn: function, takes `learning_rate` and `global_step`
|
learning_rate_decay_fn: function, takes `learning_rate` and `global_step`
|
||||||
`Tensor`s, returns `Tensor`.
|
`Tensor`s, returns `Tensor`.
|
||||||
Can be used to implement any learning rate decay
|
Can be used to implement any learning rate decay
|
||||||
@ -87,6 +96,9 @@ def optimize_loss(loss,
|
|||||||
variables: list of variables to optimize or
|
variables: list of variables to optimize or
|
||||||
`None` to use all trainable variables.
|
`None` to use all trainable variables.
|
||||||
name: The name for this operation is used to scope operations and summaries.
|
name: The name for this operation is used to scope operations and summaries.
|
||||||
|
summaries: List of internal quantities to visualize on tensorboard. If not
|
||||||
|
set only the loss and the learning rate will be reported. The
|
||||||
|
complete list is in OPTIMIZER_SUMMARIES.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Training op.
|
Training op.
|
||||||
@ -96,8 +108,8 @@ def optimize_loss(loss,
|
|||||||
"""
|
"""
|
||||||
with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"):
|
with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"):
|
||||||
# Update ops take UPDATE_OPS collection if not provided.
|
# Update ops take UPDATE_OPS collection if not provided.
|
||||||
update_ops = (set(update_ops or []) or
|
if update_ops is None:
|
||||||
set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)))
|
update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
|
||||||
# Make sure update ops are ran before computing loss.
|
# Make sure update ops are ran before computing loss.
|
||||||
if update_ops:
|
if update_ops:
|
||||||
with ops.control_dependencies(update_ops):
|
with ops.control_dependencies(update_ops):
|
||||||
@ -105,7 +117,10 @@ def optimize_loss(loss,
|
|||||||
loss = control_flow_ops.with_dependencies([barrier], loss)
|
loss = control_flow_ops.with_dependencies([barrier], loss)
|
||||||
|
|
||||||
# Moving average of the loss with decay.
|
# Moving average of the loss with decay.
|
||||||
|
# TODO(b/30439864): moving_average_decay should be removed.
|
||||||
if moving_average_decay is not None:
|
if moving_average_decay is not None:
|
||||||
|
logging.warn("'moving_average_decay' is deprecated. Please use "
|
||||||
|
"tensorboard's builtin averaging instead.")
|
||||||
# Generate moving averages of the loss.
|
# Generate moving averages of the loss.
|
||||||
loss_averages = train.ExponentialMovingAverage(moving_average_decay,
|
loss_averages = train.ExponentialMovingAverage(moving_average_decay,
|
||||||
name="avg")
|
name="avg")
|
||||||
@ -125,9 +140,12 @@ def optimize_loss(loss,
|
|||||||
raise ValueError("Learning rate should be 0d Tensor or float. "
|
raise ValueError("Learning rate should be 0d Tensor or float. "
|
||||||
"Got %s of type %s" % (
|
"Got %s of type %s" % (
|
||||||
str(learning_rate), str(type(learning_rate))))
|
str(learning_rate), str(type(learning_rate))))
|
||||||
|
if summaries is None:
|
||||||
|
summaries = ["loss", "learning_rate"]
|
||||||
if learning_rate_decay_fn is not None:
|
if learning_rate_decay_fn is not None:
|
||||||
lr = learning_rate_decay_fn(lr, global_step)
|
lr = learning_rate_decay_fn(lr, global_step)
|
||||||
logging_ops.scalar_summary("learning_rate", lr)
|
if "learning_rate" in summaries:
|
||||||
|
logging_ops.scalar_summary("learning_rate", lr)
|
||||||
|
|
||||||
# Create optimizer, given specified parameters.
|
# Create optimizer, given specified parameters.
|
||||||
if isinstance(optimizer, six.string_types):
|
if isinstance(optimizer, six.string_types):
|
||||||
@ -167,7 +185,8 @@ def optimize_loss(loss,
|
|||||||
gradients = _clip_gradients_by_norm(gradients, clip_gradients)
|
gradients = _clip_gradients_by_norm(gradients, clip_gradients)
|
||||||
|
|
||||||
# Add scalar summary for loss.
|
# Add scalar summary for loss.
|
||||||
logging_ops.scalar_summary("loss", loss)
|
if "loss" in summaries:
|
||||||
|
logging_ops.scalar_summary("loss", loss)
|
||||||
|
|
||||||
# Add histograms for variables, gradients and gradient norms.
|
# Add histograms for variables, gradients and gradient norms.
|
||||||
for gradient, variable in gradients:
|
for gradient, variable in gradients:
|
||||||
@ -177,10 +196,12 @@ def optimize_loss(loss,
|
|||||||
grad_values = gradient
|
grad_values = gradient
|
||||||
|
|
||||||
if grad_values is not None:
|
if grad_values is not None:
|
||||||
logging_ops.histogram_summary(variable.name, variable)
|
if "gradients" in summaries:
|
||||||
logging_ops.histogram_summary(variable.name + "/gradients", grad_values)
|
logging_ops.histogram_summary(variable.name + "/gradients",
|
||||||
logging_ops.histogram_summary(variable.name + "/gradient_norm",
|
grad_values)
|
||||||
clip_ops.global_norm([grad_values]))
|
if "gradient_norm" in summaries:
|
||||||
|
logging_ops.histogram_summary(variable.name + "/gradient_norm",
|
||||||
|
clip_ops.global_norm([grad_values]))
|
||||||
|
|
||||||
# Create gradient updates.
|
# Create gradient updates.
|
||||||
grad_updates = opt.apply_gradients(gradients,
|
grad_updates = opt.apply_gradients(gradients,
|
||||||
|
@ -75,7 +75,8 @@ class OptimizersTest(tf.test.TestCase):
|
|||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
session.run(train, feed_dict={x: 5})
|
session.run(train, feed_dict={x: 5})
|
||||||
var_value, global_step_value = session.run([var, global_step])
|
var_value, global_step_value = session.run([var, global_step])
|
||||||
self.assertAlmostEqual(var_value, 8.58150, 4)
|
# Due to randomness the following number may change if graph is different.
|
||||||
|
self.assertAlmostEqual(var_value, 8.5591021, 4)
|
||||||
self.assertEqual(global_step_value, 1)
|
self.assertEqual(global_step_value, 1)
|
||||||
|
|
||||||
def testGradientNoiseWithClipping(self):
|
def testGradientNoiseWithClipping(self):
|
||||||
|
@ -22,6 +22,7 @@ import inspect
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.contrib import losses
|
||||||
from tensorflow.contrib import metrics as metrics_lib
|
from tensorflow.contrib import metrics as metrics_lib
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import logging_ops
|
from tensorflow.python.ops import logging_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops import nn_ops
|
|
||||||
|
|
||||||
|
|
||||||
def regression_target(label_name=None,
|
def regression_target(label_name=None,
|
||||||
@ -70,7 +70,7 @@ def multi_class_target(n_classes, label_name=None, weight_column_name=None):
|
|||||||
will be multiplied by the loss of the example.
|
will be multiplied by the loss of the example.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of _TargetColumn
|
An instance of _MultiClassTargetColumn.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if n_classes is < 2
|
ValueError: if n_classes is < 2
|
||||||
@ -297,8 +297,17 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn):
|
|||||||
"""_TargetColumn for binary classification using SVMs."""
|
"""_TargetColumn for binary classification using SVMs."""
|
||||||
|
|
||||||
def __init__(self, label_name, weight_column_name):
|
def __init__(self, label_name, weight_column_name):
|
||||||
|
def loss_fn(logits, target):
|
||||||
|
check_shape_op = logging_ops.Assert(
|
||||||
|
math_ops.less_equal(array_ops.rank(target), 2),
|
||||||
|
["target's shape should be either [batch_size, 1] or [batch_size]"])
|
||||||
|
with ops.control_dependencies([check_shape_op]):
|
||||||
|
target = array_ops.reshape(
|
||||||
|
target, shape=[array_ops.shape(target)[0], 1])
|
||||||
|
return losses.hinge_loss(logits, target)
|
||||||
|
|
||||||
super(_BinarySvmTargetColumn, self).__init__(
|
super(_BinarySvmTargetColumn, self).__init__(
|
||||||
loss_fn=_binary_hinge_loss,
|
loss_fn=loss_fn,
|
||||||
n_classes=2,
|
n_classes=2,
|
||||||
label_name=label_name,
|
label_name=label_name,
|
||||||
weight_column_name=weight_column_name)
|
weight_column_name=weight_column_name)
|
||||||
@ -331,22 +340,6 @@ def _log_loss_with_two_classes(logits, target):
|
|||||||
return loss_vec
|
return loss_vec
|
||||||
|
|
||||||
|
|
||||||
# TODO(sibyl-vie3Poto): Move this to contrib/losses/python/losses/loss_ops.py.
|
|
||||||
def _binary_hinge_loss(logits, target):
|
|
||||||
"""Method that returns the loss vector for binary hinge loss."""
|
|
||||||
check_shape_op = logging_ops.Assert(
|
|
||||||
math_ops.less_equal(
|
|
||||||
array_ops.rank(target), 2),
|
|
||||||
["target's shape should be either [batch_size, 1] or [batch_size]"])
|
|
||||||
with ops.control_dependencies([check_shape_op]):
|
|
||||||
target = array_ops.reshape(target, shape=[array_ops.shape(target)[0], 1])
|
|
||||||
# First need to convert binary labels to -1/1 labels (as floats).
|
|
||||||
all_ones = array_ops.ones_like(logits)
|
|
||||||
labels = math_ops.sub(2 * math_ops.to_float(target), all_ones)
|
|
||||||
loss_vec = nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits)))
|
|
||||||
return loss_vec
|
|
||||||
|
|
||||||
|
|
||||||
def _softmax_cross_entropy_loss(logits, target):
|
def _softmax_cross_entropy_loss(logits, target):
|
||||||
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
|
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
|
||||||
# Check that we got int32/int64 for classification.
|
# Check that we got int32/int64 for classification.
|
||||||
|
@ -36,6 +36,18 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "load_csv_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/learn/tests/load_csv_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":learn",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "data_feeder_test",
|
name = "data_feeder_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -235,9 +247,9 @@ py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "compare_test",
|
name = "binary_transform_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["python/learn/tests/dataframe/compare_test.py"],
|
srcs = ["python/learn/tests/dataframe/binary_transform_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":learn",
|
":learn",
|
||||||
@ -625,19 +637,6 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "checkpoints_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["python/learn/utils/checkpoints_test.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":learn",
|
|
||||||
"//tensorflow:tensorflow_py",
|
|
||||||
"//tensorflow/python:framework",
|
|
||||||
"//tensorflow/python:framework_test_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "graph_io_test",
|
name = "graph_io_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -56,6 +56,7 @@ Below are few simple examples of the API. For more examples, please see [example
|
|||||||
Simple linear classification:
|
Simple linear classification:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import tensorflow.contrib.learn.python.learn as learn
|
||||||
from sklearn import datasets, metrics
|
from sklearn import datasets, metrics
|
||||||
|
|
||||||
iris = datasets.load_iris()
|
iris = datasets.load_iris()
|
||||||
@ -70,6 +71,7 @@ print("Accuracy: %f" % score)
|
|||||||
Simple linear regression:
|
Simple linear regression:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import tensorflow.contrib.learn.python.learn as learn
|
||||||
from sklearn import datasets, metrics, preprocessing
|
from sklearn import datasets, metrics, preprocessing
|
||||||
|
|
||||||
boston = datasets.load_boston()
|
boston = datasets.load_boston()
|
||||||
@ -85,6 +87,7 @@ print ("MSE: %f" % score)
|
|||||||
Example of 3 layer network with 10, 20 and 10 hidden units respectively:
|
Example of 3 layer network with 10, 20 and 10 hidden units respectively:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import tensorflow.contrib.learn.python.learn as learn
|
||||||
from sklearn import datasets, metrics
|
from sklearn import datasets, metrics
|
||||||
|
|
||||||
iris = datasets.load_iris()
|
iris = datasets.load_iris()
|
||||||
@ -99,6 +102,7 @@ print("Accuracy: %f" % score)
|
|||||||
Example of how to pass a custom model to the Estimator:
|
Example of how to pass a custom model to the Estimator:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import tensorflow.contrib.learn.python.learn as learn
|
||||||
from sklearn import datasets, metrics
|
from sklearn import datasets, metrics
|
||||||
|
|
||||||
iris = datasets.load_iris()
|
iris = datasets.load_iris()
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.contrib.learn.python.learn import preprocessing
|
|||||||
from tensorflow.contrib.learn.python.learn import utils
|
from tensorflow.contrib.learn.python.learn import utils
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import *
|
from tensorflow.contrib.learn.python.learn.dataframe import *
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import *
|
from tensorflow.contrib.learn.python.learn.estimators import *
|
||||||
|
from tensorflow.contrib.learn.python.learn.evaluable import Evaluable
|
||||||
from tensorflow.contrib.learn.python.learn.experiment import Experiment
|
from tensorflow.contrib.learn.python.learn.experiment import Experiment
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
|
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import infer
|
from tensorflow.contrib.learn.python.learn.graph_actions import infer
|
||||||
@ -41,4 +42,5 @@ from tensorflow.contrib.learn.python.learn.graph_actions import run_feeds
|
|||||||
from tensorflow.contrib.learn.python.learn.graph_actions import run_n
|
from tensorflow.contrib.learn.python.learn.graph_actions import run_n
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import train
|
from tensorflow.contrib.learn.python.learn.graph_actions import train
|
||||||
from tensorflow.contrib.learn.python.learn.learn_io import *
|
from tensorflow.contrib.learn.python.learn.learn_io import *
|
||||||
|
from tensorflow.contrib.learn.python.learn.trainable import Trainable
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
|
@ -29,11 +29,14 @@ from tensorflow.contrib.learn.python.learn.dataframe.transform import Transform
|
|||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.boolean_mask import BooleanMask
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.boolean_mask import BooleanMask
|
||||||
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.difference import Difference
|
||||||
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.hashes import HashFast
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import NumpySource
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import NumpySource
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import PandasSource
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import PandasSource
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source import ReaderSource
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source import ReaderSource
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.sum import Sum
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.sum import Sum
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=g-import-not-at-top,g-bad-import-order
|
# pylint: disable=g-import-not-at-top,g-bad-import-order
|
||||||
|
|
||||||
# Unary Transform registration
|
# Unary Transform registration
|
||||||
@ -42,9 +45,9 @@ for ut_def in _ut.UNARY_TRANSFORMS:
|
|||||||
_ut.register_unary_op(*ut_def)
|
_ut.register_unary_op(*ut_def)
|
||||||
|
|
||||||
# Comparison Transform registration
|
# Comparison Transform registration
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import compare as _cmp
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms import binary_transforms as _bt
|
||||||
for ct_def in _cmp.COMPARISON_TRANSFORMS:
|
for bt_def in _bt.BINARY_TRANSFORMS:
|
||||||
_cmp.register_comparison_ops(*ct_def)
|
_bt.register_binary_op(*bt_def)
|
||||||
|
|
||||||
__all__ = ['DataFrame', 'Series', 'PredefinedSeries', 'TransformedSeries',
|
__all__ = ['DataFrame', 'Series', 'PredefinedSeries', 'TransformedSeries',
|
||||||
'TensorFlowDataFrame', 'parameter', 'Transform']
|
'TensorFlowDataFrame', 'parameter', 'Transform']
|
||||||
|
@ -117,10 +117,11 @@ class DataFrame(object):
|
|||||||
value = [value]
|
value = [value]
|
||||||
self.assign(**dict(zip(key, value)))
|
self.assign(**dict(zip(key, value)))
|
||||||
|
|
||||||
def build(self):
|
def build(self, **kwargs):
|
||||||
# We do not allow passing a cache here, because that would encourage
|
# We do not allow passing a cache here, because that would encourage
|
||||||
# working around the rule that DataFrames cannot be expected to be
|
# working around the rule that DataFrames cannot be expected to be
|
||||||
# synced with each other (e.g., they shuffle independently).
|
# synced with each other (e.g., they shuffle independently).
|
||||||
cache = {}
|
cache = {}
|
||||||
tensors = {name: c.build(cache) for name, c in self._columns.items()}
|
tensors = {name: c.build(cache, **kwargs)
|
||||||
|
for name, c in self._columns.items()}
|
||||||
return tensors
|
return tensors
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user