Merge pull request #1 from tensorflow/master

Merge to head
This commit is contained in:
Todd Wang 2017-06-20 14:12:11 -07:00 committed by GitHub
commit 3e2093844c
6022 changed files with 520508 additions and 342174 deletions

10
.gitignore vendored
View File

@ -1,15 +1,17 @@
.DS_Store
.ipynb_checkpoints
node_modules
/.bazelrc
/.tf_configure.bazelrc
/bazel-*
/third_party/py/numpy/numpy_include
/tools/bazel.rc
/bazel_pip
/third_party/eigen3/mkl_include
/third_party/mkl/*
/tools/python_bin_path.sh
/tools/git/gen
/util/python/python_include
/util/python/python_lib
/pip_test
/_python_build
*.pyc
__pycache__
*.swp
.vscode/

View File

@ -1,11 +0,0 @@
{
"maxReviewers": 2,
"numFilesToCheck": 10,
"userBlacklist": ["tensorflower-gardener"],
"requiredOrgs": ["tensorflow"],
"skipAlreadyAssignedPR": true,
"skipAlreadyMentionedPR": true,
"skipTitle": "Branch",
"delayed": true,
"delayedUntil": "10m"
}

View File

@ -21,9 +21,151 @@ If you have improvements to TensorFlow, send us your pull requests! For those
just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/).
If you want to contribute but you're not sure where to start, take a look at the
[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/contributions%20welcome).
[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome).
These are issues that we believe are particularly well suited for outside
contributions, often because we probably won't get to them right now. If you
decide to start on an issue, leave a comment so that other people know that
you're working on it. If you want to help out, but not alone, use the issue
comment thread to coordinate.
### Contribution guidelines and standards
Before sending your pull request for
[review](https://github.com/tensorflow/tensorflow/pulls),
make sure your changes are consistent with the guidelines and follow the
TensorFlow coding style.
#### General guidelines and philosophy for contribution
* Include unit tests when you contribute new features, as they help to
a) prove that your code works correctly, b) guard against future breaking
changes to lower the maintenance cost.
* Bug fixes also generally require unit tests, because the presence of bugs
usually indicates insufficient test coverage.
* Keep API compatibility in mind when you change code in core TensorFlow,
e.g., code in [tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core) and [tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python).
TensorFlow has reached version 1 and hence cannot make
non-backward-compatible API changes without a major release. Reviewers of your
pull request will comment on any API compatibility issues.
* When you contribute a new feature to TensorFlow, the maintenance burden is (by
default) transferred to the TensorFlow team. This means that benefit of
contribution must be compared against the cost of maintaining the feature.
* Full new features (e.g., a new op implementing a cutting-edge algorithm)
typically will live in
[tensorflow/contrib](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib)
to get some airtime before decision is made regarding whether they are to be
migrated to the core.
#### License
Include a license at the top of new files.
* [C/C++ license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op.cc#L1)
* [Python license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn.py#L1)
* [Java license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/Graph.java#L1)
* [Go license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/operation.go#L1)
* [Bash license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/ci_build/ci_sanity.sh#L2)
* [HTML license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/dist/index.html#L2)
* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/components/tf_backend/backend.ts#L1)
Bazel BUILD files also need to include a license section, e.g.,
[BUILD example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/BUILD#L61).
#### C++ coding style
Changes to TensorFlow C++ code should conform to
[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do:
```bash
apt-get install -y clang-tidy
```
You can check a C/C++ file by doing:
```bash
clang-format <my_cc_file> --style=google > /tmp/my_cc_file.cc
diff <my_cc_file> /tmp/my_cc_file.cc
```
#### Python coding style
Changes to TensorFlow Python code should conform to
[Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
Use `pylint` to check your Python changes. To install `pylint` and
retrieve TensorFlow's custom style definition:
```bash
pip install pylint
wget -O /tmp/pylintrc https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc
```
To check a file with `pylint`:
```bash
pylint --rcfile=/tmp/pylintrc myfile.py
```
#### Coding style for other languages
* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html)
* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html)
* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml)
#### Running sanity check
If you have Docker installed on your system, you can perform a sanity check on
your changes by running the command:
```bash
tensorflow/tools/ci_build/ci_build.sh CPU tensorflow/tools/ci_build/ci_sanity.sh
```
This will catch most license, Python coding style and BUILD file issues that
may exist in your changes.
#### Running unit tests
There are two ways to run TensorFlow unit tests.
1. Using tools and libraries installed directly on your system.
Refer to the
[CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) and
[GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu)
for the required packages. Alternatively, use the said
[Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g.,
`tensorflow/tensorflow:nightly-devel` and `tensorflow/tensorflow:nightly-devel-gpu`
for development to avoid installing the packages directly on your system.
Once you have the packages installed, you can run a specific unit test in
bazel by doing as follows:
If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add
the `cuda` option flag
```bash
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
export flags="--config=opt --config=cuda -k"
```
For example, to run all tests under tensorflow/python, do:
```bash
bazel test ${flags} //tensorflow/python/...
```
2. Using [Docker](www.docker.com) and TensorFlow's CI scripts.
```bash
# Install Docker first, then this will build and run cpu tests
tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/...
```
See
[TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) for details.

View File

@ -1,36 +1,37 @@
NOTE: Only file GitHub issues for bugs and feature requests. All other topics will be closed.
Please go to Stack Overflow for help and support:
For general support from the community, see [StackOverflow](https://stackoverflow.com/questions/tagged/tensorflow).
To make bugs and feature requests more easy to find and organize, we close issues that are deemed
out of scope for GitHub Issues and point people to StackOverflow.
http://stackoverflow.com/questions/tagged/tensorflow
For bugs or installation issues, please provide the following information.
The more information you provide, the more easily we will be able to offer
help and advice.
If you open a GitHub issue, here is our policy:
### What related GitHub issues or StackOverflow threads have you found by searching the web for your problem?
1. It must be a bug or a feature request.
2. The form below must be filled out.
3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorflow/issues).
### Environment info
Operating System:
**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
Installed version of CUDA and cuDNN:
(please attach the output of `ls -l /path/to/cuda/lib/libcud*`):
------------------------
If installed from binary pip package, provide:
### System information
- **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**:
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
- **TensorFlow installed from (source or binary)**:
- **TensorFlow version (use command below)**:
- **Bazel version (if compiling from source)**:
- **CUDA/cuDNN version**:
- **GPU model and memory**:
- **Exact command to reproduce**:
1. A link to the pip package you installed:
2. The output from `python -c "import tensorflow; print(tensorflow.__version__)"`.
You can collect some of this information using our environment capture script:
If installed from source, provide
https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
1. The commit hash (`git rev-parse HEAD`)
2. The output of `bazel version`
You can obtain the TensorFlow version with
### If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code)
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
### Describe the problem
Describe the problem clearly here. Be sure to convey here why it's a bug in TensorFlow or a feature request.
### What other attempted solutions have you tried?
### Logs or other output that would be helpful
(If logs are large, please upload as attachment or provide link).
### Source code / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem.

View File

@ -1,6 +1,7 @@
<div align="center">
<img src="https://www.tensorflow.org/images/tf_logo_transp.png"><br><br>
</div>
-----------------
| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
@ -25,19 +26,20 @@ guidelines](CONTRIBUTING.md).**
**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for
tracking requests and bugs, but please see
[Community](tensorflow/g3doc/resources/index.md#community) for general questions
[Community](https://www.tensorflow.org/community/) for general questions
and discussion.**
## Installation
*See [Download and Setup](tensorflow/g3doc/get_started/os_setup.md) for instructions on how to install our release binaries or how to build from source.*
*See [Installing TensorFlow](https://www.tensorflow.org/install/) for instructions on how to install our release binaries or how to build from source.*
People who are a little more adventurous can also try our nightly binaries:
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
* [Android](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.2.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.2.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.2.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.2.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
@ -50,7 +52,7 @@ $ python
>>> hello = tf.constant('Hello, TensorFlow!')
>>> sess = tf.Session()
>>> sess.run(hello)
Hello, TensorFlow!
'Hello, TensorFlow!'
>>> a = tf.constant(10)
>>> b = tf.constant(32)
>>> sess.run(a+b)
@ -58,11 +60,11 @@ Hello, TensorFlow!
>>>
```
##For more information
## For more information
* [TensorFlow website](http://tensorflow.org)
* [TensorFlow website](https://tensorflow.org)
* [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf)
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
The TensorFlow community has created amazing things with TensorFlow, please see the [resources section of tensorflow.org](https://www.tensorflow.org/versions/master/resources#community) for an incomplete list.
The TensorFlow community has created amazing things with TensorFlow, please see the [resources section of tensorflow.org](https://www.tensorflow.org/about/#community) for an incomplete list.

View File

@ -1,3 +1,312 @@
# Release 1.2.0
## Major Features and Improvements
* Python 3.6 support on Windows.
* Added `tf.layers.conv3d_transpose` layer for spatio temporal deconvolution.
* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times.
* Added libverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
* Bring `tf.feature_column.*` into the API. Non-deprecated functionality from `tf.contrib.layers.*` is moved to `tf.feature_column.*` with cosmetic changes.
* `RNNCell` objects now subclass `tf.layers.Layer`. The strictness described
in the TensorFlow 1.1 release is gone: The first time an RNNCell is used,
it caches its scope. All future uses of the RNNCell will reuse variables from
that same scope. This is a breaking change from the behavior of RNNCells
in TensorFlow versions <= 1.0.1. TensorFlow 1.1 had checks in place to
ensure old code works correctly with the new semantics; this version
allows more flexible uses of RNNCell but can lead to subtle errors if
using code meant for TensorFlow <= 1.0.1. For example, writing:
`MultiRNNCell([lstm] * 5)` will now build a 5-layer LSTM stack where each
layer shares the **same** parameters. To get 5 layers each with their own
parameters, write: `MultiRNNCell([LSTMCell(...) for _ in range(5)])`.
If at all unsure, first test your code with TF 1.1; ensure it raises no
errors, and then upgrade to TF 1.2.
* RNNCells' variable names have been renamed for consistency with Keras layers.
Specifically, the previous variable names "weights" and "biases" have
been changed to "kernel" and "bias", respectively.
This may cause backward incompatibility with regard to your old
checkpoints containing such RNN cells, in which case you can use the tool
[checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py)
to convert the variable names in your old checkpoints.
* Many of the RNN functions and classes that were in the `tf.nn` namespace
before the 1.0 release and which were moved to `tf.contrib.rnn` have now
been moved back to the core namespace. This includes
`RNNCell`, `LSTMCell`, `GRUCell`, and a number of other cells. These
now reside in `tf.nn.rnn_cell` (with aliases in `tf.contrib.rnn` for backwards
compatibility). The original `tf.nn.rnn` function is now `tf.nn.static_rnn`,
and the bidirectional static and state saving static rnn functions are also
now back in the `tf.nn` namespace.
Notable exceptions are the `EmbeddingWrapper`, `InputProjectionWrapper` and
`OutputProjectionWrapper`, which will slowly be moved to deprecation
in `tf.contrib.rnn`. These are inefficient wrappers that should often
be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post-
processing of the rnn. For RNN decoding, this functionality has been replaced
with an alternative API in `tf.contrib.seq2seq`.
* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of
optimized deep learning primitives: In addition to matrix multiplication and
convolution, these building blocks include:
Direct batched convolution
Pooling: maximum, minimum, average
Normalization: LRN, batch normalization
Activation: rectified linear unit (ReLU)
Data manipulation: multi-dimensional transposition (conversion), split,
concat, sum and scale.
* TensorForest Estimator now supports SavedModel export for serving.
* Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters.
* TensorFlow C library now available for Windows.
* We released a new open-source version of TensorBoard.
* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/programmers_guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel
* Android releases of TensorFlow are now pushed to jcenter for easier
integration into apps. See
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/android/README.md
for more details.
* RNNCells' variable names have been renamed for consistency with Keras layers.
Specifically, the previous variable names "weights" and "biases" have
been changed to "kernel" and "bias", respectively.
This may cause backward incompatibility with regard to your old
checkpoints containing such RNN cells, in which case you can use the tool
[checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py)
to convert the variable names in your old checkpoints.
* Many of the RNN functions and classes that were in the `tf.nn` namespace
before the 1.0 release and which were moved to `tf.contrib.rnn` have now
been moved back to the core namespace. This includes
`RNNCell`, `LSTMCell`, `GRUCell`, and a number of other cells. These
now reside in `tf.nn.rnn_cell` (with aliases in `tf.contrib.rnn` for backwards
compatibility). The original `tf.nn.rnn` function is now `tf.nn.static_rnn`,
and the bidirectional static and state saving static rnn functions are also
now back in the `tf.nn` namespace.
Notable exceptions are the `EmbeddingWrapper`, `InputProjectionWrapper` and
`OutputProjectionWrapper`, which will slowly be moved to deprecation
in `tf.contrib.rnn`. These are inefficient wrappers that should often
be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post-
processing of the rnn. For RNN decoding, this functionality has been replaced
with an alternative API in `tf.contrib.seq2seq`.
* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of
optimized deep learning primitives: In addition to matrix multiplication and
convolution, these building blocks include:
Direct batched convolution
Pooling: maximum, minimum, average
Normalization: LRN, batch normalization
Activation: rectified linear unit (ReLU)
Data manipulation: multi-dimensional transposition (conversion), split,
concat, sum and scale.
## Deprecations
* TensorFlow 1.2 may be the last time we build with cuDNN 5.1. Starting with
TensorFlow 1.3, we will try to build all our prebuilt binaries with cuDNN 6.0.
While we will try to keep our source code compatible with cuDNN 5.1, it will
be best effort.
## Breaking Changes to the API
* `org.tensorflow.contrib.android.TensorFlowInferenceInterface` now throws exceptions where possible and has simplified method signatures.
## Changes to contrib APIs
* Added `tf.contrib.util.create_example`.
* Added bilinear interpolation to `tf.contrib.image`.
* Add `tf.contrib.stateless` for random ops with custom seed control.
* MultivariateNormalFullCovariance added to contrib/distributions/
* tensorflow/contrib/rnn undergoes RNN cell variable renaming for
consistency with Keras layers. Specifically, the previous variable names
"weights" and "biases" are changed to "kernel" and "bias", respectively.
This may cause backward incompatibility with regard to your old
checkpoints containing such RNN cells, in which case you can use the
[checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py)
to convert the variable names in your old checkpoints.
* Added `tf.contrib.kernel_methods` module with Ops and estimators for primal
(explicit) kernel methods in TensorFlow.
## Bug Fixes and Other Changes
* In python, `Operation.get_attr` on type attributes returns the Python DType
version of the type to match expected get_attr documentation rather than the
protobuf enum.
* tensorflow/contrib/rnn undergoes RNN cell variable renaming for
consistency with Keras layers. Specifically, the previous variable names
"weights" and "biases" are changed to "kernel" and "bias", respectively.
* Changed MIN_SDK version to 8.0 when building iOS libraries.
* Fixed LIBXSMM integration.
* Make decode_jpeg/decode_png/decode_gif handle all formats, since users frequently try to decode an image as the wrong type.
* Improve implicit broadcasting lowering.
* Improving stability of GCS/Bigquery clients by a faster retrying of stale transmissions.
* Remove OpKernelConstruction::op_def() as part of minimizing proto dependencies.
* VectorLaplaceDiag distribution added.
* Android demo no longer requires libtensorflow_demo.so to run (libtensorflow_inference.so still required)
* Added `categorical_column_with_vocabulary_file`.
* Introduce ops for batching/unbatching tensors across Session::Run() calls.
* Add tf.log_sigmoid(x) = tf.log(tf.sigmoid(x)) = -tf.nn.softplus(-x).
* Changed hooks lists to immutable tuples, and now allow any iterable for the associated arguments.
* Introduce TFDecorator.
* Added an Mfcc op for speech feature generation.
* Improved DirectSession::Run() overhead and error checking. Feeding a value of the wrong type will now synchronously raise an INVALID_ARGUMENT error instead of asynchronously raising an INTERNAL error. Code that depends on the (undefined) behavior when feeding a tensor of the wrong type may need to be updated.
* Added unreduced NONE, and reduced MEAN options for losses. Removed "WEIGHTED_" prefix from other Reduction constants.
* assertAllClose now handles dicts.
* Added Gmock matcher for HloInstructions.
* Add var name to errors on variable restore.
* Added an AudioSpectrogram op for audio feature generation.
* Added `reduction` arg to losses.
* `tf.placeholder` can represent scalar shapes and partially known.
* Remove estimator_spec(mode) argument.
* Added an AudioSpectrogram op for audio feature generation.
* TensorBoard disables all runs by default if there are more than 40 runs.
* Removed old doc generator code.
* GCS file system integration now supports domain buckets, e.g gs://bucket.domain.com/path.
* Add `tf.summary.text` for outputting text to TensorBoard.
* The "run" command of tfdbg's command-line interface now supports filtering of tensors by node name, op type and tensor dtype.
* `tf.string_to_number` now supports int64 and float64 outputs.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
4F2E4A2E, Aaron Schumacher, Abhi Agg, admcrae, Adriano Carmezim, Adrià Arrufat,
agramesh1, Akimitsu Seo, Alan Mosca, Alex Egg, Alex Rothberg, Alexander Heinecke,
Alexander Matyasko, Alexandr Baranezky, Alexandre Caulier, Ali Siddiqui, Anand Venkat,
Andrew Hundt, Androbin, Anmol Sharma, Arie, Arno Leist, Arron Cao, AuréLien Geron, Bairen Yi,
Beomsu Kim, Carl Thomé, cfperez, Changming Sun, Corey Wharton, critiqjo, Dalei Li, Daniel
Rasmussen, Daniel Trebbien, DaríO Hereñú, David Eng, David Norman, David Y. Zhang, Davy Song, ddurham2,
Deepak Subburam, Dmytro Kyrychuk, Dominic Rossi, Dominik SchlöSser, Dustin Tran,
Eduardo Pinho, Egil Martinsson, Elliot Saba, Eric Bigelow, Erik Smistad, Evan Klitzke,
Fabrizio Milo, Falcon Dai, Fei Gao, FloopCZ, Fung Lam, Gautam, GBLin5566, Greg Peatfield,
Gu Wang, Guenther Schmuelling, Hans Pabst, Harun Gunaydin, Huaizheng, Ido Shamay, Ikaro
Silva, Ilya Edrenkin, Immexxx, James Mishra, Jamie Cooke, Jay Young, Jayaram Bobba,
Jianfei Wang, jinghua2, Joey Meyer, John Maidens, Jonghoon Jin, Julian Villella,
Jun Kim, Jun Shi, Junwei Pan, jyegerlehner, Karan Desai, Karel Van De Plassche,
Kb Sriram, KhabarlakKonstantin, Koan-Sin Tan, krivard, Kwotsin, Leandro Gracia Gil,
Li Chen, Liangliang He, Louie Helm, lspvic, Luiz Henrique Soares, LáSzló Csomor,
Mark Wong, Mathew Wicks, Matthew Rahtz, Maxwell Paul Brickner, Michael Hofmann, Miguel
Flores Ruiz De Eguino, MikeTam1021, Mortada Mehyar, Mycosynth, Namnamseo,
Nate Harada, Neven Miculinic, Nghia Tran, Nick Lyu, Niranjan Hasabnis, Nishidha, Oleksii
Kuchaiev, Oyesh Mann Singh, Panmari, Patrick, Paul Van Eck, Piyush Chaudhary, Quim Llimona,
Raingo, Richard Davies, Ruben Vereecken, Sahit Chintalapudi, Sam Abrahams, Santiago Castro,
Scott Sievert, Sean O'Keefe, Sebastian Schlecht, Shane, Shubhankar Deshpande, Spencer Schaber,
Sunyeop Lee, t13m, td2014, Thomas H. P. Andersen, Toby Petty, Umang Mehta,
Vadim Markovtsev, Valentin Iovene, Vincent Zhao, Vit Stepanovs, Vivek Rane, Vu Pham, wannabesrevenge,
weipingpku, wuhaixutab, wydwww, Xiang Gao, Xiaolin Lin, xiaoyaozhuzi, Yaroslav Bulatov, Yi Liu,
Yoshihiro Sugi, Yuan (Terry) Tang, Yuming Wang, Yuxin Wu, Zader Zheng, Zhaojun Zhang, zhengjiajin,
ZhipengShen, Ziming Dong, zjj2wry
We are also grateful to all who filed issues or helped resolve them, asked and
answered questions, and were part of inspiring discussions.
# Release 1.1.0
## Major Features and Improvements
* Added Java API support for Windows.
* Added `tf.spectral` module. Moved existing FFT ops to `tf.spectral` while
keeping an alias in the old location (`tf.*`).
* Added 1D, 2D and 3D Fourier transform ops for real signals to `tf.spectral`.
* Added a `tf.bincount` function.
* Added Keras 2 API to contrib.
* Added a new lightweight queue-like object - `RecordInput`.
* Added `tf.contrib.image.compose_transforms` function.
* Bring `tf.estimator.*` into the API. Non-deprecated functionality from `tf.contrib.learn.Estimator` is moved to `tf.estimator.Estimator` with cosmetic changes.
* Docker images: TF images on gcr.io and Docker Hub are upgraded to ubuntu:16.04.
* Added the following features to TensorFlow Debugger (tfdbg):
* Ability to inspect Python source file against TF ops and tensors (command `print_source` / `ps`)
* New navigation bar in Curses-based UI
* NodeStepper (command `invoke_stepper`) now uses intermediate tensor dumps. It also uses `TensorHandles` as direct feeds during successive `cont` calls for improved performance and reduced memory consumption.
* Initial release of installation guides for Java, C, and Go.
* Added Text Dashboard to TensorBoard.
## Deprecations
* TensorFlow 1.1.0 will be the last time we release a binary with Mac GPU support. Going forward, we will stop testing on Mac GPU systems. We continue to welcome patches that maintain Mac GPU support, and we will try to keep the Mac GPU build working.
## Changes to contrib APIs
* The behavior of RNNCells is now stricter due to the transition towards making RNNCells act more like Keras layers.
* If an RNNCell is used twice in two different variable scopes, an error is raised describing how to avoid this behavior.
* If an RNNCell is used in a variable scope with existing conflicting variables, an error is raised showing that the RNNCell must be constructed with argument `reuse=True`.
* Deprecated contrib/distributions `pmf`, `pdf`, `log_pmf`, `log_pdf`.
* Moved `bayesflow.special_math` to distributions.
* `tf.contrib.tensor_forest.python.tensor_forest.RandomForestDeviceAssigner` removed.
* Changed some MVN classes and parameters:
* `tf.contrib.distributions.MultivariateNormalFull` replaced by `tf.contrib.distributions.MultivariateNormalTriL`.
* `tf.contrib.distributions.MultivariateNormalCholesky` replaced by `tf.contrib.distributions.MultivariateNormalTriL`
* `tf.contrib.distributions.MultivariateNormalDiagWithSoftplusStDev` replaced
by `tf.contrib.distributions.MultivariateNormalDiagWithSoftplusScale`
* `tf.contrib.distributions.MultivariateNormalDiag` arguments changed from `mu`, `diag_stddev` to `log`, `scale_diag`.
* `tf.contrib.distributions.MultivariateNormalDiagPlusVDVT` removed.
* `tf.contrib.distributions.MultivariateNormalDiagPlusLowRank` added.
## Bug Fixes and Other Changes
* Java: Support for loading models exported using the SavedModel API (courtesy @EronWright).
* Go: Added support for incremental graph execution.
* Fix a bug in the WALS solver when single-threaded.
* Added support for integer sparse feature values in `tf.contrib.layers.sparse_column_with_keys`.
* Fixed `tf.set_random_seed(0)` to be deterministic for all ops.
* Stability improvements for the GCS file system support.
* Improved TensorForest performance.
* Added support for multiple filename globs in `tf.matching_files`.
* `LogMessage` now includes a timestamp as beginning of a message.
* Added MultiBox person detector example standalone binary.
* Android demo: Makefile build functionality added to build.gradle to fully support building TensorFlow demo in Android on Windows.
* Android demo: read MultiBox priors from txt file rather than protobuf.
* Added colocation constraints to `StagingArea`.
* `sparse_matmul_op` reenabled for Android builds.
* Restrict weights rank to be the same as the broadcast target, to avoid ambiguity on broadcast rules.
* Upgraded libxsmm to 1.7.1 and applied other changes for performance and memory usage.
* Fixed bfloat16 integration of LIBXSMM sparse mat-mul.
* Improved performance and reduce memory usage by allowing ops to forward input buffers to output buffers and perform computations in-place.
* Improved the performance of CPU assignment for strings.
* Speed up matrix * vector multiplication and matrix * matrix with unknown shapes.
* C API: Graph imports now support input remapping, control dependencies, and returning imported nodes (see `TF_GraphImportGraphDefWithReturnOutputs()`)
* Multiple C++ API updates.
* Multiple TensorBoard updates including:
* Users can now view image summaries at various sampled steps (instead of just the last step).
* Bugs involving switching runs as well as the image dashboard are fixed.
* Removed data download links from TensorBoard.
* TensorBoard uses a relative data directory, for easier embedding.
* TensorBoard automatically ignores outliers for domain calculation, and formats proportional values consistently.
* Multiple tfdbg bug fixes:
* Fixed Windows compatibility issues.
* Command history now persists across runs.
* Bug fix in graph validation related to `tf.while_loops`.
* Java Maven fixes for bugs with Windows installation.
* Backport fixes and improvements from external keras.
* Keras config file handling fix.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
A. Besir Kurtulmus, Adal Chiriliuc, @akash, Alec-Desouza, Alex Rothberg, Alex
Sergeev, Alexander Heinecke, Allen Guo, Andreas Madsen, Ankesh Anand, Anton
Loss, @Aravind, @Arie, Ashutosh Das, AuréLien Geron, Bairen Yi, @bakunyo, Ben
Visser, Brady Zhou, Calpa Liu, Changming Sun, Chih Cheng Liang, Christopher
Berner, Clark Zinzow, @Conchylicultor, Dan Ellis, Dan J, Dan Jarvis, Daniel
Ylitalo, Darren Garvey, David Norman, David Truong, @DavidNorman, Dimitar
Pavlov, Dmitry Persiyanov, @Eddie, @elirex, Erfan Noury, Eron Wright, Evgeny
Mazovetskiy, Fabrizio (Misto) Milo, @fanlu, Fisher Coder, Florian Courtial,
Franck Dernoncourt, Gagan Goel, Gao, Xiang, @Gautam, Gefu Tang, @guilherme,
@guschmue, Hannah Provenza, Hans Pabst, @hartb, Hsiao Yi, Huazuo Gao, Igor
ChorążEwicz, Ivan Smirnov, Jakub Kolodziejczyk, Jason Gavris, Jason Morton, Jay
Young, Jayaram Bobba, Jeremy Sawruk, Jiaming Liu, Jihun Choi, @jiqiu, Joan Thibault,
John C F, Jojy George Varghese, Jon Malmaud, Julian Berman, Julian Niedermeier,
Junpeng Lao, Kai Sasaki, @Kankroc, Karl Lessard, Kyle Bostelmann, @Lezcano, Li
Yi, Luo Yun, @lurker, Mahmoud-Abuzaina, Mandeep Singh, Marek Kolodziej, Mark
Szepieniec, Martial Hue, Medhat Omr, Memo Akten, Michael Gharbi, MichaëL Defferrard,
Milan Straka, @MircoT, @mlucool, Muammar Ibn Faisal, Nayana Thorat, @nghiattran,
Nicholas Connor, Nikolaas Steenbergen, Niraj Patel, Niranjan Hasabnis, @Panmari,
Pavel Bulanov, Philip Pries Henningsen, Philipp Jund, @polonez, Prayag Verma, Rahul
Kavi, Raphael Gontijo Lopes, @rasbt, Raven Iqqe, Reid Pryzant, Richard Shin, Rizwan
Asif, Russell Kaplan, Ryo Asakura, RüDiger Busche, Saisai Shao, Sam Abrahams, @sanosay,
Sean Papay, @seaotterman, @selay01, Shaurya Sharma, Sriram Narayanamoorthy, Stefano
Probst, @taknevski, @tbonza, @teldridge11, Tim Anglade, Tomas Reimers, Tomer Gafner,
Valentin Iovene, Vamsi Sripathi, Viktor Malyi, Vit Stepanovs, Vivek Rane, Vlad Firoiu,
@wangg12, @will, Xiaoyu Tao, Yaroslav Bulatov, Yi Liu, Yuan (Terry) Tang, @Yufeng,
Yuming Wang, Yuxin Wu, Zafar Takhirov, Ziming Dong
We are also grateful to all who filed issues or helped resolve them, asked and
answered questions, and were part of inspiring discussions.
# Release 1.0.1
## Bug Fixes and Other Changes
* Change GraphConstructor to not increase the version when importing, but instead take the min of all versions.
* Google Cloud Storage fixes.
* Removed `tf.core` and `tf.python` modules from the API. These were never intended to be exposed. Please use the same objects through top-level `tf` module instead.
# Release 1.0.0
## Major Features and Improvements
@ -87,8 +396,12 @@ To help you upgrade your existing TensorFlow Python code to match the API change
* In the C++ API (in tensorflow/cc), Input, Output, etc. have moved
from the tensorflow::ops namespace to tensorflow.
* Change arg order for `{softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits` to be (labels, predictions), and force use of named args.
* tf.nn.rnn_cell.* and most functions in tf.nn.rnn.* (with the exception of dynamic_rnn and raw_rnn) are temporarily in tf.contrib.rnn. They will be moved back into core for TF 1.2.
* `tf.nn.sampled_softmax_loss` and `tf.nn.nce_loss` have both changed their API such that you need to switch the `inputs, labels` to `labels, inputs` parameters.
* The shape keyword argument of the `SparseTensor` constructor changes its name to `dense_shape` between Tensorflow 0.12 and Tensorflow 1.0.
## Bug Fixes and Other Changes
* Numerous C++ API updates.
* New op: `parallel_stack`.
* Introducing common tf io compression options constants for
RecordReader/RecordWriter.
@ -127,6 +440,7 @@ To help you upgrade your existing TensorFlow Python code to match the API change
* `tf.divide` now honors the name field.
* Make metrics weight broadcasting more strict.
* Add new queue-like `StagingArea` and new ops: `stage` and `unstage`.
* Enable inplace update ops for strings on CPU. Speed up string concat.
## Thanks to our Contributors
@ -193,7 +507,7 @@ answered questions, and were part of inspiring discussions.
indexing now starts from 1 instead of 0, and `bus_id==0` is used where
previously `BUS_ANY` was used.
* `Env::FileExists` and `FileSystem::FileExists` now return a tensorflow::Status
intead of a bool. Any callers to this function can be converted to a bool
instead of a bool. Any callers to this function can be converted to a bool
by adding .ok() to the call.
* The C API type `TF_SessionWithGraph` has been renamed to `TF_Session`,
indicating its preferred use in language bindings for TensorFlow.
@ -212,7 +526,7 @@ answered questions, and were part of inspiring discussions.
* `SparseTensor.shape` has been renamed to `SparseTensor.dense_shape`. Same for
`SparseTensorValue.shape`.
* `Env::FileExists` and `FileSystem::FileExists` now return a
`tensorflow::Status` intead of a bool. Any callers to this function can be
`tensorflow::Status` instead of a bool. Any callers to this function can be
converted to a bool by adding `.ok()` to the call.
* C API: Type `TF_SessionWithGraph` has been renamed to `TF_Session`, indicating
its preferred use in language bindings for TensorFlow. What was previously

520
WORKSPACE
View File

@ -2,11 +2,11 @@ workspace(name = "org_tensorflow")
http_archive(
name = "io_bazel_rules_closure",
sha256 = "60fc6977908f999b23ca65698c2bb70213403824a84f7904310b6000d78be9ce",
strip_prefix = "rules_closure-5ca1dab6df9ad02050f7ba4e816407f88690cf7d",
sha256 = "bc41b80486413aaa551860fc37471dbc0666e1dbb5236fb6177cb83b0c105846",
strip_prefix = "rules_closure-dec425a4ff3faf09a56c85d082e4eed05d8ce38f",
urls = [
"http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz", # 2017-02-03
"https://github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz",
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dec425a4ff3faf09a56c85d082e4eed05d8ce38f.tar.gz", # 2017-06-02
"https://github.com/bazelbuild/rules_closure/archive/dec425a4ff3faf09a56c85d082e4eed05d8ce38f.tar.gz",
],
)
@ -14,510 +14,56 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
closure_repositories()
load("//tensorflow:workspace.bzl", "check_version", "tf_workspace")
# We must check the bazel version before trying to parse any other BUILD files,
# in case the parsing of those build files depends on the bazel version we
# require here.
check_version("0.4.2")
load("//tensorflow:workspace.bzl", "tf_workspace")
# Uncomment and update the paths in these entries to build the Android demo.
#android_sdk_repository(
# name = "androidsdk",
# api_level = 23,
# build_tools_version = "23.0.1",
# # Ensure that you have the build_tools_version below installed in the
# # SDK manager as it updates periodically.
# build_tools_version = "25.0.2",
# # Replace with path to Android SDK on your system
# path = "<PATH_TO_SDK>",
#)
#
# Android NDK r12b is recommended (higher may cause issues with Bazel)
#android_ndk_repository(
# name="androidndk",
# path="<PATH_TO_NDK>",
# api_level=21)
# # This needs to be 14 or higher to compile TensorFlow.
# # Note that the NDK version is not the API level.
# api_level=14)
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_workspace()
new_http_archive(
name = "inception5h",
build_file = "models.BUILD",
url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip",
sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364"
name = "inception5h",
build_file = "models.BUILD",
sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip",
"http://download.tensorflow.org/models/inception5h.zip",
],
)
new_http_archive(
name = "mobile_multibox",
build_file = "models.BUILD",
url = "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96"
name = "mobile_multibox",
build_file = "models.BUILD",
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
"http://download.tensorflow.org/models/mobile_multibox_v1a.zip",
],
)
new_http_archive(
name = "stylize",
build_file = "models.BUILD",
url = "https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa"
)
# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT
new_http_archive(
name = "d3",
build_file = "bower.BUILD",
url = "https://github.com/mbostock-bower/d3-bower/archive/v3.5.15.tar.gz",
strip_prefix = "d3-bower-3.5.15",
)
new_http_archive(
name = "dagre",
build_file = "bower.BUILD",
url = "https://github.com/cpettitt/dagre/archive/v0.7.4.tar.gz",
strip_prefix = "dagre-0.7.4",
)
new_http_archive(
name = "es6_promise",
build_file = "bower.BUILD",
url = "https://github.com/components/es6-promise/archive/v2.1.0.tar.gz",
strip_prefix = "es6-promise-2.1.0",
)
new_http_archive(
name = "font_roboto",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/font-roboto/archive/v1.0.1.tar.gz",
strip_prefix = "font-roboto-1.0.1",
)
new_http_archive(
name = "graphlib",
build_file = "bower.BUILD",
url = "https://github.com/cpettitt/graphlib/archive/v1.0.7.tar.gz",
strip_prefix = "graphlib-1.0.7",
)
new_http_archive(
name = "iron_a11y_announcer",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-a11y-announcer/archive/v1.0.5.tar.gz",
strip_prefix = "iron-a11y-announcer-1.0.5",
)
new_http_archive(
name = "iron_a11y_keys_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-a11y-keys-behavior/archive/v1.1.8.tar.gz",
strip_prefix = "iron-a11y-keys-behavior-1.1.8",
)
new_http_archive(
name = "iron_ajax",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-ajax/archive/v1.2.0.tar.gz",
strip_prefix = "iron-ajax-1.2.0",
)
new_http_archive(
name = "iron_autogrow_textarea",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-autogrow-textarea/archive/v1.0.12.tar.gz",
strip_prefix = "iron-autogrow-textarea-1.0.12",
)
new_http_archive(
name = "iron_behaviors",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-behaviors/archive/v1.0.17.tar.gz",
strip_prefix = "iron-behaviors-1.0.17",
)
new_http_archive(
name = "iron_checked_element_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-checked-element-behavior/archive/v1.0.4.tar.gz",
strip_prefix = "iron-checked-element-behavior-1.0.4",
)
new_http_archive(
name = "iron_collapse",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-collapse/archive/v1.0.8.tar.gz",
strip_prefix = "iron-collapse-1.0.8",
)
new_http_archive(
name = "iron_dropdown",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-dropdown/archive/v1.4.0.tar.gz",
strip_prefix = "iron-dropdown-1.4.0",
)
new_http_archive(
name = "iron_fit_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-fit-behavior/archive/v1.2.5.tar.gz",
strip_prefix = "iron-fit-behavior-1.2.5",
)
new_http_archive(
name = "iron_flex_layout",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-flex-layout/archive/v1.3.0.tar.gz",
strip_prefix = "iron-flex-layout-1.3.0",
)
new_http_archive(
name = "iron_form_element_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-form-element-behavior/archive/v1.0.6.tar.gz",
strip_prefix = "iron-form-element-behavior-1.0.6",
)
new_http_archive(
name = "iron_icon",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-icon/archive/v1.0.11.tar.gz",
strip_prefix = "iron-icon-1.0.11",
)
new_http_archive(
name = "iron_icons",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-icons/archive/v1.1.3.tar.gz",
strip_prefix = "iron-icons-1.1.3",
)
new_http_archive(
name = "iron_iconset_svg",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-iconset-svg/archive/v1.1.0.tar.gz",
strip_prefix = "iron-iconset-svg-1.1.0",
)
new_http_archive(
name = "iron_input",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-input/archive/1.0.10.tar.gz",
strip_prefix = "iron-input-1.0.10",
)
new_http_archive(
name = "iron_list",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-list/archive/v1.3.9.tar.gz",
strip_prefix = "iron-list-1.3.9",
)
new_http_archive(
name = "iron_menu_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-menu-behavior/archive/v1.1.10.tar.gz",
strip_prefix = "iron-menu-behavior-1.1.10",
)
new_http_archive(
name = "iron_meta",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-meta/archive/v1.1.1.tar.gz",
strip_prefix = "iron-meta-1.1.1",
)
new_http_archive(
name = "iron_overlay_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-overlay-behavior/archive/v1.10.1.tar.gz",
strip_prefix = "iron-overlay-behavior-1.10.1",
)
new_http_archive(
name = "iron_range_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-range-behavior/archive/v1.0.4.tar.gz",
strip_prefix = "iron-range-behavior-1.0.4",
)
new_http_archive(
name = "iron_resizable_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-resizable-behavior/archive/v1.0.3.tar.gz",
strip_prefix = "iron-resizable-behavior-1.0.3",
)
new_http_archive(
name = "iron_scroll_target_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz",
strip_prefix = "iron-scroll-target-behavior-1.0.3",
)
new_http_archive(
name = "iron_selector",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-selector/archive/v1.5.2.tar.gz",
strip_prefix = "iron-selector-1.5.2",
)
new_http_archive(
name = "iron_validatable_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-validatable-behavior/archive/v1.1.1.tar.gz",
strip_prefix = "iron-validatable-behavior-1.1.1",
)
new_http_archive(
name = "lodash",
build_file = "bower.BUILD",
url = "https://github.com/lodash/lodash/archive/3.8.0.tar.gz",
strip_prefix = "lodash-3.8.0",
)
new_http_archive(
name = "neon_animation",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/neon-animation/archive/v1.2.2.tar.gz",
strip_prefix = "neon-animation-1.2.2",
)
http_file(
name = "numericjs_numeric_min_js",
url = "https://cdnjs.cloudflare.com/ajax/libs/numeric/1.2.6/numeric.min.js",
)
new_http_archive(
name = "paper_behaviors",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-behaviors/archive/v1.0.12.tar.gz",
strip_prefix = "paper-behaviors-1.0.12",
)
new_http_archive(
name = "paper_button",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-button/archive/v1.0.11.tar.gz",
strip_prefix = "paper-button-1.0.11",
)
new_http_archive(
name = "paper_checkbox",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-checkbox/archive/v1.4.0.tar.gz",
strip_prefix = "paper-checkbox-1.4.0",
)
new_http_archive(
name = "paper_dialog",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-dialog/archive/v1.0.4.tar.gz",
strip_prefix = "paper-dialog-1.0.4",
)
new_http_archive(
name = "paper_dialog_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-dialog-behavior/archive/v1.2.5.tar.gz",
strip_prefix = "paper-dialog-behavior-1.2.5",
)
new_http_archive(
name = "paper_dialog_scrollable",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-dialog-scrollable/archive/1.1.5.tar.gz",
strip_prefix = "paper-dialog-scrollable-1.1.5",
)
new_http_archive(
name = "paper_dropdown_menu",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-dropdown-menu/archive/v1.4.0.tar.gz",
strip_prefix = "paper-dropdown-menu-1.4.0",
)
new_http_archive(
name = "paper_header_panel",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-header-panel/archive/v1.1.4.tar.gz",
strip_prefix = "paper-header-panel-1.1.4",
)
new_http_archive(
name = "paper_icon_button",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-icon-button/archive/v1.1.3.tar.gz",
strip_prefix = "paper-icon-button-1.1.3",
)
new_http_archive(
name = "paper_input",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-input/archive/v1.1.18.tar.gz",
strip_prefix = "paper-input-1.1.18",
)
new_http_archive(
name = "paper_item",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-item/archive/v1.1.4.tar.gz",
strip_prefix = "paper-item-1.1.4",
)
new_http_archive(
name = "paper_listbox",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-listbox/archive/v1.1.2.tar.gz",
strip_prefix = "paper-listbox-1.1.2",
)
new_http_archive(
name = "paper_material",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-material/archive/v1.0.6.tar.gz",
strip_prefix = "paper-material-1.0.6",
)
new_http_archive(
name = "paper_menu",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-menu/archive/v1.2.2.tar.gz",
strip_prefix = "paper-menu-1.2.2",
)
new_http_archive(
name = "paper_menu_button",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-menu-button/archive/v1.5.1.tar.gz",
strip_prefix = "paper-menu-button-1.5.1",
)
new_http_archive(
name = "paper_progress",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-progress/archive/v1.0.9.tar.gz",
strip_prefix = "paper-progress-1.0.9",
)
new_http_archive(
name = "paper_radio_button",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-radio-button/archive/v1.1.2.tar.gz",
strip_prefix = "paper-radio-button-1.1.2",
)
new_http_archive(
name = "paper_radio_group",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-radio-group/archive/v1.0.9.tar.gz",
strip_prefix = "paper-radio-group-1.0.9",
)
new_http_archive(
name = "paper_ripple",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-ripple/archive/v1.0.5.tar.gz",
strip_prefix = "paper-ripple-1.0.5",
)
new_http_archive(
name = "paper_slider",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-slider/archive/v1.0.10.tar.gz",
strip_prefix = "paper-slider-1.0.10",
)
new_http_archive(
name = "paper_spinner",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-spinner/archive/v1.1.1.tar.gz",
strip_prefix = "paper-spinner-1.1.1",
)
new_http_archive(
name = "paper_styles",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-styles/archive/v1.1.4.tar.gz",
strip_prefix = "paper-styles-1.1.4",
)
new_http_archive(
name = "paper_tabs",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-tabs/archive/v1.7.0.tar.gz",
strip_prefix = "paper-tabs-1.7.0",
)
new_http_archive(
name = "paper_toast",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-toast/archive/v1.3.0.tar.gz",
strip_prefix = "paper-toast-1.3.0",
)
new_http_archive(
name = "paper_toggle_button",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-toggle-button/archive/v1.2.0.tar.gz",
strip_prefix = "paper-toggle-button-1.2.0",
)
new_http_archive(
name = "paper_toolbar",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-toolbar/archive/v1.1.4.tar.gz",
strip_prefix = "paper-toolbar-1.1.4",
)
new_http_archive(
name = "paper_tooltip",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-tooltip/archive/v1.1.2.tar.gz",
strip_prefix = "paper-tooltip-1.1.2",
)
new_http_archive(
name = "plottable",
build_file = "bower.BUILD",
url = "https://github.com/palantir/plottable/archive/v1.16.1.tar.gz",
strip_prefix = "plottable-1.16.1",
)
new_http_archive(
name = "polymer_archive",
build_file = "bower.BUILD",
url = "https://github.com/polymer/polymer/archive/v1.7.0.tar.gz",
strip_prefix = "polymer-1.7.0",
)
new_http_archive(
name = "promise_polyfill",
build_file = "bower.BUILD",
url = "https://github.com/polymerlabs/promise-polyfill/archive/v1.0.0.tar.gz",
strip_prefix = "promise-polyfill-1.0.0",
)
http_file(
name = "three_js_three_min_js",
url = "https://raw.githubusercontent.com/mrdoob/three.js/r77/build/three.min.js",
)
http_file(
name = "three_js_orbitcontrols_js",
url = "https://raw.githubusercontent.com/mrdoob/three.js/r77/examples/js/controls/OrbitControls.js",
)
new_http_archive(
name = "web_animations_js",
build_file = "bower.BUILD",
url = "https://github.com/web-animations/web-animations-js/archive/2.2.1.tar.gz",
strip_prefix = "web-animations-js-2.2.1",
)
new_http_archive(
name = "webcomponentsjs",
build_file = "bower.BUILD",
url = "https://github.com/webcomponents/webcomponentsjs/archive/v0.7.22.tar.gz",
strip_prefix = "webcomponentsjs-0.7.22",
)
http_file(
name = "weblas_weblas_js",
url = "https://raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/dist/weblas.js",
name = "stylize",
build_file = "models.BUILD",
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
"http://download.tensorflow.org/models/stylize_v1.zip",
],
)

View File

@ -1,645 +0,0 @@
# AUTOGENERATED FILE by tensorboard_bower_dependency_sync.py
package(default_visibility = ["//visibility:public"])
filegroup(
name = "d3",
srcs = [
"d3.js",
"d3.min.js",
"package.js",
],
)
filegroup(
name = "dagre",
srcs = [
"dist/dagre.core.js",
"dist/dagre.core.min.js",
],
)
filegroup(
name = "es6_promise",
srcs = [
"promise.js",
"promise.min.js",
],
)
filegroup(
name = "font_roboto",
srcs = ["roboto.html"],
)
filegroup(
name = "graphlib",
srcs = [
"dist/graphlib.core.js",
"dist/graphlib.core.min.js",
],
)
filegroup(
name = "iron_a11y_announcer",
srcs = [
"index.html",
"iron-a11y-announcer.html",
],
)
filegroup(
name = "iron_a11y_keys_behavior",
srcs = [
"index.html",
"iron-a11y-keys-behavior.html",
],
)
filegroup(
name = "iron_ajax",
srcs = [
"index.html",
"iron-ajax.html",
"iron-request.html",
],
)
filegroup(
name = "iron_autogrow_textarea",
srcs = [
"index.html",
"iron-autogrow-textarea.html",
],
)
filegroup(
name = "iron_behaviors",
srcs = [
"index.html",
"iron-button-state.html",
"iron-control-state.html",
],
)
filegroup(
name = "iron_checked_element_behavior",
srcs = [
"index.html",
"iron-checked-element-behavior.html",
],
)
filegroup(
name = "iron_collapse",
srcs = [
"index.html",
"iron-collapse.html",
],
)
filegroup(
name = "iron_dropdown",
srcs = [
"index.html",
"iron-dropdown.html",
"iron-dropdown-scroll-manager.html",
],
)
filegroup(
name = "iron_fit_behavior",
srcs = [
"index.html",
"iron-fit-behavior.html",
],
)
filegroup(
name = "iron_flex_layout",
srcs = [
"classes/iron-flex-layout.html",
"classes/iron-shadow-flex-layout.html",
"index.html",
"iron-flex-layout.html",
"iron-flex-layout-classes.html",
],
)
filegroup(
name = "iron_form_element_behavior",
srcs = [
"index.html",
"iron-form-element-behavior.html",
],
)
filegroup(
name = "iron_icon",
srcs = [
"index.html",
"iron-icon.html",
],
)
filegroup(
name = "iron_icons",
srcs = [
"av-icons.html",
"communication-icons.html",
"device-icons.html",
"editor-icons.html",
"hardware-icons.html",
"image-icons.html",
"index.html",
"iron-icons.html",
"maps-icons.html",
"notification-icons.html",
"places-icons.html",
"social-icons.html",
],
)
filegroup(
name = "iron_iconset_svg",
srcs = [
"index.html",
"iron-iconset-svg.html",
],
)
filegroup(
name = "iron_input",
srcs = [
"index.html",
"iron-input.html",
],
)
filegroup(
name = "iron_list",
srcs = [
"index.html",
"iron-list.html",
"test/smoke/avg-worst-case.html",
"test/smoke/dummy-data.html",
"test/smoke/index.html",
"test/smoke/physical-count.html",
],
)
filegroup(
name = "iron_menu_behavior",
srcs = [
"index.html",
"iron-menu-behavior.html",
"iron-menubar-behavior.html",
],
)
filegroup(
name = "iron_meta",
srcs = [
"index.html",
"iron-meta.html",
],
)
filegroup(
name = "iron_overlay_behavior",
srcs = [
"index.html",
"iron-focusables-helper.html",
"iron-overlay-backdrop.html",
"iron-overlay-behavior.html",
"iron-overlay-manager.html",
],
)
filegroup(
name = "iron_range_behavior",
srcs = [
"index.html",
"iron-range-behavior.html",
],
)
filegroup(
name = "iron_resizable_behavior",
srcs = [
"demo/src/x-app.html",
"index.html",
"iron-resizable-behavior.html",
],
)
filegroup(
name = "iron_scroll_target_behavior",
srcs = [
"index.html",
"iron-scroll-target-behavior.html",
],
)
filegroup(
name = "iron_selector",
srcs = [
"index.html",
"iron-multi-selectable.html",
"iron-selectable.html",
"iron-selection.html",
"iron-selector.html",
],
)
filegroup(
name = "iron_validatable_behavior",
srcs = [
"index.html",
"iron-validatable-behavior.html",
],
)
filegroup(
name = "lodash",
srcs = [
"lodash.js",
"lodash.min.js",
],
)
filegroup(
name = "neon_animation",
srcs = [
"animations/cascaded-animation.html",
"animations/fade-in-animation.html",
"animations/fade-out-animation.html",
"animations/hero-animation.html",
"animations/opaque-animation.html",
"animations/reverse-ripple-animation.html",
"animations/ripple-animation.html",
"animations/scale-down-animation.html",
"animations/scale-up-animation.html",
"animations/slide-down-animation.html",
"animations/slide-from-bottom-animation.html",
"animations/slide-from-left-animation.html",
"animations/slide-from-right-animation.html",
"animations/slide-from-top-animation.html",
"animations/slide-left-animation.html",
"animations/slide-right-animation.html",
"animations/slide-up-animation.html",
"animations/transform-animation.html",
"demo/card/index.html",
"demo/card/x-card.html",
"demo/card/x-cards-list.html",
"demo/declarative/index.html",
"demo/doc/index.html",
"demo/doc/my-animatable.html",
"demo/doc/my-dialog.html",
"demo/dropdown/animated-dropdown.html",
"demo/dropdown/index.html",
"demo/grid/animated-grid.html",
"demo/grid/fullsize-page-with-card.html",
"demo/grid/index.html",
"demo/list/full-view.html",
"demo/list/index.html",
"demo/list/list-demo.html",
"demo/list/list-view.html",
"demo/load/animated-grid.html",
"demo/load/full-page.html",
"demo/load/index.html",
"demo/reprojection/animated-grid.html",
"demo/reprojection/fullsize-page-with-card.html",
"demo/reprojection/index.html",
"demo/reprojection/reprojected-pages.html",
"demo/tiles/circles-page.html",
"demo/tiles/index.html",
"demo/tiles/squares-page.html",
"index.html",
"neon-animatable.html",
"neon-animatable-behavior.html",
"neon-animated-pages.html",
"neon-animation.html",
"neon-animation-behavior.html",
"neon-animation-runner-behavior.html",
"neon-animations.html",
"neon-shared-element-animatable-behavior.html",
"neon-shared-element-animation-behavior.html",
"web-animations.html",
],
)
filegroup(
name = "paper_behaviors",
srcs = [
"index.html",
"paper-button-behavior.html",
"paper-checked-element-behavior.html",
"paper-inky-focus-behavior.html",
"paper-ripple-behavior.html",
],
)
filegroup(
name = "paper_button",
srcs = [
"index.html",
"paper-button.html",
],
)
filegroup(
name = "paper_checkbox",
srcs = [
"index.html",
"paper-checkbox.html",
],
)
filegroup(
name = "paper_dialog",
srcs = [
"index.html",
"paper-dialog.html",
],
)
filegroup(
name = "paper_dialog_behavior",
srcs = [
"index.html",
"paper-dialog-behavior.html",
"paper-dialog-common.css",
"paper-dialog-shared-styles.html",
],
)
filegroup(
name = "paper_dialog_scrollable",
srcs = [
"index.html",
"paper-dialog-scrollable.html",
],
)
filegroup(
name = "paper_dropdown_menu",
srcs = [
"index.html",
"paper-dropdown-menu.html",
"paper-dropdown-menu-icons.html",
"paper-dropdown-menu-light.html",
"paper-dropdown-menu-shared-styles.html",
],
)
filegroup(
name = "paper_header_panel",
srcs = [
"index.html",
"paper-header-panel.html",
],
)
filegroup(
name = "paper_icon_button",
srcs = [
"index.html",
"paper-icon-button.html",
"paper-icon-button-light.html",
],
)
filegroup(
name = "paper_input",
srcs = [
"all-imports.html",
"index.html",
"paper-input.html",
"paper-input-addon-behavior.html",
"paper-input-behavior.html",
"paper-input-char-counter.html",
"paper-input-container.html",
"paper-input-error.html",
"paper-textarea.html",
],
)
filegroup(
name = "paper_item",
srcs = [
"all-imports.html",
"index.html",
"paper-icon-item.html",
"paper-item.html",
"paper-item-behavior.html",
"paper-item-body.html",
"paper-item-shared-styles.html",
],
)
filegroup(
name = "paper_listbox",
srcs = [
"index.html",
"paper-listbox.html",
],
)
filegroup(
name = "paper_material",
srcs = [
"index.html",
"paper-material.html",
"paper-material-shared-styles.html",
],
)
filegroup(
name = "paper_menu",
srcs = [
"index.html",
"paper-menu.html",
"paper-menu-shared-styles.html",
"paper-submenu.html",
],
)
filegroup(
name = "paper_menu_button",
srcs = [
"index.html",
"paper-menu-button.html",
"paper-menu-button-animations.html",
],
)
filegroup(
name = "paper_progress",
srcs = [
"index.html",
"paper-progress.html",
],
)
filegroup(
name = "paper_radio_button",
srcs = [
"index.html",
"paper-radio-button.html",
],
)
filegroup(
name = "paper_radio_group",
srcs = [
"index.html",
"paper-radio-group.html",
],
)
filegroup(
name = "paper_ripple",
srcs = [
"index.html",
"paper-ripple.html",
],
)
filegroup(
name = "paper_slider",
srcs = [
"index.html",
"paper-slider.html",
],
)
filegroup(
name = "paper_spinner",
srcs = [
"index.html",
"paper-spinner.html",
"paper-spinner-behavior.html",
"paper-spinner-lite.html",
"paper-spinner-styles.html",
],
)
filegroup(
name = "paper_styles",
srcs = [
"classes/global.html",
"classes/shadow.html",
"classes/shadow-layout.html",
"classes/typography.html",
"color.html",
"default-theme.html",
"demo.css",
"demo-pages.html",
"index.html",
"paper-styles.html",
"paper-styles-classes.html",
"shadow.html",
"typography.html",
],
)
filegroup(
name = "paper_tabs",
srcs = [
"index.html",
"paper-tab.html",
"paper-tabs.html",
"paper-tabs-icons.html",
],
)
filegroup(
name = "paper_toast",
srcs = [
"index.html",
"paper-toast.html",
],
)
filegroup(
name = "paper_toggle_button",
srcs = [
"index.html",
"paper-toggle-button.html",
],
)
filegroup(
name = "paper_toolbar",
srcs = [
"index.html",
"paper-toolbar.html",
],
)
filegroup(
name = "paper_tooltip",
srcs = [
"index.html",
"paper-tooltip.html",
],
)
filegroup(
name = "plottable",
srcs = [
"plottable.css",
"plottable.js",
"plottable.min.js",
],
)
filegroup(
name = "polymer",
srcs = [
"polymer.html",
"polymer-micro.html",
"polymer-mini.html",
],
)
filegroup(
name = "promise_polyfill",
srcs = [
"Gruntfile.js",
"Promise.js",
"Promise.min.js",
"Promise-Statics.js",
"promise-polyfill.html",
"promise-polyfill-lite.html",
],
)
filegroup(
name = "web_animations_js",
srcs = [
"web-animations.html",
"web-animations.min.js",
"web-animations-next.min.js",
"web-animations-next-lite.min.js",
],
)
filegroup(
name = "webcomponentsjs",
srcs = [
"CustomElements.js",
"CustomElements.min.js",
"HTMLImports.js",
"HTMLImports.min.js",
"MutationObserver.js",
"MutationObserver.min.js",
"ShadowDOM.js",
"ShadowDOM.min.js",
"webcomponents.js",
"webcomponents.min.js",
"webcomponents-lite.js",
"webcomponents-lite.min.js",
],
)

578
configure vendored
View File

@ -3,6 +3,8 @@
set -e
set -o pipefail
MIN_BAZEL_VERSION=0.4.5
# Find out the absolute path to where ./configure resides
pushd `dirname $0` > /dev/null
SOURCE_BASE_DIR=`pwd -P`
@ -11,40 +13,179 @@ popd > /dev/null
PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
function is_linux() {
if [[ "${PLATFORM}" == "linux" ]]; then
true
else
false
fi
[[ "${PLATFORM}" == "linux" ]]
}
function is_macos() {
if [[ "${PLATFORM}" == "darwin" ]]; then
true
else
false
fi
[[ "${PLATFORM}" == "darwin" ]]
}
function is_windows() {
# On windows, the shell script is actually running in msys
if [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]]; then
true
else
false
fi
[[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]]
}
function bazel_clean_and_fetch() {
# bazel clean --expunge currently doesn't work on Windows
# TODO(pcloudy): Re-enable it after bazel clean --expunge is fixed.
if ! is_windows; then
bazel clean --expunge
fi
bazel fetch "//tensorflow/... -//tensorflow/contrib/nccl/... \
-//tensorflow/examples/android/..."
function sed_in_place() {
sed -e $1 $2 > "$2.bak"
mv "$2.bak" $2
}
function write_to_bazelrc() {
echo "$1" >> .tf_configure.bazelrc
}
function write_action_env_to_bazelrc() {
write_to_bazelrc "build --action_env $1=\"$2\""
}
function python_path {
"$PYTHON_BIN_PATH" - <<END
from __future__ import print_function
import site
import os
try:
input = raw_input
except NameError:
pass
python_paths = []
if os.getenv('PYTHONPATH') is not None:
python_paths = os.getenv('PYTHONPATH').split(':')
try:
library_paths = site.getsitepackages()
except AttributeError:
from distutils.sysconfig import get_python_lib
library_paths = [get_python_lib()]
all_paths = set(python_paths + library_paths)
paths = []
for path in all_paths:
if os.path.isdir(path):
paths.append(path)
print(",".join(paths))
END
}
function setup_python {
## Set up python-related environment settings:
while true; do
fromuser=""
if [ -z "$PYTHON_BIN_PATH" ]; then
default_python_bin_path=$(which python || which python3 || true)
read -p "Please specify the location of python. [Default is $default_python_bin_path]: " PYTHON_BIN_PATH
fromuser="1"
if [ -z "$PYTHON_BIN_PATH" ]; then
PYTHON_BIN_PATH=$default_python_bin_path
fi
fi
if [ -e "$PYTHON_BIN_PATH" ]; then
break
fi
echo "Invalid python path. ${PYTHON_BIN_PATH} cannot be found" 1>&2
if [ -z "$fromuser" ]; then
exit 1
fi
PYTHON_BIN_PATH=""
# Retry
done
if [ -z "$PYTHON_LIB_PATH" ]; then
# Split python_path into an array of paths, this allows path containing spaces
IFS=',' read -r -a python_lib_path <<< "$(python_path)"
if [ 1 = "$USE_DEFAULT_PYTHON_LIB_PATH" ]; then
PYTHON_LIB_PATH=${python_lib_path[0]}
echo "Using python library path: $PYTHON_LIB_PATH"
else
echo "Found possible Python library paths:"
for x in "${python_lib_path[@]}"; do
echo " $x"
done
set -- "${python_lib_path[@]}"
echo "Please input the desired Python library path to use. Default is [$1]"
read b || true
if [ "$b" == "" ]; then
PYTHON_LIB_PATH=${python_lib_path[0]}
echo "Using python library path: $PYTHON_LIB_PATH"
else
PYTHON_LIB_PATH="$b"
fi
fi
fi
if [ ! -x "$PYTHON_BIN_PATH" ] || [ -d "$PYTHON_BIN_PATH" ]; then
echo "PYTHON_BIN_PATH is not executable. Is it the python binary?"
exit 1
fi
local python_major_version
python_major_version=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; import sys; print(sys.version_info[0]);' | head -c1)
if [ -z "$python_major_version" ]; then
echo -e "\n\nERROR: Problem getting python version. Is $PYTHON_BIN_PATH the correct python binary?"
exit 1
fi
# Convert python path to Windows style before writing into bazel.rc
if is_windows; then
PYTHON_BIN_PATH="$(cygpath -m "$PYTHON_BIN_PATH")"
PYTHON_LIB_PATH="$(cygpath -m "$PYTHON_LIB_PATH")"
fi
# Set-up env variables used by python_configure.bzl
write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH"
write_action_env_to_bazelrc "PYTHON_LIB_PATH" "$PYTHON_LIB_PATH"
write_to_bazelrc "build --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\""
write_to_bazelrc "build --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\""
write_to_bazelrc "build --force_python=py$python_major_version"
write_to_bazelrc "build --host_force_python=py$python_major_version"
write_to_bazelrc "build --python${python_major_version}_path=\"$PYTHON_BIN_PATH\""
write_to_bazelrc "test --force_python=py$python_major_version"
write_to_bazelrc "test --host_force_python=py$python_major_version"
write_to_bazelrc "test --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\""
write_to_bazelrc "test --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\""
write_to_bazelrc "run --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\""
write_to_bazelrc "run --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\""
# Write tools/python_bin_path.sh
echo "export PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" > tools/python_bin_path.sh
}
function version {
echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }';
}
bazel version > bazel.version
curr_bazel_version=$(head -n 1 bazel.version | cut -d ' ' -f3)
rm -f bazel.version
echo "You have bazel $curr_bazel_version installed."
if [ -z "$curr_bazel_version" ]; then
echo "WARNING: current bazel installation is not a release version."
echo "Make sure you are running at least bazel $MIN_BAZEL_VERSION."
elif [ "$(version "$MIN_BAZEL_VERSION")" -gt "$(version "$curr_bazel_version")" ]; then
echo "Please upgrade your bazel installation to version $MIN_BAZEL_VERSION or higher to build TensorFlow!"
echo "Exiting..."
exit 1
fi
# This file contains customized config settings.
rm -f .tf_configure.bazelrc
touch .tf_configure.bazelrc
if [[ ! -e .bazelrc ]]; then
if [[ -e "${HOME}/.bazelrc" ]]; then
echo "import ${HOME}/.bazelrc" >.bazelrc
else
touch .bazelrc
fi
fi
sed_in_place "/tf_configure/d" .bazelrc
echo "import %workspace%/.tf_configure.bazelrc" >> .bazelrc
# Delete any leftover BUILD files from the Makefile build, which would interfere
# with Bazel parsing.
MAKEFILE_DOWNLOAD_DIR=tensorflow/contrib/makefile/downloads
@ -52,58 +193,65 @@ if [ -d "${MAKEFILE_DOWNLOAD_DIR}" ]; then
find ${MAKEFILE_DOWNLOAD_DIR} -type f -name '*BUILD' -delete
fi
## Set up python-related environment settings
while true; do
fromuser=""
if [ -z "$PYTHON_BIN_PATH" ]; then
default_python_bin_path=$(which python || which python3 || true)
read -p "Please specify the location of python. [Default is $default_python_bin_path]: " PYTHON_BIN_PATH
fromuser="1"
if [ -z "$PYTHON_BIN_PATH" ]; then
PYTHON_BIN_PATH=$default_python_bin_path
fi
fi
if [ -e "$PYTHON_BIN_PATH" ]; then
break
fi
echo "Invalid python path. ${PYTHON_BIN_PATH} cannot be found" 1>&2
if [ -z "$fromuser" ]; then
exit 1
fi
PYTHON_BIN_PATH=""
# Retry
done
setup_python
## Set up MKL related environment settings
if false; then # Disable building with MKL for now
while [ "$TF_NEED_MKL" == "" ]; do
while [ "$TF_NEED_MKL" == "" ]; do
fromuser=""
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
fromuser="1"
case $INPUT in
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
OSNAME=`uname -s`
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
while [ "$TF_DOWNLOAD_MKL" == "" ]; do
fromuser=""
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
read -p "Do you wish to download MKL LIB from the web? [Y/n] " INPUT
fromuser="1"
case $INPUT in
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
* ) echo "Invalid selection: " $INPUT;;
[Yy]* ) TF_DOWNLOAD_MKL=1;;
[Nn]* ) TF_DOWNLOAD_MKL=0;;
"" ) TF_DOWNLOAD_MKL=1;;
* ) echo "Invalid selection: " $INPUT; exit 1;;
esac
done
OSNAME=`uname -s`
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
if [[ "$TF_DOWNLOAD_MKL" == "1" ]]; then
DST=`dirname $0`
ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170110.tgz
GITHUB_RELEASE_TAG=v0.3
ARCHIVE_BASENAME=mklml_lnx_2018.0.20170425.tgz
GITHUB_RELEASE_TAG=v0.7
MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME"
if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then
wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL
if ! [ -e "${DST}/third_party/mkl/${ARCHIVE_BASENAME}" ]; then
curl -fSsL -o "${DST}/third_party/mkl/${ARCHIVE_BASENAME}" "${MKLURL}"
fi
tar -xzf $DST/third_party/mkl/$ARCHIVE_BASENAME -C $DST/third_party/mkl/
extracted_dir_name="${ARCHIVE_BASENAME%.*}"
MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
if [ "$OSNAME" == "Linux" ]; then
else
default_mkl_path=/opt/intel/mklml
fromuser=""
if [ -z "$MKL_INSTALL_PATH" ]; then
read -p "Please specify the location where MKL is installed. [Default is $default_mkl_path]: " MKL_INSTALL_PATH
fromuser="1"
fi
if [ -z "$MKL_INSTALL_PATH" ]; then
MKL_INSTALL_PATH=$default_mkl_path
fi
# Result returned from "read" will be used unexpanded. That make "~" unusable.
# Going through one more level of expansion to handle that.
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
fi
if [ "$OSNAME" == "Linux" ]; then
# Full MKL configuration
MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version?
MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION?
@ -111,24 +259,29 @@ if false; then # Disable building with MKL for now
# MKL-ML configuration
MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version?
MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION?
elif [ "$OSNAME" == "Darwin" ]; then
elif [ "$OSNAME" == "Darwin" ]; then
echo "Darwin is unsupported yet";
exit 1
fi
fi
if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then
if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then
ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/
ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
else
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist";
loc=$(locate -e libdl.so.2 | sed -n 1p)
ln -sf $loc third_party/mkl/libdl.so.2
elif [ -e "$MKL_INSTALL_PATH/${MKL_RT_LIB_PATH}" ]; then
ln -sf $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} third_party/mkl/
ln -sf $MKL_INSTALL_PATH/${MKL_RT_OMP_LIB_PATH} third_party/mkl/
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
loc=$(locate -e libdl.so.2 | sed -n 1p)
ln -sf $loc third_party/mkl/libdl.so.2
else
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} nor $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} exists";
exit 1
fi
if [ -z "$fromuser" ]; then
exit 1
fi
fi
cat > third_party/mkl/mkl.config <<EOF
# MKL_INSTALL_PATH refers to the location of MKL root folder. The MKL header and library
@ -136,9 +289,8 @@ cat > third_party/mkl/mkl.config <<EOF
MKL_INSTALL_PATH=$MKL_INSTALL_PATH
EOF
fi # TF_NEED_MKL
################## MKL
fi # Disable building with MKL for now
fi # TF_NEED_MKL
## End MKL setup
## Set up architecture-dependent optimization flags.
if [ -z "$CC_OPT_FLAGS" ]; then
@ -155,6 +307,7 @@ if is_windows; then
TF_NEED_HDFS=0
TF_NEED_JEMALLOC=0
TF_NEED_OPENCL=0
TF_CUDA_CLANG=0
fi
if is_linux; then
@ -172,13 +325,11 @@ else
TF_NEED_JEMALLOC=0
fi
if [ "$TF_NEED_JEMALLOC" == "1" ]; then
sed -i -e "s/WITH_JEMALLOC = False/WITH_JEMALLOC = True/" tensorflow/core/platform/default/build_config.bzl
else
sed -i -e "s/WITH_JEMALLOC = True/WITH_JEMALLOC = False/" tensorflow/core/platform/default/build_config.bzl
if [[ "$TF_NEED_JEMALLOC" == "1" ]]; then
write_to_bazelrc 'build --define with_jemalloc=true'
fi
while [ "$TF_NEED_GCP" == "" ]; do
while [[ "$TF_NEED_GCP" == "" ]]; do
read -p "Do you wish to build TensorFlow with "\
"Google Cloud Platform support? [y/N] " INPUT
case $INPUT in
@ -192,23 +343,11 @@ while [ "$TF_NEED_GCP" == "" ]; do
esac
done
if [ "$TF_NEED_GCP" == "1" ]; then
## Verify that libcurl header files are available.
# Only check Linux, since on MacOS the header files are installed with XCode.
if is_linux && [[ ! -f "/usr/include/curl/curl.h" ]]; then
echo "ERROR: It appears that the development version of libcurl is not "\
"available. Please install the libcurl3-dev package."
exit 1
fi
# Update Bazel build configuration.
sed -i -e "s/WITH_GCP_SUPPORT = False/WITH_GCP_SUPPORT = True/" tensorflow/core/platform/default/build_config.bzl
else
# Update Bazel build configuration.
sed -i -e "s/WITH_GCP_SUPPORT = True/WITH_GCP_SUPPORT = False/" tensorflow/core/platform/default/build_config.bzl
if [[ "$TF_NEED_GCP" == "1" ]]; then
write_to_bazelrc 'build --define with_gcp_support=true'
fi
while [ "$TF_NEED_HDFS" == "" ]; do
while [[ "$TF_NEED_HDFS" == "" ]]; do
read -p "Do you wish to build TensorFlow with "\
"Hadoop File System support? [y/N] " INPUT
case $INPUT in
@ -222,16 +361,12 @@ while [ "$TF_NEED_HDFS" == "" ]; do
esac
done
if [ "$TF_NEED_HDFS" == "1" ]; then
# Update Bazel build configuration.
sed -i -e "s/WITH_HDFS_SUPPORT = False/WITH_HDFS_SUPPORT = True/" tensorflow/core/platform/default/build_config.bzl
else
# Update Bazel build configuration.
sed -i -e "s/WITH_HDFS_SUPPORT = True/WITH_HDFS_SUPPORT = False/" tensorflow/core/platform/default/build_config.bzl
if [[ "$TF_NEED_HDFS" == "1" ]]; then
write_to_bazelrc 'build --define with_hdfs_support=true'
fi
## Enable XLA.
while [ "$TF_ENABLE_XLA" == "" ]; do
while [[ "$TF_ENABLE_XLA" == "" ]]; do
read -p "Do you wish to build TensorFlow with the XLA just-in-time compiler (experimental)? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=1;;
@ -241,22 +376,32 @@ while [ "$TF_ENABLE_XLA" == "" ]; do
esac
done
if [ "$TF_ENABLE_XLA" == "1" ]; then
# Update Bazel build configuration.
sed -i -e "s/^WITH_XLA_SUPPORT = [FT].*/WITH_XLA_SUPPORT = True/" tensorflow/core/platform/default/build_config_root.bzl
else
# Update Bazel build configuration.
sed -i -e "s/^WITH_XLA_SUPPORT = [FT].*/WITH_XLA_SUPPORT = False/" tensorflow/core/platform/default/build_config_root.bzl
if [[ "$TF_ENABLE_XLA" == "1" ]]; then
write_to_bazelrc 'build --define with_xla_support=true'
fi
# Verbs configuration
while [ "$TF_NEED_VERBS" == "" ]; do
read -p "Do you wish to build TensorFlow with "\
"VERBS support? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "VERBS support will be enabled for "\
"TensorFlow"; TF_NEED_VERBS=1;;
[Nn]* ) echo "No VERBS support will be enabled for "\
"TensorFlow"; TF_NEED_VERBS=0;;
"" ) echo "No VERBS support will be enabled for "\
"TensorFlow"; TF_NEED_VERBS=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
# Invoke python_config and set up symlinks to python includes
./util/python/python_config.sh --setup "$PYTHON_BIN_PATH"
if [[ "$TF_NEED_VERBS" == "1" ]]; then
write_to_bazelrc 'build --define with_verbs_support=true'
fi
# Append CC optimization flags to bazel.rc
echo >> tools/bazel.rc
for opt in $CC_OPT_FLAGS; do
echo "build:opt --cxxopt=$opt --copt=$opt" >> tools/bazel.rc
write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt"
done
# Run the gen_git_source to create links where bazel can track dependencies for
@ -289,35 +434,46 @@ while [ "$TF_NEED_CUDA" == "" ]; do
done
export TF_NEED_CUDA
write_action_env_to_bazelrc "TF_NEED_CUDA" "$TF_NEED_CUDA"
export TF_NEED_OPENCL
if [[ "$TF_NEED_CUDA" == "0" ]] && [[ "$TF_NEED_OPENCL" == "0" ]]; then
echo "Configuration finished"
bazel_clean_and_fetch
exit
fi
write_action_env_to_bazelrc "TF_NEED_OPENCL" "$TF_NEED_OPENCL"
if [ "$TF_NEED_CUDA" == "1" ]; then
# Set up which gcc nvcc should use as the host compiler
# No need to set this on Windows
while ! is_windows && true; do
while [[ "$TF_CUDA_CLANG" == "" ]]; do
read -p "Do you want to use clang as CUDA compiler? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "Clang will be used as CUDA compiler"; TF_CUDA_CLANG=1;;
[Nn]* ) echo "nvcc will be used as CUDA compiler"; TF_CUDA_CLANG=0;;
"" ) echo "nvcc will be used as CUDA compiler"; TF_CUDA_CLANG=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
export TF_CUDA_CLANG
write_action_env_to_bazelrc "TF_CUDA_CLANG" "$TF_CUDA_CLANG"
# Set up which clang we should use as the cuda / host compiler.
while [[ "$TF_CUDA_CLANG" == "1" ]] && true; do
fromuser=""
if [ -z "$GCC_HOST_COMPILER_PATH" ]; then
default_gcc_host_compiler_path=$(which gcc || true)
read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH
if [ -z "$CLANG_CUDA_COMPILER_PATH" ]; then
default_clang_host_compiler_path=$(which clang || true)
read -p "Please specify which clang should be used as device and host compiler. [Default is $default_clang_host_compiler_path]: " CLANG_CUDA_COMPILER_PATH
fromuser="1"
if [ -z "$GCC_HOST_COMPILER_PATH" ]; then
GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path"
if [ -z "$CLANG_CUDA_COMPILER_PATH" ]; then
CLANG_CUDA_COMPILER_PATH="$default_clang_host_compiler_path"
fi
fi
if [ -e "$GCC_HOST_COMPILER_PATH" ]; then
export GCC_HOST_COMPILER_PATH
if [ -e "$CLANG_CUDA_COMPILER_PATH" ]; then
export CLANG_CUDA_COMPILER_PATH
write_action_env_to_bazelrc "CLANG_CUDA_COMPILER_PATH" "$CLANG_CUDA_COMPILER_PATH"
break
fi
echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2
echo "Invalid clang path. ${CLANG_CUDA_COMPILER_PATH} cannot be found" 1>&2
if [ -z "$fromuser" ]; then
exit 1
fi
GCC_HOST_COMPILER_PATH=""
CLANG_CUDA_COMPILER_PATH=""
# Retry
done
@ -325,7 +481,7 @@ done
while true; do
# Configure the Cuda SDK version to use.
if [ -z "$TF_CUDA_VERSION" ]; then
read -p "Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to use system default]: " TF_CUDA_VERSION
read -p "Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 8.0]: " TF_CUDA_VERSION
fi
fromuser=""
@ -337,6 +493,11 @@ while true; do
else
default_cuda_path="$(cygpath -m "$CUDA_PATH")"
fi
elif is_linux; then
# If the default doesn't exist, try an alternative default.
if [ ! -d $default_cuda_path ] && [ -d /opt/cuda ]; then
default_cuda_path=/opt/cuda
fi
fi
read -p "Please specify the location where CUDA $TF_CUDA_VERSION toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH
fromuser="1"
@ -361,6 +522,7 @@ while true; do
if [ -e "${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH}" ]; then
export CUDA_TOOLKIT_PATH
write_action_env_to_bazelrc "CUDA_TOOLKIT_PATH" "$CUDA_TOOLKIT_PATH"
export TF_CUDA_VERSION
break
fi
@ -374,11 +536,47 @@ while true; do
CUDA_TOOLKIT_PATH=""
done
# Set default CUDA version if not set
if [ -z "$TF_CUDA_VERSION" ]; then
TF_CUDA_VERSION="8.0"
export TF_CUDA_VERSION
fi
write_action_env_to_bazelrc "TF_CUDA_VERSION" "$TF_CUDA_VERSION"
# Set up which gcc nvcc should use as the host compiler
# No need to set this on Windows
while [[ "$TF_CUDA_CLANG" != "1" ]] && ! is_windows && true; do
fromuser=""
if [ -z "$GCC_HOST_COMPILER_PATH" ]; then
default_gcc_host_compiler_path=$(which gcc || true)
cuda_bin_symlink="$CUDA_TOOLKIT_PATH/bin/gcc"
if [ -L "$cuda_bin_symlink" ]; then
default_gcc_host_compiler_path=$(readlink $cuda_bin_symlink)
fi
read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH
fromuser="1"
if [ -z "$GCC_HOST_COMPILER_PATH" ]; then
GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path"
fi
fi
if [ -e "$GCC_HOST_COMPILER_PATH" ]; then
export GCC_HOST_COMPILER_PATH
write_action_env_to_bazelrc "GCC_HOST_COMPILER_PATH" "$GCC_HOST_COMPILER_PATH"
break
fi
echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2
if [ -z "$fromuser" ]; then
exit 1
fi
GCC_HOST_COMPILER_PATH=""
# Retry
done
# Find out where the cuDNN library is installed
while true; do
# Configure the Cudnn version to use.
# Configure the cuDNN version to use.
if [ -z "$TF_CUDNN_VERSION" ]; then
read -p "Please specify the Cudnn version you want to use. [Leave empty to use system default]: " TF_CUDNN_VERSION
read -p "Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 6.0]: " TF_CUDNN_VERSION
fi
fromuser=""
@ -389,7 +587,7 @@ while true; do
if [ -z "$CUDNN_INSTALL_PATH" ]; then
CUDNN_INSTALL_PATH=$default_cudnn_path
fi
# Result returned from "read" will be used unexpanded. That make "~" unuseable.
# Result returned from "read" will be used unexpanded. That make "~" unusable.
# Going through one more level of expansion to handle that.
CUDNN_INSTALL_PATH=`"${PYTHON_BIN_PATH}" -c "import os; print(os.path.realpath(os.path.expanduser('${CUDNN_INSTALL_PATH}')))"`
fi
@ -411,17 +609,26 @@ while true; do
CUDA_DNN_LIB_ALT_PATH="libcudnn${TF_CUDNN_EXT}.dylib"
fi
if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" -o -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then
if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" ] || [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then
export TF_CUDNN_VERSION
write_action_env_to_bazelrc "TF_CUDNN_VERSION" "$TF_CUDNN_VERSION"
export CUDNN_INSTALL_PATH
write_action_env_to_bazelrc "CUDNN_INSTALL_PATH" "$CUDNN_INSTALL_PATH"
break
fi
if is_linux; then
CUDNN_PATH_FROM_LDCONFIG="$(ldconfig -p | sed -n 's/.*libcudnn.so .* => \(.*\)/\1/p')"
if ! type ldconfig > /dev/null 2>&1; then
LDCONFIG_BIN=/sbin/ldconfig
else
LDCONFIG_BIN=ldconfig
fi
CUDNN_PATH_FROM_LDCONFIG="$($LDCONFIG_BIN -p | sed -n 's/.*libcudnn.so .* => \(.*\)/\1/p')"
if [ -e "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}" ]; then
export TF_CUDNN_VERSION
export CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})"
export CUDNN_INSTALL_PATH
CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})"
write_action_env_to_bazelrc "CUDNN_INSTALL_PATH" "$CUDNN_INSTALL_PATH"
break
fi
fi
@ -440,6 +647,13 @@ while true; do
CUDNN_INSTALL_PATH=""
done
# Set default CUDNN version if not set
if [ -z "$TF_CUDNN_VERSION" ]; then
TF_CUDNN_VERSION="6"
export TF_CUDNN_VERSION
fi
write_action_env_to_bazelrc "TF_CUDNN_VERSION" "$TF_CUDNN_VERSION"
# Configure the compute capabilities that TensorFlow builds for.
# Since Cuda toolkit is not backward-compatible, this is not guaranteed to work.
while true; do
@ -473,6 +687,7 @@ EOF
fi
else
export TF_CUDA_COMPUTE_CAPABILITIES
write_action_env_to_bazelrc "TF_CUDA_COMPUTE_CAPABILITIES" "$TF_CUDA_COMPUTE_CAPABILITIES"
break
fi
TF_CUDA_COMPUTE_CAPABILITIES=""
@ -483,9 +698,15 @@ if is_windows; then
export CUDA_PATH="$CUDA_TOOLKIT_PATH"
export CUDA_COMPUTE_CAPABILITIES="$TF_CUDA_COMPUTE_CAPABILITIES"
export NO_WHOLE_ARCHIVE_OPTION=1
# Set GCC_HOST_COMPILER_PATH to keep cuda_configure.bzl happy
export GCC_HOST_COMPILER_PATH="/usr/bin/dummy_compiler"
write_action_env_to_bazelrc "CUDA_PATH" "$CUDA_PATH"
write_action_env_to_bazelrc "CUDA_COMPUTE_CAPABILITIES" "$CUDA_COMPUTE_CAPABILITIES"
write_action_env_to_bazelrc "NO_WHOLE_ARCHIVE_OPTION" "1"
write_to_bazelrc "build --config=win-cuda"
write_to_bazelrc "test --config=win-cuda"
else
# If CUDA is enabled, always use GPU during build and test.
write_to_bazelrc "build --config=cuda"
write_to_bazelrc "test --config=cuda"
fi
# end of if "$TF_NEED_CUDA" == "1"
@ -499,7 +720,7 @@ if [ "$TF_NEED_OPENCL" == "1" ]; then
while true; do
fromuser=""
if [ -z "$HOST_CXX_COMPILER" ]; then
default_cxx_host_compiler=$(which clang++-3.6 || true)
default_cxx_host_compiler=$(which g++ || true)
read -p "Please specify which C++ compiler should be used as the host C++ compiler. [Default is $default_cxx_host_compiler]: " HOST_CXX_COMPILER
fromuser="1"
if [ -z "$HOST_CXX_COMPILER" ]; then
@ -508,6 +729,7 @@ while true; do
fi
if [ -e "$HOST_CXX_COMPILER" ]; then
export HOST_CXX_COMPILER
write_action_env_to_bazelrc "HOST_CXX_COMPILER" "$HOST_CXX_COMPILER"
break
fi
echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2
@ -522,7 +744,7 @@ done
while true; do
fromuser=""
if [ -z "$HOST_C_COMPILER" ]; then
default_c_host_compiler=$(which clang-3.6 || true)
default_c_host_compiler=$(which gcc || true)
read -p "Please specify which C compiler should be used as the host C compiler. [Default is $default_c_host_compiler]: " HOST_C_COMPILER
fromuser="1"
if [ -z "$HOST_C_COMPILER" ]; then
@ -531,6 +753,7 @@ while true; do
fi
if [ -e "$HOST_C_COMPILER" ]; then
export HOST_C_COMPILER
write_action_env_to_bazelrc "HOST_C_COMPILER" "$HOST_C_COMPILER"
break
fi
echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2
@ -561,6 +784,7 @@ while true; do
if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then
export COMPUTECPP_TOOLKIT_PATH
write_action_env_to_bazelrc "COMPUTECPP_TOOLKIT_PATH" "$COMPUTECPP_TOOLKIT_PATH"
break
fi
echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found"
@ -576,6 +800,82 @@ done
# end of if "$TF_NEED_OPENCL" == "1"
fi
bazel_clean_and_fetch
while [ "$TF_NEED_MPI" == "" ]; do
read -p "Do you wish to build TensorFlow with "\
"MPI support? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "MPI support will be enabled for "\
"TensorFlow"; TF_NEED_MPI=1;;
[Nn]* ) echo "MPI support will not be enabled for "\
"TensorFlow"; TF_NEED_MPI=0;;
"" ) echo "MPI support will not be enabled for "\
"TensorFlow"; TF_NEED_MPI=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
# Find out where the MPI toolkit is installed
while true; do
if [ "$TF_NEED_MPI" == "0" ]; then
break;
fi
fromuser=""
if [ -z "$MPI_HOME" ]; then
#Get the base folder by removing the bin path
default_mpi_path=$(dirname $(dirname $(which mpirun)) || dirname $(dirname $(which mpiexec)) || true)
read -p "Please specify the MPI toolkit folder. [Default is $default_mpi_path]: " MPI_HOME
fromuser="1"
if [ -z "$MPI_HOME" ]; then
MPI_HOME=$default_mpi_path
fi
fi
#Check that the include and library folders are where we expect them to be
if [ -e "$MPI_HOME/include" ] && [ -e "$MPI_HOME/lib" ]; then
break
fi
echo "Invalid path to the MPI Toolkit. ${MPI_HOME}/include or ${MPI_HOME}/lib cannot be found."
if [ -z "$fromuser" ]; then
exit 1
fi
# Retry
MPI_HOME=""
done
if [ "$TF_NEED_MPI" == "1" ]; then
write_to_bazelrc 'build --define with_mpi_support=true'
#Link the MPI header files
ln -sf "${MPI_HOME}/include/mpi.h" third_party/mpi/mpi.h
#Determine if we use OpenMPI or MVAPICH, these require different header files
#to be included here to make bazel dependency checker happy
if [ -e "${MPI_HOME}/include/mpi_portable_platform.h" ]; then
#OpenMPI
ln -sf "${MPI_HOME}/include/mpi_portable_platform.h" third_party/mpi/
sed -i -e "s/MPI_LIB_IS_OPENMPI=False/MPI_LIB_IS_OPENMPI=True/" third_party/mpi/mpi.bzl
else
#MVAPICH / MPICH
ln -sf "${MPI_HOME}/include/mpio.h" third_party/mpi/
ln -sf "${MPI_HOME}/include/mpicxx.h" third_party/mpi/
sed -i -e "s/MPI_LIB_IS_OPENMPI=True/MPI_LIB_IS_OPENMPI=False/" third_party/mpi/mpi.bzl
fi
if [ -e "${MPI_HOME}/lib/libmpi.so" ]; then
ln -sf "${MPI_HOME}/lib/libmpi.so" third_party/mpi/
else
echo "Cannot find the MPI library file in ${MPI_HOME}/lib "
exit 1
fi
fi
echo "Configuration finished"

View File

@ -14,8 +14,33 @@ exports_files([
# Config setting for determining if we are building for Android.
config_setting(
name = "android",
values = {"crosstool_top": "//external:android/crosstool"},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_x86",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "x86",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_x86_64",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "x86_64",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_armeabi",
values = {
"cc_target_os": "android",
"cpu": "armeabi",
},
visibility = ["//visibility:public"],
)
@ -46,6 +71,12 @@ config_setting(
config_setting(
name = "windows",
values = {"cpu": "x64_windows"},
visibility = ["//visibility:public"],
)
config_setting(
name = "windows_msvc",
values = {"cpu": "x64_windows_msvc"},
visibility = ["//visibility:public"],
)
@ -58,9 +89,7 @@ config_setting(
config_setting(
name = "ios",
values = {
"crosstool_top": "//tools/osx/crosstool:crosstool",
},
values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
visibility = ["//visibility:public"],
)
@ -70,6 +99,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_ppc64le",
values = {"cpu": "ppc"},
visibility = ["//visibility:public"],
)
config_setting(
name = "debug",
values = {
@ -86,6 +121,61 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "freebsd",
values = {"cpu": "freebsd"},
visibility = ["//visibility:public"],
)
# TODO(jhseu): Enable on other platforms other than Linux.
config_setting(
name = "with_jemalloc_linux_x86_64",
values = {
"cpu": "k8",
"define": "with_jemalloc=true",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_jemalloc_linux_ppc64le",
values = {
"cpu": "ppc",
"define": "with_jemalloc=true",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_gcp_support",
values = {"define": "with_gcp_support=true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_hdfs_support",
values = {"define": "with_hdfs_support=true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_xla_support",
values = {"define": "with_xla_support=true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_verbs_support",
values = {"define": "with_verbs_support=true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_mpi_support",
values = {"define": "with_mpi_support=true"},
visibility = ["//visibility:public"],
)
package_group(
name = "internal",
packages = ["//tensorflow/..."],
@ -123,7 +213,9 @@ filegroup(
"//tensorflow/compiler/aot/tests:all_files",
"//tensorflow/compiler/jit:all_files",
"//tensorflow/compiler/jit/graphcycles:all_files",
"//tensorflow/compiler/jit/kernels:all_files",
"//tensorflow/compiler/jit/legacy_flags:all_files",
"//tensorflow/compiler/jit/ops:all_files",
"//tensorflow/compiler/tests:all_files",
"//tensorflow/compiler/tf2xla:all_files",
"//tensorflow/compiler/tf2xla/kernels:all_files",
@ -131,7 +223,6 @@ filegroup(
"//tensorflow/compiler/xla/client:all_files",
"//tensorflow/compiler/xla/client/lib:all_files",
"//tensorflow/compiler/xla/legacy_flags:all_files",
"//tensorflow/compiler/xla/port:all_files",
"//tensorflow/compiler/xla/service:all_files",
"//tensorflow/compiler/xla/service/cpu:all_files",
"//tensorflow/compiler/xla/service/gpu:all_files",
@ -141,11 +232,28 @@ filegroup(
"//tensorflow/compiler/xla/tools:all_files",
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/android:all_files",
"//tensorflow/contrib/batching:all_files",
"//tensorflow/contrib/batching/kernels:all_files",
"//tensorflow/contrib/batching/test_util:all_files",
"//tensorflow/contrib/batching/util:all_files",
"//tensorflow/contrib/bayesflow:all_files",
"//tensorflow/contrib/boosted_trees:all_files",
"//tensorflow/contrib/boosted_trees/lib:all_files",
"//tensorflow/contrib/boosted_trees/proto:all_files",
"//tensorflow/contrib/boosted_trees/resources:all_files",
"//tensorflow/contrib/cloud:all_files",
"//tensorflow/contrib/cloud/kernels:all_files",
"//tensorflow/contrib/cluster_resolver:all_files",
"//tensorflow/contrib/compiler:all_files",
"//tensorflow/contrib/copy_graph:all_files",
"//tensorflow/contrib/crf:all_files",
"//tensorflow/contrib/cudnn_rnn:all_files",
"//tensorflow/contrib/data:all_files",
"//tensorflow/contrib/data/python/framework:all_files",
"//tensorflow/contrib/data/python/kernel_tests:all_files",
"//tensorflow/contrib/data/python/ops:all_files",
"//tensorflow/contrib/data/python/util:all_files",
"//tensorflow/contrib/decision_trees:all_files",
"//tensorflow/contrib/distributions:all_files",
"//tensorflow/contrib/factorization:all_files",
"//tensorflow/contrib/factorization/kernels:all_files",
@ -154,10 +262,15 @@ filegroup(
"//tensorflow/contrib/framework:all_files",
"//tensorflow/contrib/graph_editor:all_files",
"//tensorflow/contrib/grid_rnn:all_files",
"//tensorflow/contrib/hooks:all_files",
"//tensorflow/contrib/hvx/hvx_ops_support_checker:all_files",
"//tensorflow/contrib/image:all_files",
"//tensorflow/contrib/imperative:all_files",
"//tensorflow/contrib/input_pipeline:all_files",
"//tensorflow/contrib/input_pipeline/kernels:all_files",
"//tensorflow/contrib/integrate:all_files",
"//tensorflow/contrib/keras:all_files",
"//tensorflow/contrib/kernel_methods:all_files",
"//tensorflow/contrib/labeled_tensor:all_files",
"//tensorflow/contrib/layers:all_files",
"//tensorflow/contrib/layers/kernels:all_files",
@ -172,30 +285,44 @@ filegroup(
"//tensorflow/contrib/nn:all_files",
"//tensorflow/contrib/opt:all_files",
"//tensorflow/contrib/rnn:all_files",
"//tensorflow/contrib/saved_model:all_files",
"//tensorflow/contrib/saved_model/cc/saved_model:all_files",
"//tensorflow/contrib/seq2seq:all_files",
"//tensorflow/contrib/session_bundle:all_files",
"//tensorflow/contrib/session_bundle/example:all_files",
"//tensorflow/contrib/signal:all_files",
"//tensorflow/contrib/slim:all_files",
"//tensorflow/contrib/slim/python/slim/data:all_files",
"//tensorflow/contrib/slim/python/slim/nets:all_files",
"//tensorflow/contrib/solvers:all_files",
"//tensorflow/contrib/sparsemax:all_files",
"//tensorflow/contrib/specs:all_files",
"//tensorflow/contrib/staging:all_files",
"//tensorflow/contrib/stat_summarizer:all_files",
"//tensorflow/contrib/stateless:all_files",
"//tensorflow/contrib/tensor_forest:all_files",
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
"//tensorflow/contrib/tensorboard:all_files",
"//tensorflow/contrib/testing:all_files",
"//tensorflow/contrib/text:all_files",
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
"//tensorflow/contrib/training:all_files",
"//tensorflow/contrib/util:all_files",
"//tensorflow/contrib/verbs:all_files",
"//tensorflow/contrib/xla_tf_graph:all_files",
"//tensorflow/core:all_files",
"//tensorflow/core/debug:all_files",
"//tensorflow/core/distributed_runtime:all_files",
"//tensorflow/core/distributed_runtime/rpc:all_files",
"//tensorflow/core/grappler:all_files",
"//tensorflow/core/grappler/clusters:all_files",
"//tensorflow/core/grappler/costs:all_files",
"//tensorflow/core/grappler/inputs:all_files",
"//tensorflow/core/grappler/optimizers:all_files",
"//tensorflow/core/grappler/utils:all_files",
"//tensorflow/core/kernels:all_files",
"//tensorflow/core/kernels/cloud:all_files",
"//tensorflow/core/kernels/hexagon:all_files",
"//tensorflow/core/kernels/neon:all_files",
"//tensorflow/core/ops/compat:all_files",
"//tensorflow/core/platform/cloud:all_files",
"//tensorflow/core/platform/default/build_config:all_files",
@ -203,6 +330,7 @@ filegroup(
"//tensorflow/core/util/ctc:all_files",
"//tensorflow/core/util/tensor_bundle:all_files",
"//tensorflow/examples/android:all_files",
"//tensorflow/examples/benchmark:all_files",
"//tensorflow/examples/how_tos/reading_data:all_files",
"//tensorflow/examples/image_retraining:all_files",
"//tensorflow/examples/label_image:all_files",
@ -211,30 +339,87 @@ filegroup(
"//tensorflow/examples/tutorials/estimators:all_files",
"//tensorflow/examples/tutorials/mnist:all_files",
"//tensorflow/examples/tutorials/word2vec:all_files",
"//tensorflow/g3doc/how_tos/adding_an_op:all_files",
"//tensorflow/g3doc/tutorials:all_files",
"//tensorflow/examples/wav_to_spectrogram:all_files",
"//tensorflow/go:all_files",
"//tensorflow/java:all_files",
"//tensorflow/java/src/main/java/org/tensorflow/examples:all_files",
"//tensorflow/java/src/main/native:all_files",
"//tensorflow/python:all_files",
"//tensorflow/python/debug:all_files",
"//tensorflow/python/estimator:all_files",
"//tensorflow/python/feature_column:all_files",
"//tensorflow/python/kernel_tests:all_files",
"//tensorflow/python/kernel_tests/distributions:all_files",
"//tensorflow/python/ops/distributions:all_files",
"//tensorflow/python/saved_model:all_files",
"//tensorflow/python/tools:all_files",
"//tensorflow/tensorboard:all_files",
"//tensorflow/tensorboard/app:all_files",
"//tensorflow/tensorboard/backend:all_files",
"//tensorflow/tensorboard/backend/event_processing:all_files",
"//tensorflow/tensorboard/components:all_files",
"//tensorflow/tensorboard/components/vz_data_summary:all_files",
"//tensorflow/tensorboard/components/tf_audio_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_audio_dashboard/test:all_files",
"//tensorflow/tensorboard/components/tf_backend:all_files",
"//tensorflow/tensorboard/components/tf_backend/test:all_files",
"//tensorflow/tensorboard/components/tf_color_scale:all_files",
"//tensorflow/tensorboard/components/tf_color_scale/test:all_files",
"//tensorflow/tensorboard/components/tf_dashboard_common:all_files",
"//tensorflow/tensorboard/components/tf_dashboard_common/test:all_files",
"//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_globals:all_files",
"//tensorflow/tensorboard/components/tf_graph:all_files",
"//tensorflow/tensorboard/components/tf_graph/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_app:all_files",
"//tensorflow/tensorboard/components/tf_graph_app/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_board:all_files",
"//tensorflow/tensorboard/components/tf_graph_board/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_common:all_files",
"//tensorflow/tensorboard/components/tf_graph_controls:all_files",
"//tensorflow/tensorboard/components/tf_graph_controls/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_graph_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_debugger_data_card:all_files",
"//tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_info:all_files",
"//tensorflow/tensorboard/components/tf_graph_info/demo:all_files",
"//tensorflow/tensorboard/components/tf_graph_loader:all_files",
"//tensorflow/tensorboard/components/tf_graph_loader/demo:all_files",
"//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_image_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_imports:all_files",
"//tensorflow/tensorboard/components/tf_option_selector:all_files",
"//tensorflow/tensorboard/components/tf_profile_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_profile_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_runs_selector:all_files",
"//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_storage:all_files",
"//tensorflow/tensorboard/components/tf_storage/test:all_files",
"//tensorflow/tensorboard/components/tf_tensorboard:all_files",
"//tensorflow/tensorboard/components/tf_text_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_trace_viewer:all_files",
"//tensorflow/tensorboard/components/vz_distribution_chart:all_files",
"//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files",
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
"//tensorflow/tensorboard/components/vz_line_chart/demo:all_files",
"//tensorflow/tensorboard/components/vz_projector:all_files",
"//tensorflow/tensorboard/components/vz_projector/test:all_files",
"//tensorflow/tensorboard/components/vz_sorting:all_files",
"//tensorflow/tensorboard/components/vz_sorting/test:all_files",
"//tensorflow/tensorboard/lib:all_files",
"//tensorflow/tensorboard/lib/python:all_files",
"//tensorflow/tensorboard/demo:all_files",
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
"//tensorflow/tensorboard/plugins:all_files",
"//tensorflow/tensorboard/plugins/audio:all_files",
"//tensorflow/tensorboard/plugins/distributions:all_files",
"//tensorflow/tensorboard/plugins/graphs:all_files",
"//tensorflow/tensorboard/plugins/histograms:all_files",
"//tensorflow/tensorboard/plugins/images:all_files",
"//tensorflow/tensorboard/plugins/projector:all_files",
"//tensorflow/tensorboard/plugins/scalars:all_files",
"//tensorflow/tensorboard/plugins/text:all_files",
"//tensorflow/tensorboard/scripts:all_files",
"//tensorflow/tools/api/golden:all_files",
"//tensorflow/tools/api/lib:all_files",
"//tensorflow/tools/api/tests:all_files",
"//tensorflow/tools/common:all_files",
"//tensorflow/tools/compatibility:all_files",
"//tensorflow/tools/dist_test/server:all_files",
@ -247,6 +432,7 @@ filegroup(
"//tensorflow/tools/test:all_files",
"//tensorflow/tools/tfprof:all_files",
"//tensorflow/tools/tfprof/internal:all_files",
"//tensorflow/tools/tfprof/internal/advisor:all_files",
"//tensorflow/user_ops:all_files",
"//third_party/hadoop:all_files",
"//third_party/sycl:all_files",
@ -269,23 +455,35 @@ filegroup(
),
)
filegroup(
name = "docs_src",
data = glob(["docs_src/**/*.md"]),
)
# -------------------------------------------
# New rules should be added above this target.
# -------------------------------------------
cc_binary(
name = "libtensorflow.so",
linkopts = select({
"//tensorflow:darwin": [
"-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file
"//tensorflow/c:exported_symbols.lds",
],
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//conditions:default": [
"-z defs",
"-s",
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
"//tensorflow/c:version_script.lds",
],
}),
linkshared = 1,
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:tensorflow",
],
)
cc_binary(
name = "libtensorflow_c.so",
linkshared = 1,
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:exported_symbols.lds",
"//tensorflow/c:version_script.lds",
"//tensorflow/core:tensorflow",
],
)
@ -296,6 +494,8 @@ cc_binary(
deps = [
"//tensorflow/c:c_api",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/cc:scope",
"//tensorflow/core:tensorflow",
],
)

View File

@ -24,16 +24,19 @@ from __future__ import print_function
from tensorflow.python import *
# pylint: enable=wildcard-import
# Lazily import the `tf.contrib` module. This avoids loading all of the
# dependencies of `tf.contrib` at `import tensorflow` time.
class _LazyContribLoader(object):
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
def __getattr__(self, item):
global contrib
# Replace the lazy loader with the imported module itself.
import importlib # pylint: disable=g-import-not-at-top
contrib = importlib.import_module('tensorflow.contrib')
return getattr(contrib, item)
del absolute_import
del division
del print_function
contrib = _LazyContribLoader()
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
# must come from this module. So python adds these symbols for the
# resolution to succeed.
# pylint: disable=undefined-variable
del python
del core
# pylint: enable=undefined-variable

View File

@ -26,6 +26,22 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api.h"],
hdrs = ["c_api_internal.h"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
}),
)
tf_cuda_library(
name = "c_api",
srcs = ["c_api.cc"],
@ -34,10 +50,16 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
":c_api_internal",
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api_internal",
"//tensorflow/cc/saved_model:loader",
"//tensorflow/cc:gradients",
"//tensorflow/cc:ops",
"//tensorflow/cc:grad_ops",
"//tensorflow/cc:scope_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -45,6 +67,14 @@ tf_cuda_library(
}),
)
exports_files(
[
"version_script.lds",
"exported_symbols.lds",
],
visibility = ["//visibility:public"],
)
tf_cuda_library(
name = "tf_status_helper",
srcs = ["tf_status_helper.cc"],
@ -89,20 +119,22 @@ tf_cc_test(
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:grad_ops",
"//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/cc/saved_model:tag_constants",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:math",
"//third_party/eigen3",
],
)

View File

@ -21,8 +21,12 @@ limitations under the License.
#include <vector>
#ifndef __ANDROID__
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/cc/saved_model/loader.h"
#endif
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -52,12 +56,11 @@ limitations under the License.
using tensorflow::error::Code;
using tensorflow::errors::InvalidArgument;
using tensorflow::gtl::ArraySlice;
using tensorflow::strings::StrCat;
using tensorflow::AllocationDescription;
using tensorflow::DataType;
using tensorflow::Env;
using tensorflow::Graph;
using tensorflow::GraphDef;
using tensorflow::mutex;
using tensorflow::mutex_lock;
using tensorflow::NameRangeMap;
using tensorflow::NameRangesForNode;
@ -68,11 +71,9 @@ using tensorflow::NodeBuilder;
using tensorflow::OpDef;
using tensorflow::OpRegistry;
using tensorflow::PartialTensorShape;
using tensorflow::Reset;
using tensorflow::RunMetadata;
using tensorflow::RunOptions;
using tensorflow::Session;
using tensorflow::SessionOptions;
using tensorflow::Status;
using tensorflow::Tensor;
using tensorflow::TensorBuffer;
@ -92,9 +93,6 @@ size_t TF_DataTypeSize(TF_DataType dt) {
}
// --------------------------------------------------------------------------
struct TF_Status {
Status status;
};
TF_Status* TF_NewStatus() { return new TF_Status; }
@ -134,6 +132,9 @@ class TF_ManagedBuffer : public TensorBuffer {
proto->set_requested_bytes(rb);
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
}
// Prevents input forwarding from mutating this buffer.
bool OwnsMemory() const override { return false; }
};
void* allocate_tensor(const char* operation, size_t len) {
@ -175,12 +176,6 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
} // namespace
struct TF_Tensor {
TF_DataType dtype;
TensorShape shape;
TensorBuffer* buffer;
};
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
int num_dims, size_t len) {
void* data = allocate_tensor("TF_AllocateTensor", len);
@ -216,6 +211,18 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
return new TF_Tensor{dtype, TensorShape(dimvec), buf};
}
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
if (tensor->buffer->RefCountIsOne() &&
tensor->buffer->root_buffer()->RefCountIsOne() &&
tensor->buffer->OwnsMemory()) {
return tensor;
}
return nullptr;
}
void TF_DeleteTensor(TF_Tensor* t) {
t->buffer->Unref();
delete t;
@ -273,9 +280,6 @@ size_t TF_StringEncodedSize(size_t len) {
}
// --------------------------------------------------------------------------
struct TF_SessionOptions {
SessionOptions options;
};
TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
@ -316,9 +320,6 @@ void TF_DeleteBuffer(TF_Buffer* buffer) {
TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
// --------------------------------------------------------------------------
struct TF_DeprecatedSession {
Session* session;
};
TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
TF_Status* status) {
@ -328,7 +329,7 @@ TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
return new TF_DeprecatedSession({session});
} else {
DCHECK_EQ(nullptr, session);
return NULL;
return nullptr;
}
}
@ -502,7 +503,7 @@ static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
TF_Status* status) {
status->status = Status::OK();
for (int i = 0; i < noutputs; ++i) {
c_outputs[i] = NULL;
c_outputs[i] = nullptr;
}
}
@ -542,9 +543,8 @@ static void TF_Run_Helper(
if (handle == nullptr) {
RunOptions run_options_proto;
if (run_options != nullptr &&
!run_options_proto.ParseFromArray(run_options->data,
run_options->length)) {
if (run_options != nullptr && !run_options_proto.ParseFromArray(
run_options->data, run_options->length)) {
status->status = InvalidArgument("Unparseable RunOptions proto");
return;
}
@ -651,6 +651,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
*handle = buf;
} else {
*handle = nullptr;
status->status = result;
}
}
@ -682,11 +683,6 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle,
c_outputs, target_oper_names, nullptr, status);
}
struct TF_Library {
void* lib_handle;
TF_Buffer op_list;
};
TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadLibrary(
@ -714,68 +710,58 @@ TF_Buffer* TF_GetAllOpList() {
*(op_list.add_op()) = op;
}
TF_Buffer* ret = TF_NewBuffer();
MessageToBuffer(op_list, ret);
TF_CHECK_OK(MessageToBuffer(op_list, ret));
return ret;
}
// --------------------------------------------------------------------------
// ListDevices & SessionListDevices API
void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; }
TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList;
status->status = session->session->ListDevices(&response->response);
return response;
}
TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList;
status->status = session->session->ListDevices(&response->response);
return response;
}
int TF_DeviceListCount(const TF_DeviceList* list) {
return list->response.size();
}
#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
return_type method_name(const TF_DeviceList* list, const int index, \
TF_Status* status) { \
if (list == nullptr) { \
status->status = InvalidArgument("list is null!"); \
return err_val; \
} \
if (index < 0 || index >= list->response.size()) { \
status->status = InvalidArgument("index out of bounds"); \
return err_val; \
} \
return list->response[index].accessor; \
}
TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
nullptr);
TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
#undef TF_DEVICELIST_METHOD
} // end extern "C"
// --------------------------------------------------------------------------
// New Graph and Session API
// Structures -----------------------------------------------------------------
extern "C" {
struct TF_Graph {
TF_Graph()
: graph(OpRegistry::Global()),
refiner(graph.op_registry()),
num_sessions(0),
delete_requested(false) {}
mutex mu;
Graph graph GUARDED_BY(mu);
// Runs shape inference.
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
// Maps from name of an operation to the Node* in 'graph'.
std::unordered_map<tensorflow::string, Node*> name_map GUARDED_BY(mu);
// TF_Graph may only / must be deleted when
// num_sessions == 0 && delete_requested == true
// num_sessions incremented by TF_NewSession, and decremented by
// TF_DeleteSession.
int num_sessions GUARDED_BY(mu);
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
};
struct TF_OperationDescription {
TF_OperationDescription(TF_Graph* g, const char* op_type,
const char* node_name)
: node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
NodeBuilder node_builder;
TF_Graph* graph;
std::vector<tensorflow::string> colocation_constraints;
};
struct TF_Operation {
Node node;
};
struct TF_Session {
TF_Session(Session* s, TF_Graph* g)
: session(s), graph(g), last_num_graph_nodes(0) {}
Session* session;
TF_Graph* graph;
mutex mu;
int last_num_graph_nodes;
};
} // end extern "C"
// Helper functions -----------------------------------------------------------
namespace {
@ -785,15 +771,13 @@ TF_Operation* ToOperation(Node* node) {
}
tensorflow::string OutputName(const TF_Output& output) {
return tensorflow::strings::StrCat(output.oper->node.name(), ":",
output.index);
return StrCat(output.oper->node.name(), ":", output.index);
}
const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
const char* attr_name,
TF_Status* status) {
const tensorflow::AttrValue* attr =
tensorflow::AttrSlice(oper->node.def()).Find(attr_name);
const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
if (attr == nullptr) {
status->status =
InvalidArgument("Operation has no attr named '", attr_name, "'.");
@ -821,6 +805,7 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
}
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
dim_vec.reserve(num_dims);
for (int i = 0; i < num_dims; ++i) {
dim_vec.push_back(ic->MakeDim(dims[i]));
}
@ -899,10 +884,17 @@ void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
extern "C" {
static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
const char* op_type,
const char* oper_name)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
return new TF_OperationDescription(graph, op_type, oper_name);
}
TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
const char* oper_name) {
mutex_lock l(graph->mu);
return new TF_OperationDescription(graph, op_type, oper_name);
return TF_NewOperationLocked(graph, op_type, oper_name);
}
void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
@ -928,8 +920,8 @@ void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
}
void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
desc->colocation_constraints.emplace_back(tensorflow::strings::StrCat(
tensorflow::kColocationGroupPrefix, op->node.name()));
desc->colocation_constraints.emplace_back(
StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
}
void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
@ -1131,10 +1123,10 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
}
}
TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
TF_Status* status) {
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
Node* ret = nullptr;
mutex_lock l(desc->graph->mu);
if (desc->graph->name_map.count(desc->node_builder.node_name())) {
status->status = InvalidArgument("Duplicate node name in graph: '",
@ -1148,14 +1140,14 @@ TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
if (status->status.ok()) {
// Run shape inference function for newly added node.
//
// TODO(b/28152992): Enable returning the result of this
// code-path once we have converted all python shape functions
// to call their C++ versions.
desc->graph->refiner.AddNode(ret);
status->status = desc->graph->refiner.AddNode(ret);
}
if (status->status.ok()) {
// Add the node to the name-to-node mapping.
desc->graph->name_map[ret->name()] = ret;
} else if (ret != nullptr) {
desc->graph->graph.RemoveNode(ret);
ret = nullptr;
}
}
@ -1164,6 +1156,12 @@ TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
return ToOperation(ret);
}
TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
TF_Status* status) {
mutex_lock l(desc->graph->mu);
return TF_FinishOperationLocked(desc, status);
}
// TF_Operation functions
// ----------------------------------------------------------
@ -1176,7 +1174,7 @@ const char* TF_OperationOpType(TF_Operation* oper) {
}
const char* TF_OperationDevice(TF_Operation* oper) {
return oper->node.def().device().c_str();
return oper->node.requested_device().c_str();
}
int TF_OperationNumOutputs(TF_Operation* oper) {
@ -1191,8 +1189,8 @@ TF_DataType TF_OperationOutputType(TF_Output oper_out) {
int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
TF_Status* status) {
NameRangeMap name_ranges;
status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(),
nullptr, &name_ranges);
status->status =
NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
if (!status->status.ok()) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
@ -1213,8 +1211,8 @@ TF_DataType TF_OperationInputType(TF_Input oper_in) {
int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
TF_Status* status) {
NameRangeMap name_ranges;
status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(),
&name_ranges, nullptr);
status->status =
NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
if (!status->status.ok()) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
@ -1452,26 +1450,27 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
}
}
#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
void func(TF_Operation* oper, const char* attr_name, c_type* value, \
TF_Status* status) { \
cpp_type v; \
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &v); \
*value = static_cast<c_type>(v); \
} \
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
int max_values, TF_Status* status) { \
const auto* attr = GetAttrValue(oper, attr_name, status); \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
return; \
} \
const auto len = std::min(max_values, attr->list().list_field##_size()); \
for (int i = 0; i < len; ++i) { \
values[i] = static_cast<c_type>(attr->list().list_field(i)); \
} \
#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
void func(TF_Operation* oper, const char* attr_name, c_type* value, \
TF_Status* status) { \
cpp_type v; \
status->status = \
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
*value = static_cast<c_type>(v); \
} \
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
int max_values, TF_Status* status) { \
const auto* attr = GetAttrValue(oper, attr_name, status); \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
return; \
} \
const auto len = std::min(max_values, attr->list().list_field##_size()); \
for (int i = 0; i < len; ++i) { \
values[i] = static_cast<c_type>(attr->list().list_field(i)); \
} \
}
DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
@ -1482,7 +1481,8 @@ DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
int64_t* value, int num_dims, TF_Status* status) {
PartialTensorShape shape;
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shape);
status->status =
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
if (!status->status.ok()) return;
auto len = std::min(shape.dims(), num_dims);
for (int i = 0; i < len; ++i) {
@ -1496,7 +1496,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
int storage_size, TF_Status* status) {
std::vector<PartialTensorShape> shapes;
status->status =
tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shapes);
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
if (!status->status.ok()) return;
auto len = std::min(static_cast<int>(shapes.size()), max_values);
int64_t* p = storage;
@ -1563,7 +1563,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
TF_Tensor** value, TF_Status* status) {
*value = nullptr;
Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &t);
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
if (!status->status.ok()) return;
*value = new TF_Tensor{static_cast<TF_DataType>(t.dtype()), t.shape(),
tensorflow::TensorCApi::Buffer(t)};
@ -1574,7 +1574,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
TF_Tensor** values, int max_values,
TF_Status* status) {
std::vector<Tensor> ts;
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &ts);
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) {
@ -1653,10 +1653,6 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
status->status = MessageToBuffer(def, output_graph_def);
}
struct TF_ImportGraphDefOptions {
tensorflow::ImportGraphDefOptions opts;
};
TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
return new TF_ImportGraphDefOptions;
}
@ -1682,6 +1678,12 @@ void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
opts->opts.input_map[TensorId(src_name, src_index)] = ToTensorId(dst);
}
void TF_ImportGraphDefOptionsRemapControlDependency(
TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
}
extern void TF_ImportGraphDefOptionsAddControlDependency(
TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
opts->opts.control_dependencies.push_back(oper->node.name());
@ -1750,6 +1752,398 @@ void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
status);
}
// While loop functions -------------------------------------------------------
namespace {
bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
TF_Output* input, TF_Status* status) {
TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
// TODO(skyewm): set placeholder shape
TF_Operation* oper = TF_FinishOperation(desc, status);
if (!status->status.ok()) return false;
*input = {oper, 0};
return true;
}
bool CreateEnter(TF_Graph* g, const char* node_name, const char* frame_name,
const TF_Output& input, TF_Output* enter, TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(g->mu) {
TF_OperationDescription* desc = TF_NewOperationLocked(g, "Enter", node_name);
TF_AddInput(desc, input);
TF_SetAttrString(desc, "frame_name", frame_name, strlen(frame_name));
TF_Operation* oper = TF_FinishOperationLocked(desc, status);
if (!status->status.ok()) return false;
*enter = {oper, 0};
return true;
}
bool CreateMerge(TF_Graph* g, const char* name, const TF_Output& input,
const char* backedge_name, int backedge_index,
TF_Output* merge, TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(g->mu) {
TF_OperationDescription* desc = TF_NewOperationLocked(g, "Merge", name);
// The merge nodes accept the while loop's back edges as an input. Use the
// underlying NodeBuilder API directly to create an input to the
// not-yet-created back edge.
std::vector<NodeBuilder::NodeOut> input_list;
input_list.push_back(NodeBuilder::NodeOut(&input.oper->node, input.index));
// All merge inputs must have same type
DataType type = input.oper->node.output_type(input.index);
input_list.push_back(
NodeBuilder::NodeOut(backedge_name, backedge_index, type));
desc->node_builder.Input(input_list);
TF_Operation* oper = TF_FinishOperationLocked(desc, status);
if (!status->status.ok()) return false;
*merge = {oper, 0};
return true;
}
bool CreateSwitch(TF_Graph* g, const char* name, const TF_Output& input,
const TF_Output& predicate, TF_Output* switch_true,
TF_Output* switch_false, TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(g->mu) {
TF_OperationDescription* desc = TF_NewOperationLocked(g, "Switch", name);
TF_AddInput(desc, input);
TF_AddInput(desc, predicate);
TF_Operation* oper = TF_FinishOperationLocked(desc, status);
if (!status->status.ok()) return false;
*switch_false = {oper, 0};
*switch_true = {oper, 1};
return true;
}
bool CreateNext(TF_Graph* g, const char* name, const TF_Output& input,
TF_Output* next, TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(g->mu) {
TF_OperationDescription* desc =
TF_NewOperationLocked(g, "NextIteration", name);
TF_AddInput(desc, input);
TF_Operation* oper = TF_FinishOperationLocked(desc, status);
if (!status->status.ok()) return false;
*next = {oper, 0};
return true;
}
bool CreateExit(TF_Graph* g, const char* name, const TF_Output& input,
TF_Output* exit, TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(g->mu) {
TF_OperationDescription* desc = TF_NewOperationLocked(g, "Exit", name);
TF_AddInput(desc, input);
TF_Operation* oper = TF_FinishOperationLocked(desc, status);
if (!status->status.ok()) return false;
*exit = {oper, 0};
return true;
}
class ScopedImportGraphDefOptions {
public:
ScopedImportGraphDefOptions() { opts_ = TF_NewImportGraphDefOptions(); }
~ScopedImportGraphDefOptions() { TF_DeleteImportGraphDefOptions(opts_); }
TF_ImportGraphDefOptions* get() const { return opts_; }
private:
TF_ImportGraphDefOptions* opts_;
TF_DISALLOW_COPY_AND_ASSIGN(ScopedImportGraphDefOptions);
};
// Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`.
// `prefix` will be prepended to copied node names. `return_nodes` are nodes
// in `src_graph`, and the new corresponding nodes in `dst_graph` will be
// returned. `return_nodes` should be preallocated to size `nreturn_nodes`.
bool CopyGraph(TF_Graph* src_graph, TF_Graph* dst_graph,
const TF_Output* src_inputs,
const std::vector<TF_Output>& dst_inputs, const char* prefix,
const TF_Output* nodes_to_return, int nreturn_nodes,
TF_Output* return_nodes, TF_Status* s)
EXCLUSIVE_LOCKS_REQUIRED(dst_graph->mu) {
GraphDef gdef;
src_graph->graph.ToGraphDef(&gdef);
ScopedImportGraphDefOptions opts;
TF_ImportGraphDefOptionsSetPrefix(opts.get(), prefix);
for (int i = 0; i < dst_inputs.size(); ++i) {
TensorId src = ToTensorId(src_inputs[i]);
TF_ImportGraphDefOptionsAddInputMapping(opts.get(), src.first.data(),
src.second, dst_inputs[i]);
}
// We use the pivot node to control constants in `src_graph`
TF_Operation* pivot = dst_inputs[0].oper;
TF_ImportGraphDefOptionsAddControlDependency(opts.get(), pivot);
for (int i = 0; i < nreturn_nodes; ++i) {
TF_ImportGraphDefOptionsAddReturnOutput(
opts.get(), nodes_to_return[i].oper->node.name().c_str(),
nodes_to_return[i].index);
}
GraphImportGraphDefLocked(dst_graph, gdef, opts.get(), return_nodes,
nreturn_nodes, s);
if (TF_GetCode(s) != TF_OK) return false;
return true;
}
bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
if (params.cond_graph == nullptr || params.body_graph == nullptr ||
params.cond_graph->parent == nullptr ||
params.cond_graph->parent != params.body_graph->parent ||
params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
params.ninputs <= 0 || params.cond_inputs == nullptr ||
params.body_inputs == nullptr || params.body_outputs == nullptr) {
s->status = InvalidArgument(
"TF_WhileParams must be created by successful TF_NewWhile() call");
return false;
}
return true;
}
bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
if (params.cond_output.oper == nullptr) {
s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
return false;
}
for (int i = 0; i < params.ninputs; ++i) {
if (params.body_outputs[i].oper == nullptr) {
s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
"field isn't set");
return false;
}
}
if (params.name == nullptr) {
s->status = InvalidArgument("TF_WhileParams `name` field is null");
return false;
}
return true;
}
void FreeWhileResources(const TF_WhileParams* params) {
TF_DeleteGraph(params->cond_graph);
TF_DeleteGraph(params->body_graph);
delete[] params->cond_inputs;
delete[] params->body_inputs;
delete[] params->body_outputs;
}
TF_WhileParams EmptyWhileParams() {
return {0, nullptr, nullptr, {nullptr, 0},
nullptr, nullptr, nullptr, nullptr};
}
} // namespace
TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
TF_Status* status) {
if (ninputs == 0) {
status->status =
InvalidArgument("TF_NewWhile() must be passed at least one input");
return EmptyWhileParams();
}
TF_Graph* cond_graph = TF_NewGraph();
TF_Graph* body_graph = TF_NewGraph();
cond_graph->parent = g;
cond_graph->parent_inputs = inputs;
body_graph->parent = g;
body_graph->parent_inputs = inputs;
TF_Output* cond_inputs = new TF_Output[ninputs];
TF_Output cond_output = {nullptr, -1};
TF_Output* body_inputs = new TF_Output[ninputs];
TF_Output* body_outputs = new TF_Output[ninputs];
for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
const char* name = nullptr;
for (int i = 0; i < ninputs; ++i) {
// TODO(skyewm): prefix names with underscore (requires some plumbing)
if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
&cond_inputs[i], status)) {
break;
}
if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
&body_inputs[i], status)) {
break;
}
}
TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
body_graph, body_inputs, body_outputs, name};
if (!status->status.ok()) {
FreeWhileResources(&params);
return EmptyWhileParams();
}
return params;
}
namespace {
// TODO(skyewm): make nodes in while loop unfetchable like in Python version
void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
TF_Output* outputs) {
if (!ValidateInputWhileParams(*params, status)) return;
TF_Graph* parent = params->cond_graph->parent;
TF_Output* parent_inputs = params->cond_graph->parent_inputs;
int n = params->ninputs;
mutex_lock l(parent->mu);
// Create Enter nodes
std::vector<TF_Output> enter_nodes(n);
for (int i = 0; i < n; ++i) {
if (!CreateEnter(parent, StrCat(params->name, "/enter", i).c_str(),
params->name, parent_inputs[i], &enter_nodes[i], status)) {
return;
}
}
// Create Merge nodes
std::vector<TF_Output> merge_nodes(n);
for (int i = 0; i < n; ++i) {
if (!CreateMerge(parent, StrCat(params->name, "/merge", i).c_str(),
enter_nodes[i], StrCat(params->name, "/next", i).c_str(),
0, &merge_nodes[i], status)) {
return;
}
}
// Copy cond_graph to parent and replace input placeholders with merge node
// outputs, and get handle to new cond output
tensorflow::string cond_prefix = StrCat(params->name, "/cond");
TF_Output cond_output;
if (!CopyGraph(params->cond_graph, parent, params->cond_inputs, merge_nodes,
cond_prefix.c_str(), &params->cond_output, 1, &cond_output,
status)) {
return;
}
// Create Switch nodes
std::vector<TF_Output> switch_trues(n);
std::vector<TF_Output> switch_falses(n);
for (int i = 0; i < n; ++i) {
if (!CreateSwitch(parent, StrCat(params->name, "/switch", i).c_str(),
merge_nodes[i], cond_output, &switch_trues[i],
&switch_falses[i], status)) {
return;
}
}
// Copy body_graph to parent, replace input placeholders with switch node
// true outputs, and get handles to new body outputs
tensorflow::string body_prefix = StrCat(params->name, "/body");
std::vector<TF_Output> body_outputs(n);
if (!CopyGraph(params->body_graph, parent, params->body_inputs, switch_trues,
body_prefix.c_str(), params->body_outputs, n,
body_outputs.data(), status)) {
return;
}
// Create Next nodes
std::vector<TF_Output> next_nodes(n);
for (int i = 0; i < n; ++i) {
if (!CreateNext(parent, StrCat(params->name, "/next", i).c_str(),
body_outputs[i], &next_nodes[i], status)) {
return;
}
}
// Create Exit nodes (which are the outputs of the while loop)
for (int i = 0; i < n; ++i) {
if (!CreateExit(parent, StrCat(params->name, "/exit", i).c_str(),
switch_falses[i], &outputs[i], status)) {
return;
}
}
}
} // namespace
void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
TF_Output* outputs) {
// If it appears the caller created or modified `params`, don't free resources
if (!ValidateConstWhileParams(*params, status)) return;
TF_FinishWhileHelper(params, status, outputs);
FreeWhileResources(params);
}
void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
#ifndef __ANDROID__
namespace {
void OutputsFromTFOutputs(TF_Output* tf_outputs, int n, TF_Status* status,
std::vector<tensorflow::Output>* outputs) {
outputs->resize(n);
for (int i = 0; i < n; i++) {
const TF_Output& tf_output = tf_outputs[i];
(*outputs)[i] = tensorflow::Output(&tf_output.oper->node, tf_output.index);
}
}
void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
TF_Output* tf_outputs) {
for (int i = 0; i < outputs.size(); i++) {
tf_outputs[i].oper = ToOperation(outputs[i].node());
tf_outputs[i].index = outputs[i].index();
}
}
} // namespace
#endif // __ANDROID__
void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
TF_Output* dx, TF_Status* status, TF_Output* dy) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Adding gradients is not supported in Android. File a bug at "
"https://github.com/tensorflow/tensorflow/issues if this feature is "
"important to you");
#else
std::vector<tensorflow::Output> y_arg;
std::vector<tensorflow::Output> x_arg;
std::vector<tensorflow::Output> dy_arg;
OutputsFromTFOutputs(y, ny, status, &y_arg);
OutputsFromTFOutputs(x, nx, status, &x_arg);
{
// We need to hold on to the lock while we have a scope that uses TF_Graph.
mutex_lock graph_lock(g->mu);
const int max_node_id_before = g->graph.num_node_ids();
tensorflow::Scope scope =
NewInternalScope(&g->graph, &status->status, &g->refiner);
if (dx != nullptr) {
std::vector<tensorflow::Output> dx_arg;
OutputsFromTFOutputs(dx, ny, status, &dx_arg);
status->status =
AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
} else {
status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
}
// Update g->name_map with the name_map from the scope, which will contain
// the new gradient ops.
for (int i = max_node_id_before; i < g->graph.num_node_ids(); ++i) {
Node* n = g->graph.FindNodeId(i);
if (n == nullptr) continue;
g->name_map[n->name()] = n;
}
}
// Unpack the results from grad_outputs_arg.
TFOutputsFromOutputs(dy_arg, dy);
#endif // __ANDROID__
}
// TF_Session functions ----------------------------------------------
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
@ -1764,15 +2158,23 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
return new TF_Session(session, graph);
} else {
DCHECK_EQ(nullptr, session);
return NULL;
return nullptr;
}
}
#ifndef __ANDROID__
TF_Session* TF_LoadSessionFromSavedModel(
const TF_SessionOptions* session_options, const TF_Buffer* run_options,
const char* export_dir, const char* const* tags, int tags_len,
TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
// TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that
// the tensorflow/cc/saved_model:loader build target is Android friendly.
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Loading a SavedModel is not supported in Android. File a bug at "
"https://github.com/tensorflow/tensorflow/issues if this feature is "
"important to you");
return nullptr;
#else
mutex_lock l(graph->mu);
if (!graph->name_map.empty()) {
@ -1781,9 +2183,8 @@ TF_Session* TF_LoadSessionFromSavedModel(
}
RunOptions run_options_proto;
if (run_options != nullptr &&
!run_options_proto.ParseFromArray(run_options->data,
run_options->length)) {
if (run_options != nullptr && !run_options_proto.ParseFromArray(
run_options->data, run_options->length)) {
status->status = InvalidArgument("Unparseable RunOptions proto");
return nullptr;
}
@ -1821,8 +2222,8 @@ TF_Session* TF_LoadSessionFromSavedModel(
graph->num_sessions += 1;
session->last_num_graph_nodes = graph->graph.num_node_ids();
return session;
}
#endif // __ANDROID__
}
void TF_CloseSession(TF_Session* s, TF_Status* status) {
status->status = s->session->Close();
@ -1853,7 +2254,7 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
const auto num_nodes = graph.num_node_ids();
if (session->last_num_graph_nodes < num_nodes) {
GraphDef graph_def;
graph_def.mutable_versions()->CopyFrom(graph.versions());
*graph_def.mutable_versions() = graph.versions();
// Fill graph_def with nodes with ids in the range
// [session->last_num_graph_nodes, num_nodes), that is the nodes
// added since the last TF_SessionRun() call.
@ -1954,6 +2355,11 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
}
}
void TF_DeletePRunHandle(const char* handle) {
delete[] handle;
// TODO(suharshs): Free up any resources held by the partial run state.
}
void TF_SessionPRun(TF_Session* session, const char* handle,
const TF_Output* inputs, TF_Tensor* const* input_values,
int ninputs, const TF_Output* outputs,

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,119 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api.h"
#include <vector>
#include <unordered_map>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
struct TF_Status {
tensorflow::Status status;
};
struct TF_Tensor {
TF_DataType dtype;
tensorflow::TensorShape shape;
tensorflow::TensorBuffer* buffer;
};
struct TF_SessionOptions {
tensorflow::SessionOptions options;
};
struct TF_DeprecatedSession {
tensorflow::Session* session;
};
struct TF_Library {
void* lib_handle;
TF_Buffer op_list;
};
struct TF_Graph {
TF_Graph()
: graph(tensorflow::OpRegistry::Global()),
refiner(graph.versions().producer(), graph.op_registry()),
num_sessions(0),
delete_requested(false),
parent(nullptr),
parent_inputs(nullptr) {}
tensorflow::mutex mu;
tensorflow::Graph graph GUARDED_BY(mu);
// Runs shape inference.
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
// Maps from name of an operation to the Node* in 'graph'.
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
GUARDED_BY(mu);
// TF_Graph may only / must be deleted when
// num_sessions == 0 && delete_requested == true
// num_sessions incremented by TF_NewSession, and decremented by
// TF_DeleteSession.
int num_sessions GUARDED_BY(mu);
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
// Used to link graphs contained in TF_WhileParams to the parent graph that
// will eventually contain the full while loop.
TF_Graph* parent;
TF_Output* parent_inputs;
};
struct TF_OperationDescription {
TF_OperationDescription(TF_Graph* g, const char* op_type,
const char* node_name)
: node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
tensorflow::NodeBuilder node_builder;
TF_Graph* graph;
std::vector<tensorflow::string> colocation_constraints;
};
struct TF_Operation {
tensorflow::Node node;
};
struct TF_Session {
TF_Session(tensorflow::Session* s, TF_Graph* g)
: session(s), graph(g), last_num_graph_nodes(0) {}
tensorflow::Session* session;
TF_Graph* graph;
tensorflow::mutex mu;
int last_num_graph_nodes;
};
struct TF_ImportGraphDefOptions {
tensorflow::ImportGraphDefOptions opts;
};
struct TF_DeviceList {
std::vector<tensorflow::DeviceAttributes> response;
};

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/node_def.pb_text.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -38,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/util/equal_graph_def.h"
using tensorflow::int32;
using tensorflow::string;
@ -105,6 +107,22 @@ TEST(CAPI, AllocateTensor) {
TF_DeleteTensor(t);
}
TEST(CAPI, MaybeMove) {
const int num_bytes = 6 * sizeof(float);
float* values =
reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
EIGEN_MAX_ALIGN_BYTES, num_bytes));
int64_t dims[] = {2, 3};
bool deallocator_called = false;
TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
&Deallocator, &deallocator_called);
TF_Tensor* o = TF_TensorMaybeMove(t);
ASSERT_TRUE(o == nullptr); // It is unsafe to move memory TF might not own.
TF_DeleteTensor(t);
EXPECT_TRUE(deallocator_called);
}
TEST(CAPI, LibraryLoadFunctions) {
// Load the library.
TF_Status* status = TF_NewStatus();
@ -261,6 +279,19 @@ static void Int32Deallocator(void* data, size_t, void* arg) {
delete[] static_cast<int32*>(data);
}
// Create a tensor with values of type TF_INT8 provided by `values`.
static TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims,
const char* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
num_values *= dims[i];
}
TF_Tensor* t =
TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values);
memcpy(TF_TensorData(t), values, sizeof(char) * num_values);
return t;
}
static TF_Tensor* Int32Tensor(int32 v) {
const int num_bytes = sizeof(int32);
int32* values = new int32[1];
@ -269,29 +300,44 @@ static TF_Tensor* Int32Tensor(int32 v) {
&Int32Deallocator, nullptr);
}
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", "feed");
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
const char* name = "feed") {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
return TF_FinishOperation(desc, s);
}
TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s) {
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", "scalar");
TF_SetAttrTensor(desc, "value", tensor.get(), s);
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name = "const") {
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", t, s);
if (TF_GetCode(s) != TF_OK) return nullptr;
TF_SetAttrType(desc, "dtype", TF_INT32);
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
return TF_FinishOperation(desc, s);
}
TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar") {
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
return Const(tensor.get(), graph, s, name);
}
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", "add");
TF_Status* s, const char* name = "add") {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
return TF_FinishOperation(desc, s);
}
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
const char* name = "add") {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output inputs[2] = {l, r};
TF_AddInputList(desc, inputs, 2);
return TF_FinishOperation(desc, s);
}
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
TF_Output neg_input = {n, 0};
@ -299,6 +345,14 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
return TF_FinishOperation(desc, s);
}
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
TF_Status* s) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than");
TF_AddInput(desc, l);
TF_AddInput(desc, r);
return TF_FinishOperation(desc, s);
}
bool IsPlaceholder(const NodeDef& node_def) {
if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
return false;
@ -667,6 +721,28 @@ TEST(CAPI, Graph) {
TF_DeleteStatus(s);
}
/*
TODO(skyewm): this test currently DCHECKs, change to bad status
TEST(CAPI, InputFromDifferentGraphError) {
TF_Status* s = TF_NewStatus();
TF_Graph* g1 = TF_NewGraph();
TF_Graph* g2 = TF_NewGraph();
TF_Operation* feed = Placeholder(g1, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Attempt to create node in g2 with input from g1
Neg(feed, g2, s);
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
EXPECT_STREQ("foo", TF_Message(s));
TF_DeleteGraph(g1);
TF_DeleteGraph(g2);
TF_DeleteStatus(s);
}
*/
TEST(CAPI, ImportGraphDef) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
@ -765,6 +841,33 @@ TEST(CAPI, ImportGraphDef) {
EXPECT_EQ(feed, control_inputs[0]);
EXPECT_EQ(feed2, control_inputs[1]);
// Export to a graph def so we can import a graph with control dependencies
TF_DeleteBuffer(graph_def);
graph_def = TF_NewBuffer();
TF_GraphToGraphDef(graph, graph_def, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Import again, with remapped control dependency, into the same graph
TF_DeleteImportGraphDefOptions(opts);
opts = TF_NewImportGraphDefOptions();
TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
TF_GraphImportGraphDef(graph, graph_def, opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* scalar4 =
TF_GraphOperationByName(graph, "imported4/imported3/scalar");
TF_Operation* feed4 =
TF_GraphOperationByName(graph, "imported4/imported2/feed");
// Check that imported `imported3/scalar` has remapped control dep from
// original graph and imported control dep
num_control_inputs = TF_OperationGetControlInputs(
scalar4, control_inputs, TF_OperationNumControlInputs(scalar4));
ASSERT_EQ(2, num_control_inputs);
EXPECT_EQ(feed, control_inputs[0]);
EXPECT_EQ(feed4, control_inputs[1]);
TF_DeleteImportGraphDefOptions(opts);
TF_DeleteBuffer(graph_def);
@ -784,7 +887,7 @@ class CSession {
TF_DeleteSessionOptions(opts);
}
CSession(TF_Session* session) { session_ = session; }
explicit CSession(TF_Session* session) : session_(session) {}
~CSession() {
TF_Status* s = TF_NewStatus();
@ -793,8 +896,7 @@ class CSession {
TF_DeleteStatus(s);
}
void SetInputs(
std::initializer_list<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
void SetInputs(std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
DeleteInputValues();
inputs_.clear();
for (const auto& p : inputs) {
@ -811,6 +913,11 @@ class CSession {
}
}
void SetOutputs(const std::vector<TF_Output>& outputs) {
ResetOutputValues();
outputs_ = outputs;
}
void SetTargets(std::initializer_list<TF_Operation*> targets) {
targets_.clear();
for (TF_Operation* t : targets) {
@ -937,6 +1044,103 @@ TEST(CAPI, Session) {
TF_DeleteStatus(s);
}
TEST(CAPI, SessionPRun) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// Construct the graph: A + 2 + B
TF_Operation* a = Placeholder(graph, s, "A");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* b = Placeholder(graph, s, "B");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* two = ScalarConst(2, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* plus2 = Add(a, two, graph, s, "plus2");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* plusB = Add(plus2, b, graph, s, "plusB");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Setup a session and a partial run handle. The partial run will allow
// computation of A + 2 + B in two phases (calls to TF_SessionPRun):
// 1. Feed A and get (A+2)
// 2. Feed B and get (A+2)+B
TF_SessionOptions* opts = TF_NewSessionOptions();
TF_Session* sess = TF_NewSession(graph, opts, s);
TF_DeleteSessionOptions(opts);
TF_Output feeds[] = {TF_Output{a, 0}, TF_Output{b, 0}};
TF_Output fetches[] = {TF_Output{plus2, 0}, TF_Output{plusB, 0}};
const char* handle = nullptr;
TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
TF_ARRAYSIZE(fetches), nullptr, 0, &handle, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Feed A and fetch A + 2.
TF_Output feeds1[] = {TF_Output{a, 0}};
TF_Output fetches1[] = {TF_Output{plus2, 0}};
TF_Tensor* feedValues1[] = {Int32Tensor(1)};
TF_Tensor* fetchValues1[1];
TF_SessionPRun(sess, handle, feeds1, feedValues1, 1, fetches1, fetchValues1,
1, nullptr, 0, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(3, *(static_cast<int32*>(TF_TensorData(fetchValues1[0]))));
TF_DeleteTensor(feedValues1[0]);
TF_DeleteTensor(fetchValues1[0]);
// Feed B and fetch (A + 2) + B.
TF_Output feeds2[] = {TF_Output{b, 0}};
TF_Output fetches2[] = {TF_Output{plusB, 0}};
TF_Tensor* feedValues2[] = {Int32Tensor(4)};
TF_Tensor* fetchValues2[1];
TF_SessionPRun(sess, handle, feeds2, feedValues2, 1, fetches2, fetchValues2,
1, nullptr, 0, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(7, *(static_cast<int32*>(TF_TensorData(fetchValues2[0]))));
TF_DeleteTensor(feedValues2[0]);
TF_DeleteTensor(fetchValues2[0]);
// Clean up.
TF_DeletePRunHandle(handle);
TF_DeleteSession(sess, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
TEST(CAPI, ShapeInferenceError) {
// TF_FinishOperation should fail if the shape of the added operation cannot
// be inferred.
TF_Status* status = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// Create this failure by trying to add two nodes with incompatible shapes
// (A tensor with shape [2] and a tensor with shape [3] cannot be added).
const char data[] = {1, 2, 3};
const int64_t vec2_dims[] = {2};
unique_tensor_ptr vec2_tensor(
Int8Tensor(vec2_dims, TF_ARRAYSIZE(vec2_dims), data), TF_DeleteTensor);
TF_Operation* vec2 = Const(vec2_tensor.get(), graph, status, "vec2");
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const int64_t vec3_dims[] = {3};
unique_tensor_ptr vec3_tensor(
Int8Tensor(vec3_dims, TF_ARRAYSIZE(vec3_dims), data), TF_DeleteTensor);
TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3");
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Operation* add = Add(vec2, vec3, graph, status);
ASSERT_NE(TF_OK, TF_GetCode(status));
ASSERT_TRUE(add == nullptr);
TF_DeleteGraph(graph);
TF_DeleteStatus(status);
}
TEST(CAPI, ColocateWith) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
@ -1068,16 +1272,582 @@ TEST(CAPI, SavedModelNullArgsAreValid) {
TF_DeleteStatus(s);
}
// Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
num_values *= dims[i];
class CApiWhileLoopTest : public ::testing::Test {
protected:
CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
~CApiWhileLoopTest() override {
TF_DeleteGraph(graph_);
TF_DeleteStatus(s_);
}
TF_Tensor* t =
TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values);
memcpy(TF_TensorData(t), values, sizeof(char) * num_values);
return t;
void Init(int ninputs) {
DCHECK(inputs_.empty());
DCHECK_GT(ninputs, 0);
for (int i = 0; i < ninputs; ++i) {
TF_Operation* placeholder = Placeholder(
graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str());
DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
inputs_.push_back({placeholder, 0});
}
original_graph_description_ = GraphDebugString();
params_.reset(new TF_WhileParams(
TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_)));
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
ASSERT_EQ(original_graph_description_, GraphDebugString())
<< "TF_NewWhile() altered graph";
params_->name = "test_loop";
// Initialize outputs_ so we can easily detect errors/bugs
outputs_.resize(ninputs, {nullptr, -1});
}
void ExpectOK() {
TF_FinishWhile(params_.get(), s_, &outputs_[0]);
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
}
void ExpectError(TF_Code expected_code, const string& expected_msg) {
TF_FinishWhile(params_.get(), s_, &outputs_[0]);
EXPECT_EQ(expected_code, TF_GetCode(s_));
EXPECT_EQ(expected_msg, TF_Message(s_));
// TODO(skyewm): this assert is currently broken. Fix or remove guarantee.
// ASSERT_EQ(original_graph_description_, GraphDebugString()) <<
// "TF_FinishWhile() altered graph on error";
}
void Run(std::initializer_list<int> input_values) {
DCHECK_EQ(inputs_.size(), input_values.size());
std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
int i = 0;
for (int v : input_values) {
inputs[i] = {inputs_[i].oper, Int32Tensor(v)};
++i;
}
csession_.reset(new CSession(graph_, s_));
csession_->SetInputs(inputs);
csession_->SetOutputs(outputs_);
csession_->Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
}
void ExpectOutputValue(int idx, int expected_value) {
TF_Tensor* out = csession_->output_tensor(idx);
ASSERT_TRUE(out != nullptr);
EXPECT_EQ(TF_INT32, TF_TensorType(out));
EXPECT_EQ(0, TF_NumDims(out));
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
int32* data = static_cast<int32*>(TF_TensorData(out));
EXPECT_EQ(expected_value, *data);
}
// Create a valid conditional graph. Useful for testing unrelated errors.
void CreateCondGraph() {
TF_Operation* one = ScalarConst(1, params_->cond_graph, s_);
TF_Operation* less_than =
LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_);
DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
params_->cond_output = {less_than, 0};
}
string GraphDebugString() const {
TF_Buffer* buf = TF_NewBuffer();
TF_GraphToGraphDef(graph_, buf, s_);
DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
GraphDef def;
bool success = def.ParseFromArray(buf->data, buf->length);
DCHECK(success);
TF_DeleteBuffer(buf);
return def.DebugString();
}
TF_Status* s_;
TF_Graph* graph_;
std::vector<TF_Output> inputs_; // The inputs to the while loop
std::vector<TF_Output> outputs_; // The final outputs of the while loop
std::unique_ptr<TF_WhileParams> params_;
std::unique_ptr<CSession> csession_;
private:
// Used to verify that errors don't change graph_
string original_graph_description_;
};
TEST_F(CApiWhileLoopTest, BasicLoop) {
Init(2);
// Validate TF_WhileParams returned by TF_NewWhile()
EXPECT_TRUE(params_->body_graph != nullptr);
EXPECT_TRUE(params_->cond_graph != nullptr);
EXPECT_EQ(params_->ninputs, 2);
ASSERT_TRUE(params_->cond_inputs != nullptr);
ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr);
EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr);
ASSERT_TRUE(params_->body_inputs != nullptr);
EXPECT_TRUE(params_->body_inputs[0].oper != nullptr);
EXPECT_TRUE(params_->body_inputs[1].oper != nullptr);
ASSERT_TRUE(params_->body_outputs != nullptr);
// Create loop: while (input1 < input2) input1 += input2 + 1
TF_Operation* less_than =
LessThan(params_->cond_inputs[0], params_->cond_inputs[1],
params_->cond_graph, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
params_->cond_output = {less_than, 0};
TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1],
params_->body_graph, s_, "add1");
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2");
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
params_->body_outputs[0] = {add2, 0};
params_->body_outputs[1] = params_->body_inputs[1];
// Finalize while loop
ExpectOK();
// Validate while loop outputs returned by TF_FinishWhile()
EXPECT_TRUE(outputs_[0].oper != nullptr);
EXPECT_GE(outputs_[0].index, 0);
EXPECT_TRUE(outputs_[1].oper != nullptr);
EXPECT_GE(outputs_[1].index, 0);
// Run the graph
Run({-9, 2});
ExpectOutputValue(0, 3);
ExpectOutputValue(1, 2);
}
TEST_F(CApiWhileLoopTest, NestedLoop) {
Init(2);
// Create nested loop:
// while (input1 < 6) {
// inner_input1 = input1
// while (inner_input1 < 3) {
// input2 += 1
// inner_input1 += 2
// }
// input1 += input2
// }
//
// Expected execution with initial values input1 = input2 = 0:
//
// outer inner inner_
// step# step# input1 input2 input1
// ------------------------------------
// 0 0 0 0 0
// 0 1 0 1 2
// 0 2 0 2 4
// 0 - 2 2 -
// 1 0 2 2 2
// 1 1 2 3 4
// 1 - 5 3 -
// 2 0 5 3 5
// 2 - 8 3 -
// Create outer cond graph
TF_Operation* six = ScalarConst(6, params_->cond_graph, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* less_than =
LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
params_->cond_output = {less_than, 0};
// Create outer body graph
// Init inner graph
TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]};
TF_WhileParams inner_params =
TF_NewWhile(params_->body_graph, inner_inputs, 2, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
inner_params.name = "inner_loop";
// Create inner cond graph
TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* inner_less_than = LessThan(
inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
inner_params.cond_output = {inner_less_than, 0};
// Create inner body graph
TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one");
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two");
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* input2_add =
Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
inner_params.body_outputs[1] = {input2_add, 0};
TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two,
inner_params.body_graph, s_, "add2");
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
inner_params.body_outputs[0] = {inner_input1_add, 0};
// Finalize inner graph
TF_Output inner_outputs[2] = {{nullptr, -1}};
TF_FinishWhile(&inner_params, s_, inner_outputs);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* input1_add =
Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
params_->body_outputs[0] = {input1_add, 0};
params_->body_outputs[1] = inner_outputs[1];
// Finalize outer graph
ExpectOK();
// Check for a few expected nodes
const char* node_name = "test_loop/cond/scalar";
EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
node_name = "test_loop/body/add";
EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
node_name = "test_loop/body/inner_loop/body/one";
EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
node_name = "test_loop/body/inner_loop/cond/less_than";
EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
// Run the graph
Run({0, 0});
ExpectOutputValue(0, 8);
ExpectOutputValue(1, 3);
}
TEST_F(CApiWhileLoopTest, BadCondOutput) {
Init(1);
params_->body_outputs[0] = params_->body_inputs[0];
ExpectError(TF_INVALID_ARGUMENT,
"TF_WhileParams `cond_output` field isn't set");
}
TEST_F(CApiWhileLoopTest, BadBodyOutput) {
Init(1);
CreateCondGraph();
ExpectError(TF_INVALID_ARGUMENT,
"TF_WhileParams `body_outputs[0]` field isn't set");
}
TEST_F(CApiWhileLoopTest, NullName) {
Init(1);
CreateCondGraph();
params_->body_outputs[0] = params_->body_inputs[0];
params_->name = nullptr;
ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null");
}
TEST_F(CApiWhileLoopTest, WrongGraph) {
Init(1);
CreateCondGraph();
// Set body output to output from outer graph
params_->body_outputs[0] = inputs_[0];
// TODO(skyewm): improve error message
ExpectError(TF_INVALID_ARGUMENT,
"Requested return node 'p0' not found in graph def");
}
TEST_F(CApiWhileLoopTest, BadTypes) {
Init(1);
CreateCondGraph();
// Op that has a float input + output
TF_OperationDescription* desc = TF_NewOperation(
params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op");
TF_AddInput(desc, params_->body_inputs[0]);
TF_FinishOperation(desc, s_);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
string msg(TF_Message(s_));
EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while "
"building NodeDef 'float_op'"),
msg.npos);
TF_AbortWhile(params_.get());
}
REGISTER_OP("TestOpWithNoGradient")
.Input("x: T")
.Output("y: T")
.Attr("T: {float, double}")
.Doc(R"doc(
Test op with no grad registered.
x: input
y: output
)doc")
.SetShapeFn(tensorflow::shape_inference::UnknownShape);
class CApiGradientsTest : public ::testing::Test {
protected:
CApiGradientsTest()
: s_(TF_NewStatus()),
graph_(TF_NewGraph()),
expected_graph_(TF_NewGraph()) {}
~CApiGradientsTest() override {
TF_DeleteGraph(graph_);
TF_DeleteGraph(expected_graph_);
TF_DeleteStatus(s_);
}
void TestGradientsSuccess(bool grad_inputs_provided) {
TF_Output inputs[2];
TF_Output outputs[1];
TF_Output grad_outputs[2];
TF_Output expected_grad_outputs[2];
BuildSuccessGraph(inputs, outputs);
BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);
AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs);
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Compare that the graphs match.
GraphDef expected_gdef;
GraphDef gdef;
EXPECT_TRUE(GetGraphDef(expected_graph_, &expected_gdef));
EXPECT_TRUE(GetGraphDef(graph_, &gdef));
TF_EXPECT_GRAPH_EQ(expected_gdef, gdef);
// Compare that the output of the gradients of both graphs match.
RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs);
}
void TestGradientsError(bool grad_inputs_provided) {
TF_Output inputs[1];
TF_Output outputs[1];
TF_Output grad_outputs[1];
BuildErrorGraph(inputs, outputs);
AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs);
string expected_msg =
"No gradient defined for op: TestOpWithNoGradient. Please see "
"https://www.tensorflow.org/code/"
"tensorflow/cc/gradients/README.md"
" for instructions on how to add C++ gradients.";
EXPECT_EQ(expected_msg, TF_Message(s_));
}
// Run the graph and ensure that the gradient values are as expected.
void RunGraphsAndCompareOutputs(TF_Output* grad_outputs,
TF_Output* expected_grad_outputs) {
std::unique_ptr<CSession> csession(new CSession(graph_, s_));
std::unique_ptr<CSession> expected_csession(
new CSession(expected_graph_, s_));
std::vector<TF_Output> grad_outputs_vec;
grad_outputs_vec.assign(grad_outputs, grad_outputs + 2);
csession->SetOutputs(grad_outputs_vec);
csession->Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Tensor* out0 = csession->output_tensor(0);
TF_Tensor* out1 = csession->output_tensor(1);
std::vector<TF_Output> expected_grad_outputs_vec;
expected_grad_outputs_vec.assign(expected_grad_outputs,
expected_grad_outputs + 2);
expected_csession->SetOutputs(expected_grad_outputs_vec);
expected_csession->Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Tensor* expected_out0 = expected_csession->output_tensor(0);
TF_Tensor* expected_out1 = expected_csession->output_tensor(1);
CompareTensors(out0, expected_out0);
CompareTensors(out1, expected_out1);
}
void CompareTensors(TF_Tensor* a, TF_Tensor* b) {
float* a_data = static_cast<float*>(TF_TensorData(a));
float* b_data = static_cast<float*>(TF_TensorData(b));
EXPECT_EQ(*a_data, *b_data);
}
void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs,
TF_Output* outputs, int noutputs, TF_Output* grad_outputs) {
if (grad_inputs_provided) {
TF_Output grad_inputs[1];
const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0};
TF_Operation* grad_inputs_op =
FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs");
grad_inputs[0] = TF_Output{grad_inputs_op, 0};
TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs,
s_, grad_outputs);
} else {
TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_,
grad_outputs);
}
}
void BuildErrorGraph(TF_Output* inputs, TF_Output* outputs) {
const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
TF_Operation* nograd = NoGradientOp(graph_, s_, const0, "NoGrad");
inputs[0] = TF_Output{const0, 0};
outputs[0] = TF_Output{nograd, 0};
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
}
void BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) {
// Construct the following graph:
// |
// z|
// |
// MatMul
// / \
// ^ ^
// | |
// x| y|
// | |
// | |
// Const_0 Const_1
//
const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
TF_Operation* const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1");
TF_Operation* matmul = MatMul(graph_, s_, const0, const1, "MatMul");
inputs[0] = TF_Output{const0, 0};
inputs[1] = TF_Output{const1, 0};
outputs[0] = TF_Output{matmul, 0};
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
}
void BuildExpectedGraph(bool grad_inputs_provided,
TF_Output* expected_grad_outputs) {
// The expected graph looks like this if grad_inputs_provided.
// If grad_inputs_provided is false, Const_0 will be a OnesLike op.
// ^ ^
// dy| dx| // MatMul Gradient Graph
// | |
// MatMul_2 MatMul_1
// ^ ^ ^ ^
// | |----------| |
// | ^ |
// | dz| |
// | | |
// | Const_3 |
// | |
// | ^ |
// | z| | // MatMul Forward Graph
// | | |
// | MatMul |
// | / \ |
// | ^ ^ |
// | | | |
// |---x| y|----|
// | |
// | |
// Const_0 Const_1
//
const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
TF_Operation* const0 =
FloatConst2x2(expected_graph_, s_, const0_val, "Const_0");
TF_Operation* const1 =
FloatConst2x2(expected_graph_, s_, const1_val, "Const_1");
TF_Operation* matmul =
MatMul(expected_graph_, s_, const0, const1, "MatMul");
TF_Operation* const3;
if (grad_inputs_provided) {
const float const3_val[] = {1.0, 1.0, 1.0, 1.0};
const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs");
} else {
const3 = OnesLike(expected_graph_, s_, matmul, "OnesLike");
}
TF_Operation* matmul1 =
MatMul(expected_graph_, s_, const3, const1, "MatMul_1", false, true);
TF_Operation* matmul2 =
MatMul(expected_graph_, s_, const0, const3, "MatMul_2", true, false);
expected_grad_outputs[0] = {matmul1, 0};
expected_grad_outputs[1] = {matmul2, 0};
}
TF_Tensor* FloatTensor2x2(const float* values) {
const int64_t dims[2] = {2, 2};
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);
memcpy(TF_TensorData(t), values, sizeof(float) * 4);
return t;
}
TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s,
const float* values, const char* name) {
unique_tensor_ptr tensor(FloatTensor2x2(values), TF_DeleteTensor);
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", tensor.get(), s);
if (TF_GetCode(s) != TF_OK) return nullptr;
TF_SetAttrType(desc, "dtype", TF_FLOAT);
TF_Operation* op = TF_FinishOperation(desc, s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
return op;
}
TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l,
TF_Operation* r, const char* name,
bool transpose_a = false, bool transpose_b = false) {
TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name);
if (transpose_a) {
TF_SetAttrBool(desc, "transpose_a", 1);
}
if (transpose_b) {
TF_SetAttrBool(desc, "transpose_b", 1);
}
TF_AddInput(desc, {l, 0});
TF_AddInput(desc, {r, 0});
TF_Operation* op = TF_FinishOperation(desc, s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
return op;
}
TF_Operation* OnesLike(TF_Graph* graph, TF_Status* s, TF_Operation* in,
const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "OnesLike", name);
TF_AddInput(desc, {in, 0});
TF_Operation* op = TF_FinishOperation(desc, s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
return op;
}
TF_Operation* NoGradientOp(TF_Graph* graph, TF_Status* s, TF_Operation* in,
const char* name) {
TF_OperationDescription* desc =
TF_NewOperation(graph, "TestOpWithNoGradient", name);
TF_AddInput(desc, {in, 0});
TF_Operation* op = TF_FinishOperation(desc, s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
return op;
}
TF_Status* s_;
TF_Graph* graph_;
TF_Graph* expected_graph_;
};
TEST_F(CApiGradientsTest, Gradients_GradInputs) { TestGradientsSuccess(true); }
TEST_F(CApiGradientsTest, Gradients_NoGradInputs) {
TestGradientsSuccess(false);
}
TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) {
TestGradientsError(true);
}
TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
TestGradientsError(false);
}
void StringVectorToArrays(const std::vector<string>& v,
@ -1095,9 +1865,13 @@ void StringVectorToArrays(const std::vector<string>& v,
// Registers two ops, each with a single attribute called 'v'.
// The attribute in one op will have a type 'type', the other
// will have list(type).
#define ATTR_TEST_REGISTER_OP(type) \
REGISTER_OP("CApiAttributesTestOp" #type).Attr("v: " #type); \
REGISTER_OP("CApiAttributesTestOpList" #type).Attr("v: list(" #type ")")
#define ATTR_TEST_REGISTER_OP(type) \
REGISTER_OP("CApiAttributesTestOp" #type) \
.Attr("v: " #type) \
.SetShapeFn(tensorflow::shape_inference::UnknownShape); \
REGISTER_OP("CApiAttributesTestOpList" #type) \
.Attr("v: list(" #type ")") \
.SetShapeFn(tensorflow::shape_inference::UnknownShape)
ATTR_TEST_REGISTER_OP(string);
ATTR_TEST_REGISTER_OP(int);
ATTR_TEST_REGISTER_OP(float);
@ -1504,8 +2278,8 @@ TEST_F(CApiAttributesTest, TensorList) {
EXPECT_EQ(TF_INT8, TF_TensorType(v)) << i;
EXPECT_EQ(tensor_ndims[i], TF_NumDims(v)) << i;
for (int j = 0; j < TF_NumDims(v); ++j) {
EXPECT_EQ(tensor_dims[i][j], TF_Dim(v, j)) << "Tensor #" << i
<< ", dimension #" << j;
EXPECT_EQ(tensor_dims[i][j], TF_Dim(v, j))
<< "Tensor #" << i << ", dimension #" << j;
}
EXPECT_EQ(sizeof(char) * tensor_size[i], TF_TensorByteSize(v)) << i;
EXPECT_EQ(0,

View File

@ -58,6 +58,7 @@ CheckpointReader::CheckpointReader(const string& filename,
CheckpointReader::~CheckpointReader() {
delete var_to_shape_map_ptr_;
delete reader_;
delete v2_reader_;
}
bool CheckpointReader::HasTensor(const string& name) const {

View File

@ -0,0 +1 @@
_TF_*

67
tensorflow/c/generate-pc.sh Executable file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
TF_PREFIX='/usr/local'
usage() {
echo "Usage: $0 OPTIONS"
echo -e "-p, --prefix\tset installation prefix (default: /usr/local)"
echo -e "-v, --version\tset TensorFlow version"
echo -e "-h, --help\tdisplay this message"
}
[ $# == 0 ] && usage && exit 0
# read the options
ARGS=$(getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@")
eval set -- "$ARGS"
# extract options and their arguments into variables.
while true ; do
case "$1" in
-h|--help) usage ; exit ;;
-p|--prefix)
case "$2" in
"") shift 2 ;;
*) TF_PREFIX=$2 ; shift 2 ;;
esac ;;
-v|--version)
case "$2" in
"") shift 2 ;;
*) TF_VERSION=$2 ; shift 2 ;;
esac ;;
--) shift ; break ;;
*) echo "Internal error! Try '$0 --help' for more information." ; exit 1 ;;
esac
done
[ -z $TF_VERSION ] && echo "Specify a version using -v or --version" && exit 1
echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX"
cat << EOF > tensorflow.pc
prefix=${TF_PREFIX}
exec_prefix=\${prefix}
libdir=\${exec_prefix}/lib
includedir=\${prefix}/include
Name: TensorFlow
Version: ${TF_VERSION}
Description: Library for computation using data flow graphs for scalable machine learning
Requires:
Libs: -L\${libdir} -ltensorflow
Cflags: -I\${includedir}
EOF

View File

@ -0,0 +1,9 @@
VERS_1.0 {
# Export symbols in c_api.h.
global:
TF_*;
# Hide everything else.
local:
*;
};

View File

@ -8,8 +8,6 @@ package(
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
@ -36,6 +34,7 @@ cc_library(
tf_cc_test(
name = "framework_gradients_test",
size = "small",
srcs = ["framework/gradients_test.cc"],
deps = [
":cc_ops",
@ -44,8 +43,8 @@ tf_cc_test(
":gradients",
":testutil",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
@ -59,7 +58,6 @@ cc_library(
deps = [
":cc_ops",
":client_session",
":grad_op_registry",
":gradients",
":ops",
":scope",
@ -72,6 +70,7 @@ cc_library(
tf_cc_test(
name = "framework_gradient_checker_test",
size = "small",
srcs = ["framework/gradient_checker_test.cc"],
deps = [
":cc_ops",
@ -80,8 +79,8 @@ tf_cc_test(
":gradient_checker",
":testutil",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
@ -93,6 +92,7 @@ cc_library(
deps = [
":array_grad",
":math_grad",
":nn_grad",
],
)
@ -124,7 +124,10 @@ cc_library_with_android_deps(
cc_library_with_android_deps(
name = "scope",
srcs = ["framework/scope.cc"],
srcs = [
"framework/scope.cc",
"framework/scope_internal.h",
],
hdrs = ["framework/scope.h"],
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
common_deps = [
@ -138,8 +141,18 @@ cc_library_with_android_deps(
],
)
cc_library_with_android_deps(
name = "scope_internal",
hdrs = ["framework/scope_internal.h"],
common_deps = [
":scope",
],
deps = [],
)
tf_cc_test(
name = "framework_scope_test",
size = "small",
srcs = ["framework/scope_test.cc"],
deps = [
":ops",
@ -169,6 +182,7 @@ cc_library_with_android_deps(
tf_cc_test(
name = "client_client_session_test",
size = "small",
srcs = ["client/client_session_test.cc"],
deps = [
":cc_ops",
@ -203,6 +217,7 @@ cc_library_with_android_deps(
tf_cc_test(
name = "ops_const_op_test",
size = "small",
srcs = ["ops/const_op_test.cc"],
deps = [
":const_op",
@ -231,11 +246,13 @@ cc_library(
":cc_ops_internal",
":grad_op_registry",
":gradients",
"//tensorflow/core:lib_proto_parsing",
],
)
tf_cc_test(
name = "gradients_array_grad_test",
size = "small",
srcs = ["gradients/array_grad_test.cc"],
deps = [
":array_grad",
@ -266,6 +283,7 @@ cc_library(
tf_cc_test(
name = "gradients_math_grad_test",
size = "small",
srcs = ["gradients/math_grad_test.cc"],
deps = [
":cc_ops",
@ -296,6 +314,7 @@ cc_library(
tf_cc_test(
name = "gradients_nn_grad_test",
size = "small",
srcs = ["gradients/nn_grad_test.cc"],
deps = [
":cc_ops",
@ -315,6 +334,7 @@ tf_gen_op_wrappers_cc(
name = "cc_ops",
op_lib_names = [
"array_ops",
"audio_ops",
"candidate_sampling_ops",
"control_flow_ops",
"data_flow_ops",
@ -343,6 +363,7 @@ tf_gen_op_wrappers_cc(
tf_cc_test(
name = "framework_cc_ops_test",
size = "small",
srcs = ["framework/cc_ops_test.cc"],
deps = [
":cc_ops",
@ -376,6 +397,34 @@ tf_gen_op_wrappers_cc(
visibility = ["//tensorflow:internal"],
)
tf_gen_op_wrappers_cc(
name = "functional_ops",
include_internal_ops = 1,
op_lib_names = [
"functional_ops",
],
pkg = "//tensorflow/core",
visibility = ["//tensorflow:internal"],
)
tf_gen_op_wrappers_cc(
name = "resource_variable_ops",
include_internal_ops = 1,
op_lib_names = [
"resource_variable_ops",
],
pkg = "//tensorflow/core",
visibility = ["//tensorflow:internal"],
)
tf_gen_op_wrappers_cc(
name = "remote_fused_graph_ops",
op_lib_names = [
"remote_fused_graph_ops",
],
pkg = "//tensorflow/core",
)
cc_library_with_android_deps(
name = "cc_op_gen_main",
srcs = [
@ -414,7 +463,6 @@ cc_library(
":client_session",
":ops",
":scope",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib_internal",
"//tensorflow/core:tensorflow",
@ -433,13 +481,25 @@ cc_binary(
name = "tutorials_example_trainer",
srcs = ["tutorials/example_trainer.cc"],
copts = tf_copts(),
linkopts = [
"-lpthread",
"-lm",
],
linkopts = select({
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//tensorflow:darwin": [
"-lm",
"-lpthread",
],
"//tensorflow:ios": [
"-lm",
"-lpthread",
],
"//conditions:default": [
"-lm",
"-lpthread",
"-lrt",
],
}),
deps = [
":cc_ops",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -471,7 +531,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core/kernels:ops_util",
],
)
@ -512,6 +571,7 @@ cc_library(
tf_cc_test(
name = "coordinator_test",
size = "small",
srcs = ["training/coordinator_test.cc"],
deps = [
":cc_ops",

View File

@ -16,32 +16,55 @@ limitations under the License.
#include "tensorflow/cc/client/client_session.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class ClientSession::Impl {
private:
friend class ClientSession;
Impl(Session* session, std::shared_ptr<Graph> graph)
: session_(session), graph_(std::move(graph)) {}
static SessionOptions MakeDefaultSessionOptions(const string& target);
Status MaybeExtendGraph() const;
std::unique_ptr<Session> session_;
std::shared_ptr<Graph> graph_;
mutable mutex mu_;
mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
};
ClientSession::ClientSession(const Scope& scope, const string& target)
: ClientSession(scope, MakeDefaultSessionOptions(target)) {}
: ClientSession(scope, Impl::MakeDefaultSessionOptions(target)) {}
ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "") {}
ClientSession::ClientSession(const Scope& scope,
const SessionOptions& session_options)
: graph_(scope.graph_as_shared_ptr()) {
const SessionOptions& session_options) {
Session* new_session;
Status status = NewSession(session_options, &new_session);
TF_CHECK_OK(status) << status;
session_.reset(new_session);
CHECK_NOTNULL(session_.get());
impl_.reset(new Impl(new_session, scope.graph_as_shared_ptr()));
CHECK_NOTNULL(impl()->session_.get());
}
SessionOptions ClientSession::MakeDefaultSessionOptions(
const string& target) const {
// Define destructor here so we can forward declare `Impl` in client_session.h.
// If we define a dtor in the header file or use the default dtor,
// unique_ptr<Impl> needs the complete type.
ClientSession::~ClientSession() {}
SessionOptions ClientSession::Impl::MakeDefaultSessionOptions(
const string& target) {
SessionOptions options;
options.env = Env::Default();
options.target = target;
@ -67,7 +90,7 @@ Status ClientSession::Run(const FeedType& inputs,
nullptr);
}
Status ClientSession::MaybeExtendGraph() const {
Status ClientSession::Impl::MaybeExtendGraph() const {
mutex_lock l(mu_);
int num_nodes = graph_->num_node_ids();
if (num_nodes > last_num_graph_nodes_) {
@ -90,16 +113,18 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
feeds.emplace_back(feed.first.name(), feed.second.tensor);
}
std::vector<string> output_tensor_names;
output_tensor_names.reserve(fetch_outputs.size());
for (auto const& output : fetch_outputs) {
output_tensor_names.push_back(output.name());
}
std::vector<string> target_node_names;
target_node_names.reserve(run_outputs.size());
for (auto const& output : run_outputs) {
target_node_names.push_back(output.node()->name());
}
TF_RETURN_IF_ERROR(MaybeExtendGraph());
return session_->Run(run_options, feeds, output_tensor_names,
target_node_names, outputs, run_metadata);
TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
return impl()->session_->Run(run_options, feeds, output_tensor_names,
target_node_names, outputs, run_metadata);
}
} // end namespace tensorflow

View File

@ -23,14 +23,13 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
/// @addtogroup core
/// @{
/// A `ClientSession` object lets the caller drive the evaluation of the
/// TensorFlow graph constructed with the C++ API.
///
@ -64,6 +63,8 @@ class ClientSession {
/// Create a new session, configuring it with `session_options`.
ClientSession(const Scope& scope, const SessionOptions& session_options);
~ClientSession();
/// Evaluate the tensors in `fetch_outputs`. The values are returned as
/// `Tensor` objects in `outputs`. The number and order of `outputs` will
/// match `fetch_outputs`.
@ -89,18 +90,14 @@ class ClientSession {
// TODO(keveman): Add support for partial run.
private:
SessionOptions MakeDefaultSessionOptions(const string& target) const;
Status MaybeExtendGraph() const;
std::unique_ptr<Session> session_;
std::shared_ptr<Graph> graph_;
mutable mutex mu_;
mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
TF_DISALLOW_COPY_AND_ASSIGN(ClientSession);
class Impl;
std::unique_ptr<Impl> impl_;
Impl* impl() { return impl_.get(); }
const Impl* impl() const { return impl_.get(); }
};
/// @}
} // end namespace tensorflow
#endif // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_

View File

@ -49,7 +49,7 @@ TEST(ClientSessionTest, Feed) {
TEST(ClientSessionTest, Extend) {
Scope root = Scope::NewRootScope();
auto a = Placeholder(root, DT_INT32);
auto a = Placeholder(root, DT_INT32, Placeholder::Shape({2}));
auto c = Add(root, a, {2, 2});
ClientSession session(root);
std::vector<Tensor> outputs;

View File

@ -57,6 +57,16 @@ string GetPath(const string& dot_h_fname) {
return result;
}
// Converts: some/path/to/file.xx
// to: file
// (note that suffix is removed)
string GetFilename(const string& path) {
size_t slash_pos = path.rfind('/');
if (slash_pos == path.npos) slash_pos = -1;
size_t dot_pos = path.rfind('.');
return path.substr(slash_pos + 1, dot_pos - (slash_pos + 1));
}
// Converts:
// cc/ops/gen_foo_ops.h
// to:
@ -77,6 +87,17 @@ string ToGuard(const string& path) {
return guard;
}
// Converts: some_name_xyz
// to: Some Name Xyz
string ToTitle(const string& name) {
string title = name;
for (int i = 0; i < title.size(); ++i) {
if (title[i] == '_') title[i] = ' ';
}
str_util::TitlecaseString(&title, " ");
return title;
}
// Change: Into:
// ABC /// ABC
// ///
@ -105,7 +126,11 @@ string PrintString(const string& str) {
return strings::StrCat("\"", str_util::CEscape(str), "\"");
}
string PrintTensorShape(const TensorShape& shape) {
string PrintTensorShape(const TensorShapeProto& shape_proto) {
PartialTensorShape shape(shape_proto);
if (shape.IsIdenticalTo(PartialTensorShape())) {
return "::tensorflow::PartialTensorShape() /* unknown */";
}
string ret = "{";
for (int d = 0; d < shape.dims(); ++d) {
if (d > 0) strings::StrAppend(&ret, ", ");
@ -167,7 +192,13 @@ string PrintTensor(const TensorProto& tensor_proto) {
}
}
string PrintAttrValue(string op, const AttrValue& attr_value) {
string PrintTensorProto(const TensorProto& proto) {
return strings::StrCat("Input::Initializer(", "{", PrintTensor(proto), "}, ",
PrintTensorShape(proto.tensor_shape()),
").AsTensorProto()");
}
string PrintAttrValue(const string& op, const AttrValue& attr_value) {
switch (attr_value.value_case()) {
case AttrValue::kS:
return PrintString(attr_value.s());
@ -182,12 +213,9 @@ string PrintAttrValue(string op, const AttrValue& attr_value) {
case AttrValue::kType:
return EnumName_DataType(attr_value.type());
case AttrValue::kShape:
return PrintTensorShape(TensorShape(attr_value.shape()));
return PrintTensorShape(attr_value.shape());
case AttrValue::kTensor:
return strings::StrCat(
"Input::Initializer(", "{", PrintTensor(attr_value.tensor()), "}, ",
PrintTensorShape(TensorShape(attr_value.tensor().tensor_shape())),
").AsTensorProto()");
return PrintTensorProto(attr_value.tensor());
case AttrValue::kList: {
string ret = "{";
if (attr_value.list().s_size() > 0) {
@ -220,8 +248,14 @@ string PrintAttrValue(string op, const AttrValue& attr_value) {
} else if (attr_value.list().shape_size() > 0) {
for (int i = 0; i < attr_value.list().shape_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(
&ret, PrintTensorShape(TensorShape(attr_value.list().shape(i))));
strings::StrAppend(&ret,
PrintTensorShape(attr_value.list().shape(i)));
}
} else if (attr_value.list().tensor_size() > 0) {
for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(&ret,
PrintTensorProto(attr_value.list().tensor(i)));
}
}
strings::StrAppend(&ret, "}");
@ -271,8 +305,8 @@ std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
{"list(bool)", {"gtl::ArraySlice<bool>", true}},
{"type", {"DataType", false}},
{"list(type)", {"DataTypeSlice", true}},
{"shape", {"TensorShape", false}},
{"list(shape)", {"gtl::ArraySlice<TensorShape>", true}},
{"shape", {"PartialTensorShape", false}},
{"list(shape)", {"gtl::ArraySlice<PartialTensorShape>", true}},
{"tensor", {"TensorProto", true}},
{"list(tensor)", {"gtl::ArraySlice<TensorProto>", true}},
{"func", {"NameAttrList", true}},
@ -416,6 +450,7 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
}
strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n");
// Process inputs
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
arg_types.push_back(strings::StrCat(
@ -430,30 +465,45 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
arg.description(), "\n");
}
}
// Process attrs
string required_attrs_comment;
string optional_attrs_comment;
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
// If the attr is going to be inferred or is optional, don't add it as a
// required argument.
if ((inferred_input_attrs.find(attr.name()) !=
inferred_input_attrs.end()) ||
attr.has_default_value()) {
continue;
}
// Skip inferred arguments
if (inferred_input_attrs.count(attr.name()) > 0) continue;
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
string attr_name = AvoidCPPKeywords(attr.name());
arg_types.push_back(strings::StrCat(use_const ? "const " : "",
attr_type_name, use_const ? "&" : ""));
arg_names.push_back(AvoidCPPKeywords(attr.name()));
string attr_comment;
if (!attr.description().empty()) {
strings::StrAppend(&comment, "* ", AvoidCPPKeywords(attr.name()), ":\n");
// TODO(keveman): Word wrap and indent this, to handle multi-line
// descriptions.
strings::StrAppend(&comment, " ", attr.description(), "\n");
strings::StrAppend(&attr_comment, "* ", attr_name, ": ",
attr.description(), "\n");
}
if (attr.has_default_value()) {
strings::StrAppend(&optional_attrs_comment, attr_comment);
} else {
strings::StrAppend(&required_attrs_comment, attr_comment);
arg_types.push_back(strings::StrCat(
use_const ? "const " : "", attr_type_name, use_const ? "&" : ""));
arg_names.push_back(attr_name);
}
}
strings::StrAppend(&comment, required_attrs_comment);
if (!optional_attrs_comment.empty()) {
strings::StrAppend(&comment, "\nOptional attributes (see `Attrs`):\n");
strings::StrAppend(&comment, optional_attrs_comment);
}
// Process outputs
for (int i = 0; i < op_def.output_arg_size(); ++i) {
const auto& arg = op_def.output_arg(i);
bool is_list = ArgIsList(arg);
@ -509,8 +559,6 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
string OpInfo::GetOpAttrStruct() const {
string struct_fields;
string setters;
string attrs_comment =
strings::StrCat("Optional attribute setters for ", op_name, " :\n\n");
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
@ -531,13 +579,15 @@ string OpInfo::GetOpAttrStruct() const {
strings::StrCat(camel_case_name, suffix, "(", use_const ? "const " : "",
attr_type_name, use_const ? "&" : "");
strings::StrAppend(&attrs_comment, attr_func_def, "): Defaults to ",
SummarizeAttrValue(attr.default_value()), "\n");
string attr_comment;
if (!attr.description().empty()) {
// TODO(keveman): Word wrap and indent this to handle multi-line
// description.
strings::StrAppend(&attrs_comment, " ", attr.description(), "\n");
strings::StrAppend(&attr_comment, attr.description(), "\n\n");
}
strings::StrAppend(&attr_comment, "Defaults to ",
SummarizeAttrValue(attr.default_value()), "\n");
attr_comment = MakeComment(attr_comment, " ");
strings::StrAppend(&setters, attr_comment);
strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n");
strings::StrAppend(&setters, " Attrs ret = *this;\n");
strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n");
@ -552,6 +602,8 @@ string OpInfo::GetOpAttrStruct() const {
return "";
}
string attrs_comment =
strings::StrCat("Optional attribute setters for ", op_name, "\n");
string struct_decl = MakeComment(attrs_comment, " ");
strings::StrAppend(&struct_decl, " struct Attrs {\n");
strings::StrAppend(&struct_decl, setters, struct_fields);
@ -678,7 +730,7 @@ void OpInfo::GetOutput(string* out) const {
// One output, no need for NameRangeMap
if (is_list_output[0]) {
strings::StrAppend(out,
" for (int64 i = 0; i < ret->num_outputs(); ++i)\n");
" for (int32 i = 0; i < ret->num_outputs(); ++i)\n");
strings::StrAppend(out, " this->", output_names[0],
".push_back(Output(ret, i));\n");
} else {
@ -688,11 +740,10 @@ void OpInfo::GetOutput(string* out) const {
return;
}
strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n");
strings::StrAppend(
out,
" ::tensorflow::Status _status_ = "
"::tensorflow::NameRangesForNode(ret->def(), ret->op_def(), "
"nullptr, &_outputs_range);\n");
strings::StrAppend(out,
" ::tensorflow::Status _status_ = "
"::tensorflow::NameRangesForNode(*ret, ret->op_def(), "
"nullptr, &_outputs_range);\n");
strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str,
".UpdateStatus(_status_);\n", " return;\n");
strings::StrAppend(out, " }\n\n");
@ -701,7 +752,7 @@ void OpInfo::GetOutput(string* out) const {
const string arg_range = strings::StrCat(
"_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
if (is_list_output[i]) {
strings::StrAppend(out, " for (int64 i = ", arg_range, ".first; i < ",
strings::StrAppend(out, " for (int32 i = ", arg_range, ".first; i < ",
arg_range, ".second; ++i)\n");
strings::StrAppend(out, " this->", output_names[i],
".push_back(Output(ret, i));\n");
@ -841,6 +892,10 @@ namespace ops {
)include",
"#include \"", op_header, "\"\n", namespace_begin);
const string filename = GetFilename(dot_h_fname);
const string doxygen = strings::StrCat("/// @defgroup ", filename, " ",
ToTitle(filename), "\n", "/// @{\n\n");
TF_CHECK_OK(h->Append(
strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
"#ifndef ",
@ -850,6 +905,7 @@ namespace ops {
*op_header_guard, "\n\n")));
TF_CHECK_OK(h->Append(header));
TF_CHECK_OK(h->Append(namespace_begin));
TF_CHECK_OK(h->Append(doxygen));
TF_CHECK_OK(cc->Append(cc_header));
}
@ -860,7 +916,9 @@ void FinishFiles(bool internal, WritableFile* h, WritableFile* cc,
} // namespace tensorflow
)footer"
:
R"footer(} // namespace ops
R"footer(/// @}
} // namespace ops
} // namespace tensorflow
)footer";
@ -892,7 +950,7 @@ void WriteCCOps(const OpList& ops, const string& dot_h_fname,
// Load the override map.
OpGenOverrideMap override_map;
if (!overrides_fnames.empty()) {
override_map.LoadFileList(env, overrides_fnames);
TF_CHECK_OK(override_map.LoadFileList(env, overrides_fnames));
}
// Write the initial boilerplate to the .h and .cc files.

View File

@ -32,10 +32,11 @@ Output Linear(const Scope& scope, Input x, Input w, Input b) {
return BiasAdd(cop_scopes.last, m, b);
}
void GetColocationConstraints(Output tensor, std::vector<string>* constraints) {
void GetColocationConstraints(const Output& tensor,
std::vector<string>* constraints) {
constraints->clear();
TF_EXPECT_OK(
GetNodeAttr(tensor.op().node()->def(), kColocationAttrName, constraints));
TF_EXPECT_OK(GetNodeAttr(tensor.op().node()->attrs(), kColocationAttrName,
constraints));
}
} // namespace
@ -158,11 +159,11 @@ TEST(CCOpTest, KernelLabel) {
Scope root = Scope::NewRootScope();
auto add = Add(root.WithKernelLabel("AddWithKernelLabel"), 1.0f, 2.0f);
TF_EXPECT_OK(root.status());
const auto& attrs = add.z.op().node()->def().attr();
ASSERT_TRUE(attrs.find("_kernel") != attrs.end());
auto kernel_attr = attrs.find("_kernel")->second;
TF_EXPECT_OK(AttrValueHasType(kernel_attr, "string"));
EXPECT_EQ(kernel_attr.s(), "AddWithKernelLabel");
AttrSlice attrs = add.z.op().node()->attrs();
const auto* kernel_attr = attrs.Find("_kernel");
ASSERT_TRUE(kernel_attr);
TF_EXPECT_OK(AttrValueHasType(*kernel_attr, "string"));
EXPECT_EQ(kernel_attr->s(), "AddWithKernelLabel");
}
TEST(CCOpTest, ColocateWith) {
@ -189,8 +190,7 @@ TEST(CCOpTest, ColocateWith) {
Scope with_colocate = root.ColocateWith(c3).ColocateWith(c4);
auto c6 = Const(with_colocate.WithOpName("c6").ClearColocation(), 7);
const auto& attrs = c6.op().node()->def().attr();
EXPECT_TRUE(attrs.find("_class") == attrs.end());
EXPECT_FALSE(c6.op().node()->attrs().Find("_class"));
}
TEST(CCOpTest, TemplatedConst) {

View File

@ -32,7 +32,13 @@ bool GradOpRegistry::Register(const string& op, GradFunc func) {
Status GradOpRegistry::Lookup(const string& op, GradFunc* func) const {
auto iter = registry_.find(op);
if (iter == registry_.end()) {
return errors::NotFound("No gradient defined for op: ", op);
const string error_msg =
"No gradient defined for op: " + op +
". Please see "
"https://www.tensorflow.org/code/"
"tensorflow/cc/gradients/README.md"
" for instructions on how to add C++ gradients.";
return errors::NotFound(error_msg);
}
*func = iter->second;
return Status::OK();

View File

@ -22,8 +22,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
namespace {
// TODO(andydavis) Support returning relative error (as opposed to max error)
@ -39,14 +37,16 @@ Status ComputeTheoreticalJacobianTranspose(
const std::vector<TensorShape>& x_shapes,
const std::vector<Tensor>& x_datas, const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
std::vector<Tensor>& jacobian_ts) {
int y_num = y_shapes.size();
int x_num = x_shapes.size();
std::vector<Tensor>* jacobian_ts) {
size_t y_num = y_shapes.size();
size_t x_num = x_shapes.size();
// Call AddSymbolicGradients to get 'dxs' (we will feed 'dys').
OutputList dys;
dys.reserve(y_shapes.size());
for (const auto& y_shape : y_shapes) {
// TODO(suharshs): This currently assumes that all x's are the same type.
dys.push_back(Cast(scope, Const(scope, 1.0, y_shape), xs[0].type()));
dys.push_back(
ops::Cast(scope, ops::Const(scope, 1.0, y_shape), xs[0].type()));
}
OutputList dxs;
TF_RETURN_IF_ERROR(AddSymbolicGradients(scope, ys, xs, dys, &dxs));
@ -84,7 +84,7 @@ Status ComputeTheoreticalJacobianTranspose(
for (int x_idx = 0; x_idx < x_num; x_idx++) {
const int64 x_size = x_shapes[x_idx].num_elements();
auto jacobian = jacobian_ts[x_idx * y_num + y_idx].matrix<T>();
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<T>();
auto dx_flat = dxout[x_idx].flat<T>();
for (int r = 0; r < x_size; ++r) {
jacobian(r, c) = dx_flat(r);
@ -97,20 +97,20 @@ Status ComputeTheoreticalJacobianTranspose(
return Status::OK();
}
Status EvaluateGraph(ClientSession& session, const OutputList& xs,
const OutputList& ys, std::vector<Tensor>& x_datas,
Status EvaluateGraph(ClientSession* session, const OutputList& xs,
const OutputList& ys, std::vector<Tensor>* x_datas,
std::vector<Tensor>* y_datas) {
// Create the feed list.
ClientSession::FeedType feed_list;
for (int i = 0; i < x_datas.size(); i++) {
feed_list.insert({xs[i], x_datas[i]});
for (int i = 0; i < x_datas->size(); i++) {
feed_list.insert({xs[i], (*x_datas)[i]});
}
TF_RETURN_IF_ERROR(session.Run(feed_list, ys, y_datas));
TF_RETURN_IF_ERROR(session->Run(feed_list, ys, y_datas));
for (int y_idx = 0; y_idx < y_datas->size(); y_idx++) {
for (int x_idx = 0; x_idx < x_datas.size(); x_idx++) {
for (int x_idx = 0; x_idx < x_datas->size(); x_idx++) {
Tensor y_data = (*y_datas)[y_idx];
if (y_data.SharesBufferWith(x_datas[x_idx])) {
if (y_data.SharesBufferWith((*x_datas)[x_idx])) {
// Create copies of outputs that share a buffer with any inputs since
// the underlying buffer of the input Tensors are not copied for some
// operations (i.e. Identity), which can lead to incorrect results for
@ -128,14 +128,14 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
const T delta,
std::vector<Tensor>& x_datas,
std::vector<Tensor>& jacobian_ts) {
int y_num = y_shapes.size();
int x_num = x_shapes.size();
std::vector<Tensor>* x_datas,
std::vector<Tensor>* jacobian_ts) {
size_t y_num = y_shapes.size();
size_t x_num = x_shapes.size();
ClientSession session(scope);
for (int x_idx = 0; x_idx < x_num; x_idx++) {
auto x_data_flat = x_datas[x_idx].flat<T>();
auto x_data_flat = (*x_datas)[x_idx].flat<T>();
const int64 x_size = x_shapes[x_idx].num_elements();
// Compute the numeric Jacobian one column at a time by perturbing each
@ -147,11 +147,11 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
// Evaluate at positive delta.
x_data_flat(r) = v + delta;
std::vector<Tensor> y_pos;
TF_RETURN_IF_ERROR(EvaluateGraph(session, xs, ys, x_datas, &y_pos));
TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_pos));
// Evaluate at negative delta.
x_data_flat(r) = v - delta;
std::vector<Tensor> y_neg;
TF_RETURN_IF_ERROR(EvaluateGraph(session, xs, ys, x_datas, &y_neg));
TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_neg));
for (int y_idx = 0; y_idx < y_num; y_idx++) {
// Compute element-wise centered difference and store in each Jacobian.
@ -159,7 +159,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
auto y_neg_flat = y_neg[y_idx].flat<T>();
const int64 y_size = y_shapes[y_idx].num_elements();
const T scale = 2 * delta;
auto jacobian = jacobian_ts[x_idx * y_num + y_idx].matrix<T>();
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<T>();
for (int c = 0; c < y_size; ++c) {
jacobian(r, c) = (y_pos_flat(c) - y_neg_flat(c)) / scale;
}
@ -175,11 +175,11 @@ template <typename T>
void InitJacobians(const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const std::vector<TensorShape>& y_shapes,
std::vector<Tensor>& jacobians) {
int y_num = y_shapes.size();
int x_num = x_shapes.size();
std::vector<Tensor>* jacobians) {
size_t y_num = y_shapes.size();
size_t x_num = x_shapes.size();
jacobians.resize(y_num * x_num);
jacobians->resize(y_num * x_num);
for (int x_idx = 0; x_idx < x_num; x_idx++) {
const int64 x_size = x_shapes[x_idx].num_elements();
for (int y_idx = 0; y_idx < y_num; y_idx++) {
@ -187,7 +187,7 @@ void InitJacobians(const OutputList& xs,
Tensor jacobian_t(xs[x_idx].type(), {x_size, y_size});
auto jacobian_t_flat = jacobian_t.flat<T>();
jacobian_t_flat.setZero();
jacobians[x_idx * y_num + y_idx] = std::move(jacobian_t);
(*jacobians)[x_idx * y_num + y_idx] = std::move(jacobian_t);
}
}
}
@ -197,23 +197,23 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
std::vector<Tensor>& x_datas,
std::vector<Tensor>* x_datas,
T* max_error) {
// Initialize theoretical Jacobians to zeros.
std::vector<Tensor> jacobian_ts;
InitJacobians<T>(xs, x_shapes, y_shapes, jacobian_ts);
InitJacobians<T>(xs, x_shapes, y_shapes, &jacobian_ts);
// Compute theoretical Jacobian.
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
scope, xs, x_shapes, x_datas, ys, y_shapes, jacobian_ts));
scope, xs, x_shapes, *x_datas, ys, y_shapes, &jacobian_ts));
// Initialize numeric Jacobian to zeros.
std::vector<Tensor> jacobian_ns;
InitJacobians<T>(xs, x_shapes, y_shapes, jacobian_ns);
InitJacobians<T>(xs, x_shapes, y_shapes, &jacobian_ns);
// Compute numeric Jacobian.
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
scope, xs, x_shapes, ys, y_shapes, 1e-3, x_datas, jacobian_ns));
scope, xs, x_shapes, ys, y_shapes, 1e-3, x_datas, &jacobian_ns));
for (int i = 0; i < jacobian_ts.size(); i++) {
// Compute the maximum error between theoretical and numeric Jacobians.
@ -256,7 +256,7 @@ Status ComputeGradientError(const Scope& scope, const OutputList& xs,
}
// Compute gradient error.
return ComputeGradientErrorInternal(scope, xs, x_shapes, ys, y_shapes,
x_datas, max_error);
&x_datas, max_error);
}
template <typename T>
@ -267,7 +267,7 @@ Status ComputeGradientError(const Scope& scope, const Output& x,
std::vector<Tensor> x_datas(1, Tensor(x_init_value));
// Compute gradient error.
return ComputeGradientErrorInternal(scope, {x}, {x_datas[0].shape()}, {y},
{y_shape}, x_datas, max_error);
{y_shape}, &x_datas, max_error);
}
#define INSTANTIATE_GRAD_ERR_TYPE(T) \

View File

@ -19,9 +19,9 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/equal_graph_def.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)

View File

@ -210,8 +210,8 @@ Status SymbolicGradientBuilder::Initialize() {
{
// Initialize backprop with `grad_inputs_`.
const int num_dy = grad_inputs_.size();
for (int i = 0; i < num_dy; ++i) {
const size_t num_dy = grad_inputs_.size();
for (size_t i = 0; i < num_dy; ++i) {
TF_RETURN_IF_ERROR(BackpropAlongEdge(grad_inputs_[i], outputs_[i]));
}
}
@ -308,7 +308,7 @@ Status SymbolicGradientBuilder::AddGradients() {
continue;
}
const int num_no_grad = no_grad_dy_indices.size();
const size_t num_no_grad = no_grad_dy_indices.size();
if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) {
// No grad defined for this op, or all outputs returned 'NoGradient':
// Backprop 'NoGradient' along the in edges.
@ -367,6 +367,19 @@ Status AddSymbolicGradients(const Scope& scope,
return builder.AddGradients();
}
Status AddSymbolicGradients(const Scope& scope,
const std::vector<Output>& outputs,
const std::vector<Output>& inputs,
std::vector<Output>* grad_outputs) {
std::vector<Output> grad_inputs;
grad_inputs.reserve(outputs.size());
for (const Output& output : outputs) {
grad_inputs.emplace_back(ops::OnesLike(scope, output));
}
return AddSymbolicGradients(scope, outputs, inputs, grad_inputs,
grad_outputs);
}
Output NoGradient() { return SymbolicGradientBuilder::NoGradient(); }
} // end namespace tensorflow

View File

@ -27,16 +27,19 @@ namespace tensorflow {
/// derivatives of some loss function 'L' w.r.t 'outputs'), adds gradient nodes
/// to the graph associated with 'scope', which compute (and return in
/// 'grad_outputs') the symbolic partial derivatives of 'L' w.r.t 'inputs'.
///
// TODO(andydavis) Add overload of this function with no 'grad_inputs' arg.
// Implementation will fill in 'OnesLike' for all shapes in 'outputs'.
Status AddSymbolicGradients(const Scope& scope,
const std::vector<Output>& outputs,
const std::vector<Output>& inputs,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs);
// Same as above, but uses 'OnesLike' for all shapes in
// 'outputs' as grad_inputs.
Status AddSymbolicGradients(const Scope& scope,
const std::vector<Output>& outputs,
const std::vector<Output>& inputs,
std::vector<Output>* grad_outputs);
/// Returns a sentinel Output that represents 'no gradient' (i.e. no gradient
/// flows along some graph edge during backpropagation).
/// Can be returned in 'grad_outputs' by an invocation of 'AddSymbolicGradients'

View File

@ -19,9 +19,9 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/equal_graph_def.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
@ -40,7 +40,7 @@ class GradientsTest : public ::testing::Test {
TF_ASSERT_OK(scope_test_.ToGraphDef(&gdef_test));
GraphDef gdef_exp;
TF_ASSERT_OK(scope_expected_.ToGraphDef(&gdef_exp));
TF_EXPECT_GRAPH_EQ(gdef_test, gdef_exp);
TF_EXPECT_GRAPH_EQ(gdef_exp, gdef_test);
}
Scope scope_expected_;
@ -98,6 +98,32 @@ TEST_F(GradientsTest, OneMatMul) {
CompareTestAndExpectedGraphs();
}
TEST_F(GradientsTest, OneMatMul_InferGradInputs) {
for (const bool expected : {false, true}) {
const Scope& scope = expected ? scope_expected_ : scope_test_;
// Construct forward graph.
auto x = Const(scope, {{1.0, 2.0}, {3.0, 4.0}});
auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
auto z = MatMul(scope, x, y);
TF_ASSERT_OK(scope.status());
CHECK_NOTNULL(z.node());
if (expected) {
// Construct backward graph.
// The gradients function adds a OnesLike to create a dz of ones with the
// shape of z.
auto dz = OnesLike(scope, z);
auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true));
auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true));
} else {
// Call AddSymbolicGradients.
std::vector<Output> grad_outputs;
TF_ASSERT_OK(AddSymbolicGradients(scope, {z}, {x, y}, &grad_outputs));
}
}
CompareTestAndExpectedGraphs();
}
TEST_F(GradientsTest, TwoMatMuls_Chained) {
for (const bool expected : {false, true}) {
const Scope& scope = expected ? scope_expected_ : scope_test_;
@ -234,7 +260,7 @@ TEST_F(GradientsTest, StackUnstack_StopBackprop) {
}
TEST_F(GradientsTest, DependentGradOutputs) {
// Tests that dependant gradients (in this case the gradients w.r.t to the
// Tests that dependent gradients (in this case the gradients w.r.t to the
// output and one input of MatMul) are computed properly.
// Create two chained MatMul ops.

View File

@ -20,7 +20,7 @@ namespace tensorflow {
Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {}
Output Operation::input(int i) const {
Output Operation::input(int32 i) const {
CHECK_NOTNULL(node_);
CHECK_GE(i, 0);
CHECK_LT(i, node_->num_inputs());
@ -37,14 +37,14 @@ Output Operation::input(int i) const {
return Output(inputs_[i].first, inputs_[i].second);
}
Output Operation::output(int i) const {
Output Operation::output(int32 i) const {
CHECK_NOTNULL(node_);
CHECK_GE(i, 0);
CHECK_LT(i, node_->num_outputs());
return Output(node_, i);
}
uint64 Operation::hash(int64 index) const {
uint64 Operation::hash(int32 index) const {
return ::tensorflow::Hash64(reinterpret_cast<const char*>(&node_),
sizeof(Node*), index);
}

View File

@ -26,30 +26,35 @@ limitations under the License.
namespace tensorflow {
/// @defgroup core Core Tensorflow API
class Output;
/// @addtogroup core
/// @{
/// Represents a node in the computation graph.
class Operation {
public:
Operation() : node_(nullptr) {}
explicit Operation(Node* n);
int num_inputs() const { return node_->num_inputs(); }
DataType input_type(int o) const { return node_->input_type(o); }
Output input(int i) const;
int32 num_inputs() const { return node_->num_inputs(); }
DataType input_type(int32 o) const { return node_->input_type(o); }
Output input(int32 i) const;
int num_outputs() const { return node_->num_outputs(); }
DataType output_type(int o) const { return node_->output_type(o); }
Output output(int i) const;
int32 num_outputs() const { return node_->num_outputs(); }
DataType output_type(int32 o) const { return node_->output_type(o); }
Output output(int32 i) const;
Node* node() const { return node_; }
uint64 hash(int64 index) const;
uint64 hash(int32 index) const;
bool operator==(const Operation& other) const { return node_ == other.node_; }
private:
typedef std::vector<std::pair<Node*, int64>> Inputs;
typedef std::vector<std::pair<Node*, int32>> Inputs;
static Inputs GetInputs(Node* node);
Inputs inputs_;
@ -61,12 +66,12 @@ class Output {
public:
Output() = default;
explicit Output(Node* n) : op_(n) {}
Output(Node* n, int64 index) : op_(n), index_(index) {}
Output(const Operation& op, int64 index) : op_(op), index_(index) {}
Output(Node* n, int32 index) : op_(n), index_(index) {}
Output(const Operation& op, int32 index) : op_(op), index_(index) {}
Operation op() const { return op_; }
Node* node() const { return op().node(); }
int64 index() const { return index_; }
int32 index() const { return index_; }
DataType type() const { return op_.output_type(index_); }
string name() const { return strings::StrCat(node()->name(), ":", index()); }
bool operator==(const Output& other) const {
@ -77,13 +82,14 @@ class Output {
private:
Operation op_ = Operation(nullptr);
int64 index_ = 0;
int32 index_ = 0;
};
/// Hash class that can be used for e.g. storing Outputs in an unordered_map
struct OutputHash {
std::size_t operator()(const Output& output) const {
return Hash64Combine(std::hash<Node*>()(output.node()),
std::hash<int64>()(output.index()));
std::hash<int32>()(output.index()));
}
};
@ -161,6 +167,7 @@ class Input {
/// initializer list is indeed a valid multi-dimensional tensor.
Initializer(const std::initializer_list<Initializer>& v);
// START_SKIP_DOXYGEN
template <typename T, bool = std::is_convertible<T, string>::value>
struct RealType {
typedef string type;
@ -170,6 +177,7 @@ class Input {
struct RealType<T, false> {
typedef T type;
};
// END_SKIP_DOXYGEN
TensorProto AsTensorProto() {
TensorProto tensor_proto;
@ -222,12 +230,12 @@ class Input {
/// Constructor specifying a node name, index and datatype. This should only
/// be used for specifying a backward edge, needed by control flow.
Input(const string& name, int i, DataType dt)
Input(const string& name, int32 i, DataType dt)
: node_name_(name), index_(i), data_type_(dt) {}
Node* node() const { return output_.node(); }
string node_name() const { return node_name_; }
int index() const { return node_name_.empty() ? output_.index() : index_; }
int32 index() const { return node_name_.empty() ? output_.index() : index_; }
DataType data_type() const { return data_type_; }
Status status() const { return status_; }
const Tensor& tensor() const { return tensor_; }
@ -237,7 +245,7 @@ class Input {
Output output_ = Output(Operation(nullptr), 0);
Tensor tensor_;
const string node_name_ = "";
int index_ = 0;
int32 index_ = 0;
DataType data_type_ = DT_INVALID;
};
@ -284,6 +292,8 @@ class InputList {
std::vector<Input> inputs_;
};
/// @}
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_

View File

@ -16,15 +16,116 @@ limitations under the License.
#include <algorithm>
#include <vector>
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
Scope::Scope(Graph* graph, Status* status, Scope::NameMap* name_map,
ShapeRefiner* refiner)
class Scope::Impl {
public:
// A NameMap is used to keep track of suffixes for names used in a scope. A
// name that has not been used so far in a scope will get no suffix. Later
// uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes
// can share the same NameMap. For instance, a new scope created using
// WithControlDependencies() should would share the same NameMap with the
// parent.
typedef std::unordered_map<string, int> NameMap;
Impl(const std::shared_ptr<Graph>& graph,
const std::shared_ptr<Status>& status,
const std::shared_ptr<NameMap>& name_map,
const std::shared_ptr<ShapeRefiner>& refiner);
private:
friend class Scope;
// Tag types to choose the constructor to dispatch.
struct Tags {
enum class ScopeName;
enum class OpName;
enum class ControlDeps;
enum class Device;
enum class SingleUseScope;
enum class ExitOnError;
enum class KernelLabel;
enum class Colocate;
};
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner);
Impl(const Scope& other, Tags::ScopeName, const string& name,
bool copy_names);
Impl(const Scope& other, Tags::OpName, const string& name,
const string& op_name);
Impl(const Scope& other, Tags::ControlDeps,
std::vector<Operation> control_deps, bool clear_control_deps);
Impl(const Scope& other, Tags::Device, const string& device);
Impl(const Scope& other, Tags::SingleUseScope, const string& op_name);
Impl(const Scope& other, Tags::ExitOnError);
Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label);
Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op,
bool clear_colocations);
std::unordered_set<string> GetColocationConstraints(
const Operation& colocate_with_op) const;
// Helper functions to get a unique names.
string GetUniqueName(const string& prefix, bool check_single_use) const;
string GetNameForOp(const string& default_name) const;
bool single_use_scope() const { return scope_used_ != nullptr; }
// The graph, status, and name maps are shared by all child scopes
// created from a single 'root' scope. A root scope is created by calling the
// Scope::NewRootScope function, which creates a new graph, a new status and
// the name maps.
std::shared_ptr<Graph> graph_ = nullptr;
std::shared_ptr<Status> status_ = nullptr;
std::shared_ptr<NameMap> name_map_ = nullptr;
std::shared_ptr<ShapeRefiner> refiner_ = nullptr;
// If scope_used_ is not nullptr, op_name_ should be empty and
// GetUniqueNameForOp can only be called once on this scope. More calls to
// GetUniqueNameForOp will cause an error status to be set on this scope.
std::shared_ptr<bool> scope_used_ = nullptr;
const std::vector<Operation> control_deps_;
const string name_ = "";
const string op_name_ = "";
const bool exit_on_error_ = false;
const string kernel_label_ = "";
const string device_ = "";
const std::unordered_set<string> colocation_constraints_;
};
Scope::Scope(Impl* impl) : impl_(impl) {}
Scope::Scope(const Scope& other) : impl_(new Impl(*other.impl())) {}
Scope::~Scope() {}
Scope& Scope::operator=(const Scope& other) {
// We can't copy Impls because of the const members, use copy ctor instead
impl_.reset(new Impl(*other.impl_));
return *this;
}
Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
ShapeRefiner* refiner)
: graph_(graph),
status_(status),
name_map_(name_map),
refiner_(refiner),
scope_used_(nullptr),
colocation_constraints_() {}
Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
const std::shared_ptr<Status>& status,
const std::shared_ptr<NameMap>& name_map,
const std::shared_ptr<ShapeRefiner>& refiner)
: graph_(graph),
status_(status),
name_map_(name_map),
@ -34,143 +135,145 @@ Scope::Scope(Graph* graph, Status* status, Scope::NameMap* name_map,
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
ShapeRefiner* refiner = new ShapeRefiner(graph->op_registry());
return Scope(graph, new Status, new Scope::NameMap, refiner);
ShapeRefiner* refiner =
new ShapeRefiner(graph->versions().producer(), graph->op_registry());
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
}
Scope::Scope(const Scope& other, Scope::Tags::ScopeName, const string& name,
bool copy_names)
: graph_(other.graph_),
status_(other.status_),
name_map_(copy_names ? other.name_map_
Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
bool copy_names)
: graph_(other.impl()->graph_),
status_(other.impl()->status_),
name_map_(copy_names ? other.impl()->name_map_
: std::shared_ptr<NameMap>(new NameMap)),
refiner_(other.refiner_),
refiner_(other.impl()->refiner_),
scope_used_(nullptr),
control_deps_(other.control_deps_),
control_deps_(other.impl()->control_deps_),
name_(name),
op_name_(""),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::OpName, const string& name,
const string& op_name)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
refiner_(other.refiner_),
scope_used_(other.scope_used_),
control_deps_(other.control_deps_),
Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
const string& op_name)
: graph_(other.impl()->graph_),
status_(other.impl()->status_),
name_map_(other.impl()->name_map_),
refiner_(other.impl()->refiner_),
scope_used_(other.impl()->scope_used_),
control_deps_(other.impl()->control_deps_),
name_(name),
op_name_(op_name),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::ControlDeps,
std::vector<Operation> control_deps, bool clear_control_deps)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
refiner_(other.refiner_),
scope_used_(other.scope_used_),
control_deps_(clear_control_deps
? std::vector<Operation>()
: (control_deps.insert(control_deps.begin(),
other.control_deps_.begin(),
other.control_deps_.end()),
control_deps)),
name_(other.name_),
op_name_(other.op_name_),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
std::vector<Operation> control_deps, bool clear_control_deps)
: graph_(other.impl()->graph_),
status_(other.impl()->status_),
name_map_(other.impl()->name_map_),
refiner_(other.impl()->refiner_),
scope_used_(other.impl()->scope_used_),
control_deps_(
clear_control_deps
? std::vector<Operation>()
: (control_deps.insert(control_deps.begin(),
other.impl()->control_deps_.begin(),
other.impl()->control_deps_.end()),
control_deps)),
name_(other.impl()->name_),
op_name_(other.impl()->op_name_),
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::Device, const string& device)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
refiner_(other.refiner_),
scope_used_(other.scope_used_),
control_deps_(other.control_deps_),
name_(other.name_),
op_name_(other.op_name_),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
: graph_(other.impl()->graph_),
status_(other.impl()->status_),
name_map_(other.impl()->name_map_),
refiner_(other.impl()->refiner_),
scope_used_(other.impl()->scope_used_),
control_deps_(other.impl()->control_deps_),
name_(other.impl()->name_),
op_name_(other.impl()->op_name_),
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(device),
colocation_constraints_(other.colocation_constraints_) {}
colocation_constraints_(other.impl()->colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::SingleUseScope,
const string& op_name)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
refiner_(other.refiner_),
Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
const string& op_name)
: graph_(other.impl()->graph_),
status_(other.impl()->status_),
name_map_(other.impl()->name_map_),
refiner_(other.impl()->refiner_),
scope_used_(new bool(false)),
control_deps_(other.control_deps_),
name_(other.name_),
control_deps_(other.impl()->control_deps_),
name_(other.impl()->name_),
op_name_(op_name),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::ExitOnError)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
refiner_(other.refiner_),
scope_used_(other.scope_used_),
control_deps_(other.control_deps_),
name_(other.name_),
op_name_(other.op_name_),
Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
: graph_(other.impl()->graph_),
status_(other.impl()->status_),
name_map_(other.impl()->name_map_),
refiner_(other.impl()->refiner_),
scope_used_(other.impl()->scope_used_),
control_deps_(other.impl()->control_deps_),
name_(other.impl()->name_),
op_name_(other.impl()->op_name_),
exit_on_error_(true),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::KernelLabel,
const string& kernel_label)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
refiner_(other.refiner_),
scope_used_(other.scope_used_),
control_deps_(other.control_deps_),
name_(other.name_),
op_name_(other.op_name_),
exit_on_error_(other.exit_on_error_),
Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
const string& kernel_label)
: graph_(other.impl()->graph_),
status_(other.impl()->status_),
name_map_(other.impl()->name_map_),
refiner_(other.impl()->refiner_),
scope_used_(other.impl()->scope_used_),
control_deps_(other.impl()->control_deps_),
name_(other.impl()->name_),
op_name_(other.impl()->op_name_),
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(kernel_label),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::Colocate,
const Operation& colocate_with_op, bool clear_colocations)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
refiner_(other.refiner_),
scope_used_(other.scope_used_),
control_deps_(other.control_deps_),
name_(other.name_),
op_name_(other.op_name_),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(other.device_),
Scope::Impl::Impl(const Scope& other, Tags::Colocate,
const Operation& colocate_with_op, bool clear_colocations)
: graph_(other.impl()->graph_),
status_(other.impl()->status_),
name_map_(other.impl()->name_map_),
refiner_(other.impl()->refiner_),
scope_used_(other.impl()->scope_used_),
control_deps_(other.impl()->control_deps_),
name_(other.impl()->name_),
op_name_(other.impl()->op_name_),
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
: other.GetColocationConstraints(colocate_with_op)) {}
: other.impl()->GetColocationConstraints(colocate_with_op)) {}
std::unordered_set<string> Scope::GetColocationConstraints(
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const {
std::unordered_set<string> current_constraints(colocation_constraints_);
const NodeDef& node_def = colocate_with_op.node()->def();
const AttrSlice attrs = colocate_with_op.node()->attrs();
std::vector<string> node_constraints;
if (GetNodeAttr(node_def, kColocationAttrName, &node_constraints).ok()) {
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
for (const string& entry : node_constraints) {
StringPiece s(entry);
if (s.Consume(kColocationGroupPrefix)) {
@ -183,45 +286,59 @@ std::unordered_set<string> Scope::GetColocationConstraints(
return current_constraints;
}
bool Scope::ok() const { return impl()->status_->ok(); }
Graph* Scope::graph() const { return impl()->graph_.get(); }
std::shared_ptr<Graph> Scope::graph_as_shared_ptr() const {
return impl()->graph_;
}
Status Scope::status() const { return *impl()->status_; }
const std::vector<Operation>& Scope::control_deps() const {
return impl()->control_deps_;
}
void Scope::UpdateStatus(const Status s) const {
status_->Update(s);
if (exit_on_error_ && !status_->ok()) {
LOG(FATAL) << *status_;
impl()->status_->Update(s);
if (impl()->exit_on_error_ && !ok()) {
LOG(FATAL) << *impl()->status_;
}
}
Status Scope::ToGraphDef(GraphDef* gdef) const {
if (!status_->ok()) {
return *status_;
if (!ok()) {
return *impl()->status_;
}
graph()->ToGraphDef(gdef);
return Status::OK();
}
Status Scope::ToGraph(Graph* g) const {
if (status_->ok()) {
if (ok()) {
GraphDef graph_def;
graph()->ToGraphDef(&graph_def);
GraphConstructorOptions opts;
UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g));
}
return *status_;
return *impl()->status_;
}
void Scope::UpdateBuilder(NodeBuilder* builder) const {
std::vector<Node*> control_inputs;
for (const auto& op : control_deps_) {
for (const auto& op : impl()->control_deps_) {
control_inputs.push_back(op.node());
}
builder->ControlInputs(control_inputs);
if (!kernel_label_.empty()) {
builder->Attr("_kernel", kernel_label_);
if (!impl()->kernel_label_.empty()) {
builder->Attr("_kernel", impl()->kernel_label_);
}
if (!colocation_constraints_.empty()) {
std::vector<string> constraints(colocation_constraints_.begin(),
colocation_constraints_.end());
if (!impl()->colocation_constraints_.empty()) {
std::vector<string> constraints(impl()->colocation_constraints_.begin(),
impl()->colocation_constraints_.end());
// Sort the set.
std::sort(constraints.begin(), constraints.end());
// Add loc:@ prefix
@ -231,12 +348,13 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const {
});
builder->Attr(kColocationAttrName, constraints);
}
if (!device_.empty()) {
builder->Device(device_);
if (!impl()->device_.empty()) {
builder->Device(impl()->device_);
}
}
string Scope::GetUniqueName(const string& prefix, bool check_single_use) const {
string Scope::Impl::GetUniqueName(const string& prefix,
bool check_single_use) const {
if (check_single_use && single_use_scope()) {
if (*scope_used_) {
*status_ =
@ -256,7 +374,7 @@ string Scope::GetUniqueName(const string& prefix, bool check_single_use) const {
return unique_name;
}
string Scope::GetNameForOp(const string& default_name) const {
string Scope::Impl::GetNameForOp(const string& default_name) const {
const string unique_name =
GetUniqueName(default_name, true /* check_single_use */);
const string sep = name_.empty() || unique_name.empty() ? "" : "/";
@ -264,96 +382,125 @@ string Scope::GetNameForOp(const string& default_name) const {
}
string Scope::GetUniqueNameForOp(const string& default_name) const {
if (single_use_scope()) {
if (op_name_.empty() || *scope_used_) {
*status_ =
if (impl()->single_use_scope()) {
if (impl()->op_name_.empty() || *impl()->scope_used_) {
*impl()->status_ =
errors::InvalidArgument("Cannot get a unique name in this scope");
return "";
}
*scope_used_ = true;
return op_name_;
*impl()->scope_used_ = true;
return impl()->op_name_;
}
return op_name_.empty() ? GetNameForOp(default_name) : GetNameForOp(op_name_);
return impl()->op_name_.empty() ? impl()->GetNameForOp(default_name)
: impl()->GetNameForOp(impl()->op_name_);
}
Scope Scope::NewSubScope(const string& child_scope_name) const {
if (child_scope_name.empty()) {
return Scope(*this, Scope::Tags::ScopeName(), name_, true /* copy_names */);
return Scope(new Impl(*this, Impl::Tags::ScopeName(), impl()->name_,
true /* copy_names */));
}
const string unique_name =
GetUniqueName(child_scope_name, false /* check_single_use */);
const string sep = name_.empty() || unique_name.empty() ? "" : "/";
return Scope(*this, Scope::Tags::ScopeName(),
strings::StrCat(name_, sep, unique_name),
false /* copy_names */);
impl()->GetUniqueName(child_scope_name, false /* check_single_use */);
const string sep = impl()->name_.empty() || unique_name.empty() ? "" : "/";
return Scope(new Impl(*this, Impl::Tags::ScopeName(),
strings::StrCat(impl()->name_, sep, unique_name),
false /* copy_names */));
}
Scope Scope::WithOpName(const string& op_name) const {
if (single_use_scope()) {
if (impl()->single_use_scope()) {
UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name,
" on this scope"));
return *this;
}
return Scope(*this, Scope::Tags::OpName(), name_, op_name);
return Scope(new Impl(*this, Impl::Tags::OpName(), impl()->name_, op_name));
}
Scope Scope::WithControlDependencies(
const gtl::ArraySlice<Operation>& control_deps) const {
return Scope(*this, Scope::Tags::ControlDeps(),
return Scope(
new Impl(*this, Impl::Tags::ControlDeps(),
std::vector<Operation>(control_deps.begin(), control_deps.end()),
/* clear_control_deps */ false);
/* clear_control_deps */ false));
}
Scope Scope::WithControlDependencies(const Output& control_dep) const {
return Scope(*this, Scope::Tags::ControlDeps(),
std::vector<Operation>(1, control_dep.op()),
/* clear_control_deps */ false);
return Scope(new Impl(*this, Impl::Tags::ControlDeps(),
std::vector<Operation>(1, control_dep.op()),
/* clear_control_deps */ false));
}
Scope Scope::WithNoControlDependencies() const {
return Scope(*this, Scope::Tags::ControlDeps(), std::vector<Operation>(),
/* clear_control_deps */ true);
return Scope(new Impl(*this, Impl::Tags::ControlDeps(),
std::vector<Operation>(),
/* clear_control_deps */ true));
}
Scope Scope::WithDevice(const string& device) const {
return Scope(*this, Scope::Tags::Device(), device);
return Scope(new Impl(*this, Impl::Tags::Device(), device));
}
Scope Scope::ColocateWith(const Operation& op) const {
return Scope(*this, Scope::Tags::Colocate(), op,
/* clear_colocations */ false);
return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
/* clear_colocations */ false));
}
Scope Scope::ClearColocation() const {
return Scope(*this, Scope::Tags::Colocate(), Operation(),
/* clear_colocations */ true);
return Scope(new Impl(*this, Impl::Tags::Colocate(), Operation(),
/* clear_colocations */ true));
}
Scope Scope::ExitOnError() const {
return Scope(*this, Scope::Tags::ExitOnError());
return Scope(new Impl(*this, Impl::Tags::ExitOnError()));
}
Scope Scope::WithKernelLabel(const string& kernel_label) const {
return Scope(*this, Scope::Tags::KernelLabel(), kernel_label);
return Scope(new Impl(*this, Impl::Tags::KernelLabel(), kernel_label));
}
CompositeOpScopes Scope::GetCompositeOpScopes(
const string& composite_op_name) const {
if (op_name_.empty() && composite_op_name.empty()) {
if (impl()->op_name_.empty() && composite_op_name.empty()) {
UpdateStatus(errors::InvalidArgument(
"Cannot create composite op scopes with empty name"));
return {*this, *this};
}
if (!single_use_scope()) {
Scope child = NewSubScope(op_name_.empty() ? composite_op_name : op_name_);
const string child_op_sep = name_.empty() ? "" : "_";
return {child, Scope(child, Scope::Tags::SingleUseScope(),
strings::StrCat(name_, child_op_sep, child.name_))};
if (!impl()->single_use_scope()) {
Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name
: impl()->op_name_);
const string child_op_sep = impl()->name_.empty() ? "" : "_";
const string child_name =
strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_);
return {child,
Scope(new Impl(child, Impl::Tags::SingleUseScope(), child_name))};
} else {
return {
Scope(*this, Scope::Tags::ScopeName(), op_name_, true /* copy_names */),
*this};
return {Scope(new Impl(*this, Impl::Tags::ScopeName(), impl()->op_name_,
true /* copy_names */)),
*this};
}
}
class InternalScope {
public:
// NewScope doesn't take ownership of the inputs.
static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap;
for (const Node* node : graph->nodes()) {
(*name_map)[node->name()] = 0;
}
// We provide null destructors for these shared ptrs (except for name_map)
// since the caller owns them and doesn't want the scope to destroy them.
return Scope(new Scope::Impl(
std::shared_ptr<Graph>(graph, [](Graph*) {}),
std::shared_ptr<Status>(status, [](Status*) {}),
std::shared_ptr<Scope::Impl::NameMap>(name_map),
std::shared_ptr<ShapeRefiner>(refiner, [](ShapeRefiner*) {})));
}
};
Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
return InternalScope::NewScope(graph, status, refiner);
}
} // namespace tensorflow

View File

@ -23,16 +23,19 @@ limitations under the License.
#include <vector>
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
class Graph;
class GraphDef;
class NodeBuilder;
struct CompositeOpScopes;
/// @addtogroup core
/// @{
/// A `Scope` object represents a set of related TensorFlow ops that have the
/// same properties such as a common name prefix.
///
@ -91,6 +94,10 @@ struct CompositeOpScopes;
/// op-constructor functions on the same `Scope` object.
class Scope {
public:
Scope(const Scope& other);
~Scope();
Scope& operator=(const Scope& other);
// The following functions are for users making graphs. They return brand new
// scopes, or scopes derived from an existing scope object.
@ -161,20 +168,21 @@ class Scope {
// START_SKIP_DOXYGEN
/// Update the builder with properties accumulated in this scope.
// TODO(skyewm): NodeBuilder is not part of public API
void UpdateBuilder(NodeBuilder* builder) const;
// END_SKIP_DOXYGEN
CompositeOpScopes GetCompositeOpScopes(const string& composite_op_name) const;
bool ok() const { return status_->ok(); }
bool ok() const;
Graph* graph() const { return graph_.get(); }
// TODO(skyewm): Graph is not part of public API
Graph* graph() const;
ShapeRefiner* refiner() const { return refiner_.get(); }
// TODO(skyewm): Graph is not part of public API
std::shared_ptr<Graph> graph_as_shared_ptr() const;
std::shared_ptr<Graph> graph_as_shared_ptr() const { return graph_; }
Status status() const { return *status_; }
Status status() const;
/// If status() is Status::OK(), convert the Graph object stored in this scope
/// to a GraphDef proto and return Status::OK(). Otherwise, return the error
@ -193,74 +201,15 @@ class Scope {
Status ToGraph(Graph* g) const;
// END_SKIP_DOXYGEN
const std::vector<Operation>& control_deps() const { return control_deps_; }
const std::vector<Operation>& control_deps() const;
private:
// Tag types to choose the constructor to dispatch.
struct Tags {
enum class ScopeName;
enum class OpName;
enum class ControlDeps;
enum class Device;
enum class SingleUseScope;
enum class ExitOnError;
enum class KernelLabel;
enum class Colocate;
};
// A NameMap is used to keep track of suffixes for names used in a scope. A
// name that has not been used so far in a scope will get no suffix. Later
// uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes
// can share the same NameMap. For instance, a new scope created using
// WithControlDependencies() should would share the same NameMap with the
// parent.
typedef std::unordered_map<string, int> NameMap;
Scope(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner);
Scope(const Scope& other, Tags::ScopeName, const string& name,
bool copy_names);
Scope(const Scope& other, Tags::OpName, const string& name,
const string& op_name);
Scope(const Scope& other, Tags::ControlDeps,
std::vector<Operation> control_deps, bool clear_control_deps);
Scope(const Scope& other, Tags::Device, const string& device);
Scope(const Scope& other, Tags::SingleUseScope, const string& op_name);
Scope(const Scope& other, Tags::ExitOnError);
Scope(const Scope& other, Tags::KernelLabel, const string& kernel_label);
Scope(const Scope& other, Tags::Colocate, const Operation& colocate_with_op,
bool clear_colocations);
std::unordered_set<string> GetColocationConstraints(
const Operation& colocate_with_op) const;
// Helper functions to get a unique names.
string GetUniqueName(const string& prefix, bool check_single_use) const;
string GetNameForOp(const string& default_name) const;
bool single_use_scope() const { return scope_used_ != nullptr; }
// The graph, status, and name maps are shared by all child scopes
// created from a single 'root' scope. A root scope is created by calling the
// Scope::NewRootScope function, which creates a new graph, a new status and
// the name maps.
std::shared_ptr<Graph> graph_ = nullptr;
std::shared_ptr<Status> status_ = nullptr;
std::shared_ptr<NameMap> name_map_ = nullptr;
std::shared_ptr<ShapeRefiner> refiner_ = nullptr;
// If scope_used_ is not nullptr, op_name_ should be empty and
// GetUniqueNameForOp can only be called once on this scope. More calls to
// GetUniqueNameForOp will cause an error status to be set on this scope.
std::shared_ptr<bool> scope_used_ = nullptr;
const std::vector<Operation> control_deps_;
const string name_ = "";
const string op_name_ = "";
const bool exit_on_error_ = false;
const string kernel_label_ = "";
const string device_ = "";
const std::unordered_set<string> colocation_constraints_;
friend class InternalScope;
class Impl;
std::unique_ptr<Impl> impl_;
Impl* impl() { return impl_.get(); }
const Impl* impl() const { return impl_.get(); }
explicit Scope(Impl*);
};
/// A helper struct to hold the scopes that would be used by a function
@ -273,6 +222,8 @@ struct CompositeOpScopes {
Scope last;
};
/// @}
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_

View File

@ -0,0 +1,33 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
#include "tensorflow/cc/framework/scope.h"
namespace tensorflow {
class ShapeRefiner;
// NewInternalScope returns a new scope which doesn't take ownership of
// graph, status, name_map, and refiner.
// This is intended to enable the C API (which are used by other language
// bindings) to create a Scope and access C++ functionality (i.e. gradients).
Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/cc/framework/testutil.h"
#include <utility>
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/default_device.h"
@ -30,7 +32,7 @@ void GetTensors(const Scope& scope, OutputList tensors,
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
std::vector<Tensor> outputs;
GetTensors(scope, {tensor}, &outputs);
GetTensors(scope, {std::move(tensor)}, &outputs);
*out = outputs[0];
}

View File

@ -0,0 +1,52 @@
# C++ gradients
Gradients are currently being ported from
[python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/ops)
to C++ (in this directory).
Contributions are welcome and much appreciated; please follow the instructions
below.
1. Create the op gradient function in `foo_grad.cc` corresponding to the
`foo_grad.py` file where the op originated (i.e. `array_grad.py` op
gradients should be written in `array_grad.cc`).
2. Write the op gradient with the following naming scheme:
Status OpNameGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
...
return scope.status();
}
REGISTER_GRADIENT_OP("OpName", OpNameGrad);
3. Ops gradients are implemented by using the [C++
API](https://www.tensorflow.org/api_docs/cc/).
4. Tests should be included in `foo_grad_test.cc`. Please see
[`array_grad_test.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/gradients/array_grad_test.cc)
for an many examples. Tests are as simple as, creating a placeholder input
for the op's inputs and calling `RunTest` (`RunTest` uses a [gradient
checker](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/framework/gradient_checker.cc)
to verify that the theoretical gradient matches the numeric gradient). For
example:
TEST_F(ArrayGradTest, IdentityGrad) {
TensorShape shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = Identity(scope_, x);
RunTest(x, shape, y, shape);
}
NOTE: There are some ops that require features from the C++ API that are not yet
implemented.
* Ops that require PartialTensorShape information cannot yet be implemented.
* Ops that require SparseTensor or IndexSlices (currently only in python)
cannot yet be implemented.
* Maybe more.
For questions: Please create an issue assigned to suharshs.

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/ops/array_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
@ -42,9 +43,9 @@ Status PackGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int N;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
int axis;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
grad_outputs->reserve(N);
auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
@ -59,7 +60,7 @@ Status UnpackGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int axis;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
return scope.status();
}
@ -89,6 +90,16 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
grad_outputs->push_back(Identity(scope, grad_inputs[0]));
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
Status SplitGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
@ -150,9 +161,12 @@ REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
Status CheckNumericsGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
grad_outputs->push_back(CheckNumerics(
scope, grad_inputs[0],
"Not a number (NaN) or infinity (Inf) values detected in gradient."));
string message;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
string err_msg = strings::StrCat(
"Not a number (NaN) or infinity (Inf) values detected in gradient. ",
message);
grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
return scope.status();
}
REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
@ -201,9 +215,9 @@ Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
std::vector<Output>* grad_outputs) {
auto seq_lengths = op.input(1);
int batch_dim;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
int seq_dim;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
grad_outputs->push_back(
ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
ReverseSequence::BatchDim(batch_dim)));
@ -253,7 +267,8 @@ Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int block_size;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
grad_outputs->push_back(
BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
grad_outputs->push_back(NoGradient());
@ -276,7 +291,8 @@ Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int block_size;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
grad_outputs->push_back(
SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
grad_outputs->push_back(NoGradient());
@ -299,7 +315,8 @@ Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int block_size;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
return scope.status();
}
@ -309,7 +326,8 @@ Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int block_size;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
return scope.status();
}
@ -319,7 +337,7 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
string mode;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
scope, grad_inputs[0], op.input(1), mode));
grad_outputs->push_back(NoGradient());
@ -332,7 +350,7 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
string mode;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
grad_outputs->push_back(NoGradient());
return scope.status();

View File

@ -21,6 +21,17 @@ namespace tensorflow {
namespace ops {
namespace {
// Conjugate helper function returns the conjugate of an Output if it
// is complex valued.
Output ConjugateHelper(const Scope& scope, const Output& out) {
DataType dtype = out.type();
if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
return Conj(scope, out);
} else {
return out;
}
}
// TODO(andydavis) Add control dependencies to gradient functions (as needed).
Status AbsGrad(const Scope& scope, const Operation& op,
@ -44,9 +55,11 @@ REGISTER_GRADIENT_OP("Neg", NegGrad);
Status InvGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// dx = dy * (-1 * (y * y))
// dy/dx = -1/x^2 = -y^2
auto dydx = Neg(scope, Square(scope, op.output(0)));
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], Neg(scope, Square(scope, op.output(0)))));
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Inv", InvGrad);
@ -55,10 +68,12 @@ REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
Status SquareGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// dx = dy * (2 * x)
// dy/dx = (2 * x)
auto two = Cast(scope, Const(scope, 2), op.input(0).type());
auto dydx = Mul(scope, two, op.input(0));
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], Mul(scope, two, op.input(0))));
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Square", SquareGrad);
@ -68,11 +83,12 @@ Status SqrtGrad(const Scope& scope, const Operation& op,
std::vector<Output>* grad_outputs) {
// y = sqrt(x)
// dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y)
// dx = dy * (0.5 * (1 / y))
auto y_inv = Reciprocal(scope, op.output(0));
auto half = Cast(scope, Const(scope, 0.5), op.input(0).type());
auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv));
grad_outputs->push_back(dx);
auto dydx = Mul(scope, half, y_inv);
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
@ -82,14 +98,14 @@ Status RsqrtGrad(const Scope& scope, const Operation& op,
std::vector<Output>* grad_outputs) {
// y = 1/x^1/2 = x^-1/2
// dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1
// dx = dy * (-1/2 * y * x^-1)
auto x_inv = Reciprocal(scope, op.input(0));
auto y = op.output(0);
auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type());
auto a = Mul(scope, neghalf, x_inv);
auto b = Mul(scope, a, y);
auto dx = Mul(scope, grad_inputs[0], b);
grad_outputs->push_back(dx);
auto dydx = Mul(scope, a, y);
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
@ -97,10 +113,11 @@ REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
Status ExpGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// y = exp(x)
// dy/dx = exp(x)
// dx = dy * y
grad_outputs->push_back(Mul(scope, grad_inputs[0], op.output(0)));
// dy/dx = exp(x) = y
// grad(x) = grad(y) * conj(dy/dx)
// = grad(y) * conj(y)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
return scope.status();
}
REGISTER_GRADIENT_OP("Exp", ExpGrad);
@ -108,10 +125,12 @@ REGISTER_GRADIENT_OP("Exp", ExpGrad);
Status Expm1Grad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// f(x) = expm1(x)
// df/dx = exp(x)
// dx = dy * exp(x)
grad_outputs->push_back(Mul(scope, grad_inputs[0], Exp(scope, op.input(0))));
// y = expm1(x)
// dy/dx = exp(x)
auto dydx = Exp(scope, op.input(0));
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
@ -119,11 +138,12 @@ REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
Status LogGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// f(x) = log(x) = y
// df/dx = 1 / x
// dx = dy * (1 / x)
// y = log(x)
// dy/dx = 1 / x
auto dydx = Reciprocal(scope, op.input(0));
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], Reciprocal(scope, op.input(0))));
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Log", LogGrad);
@ -131,26 +151,54 @@ REGISTER_GRADIENT_OP("Log", LogGrad);
Status Log1pGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// f(x) = log1p(x) = y
// df/dx = 1 / (1 + x)
// dx = dy * (1 / (1 + x))
// y = log1p(x)
// dy/dx = 1 / (1 + x)
auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Div(scope, grad_inputs[0], Add(scope, one, op.input(0))));
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
Status SinhGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// y = sinh(x)
// dy/dx = cosh(x)
auto dydx = Cosh(scope, op.input(0));
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Sinh", SinhGrad);
Status CoshGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// y = cosh(x)
// dy/dx = sinh(x)
auto dydx = Sinh(scope, op.input(0));
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Cosh", CoshGrad);
Status TanhGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// y = tanh(x)
// dy/dx = 1 - (tanh(x))^2 = 1 - y^2
// dx = dy * (1 - y^2)
auto y2 = Square(scope, op.output(0));
auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
auto dx = Mul(scope, grad_inputs[0], Sub(scope, one, y2));
grad_outputs->push_back(dx);
auto dydx = Sub(scope, one, y2);
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Tanh", TanhGrad);
@ -160,11 +208,13 @@ Status SigmoidGrad(const Scope& scope, const Operation& op,
std::vector<Output>* grad_outputs) {
// y = 1 / (1 + exp(-x))
// dy/dx = y * (1 - y)
// dx = dy * y * (1 - y)
auto y = op.output(0);
auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
auto dx = Mul(scope, grad_inputs[0], Mul(scope, y, Sub(scope, one, y)));
grad_outputs->push_back(dx);
auto dydx = Mul(scope, y, Sub(scope, one, y));
// dx = dy * y * (1 - y)
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
@ -185,9 +235,10 @@ Status SinGrad(const Scope& scope, const Operation& op,
std::vector<Output>* grad_outputs) {
// y = sin(x)
// dy/dx = cos(x)
// dx = dy * cos(x)
auto dx = Mul(scope, grad_inputs[0], Cos(scope, op.input(0)));
grad_outputs->push_back(dx);
auto dydx = Cos(scope, op.input(0));
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Sin", SinGrad);
@ -197,9 +248,10 @@ Status CosGrad(const Scope& scope, const Operation& op,
std::vector<Output>* grad_outputs) {
// y = cos(x)
// dy/dx = -sin(x)
// dx = dy * -sin(x)
auto dx = Mul(scope, grad_inputs[0], Neg(scope, Sin(scope, op.input(0))));
grad_outputs->push_back(dx);
auto dydx = Neg(scope, Sin(scope, op.input(0)));
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
REGISTER_GRADIENT_OP("Cos", CosGrad);
@ -208,12 +260,12 @@ Status AsinGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// y = asin(x)
// dy/dx = 1 / (1 - x * x)^1/2
// dx = dy * (1 / (1 - x * x)^1/2)
// dy/dx = 1 / sqrt(1 - x^2)
auto x2 = Square(scope, op.input(0));
auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
auto dx = Mul(scope, grad_inputs[0], dydx);
// grad(x) = grad(y) * conj(dy/dx)
auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
grad_outputs->push_back(dx);
return scope.status();
}
@ -239,9 +291,9 @@ Status TanGrad(const Scope& scope, const Operation& op,
std::vector<Output>* grad_outputs) {
// y = tan(x)
// dy/dx = sec(x)^2 = 1 / cos(x)^2
// dx = dy * (1 / cos(x)^2)
auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
auto dx = Mul(scope, grad_inputs[0], dydx);
// grad(x) = grad(y) * conj(dy/dx)
auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
grad_outputs->push_back(dx);
return scope.status();
}
@ -324,7 +376,7 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op,
const string& attr_adj_x, const string& attr_adj_y,
std::vector<Output>* grad_outputs) {
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype));
if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
return errors::Unimplemented(
"MatMul gradient for complex data type is not supported yet.");
@ -332,8 +384,10 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op,
bool ta;
bool tb;
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
TF_RETURN_IF_ERROR(
GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta));
TF_RETURN_IF_ERROR(
GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb));
if (!ta && !tb) {
return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),

View File

@ -45,6 +45,8 @@ class CWiseUnaryGradTest : public ::testing::Test {
EXPM1,
LOG,
LOG1P,
SINH,
COSH,
TANH,
SIGMOID,
SIGN,
@ -56,23 +58,25 @@ class CWiseUnaryGradTest : public ::testing::Test {
ATAN
};
void TestCWiseGrad(UnaryOpType op_type, std::function<float(int)> x_fn,
std::function<float(float)> dy_fn,
std::function<float(float, float)> dx_fn) {
Tensor x(DT_FLOAT, {2, 3, 2});
auto x_flat = x.flat<float>();
template <typename T>
void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn,
const std::function<T(const T&)>& dy_fn,
const std::function<T(const T&, const T&)>& dx_fn) {
DataType dtype = DataTypeToEnum<T>::v();
Tensor x(dtype, {2, 3, 2});
auto x_flat = x.flat<T>();
for (int i = 0; i < x_flat.size(); ++i) {
x_flat(i) = x_fn(i);
}
Tensor dy(DT_FLOAT, {2, 3, 2});
auto dy_flat = dy.flat<float>();
Tensor dy(dtype, {2, 3, 2});
auto dy_flat = dy.flat<T>();
for (int i = 0; i < dy_flat.size(); ++i) {
dy_flat(i) = dy_fn(x_flat(i));
}
Tensor dx(DT_FLOAT, {2, 3, 2});
auto dx_flat = dx.flat<float>();
Tensor dx(dtype, {2, 3, 2});
auto dx_flat = dx.flat<T>();
for (int i = 0; i < dx_flat.size(); ++i) {
dx_flat(i) = dx_fn(x_flat(i), dy_flat(i));
}
@ -109,6 +113,12 @@ class CWiseUnaryGradTest : public ::testing::Test {
case LOG1P:
y = Log1p(scope_, x);
break;
case SINH:
y = Sinh(scope_, x);
break;
case COSH:
y = Cosh(scope_, x);
break;
case TANH:
y = Tanh(scope_, x);
break;
@ -146,7 +156,19 @@ class CWiseUnaryGradTest : public ::testing::Test {
test::ExpectClose(output, dx);
}
float RV(std::vector<float> v) { return v[random::New64() % v.size()]; }
float RV(const std::vector<float>& v) {
return v[random::New64() % v.size()];
}
complex64 CRV(const std::vector<complex64>& v) {
return v[random::New64() % v.size()];
}
complex64 conjugate(const complex64& val) {
return complex64(val.real(), -val.imag());
}
const complex64 one_{1.0, 0};
Scope scope_;
};
@ -155,14 +177,14 @@ TEST_F(CWiseUnaryGradTest, Abs) {
auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
auto dx_fn = [this](const float x, const float dy) { return x * dy; };
TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Neg) {
auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
auto dx_fn = [this](const float x, const float dy) { return -dy; };
TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Reciprocal) {
@ -171,14 +193,36 @@ TEST_F(CWiseUnaryGradTest, Reciprocal) {
auto dx_fn = [this](const float x, const float dy) {
return -(1 / (x * x)) * dy;
};
TestCWiseGrad(INV, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
auto dy_fn = [this](const complex64 x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64 x, const complex64 dy) {
return -conjugate(one_ / (x * x)) * dy;
};
TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Square) {
auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); };
auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; };
TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Square_Complex) {
auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return conjugate(complex64(2, 0) * x) * dy;
};
TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Sqrt) {
@ -187,7 +231,18 @@ TEST_F(CWiseUnaryGradTest, Sqrt) {
auto dx_fn = [this](const float x, const float dy) {
return dy * 0.5 * (1.0 / std::sqrt(x));
};
TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy;
};
TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Rsqrt) {
@ -196,7 +251,18 @@ TEST_F(CWiseUnaryGradTest, Rsqrt) {
auto dx_fn = [this](const float x, const float dy) {
return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x);
};
TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy;
};
TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Exp) {
@ -205,7 +271,18 @@ TEST_F(CWiseUnaryGradTest, Exp) {
auto dx_fn = [this](const float x, const float dy) {
return dy * std::exp(x);
};
TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Exp_Complex) {
auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy * conjugate(std::exp(x));
};
TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Expm1) {
@ -214,14 +291,36 @@ TEST_F(CWiseUnaryGradTest, Expm1) {
auto dx_fn = [this](const float x, const float dy) {
return dy * std::exp(x);
};
TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy * conjugate(std::exp(x));
};
TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Log) {
auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); };
TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Log_Complex) {
auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy * conjugate(one_ / x);
};
TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Log1p) {
@ -230,7 +329,64 @@ TEST_F(CWiseUnaryGradTest, Log1p) {
auto dx_fn = [this](const float x, const float dy) {
return dy * (1.0 / (1.0 + x));
};
TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
auto x_fn = [this](const int i) {
return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy / (one_ + conjugate(x));
};
TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Sinh) {
auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
auto dx_fn = [this](const float x, const float dy) {
return dy * std::cosh(x);
};
TestCWiseGrad<float>(SINH, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy * conjugate(std::cosh(x));
};
TestCWiseGrad<complex64>(SINH, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Cosh) {
auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
auto dx_fn = [this](const float x, const float dy) {
return dy * std::sinh(x);
};
TestCWiseGrad<float>(COSH, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy * conjugate(std::sinh(x));
};
TestCWiseGrad<complex64>(COSH, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Tanh) {
@ -240,7 +396,21 @@ TEST_F(CWiseUnaryGradTest, Tanh) {
const float y = std::tanh(x);
return dy * (1.0 - y * y);
};
TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
const complex64 y = std::tanh(x);
return dy * conjugate((one_ - y * y));
};
TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Sigmoid) {
@ -250,14 +420,28 @@ TEST_F(CWiseUnaryGradTest, Sigmoid) {
const float y = 1.0 / (1.0 + std::exp(-x));
return dy * y * (1.0 - y);
};
TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
const complex64 y = one_ / (one_ + std::exp(-x));
return dy * conjugate(y * (one_ - y));
};
TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Sign) {
auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
auto dx_fn = [this](const float x, const float dy) { return 0.0; };
TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Sin) {
@ -266,7 +450,20 @@ TEST_F(CWiseUnaryGradTest, Sin) {
auto dx_fn = [this](const float x, const float dy) {
return dy * std::cos(x);
};
TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Sin_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy * conjugate(std::cos(x));
};
TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Cos) {
@ -275,7 +472,20 @@ TEST_F(CWiseUnaryGradTest, Cos) {
auto dx_fn = [this](const float x, const float dy) {
return dy * -1.0 * std::sin(x);
};
TestCWiseGrad(COS, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Cos_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy * conjugate(-std::sin(x));
};
TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Asin) {
@ -284,7 +494,24 @@ TEST_F(CWiseUnaryGradTest, Asin) {
auto dx_fn = [this](const float x, const float dy) {
return dy * (1.0 / std::sqrt(1.0 - x * x));
};
TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Asin_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy / conjugate(std::sqrt(one_ - x * x));
};
// TODO(kbsriram)
// Enable test when the asin kernel supports complex numbers
if (false) {
TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn);
}
}
TEST_F(CWiseUnaryGradTest, Acos) {
@ -293,7 +520,24 @@ TEST_F(CWiseUnaryGradTest, Acos) {
auto dx_fn = [this](const float x, const float dy) {
return dy * (-1.0 / std::sqrt(1.0 - x * x));
};
TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Acos_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy / -conjugate(std::sqrt(one_ - x * x));
};
// TODO(kbsriram)
// Add test when the acos kernel supports complex numbers
if (false) {
TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn);
}
}
TEST_F(CWiseUnaryGradTest, Tan) {
@ -303,7 +547,25 @@ TEST_F(CWiseUnaryGradTest, Tan) {
const float cosx = std::cos(x);
return dy * (1 / (cosx * cosx));
};
TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Tan_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
const complex64 cosx = std::cos(x);
return dy / conjugate(cosx * cosx);
};
// TODO(kbsriram)
// Enable when tan kernel supports complex inputs
if (false) {
TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn);
}
}
TEST_F(CWiseUnaryGradTest, Atan) {
@ -312,7 +574,24 @@ TEST_F(CWiseUnaryGradTest, Atan) {
auto dx_fn = [this](const float x, const float dy) {
return dy * (1 / (1 + x * x));
};
TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn);
TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn);
}
TEST_F(CWiseUnaryGradTest, Atan_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
auto dy_fn = [this](const complex64& x) {
return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
};
auto dx_fn = [this](const complex64& x, const complex64& dy) {
return dy / (one_ + x * x);
};
// TODO(kbsriram)
// Add test when the atan kernel supports complex numbers
if (false) {
TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn);
}
}
class CWiseUnaryComplexGradTest : public ::testing::Test {

View File

@ -23,6 +23,9 @@ limitations under the License.
namespace tensorflow {
namespace ops {
/// @defgroup const_op Const Op
/// @{
Output Const(const Scope& scope, const Input::Initializer& val);
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
@ -70,6 +73,8 @@ Output Const(const Scope& scope, const std::initializer_list<T>& v,
std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
const InputList& inp);
/// }@
} // namespace ops
} // namespace tensorflow

View File

@ -28,9 +28,9 @@ void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values,
TensorShape shape) {
EXPECT_TRUE(n->IsConstant());
Tensor tensor;
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor));
DataType dtype;
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
EXPECT_EQ(tensor.dtype(), dtype);
test::ExpectTensorEqual<T>(tensor, test::AsTensor(values, shape));
}
@ -39,9 +39,9 @@ void ExpectTypeAndShape(const Node* n, DataType expected_dtype,
TensorShape expected_shape) {
EXPECT_TRUE(n->IsConstant());
Tensor tensor;
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor));
DataType dtype;
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
EXPECT_EQ(dtype, expected_dtype);
EXPECT_EQ(expected_shape, TensorShape(tensor.shape()));
}

View File

@ -22,7 +22,7 @@ op { name: "Where" input_rename: { from: "input" to: "condition" } }
op { name: "ThreadUnsafeUnigramCandidateSampler", skip: true }
# control_flow_ops
# TODO(josh11b): Hide Switch and Merge once we write and migrate users to
# TODO(joshl): Hide Switch and Merge once we write and migrate users to
# a Cond() API.
#op { name: "Switch" hide: true }
#op { name: "Merge" hide: true }
@ -150,6 +150,12 @@ op { name: "Neg" rename_to: "Negate" alias: "Neg" }
op { name: "Prod" alias: "ReduceProd" input_rename: { from: "reduction_indices" to: "axis" } }
op { name: "Sub" rename_to: "Subtract" alias: "Sub" }
op { name: "Sum" alias: "ReduceSum" input_rename: { from: "reduction_indices" to: "axis" } }
op { name: "SigmoidGrad" hide: true }
op { name: "TanhGrad" hide: true }
op { name: "InvGrad" hide: true }
op { name: "ReciprocalGrad" hide: true }
op { name: "SqrtGrad" hide: true }
op { name: "RsqrtGrad" hide: true }
# *Grad ops get hidden, only for use by the gradient code.
op { name: "SigmoidGrad" hide: true }

View File

@ -9,7 +9,14 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load(
"//tensorflow:tensorflow.bzl",
"if_android",
"if_ios",
"if_mobile",
"if_not_mobile",
"tf_cc_test",
)
cc_library(
name = "constants",
@ -28,17 +35,33 @@ cc_library(
cc_library(
name = "loader",
hdrs = ["loader.h"],
deps = [
":loader_lite",
] + if_not_mobile([
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
]) + if_android([
"//tensorflow/core:android_tensorflow_lib",
]),
)
cc_library(
name = "loader_lite",
srcs = ["loader.cc"],
hdrs = ["loader.h"],
deps = [
":constants",
] + if_not_mobile([
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core/util/tensor_bundle:naming",
],
# mobile not supported yet
]),
)
tf_cc_test(
@ -66,6 +89,7 @@ filegroup(
name = "saved_model_half_plus_two",
srcs = glob([
"testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**",
]),
)

View File

@ -33,6 +33,9 @@ constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
/// SavedModel legacy init op key.
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
/// SavedModel main op key.
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
/// Directory in which to save the SavedModel variables.
constexpr char kSavedModelVariablesDirectory[] = "variables";

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/tensor_bundle/naming.h"
@ -36,7 +37,7 @@ auto* load_attempt_count = monitoring::Counter<2>::New(
"status");
auto* load_latency = monitoring::Counter<1>::New(
"/tensorflow/cc/saved_model/load_latency",
"Latency in microseconds for SavedModels that were succesfully loaded.",
"Latency in microseconds for SavedModels that were successfully loaded.",
"model_path");
constexpr char kLoadAttemptFail[] = "fail";
constexpr char kLoadAttemptSuccess[] = "success";
@ -106,6 +107,37 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir,
}
}
bool HasMainOp(const MetaGraphDef& meta_graph_def) {
const auto& collection_def_map = meta_graph_def.collection_def();
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
return true;
}
return false;
}
Status RunMainOp(const RunOptions& run_options, const string& export_dir,
const MetaGraphDef& meta_graph_def,
const std::vector<AssetFileDef>& asset_file_defs,
Session* session) {
LOG(INFO) << "Running MainOp on SavedModel bundle.";
const auto& collection_def_map = meta_graph_def.collection_def();
const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey);
if (main_op_it != collection_def_map.end()) {
if (main_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
std::vector<std::pair<string, Tensor>> inputs;
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
const StringPiece main_op_name = main_op_it->second.node_list().value(0);
return session->Run(run_options, inputs, {}, {main_op_name.ToString()},
nullptr /* outputs */, &run_metadata);
}
return Status::OK();
}
Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name,
@ -121,8 +153,9 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
const string variables_index_path = io::JoinPath(
variables_directory, MetaFilename(kSavedModelVariablesFilename));
if (!Env::Default()->FileExists(variables_index_path).ok()) {
return errors::NotFound(
"Checkpoint index file not found in SavedModel directory.");
LOG(INFO) << "The specified SavedModel has no variables; no checkpoints "
"were restored.";
return Status::OK();
}
const string variables_path =
io::JoinPath(variables_directory, kSavedModelVariablesFilename);
@ -210,11 +243,15 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(),
asset_file_defs, bundle->session.get()));
// TODO(sukritiramesh): Add support for a single main op to run upon load,
// which will supersede the legacy_init_op and separate RunRestore.
TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir,
bundle->meta_graph_def, asset_file_defs,
bundle->session.get()));
if (HasMainOp(bundle->meta_graph_def)) {
TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir,
bundle->meta_graph_def, asset_file_defs,
bundle->session.get()));
} else {
TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir,
bundle->meta_graph_def, asset_file_defs,
bundle->session.get()));
}
return Status::OK();
}

View File

@ -36,7 +36,7 @@ struct SavedModelBundle {
/// resource leaks, we explicitly call Close on Sessions that we create.
~SavedModelBundle() {
if (session) {
session->Close();
session->Close().IgnoreError();
}
}

View File

@ -31,6 +31,8 @@ namespace {
constexpr char kTestDataPbTxt[] =
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
constexpr char kTestDataMainOp[] =
"cc/saved_model/testdata/half_plus_two_main_op/00000123";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123";
@ -165,6 +167,18 @@ TEST_F(LoaderTest, PbtxtFormat) {
CheckSavedModelBundle(export_dir, bundle);
}
TEST_F(LoaderTest, MainOpFormat) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataMainOp);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(export_dir, bundle);
}
TEST_F(LoaderTest, InvalidExportPath) {
SavedModelBundle bundle;
RunOptions run_options;

View File

@ -0,0 +1 @@
asset-file-contents

View File

@ -284,6 +284,7 @@ meta_graphs {
type: "shape"
default_value {
shape {
unknown_rank: true
}
}
}
@ -447,7 +448,7 @@ meta_graphs {
}
}
tags: "serve"
tensorflow_version: "0.12.head"
tensorflow_version: "1.1.0-rc2"
tensorflow_git_version: "unknown"
}
graph_def {
@ -885,6 +886,7 @@ meta_graphs {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
@ -1714,7 +1716,7 @@ meta_graphs {
dtype: DT_STRING
tensor_shape {
}
string_val: "_temp_aeab824a1fc94305a10a2504f5995de2/part"
string_val: "_temp_80e928f1e0c844239d136d1ea966099d/part"
}
}
}
@ -2444,7 +2446,7 @@ meta_graphs {
input: "^save/restore_shard"
}
versions {
producer: 21
producer: 23
}
}
saver_def {
@ -2503,6 +2505,42 @@ meta_graphs {
}
}
}
signature_def {
key: "classify_x2_to_y3"
value {
inputs {
key: "inputs"
value {
name: "x2:0"
dtype: DT_FLOAT
tensor_shape {
dim {
size: -1
}
dim {
size: 1
}
}
}
}
outputs {
key: "scores"
value {
name: "y3:0"
dtype: DT_FLOAT
tensor_shape {
dim {
size: -1
}
dim {
size: 1
}
}
}
}
method_name: "tensorflow/serving/classify"
}
}
signature_def {
key: "classify_x_to_y"
value {

View File

@ -31,8 +31,8 @@ Coordinator::Coordinator(const std::vector<error::Code>& clean_stop_errors)
}
Coordinator::~Coordinator() {
RequestStop();
Join();
RequestStop().IgnoreError();
Join().IgnoreError();
}
Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) {
@ -115,4 +115,15 @@ void Coordinator::WaitForStop() {
}
}
} // namespace
Status Coordinator::ExportCostGraph(CostGraphDef* cost_graph) const {
mutex_lock l(runners_lock_);
for (auto& t : runners_) {
Status s = t->ExportCostGraph(cost_graph);
if (!s.ok()) {
return s;
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -21,19 +21,24 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
/// The abstract interface for runners which must implement the Join function.
/// The abstract interface for runners which must implement the Join and the
/// IsRunning function.
class RunnerInterface {
public:
virtual ~RunnerInterface() {}
virtual Status Join() = 0;
virtual Status ExportCostGraph(CostGraphDef* cost_graph) const {
return Status(error::INVALID_ARGUMENT, "No cost model to export.");
}
/// Returns true iff the runner is running, i.e. if it is trying to populate
/// its queue.
virtual bool IsRunning() const = 0;
@ -101,6 +106,9 @@ class Coordinator {
/// RequestStop() is called.
void WaitForStop();
// Returns the cost graph from stored run metadata in registered runners.
Status ExportCostGraph(CostGraphDef* cost_graph) const;
private:
std::unordered_set<int> clean_stop_errors_;
condition_variable wait_for_stop_;
@ -111,12 +119,10 @@ class Coordinator {
mutex status_lock_;
Status status_ GUARDED_BY(status_lock_);
mutex runners_lock_;
mutable mutex runners_lock_;
std::vector<std::unique_ptr<RunnerInterface>> runners_
GUARDED_BY(runners_lock_);
std::atomic<int> num_runners_to_cancel_;
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
};

View File

@ -29,9 +29,10 @@ namespace {
using error::Code;
void WaitForStopThread(Coordinator* coord, bool* stopped, Notification* done) {
void WaitForStopThread(Coordinator* coord, Notification* about_to_wait,
Notification* done) {
about_to_wait->Notify();
coord->WaitForStop();
*stopped = true;
done->Notify();
}
@ -39,22 +40,22 @@ TEST(CoordinatorTest, TestStopAndWaitOnStop) {
Coordinator coord;
EXPECT_EQ(coord.ShouldStop(), false);
bool stopped = false;
Notification about_to_wait;
Notification done;
Env::Default()->SchedClosure(
std::bind(&WaitForStopThread, &coord, &stopped, &done));
Env::Default()->SleepForMicroseconds(10000000);
EXPECT_EQ(stopped, false);
std::bind(&WaitForStopThread, &coord, &about_to_wait, &done));
about_to_wait.WaitForNotification();
Env::Default()->SleepForMicroseconds(1000 * 1000);
EXPECT_FALSE(done.HasBeenNotified());
coord.RequestStop();
TF_EXPECT_OK(coord.RequestStop());
done.WaitForNotification();
EXPECT_EQ(stopped, true);
EXPECT_EQ(coord.ShouldStop(), true);
EXPECT_TRUE(coord.ShouldStop());
}
class MockQueueRunner : public RunnerInterface {
public:
MockQueueRunner(Coordinator* coord) {
explicit MockQueueRunner(Coordinator* coord) {
coord_ = coord;
join_counter_ = nullptr;
thread_pool_.reset(new thread::ThreadPool(Env::Default(), "test-pool", 10));
@ -66,17 +67,19 @@ class MockQueueRunner : public RunnerInterface {
join_counter_ = join_counter;
}
void StartCounting(std::atomic<int>* counter, int until) {
void StartCounting(std::atomic<int>* counter, int until,
Notification* start = nullptr) {
thread_pool_->Schedule(
std::bind(&MockQueueRunner::CountThread, this, counter, until));
std::bind(&MockQueueRunner::CountThread, this, counter, until, start));
}
void StartSettingStatus(const Status& status, BlockingCounter* counter) {
thread_pool_->Schedule(
std::bind(&MockQueueRunner::SetStatusThread, this, status, counter));
void StartSettingStatus(const Status& status, BlockingCounter* counter,
Notification* start) {
thread_pool_->Schedule(std::bind(&MockQueueRunner::SetStatusThread, this,
status, counter, start));
}
Status Join() {
Status Join() override {
if (join_counter_ != nullptr) {
(*join_counter_)++;
}
@ -93,15 +96,17 @@ class MockQueueRunner : public RunnerInterface {
void Stop() { stopped_ = true; }
private:
void CountThread(std::atomic<int>* counter, int until) {
void CountThread(std::atomic<int>* counter, int until, Notification* start) {
if (start != nullptr) start->WaitForNotification();
while (!coord_->ShouldStop() && counter->load() < until) {
(*counter)++;
Env::Default()->SleepForMicroseconds(100000);
Env::Default()->SleepForMicroseconds(10 * 1000);
}
coord_->RequestStop();
coord_->RequestStop().IgnoreError();
}
void SetStatusThread(const Status& status, BlockingCounter* counter) {
Env::Default()->SleepForMicroseconds(100000);
void SetStatusThread(const Status& status, BlockingCounter* counter,
Notification* start) {
start->WaitForNotification();
SetStatus(status);
counter->DecrementCount();
}
@ -118,19 +123,19 @@ TEST(CoordinatorTest, TestRealStop) {
std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
qr1->StartCounting(&counter, 100);
coord.RegisterRunner(std::move(qr1));
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
qr2->StartCounting(&counter, 100);
coord.RegisterRunner(std::move(qr2));
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
// Wait until the counting has started
while (counter.load() == 0)
;
coord.RequestStop();
TF_EXPECT_OK(coord.RequestStop());
int temp_counter = counter.load();
Env::Default()->SleepForMicroseconds(10000000);
Env::Default()->SleepForMicroseconds(1000 * 1000);
EXPECT_EQ(temp_counter, counter.load());
TF_EXPECT_OK(coord.Join());
}
@ -138,12 +143,14 @@ TEST(CoordinatorTest, TestRealStop) {
TEST(CoordinatorTest, TestRequestStop) {
Coordinator coord;
std::atomic<int> counter(0);
Notification start;
std::unique_ptr<MockQueueRunner> qr;
for (int i = 0; i < 10; i++) {
qr.reset(new MockQueueRunner(&coord));
qr->StartCounting(&counter, 10);
coord.RegisterRunner(std::move(qr));
qr->StartCounting(&counter, 10, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr)));
}
start.Notify();
coord.WaitForStop();
EXPECT_EQ(coord.ShouldStop(), true);
@ -156,41 +163,43 @@ TEST(CoordinatorTest, TestJoin) {
int join_counter = 0;
std::unique_ptr<MockQueueRunner> qr1(
new MockQueueRunner(&coord, &join_counter));
coord.RegisterRunner(std::move(qr1));
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
std::unique_ptr<MockQueueRunner> qr2(
new MockQueueRunner(&coord, &join_counter));
coord.RegisterRunner(std::move(qr2));
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
coord.RequestStop();
TF_EXPECT_OK(coord.RequestStop());
TF_EXPECT_OK(coord.Join());
EXPECT_EQ(join_counter, 2);
}
TEST(CoordinatorTest, StatusReporting) {
Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE});
Notification start;
BlockingCounter counter(3);
std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter);
coord.RegisterRunner(std::move(qr1));
qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter);
coord.RegisterRunner(std::move(qr2));
qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord));
qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter);
coord.RegisterRunner(std::move(qr3));
qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr3)));
start.Notify();
counter.Wait();
coord.RequestStop();
TF_EXPECT_OK(coord.RequestStop());
EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);
}
TEST(CoordinatorTest, JoinWithoutStop) {
Coordinator coord;
std::unique_ptr<MockQueueRunner> qr(new MockQueueRunner(&coord));
coord.RegisterRunner(std::move(qr));
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr)));
EXPECT_EQ(coord.Join().code(), Code::FAILED_PRECONDITION);
}
@ -198,7 +207,7 @@ TEST(CoordinatorTest, JoinWithoutStop) {
TEST(CoordinatorTest, AllRunnersStopped) {
Coordinator coord;
MockQueueRunner* qr = new MockQueueRunner(&coord);
coord.RegisterRunner(std::unique_ptr<RunnerInterface>(qr));
TF_ASSERT_OK(coord.RegisterRunner(std::unique_ptr<RunnerInterface>(qr)));
EXPECT_FALSE(coord.AllRunnersStopped());
qr->Stop();

View File

@ -49,7 +49,12 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
enqueue_op_names_.insert(enqueue_op_names_.end(),
queue_runner_def.enqueue_op_name().begin(),
queue_runner_def.enqueue_op_name().end());
runs_ = enqueue_op_names_.size();
size_t op_names_size = enqueue_op_names_.size();
if (op_names_size > kint32max) {
return Status(error::INVALID_ARGUMENT,
"Enqueue ops to run cannot exceed kint32max");
}
runs_ = static_cast<int>(op_names_size);
if (runs_ == 0) {
return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run.");
}
@ -77,11 +82,17 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
QueueRunner::~QueueRunner() {
// Cannot run Stop() here because the session might already be closed or
// destroyed.
Join();
Join().IgnoreError();
}
Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
Status QueueRunner::StartAndCollectCostGraph(Session* sess,
const RunOptions* run_options) {
SetRunArgumentsAndCostGraph(run_options);
return Start(sess, 0);
}
Status QueueRunner::Start(Session* sess, int wait_for) {
counter_.reset(new BlockingCounter(runs_));
for (const string& enqueue_op : enqueue_op_names_) {
@ -109,12 +120,18 @@ Status QueueRunner::Start(Session* sess, int wait_for) {
return Status::OK();
}
Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms,
const RunOptions* run_options) {
SetRunArgumentsAndCostGraph(run_options);
return Start(session, wait_for_ms);
}
void QueueRunner::Stop(Session* sess) {
if (coord_ != nullptr) {
coord_->WaitForStop();
}
if (!cancel_op_name_.empty()) {
UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr));
UpdateStatus(RealRun(sess, cancel_op_name_, false));
}
stopped_ = true;
}
@ -149,7 +166,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
if (coord_ && coord_->ShouldStop()) {
break;
}
status = sess->Run({}, {}, {enqueue_op}, nullptr);
status = RealRun(sess, enqueue_op, true);
if (first_iteration) {
if (!status.ok()) {
mutex_lock l(mu_);
@ -170,12 +187,14 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
// will be run anway in this case.
if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) {
if (last_run && !close_op_name_.empty()) {
UpdateStatus(sess->Run({}, {}, {close_op_name_}, nullptr));
UpdateStatus(RealRun(sess, close_op_name_, false));
}
} else if (!status.ok()) {
LOG(ERROR) << "Queue runner thread got a failure status: "
<< status.ToString();
UpdateStatus(status);
if (coord_) {
coord_->RequestStop();
coord_->RequestStop().IgnoreError();
}
}
}
@ -185,4 +204,39 @@ Status QueueRunner::GetStatus() {
return status_;
}
Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const {
if (!cg_mu_) {
return Status(error::FAILED_PRECONDITION,
"This QueueRunner doesn't collect a cost graph.");
}
mutex_lock l(*cg_mu_);
cost_graph->MergeFrom(*cost_graph_);
return Status::OK();
}
void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions* run_options) {
cg_mu_.reset(new mutex());
{
mutex_lock l(*cg_mu_);
cost_graph_.reset(new CostGraphDef());
}
if (run_options) {
run_options_ = *run_options;
}
}
Status QueueRunner::RealRun(Session* sess, const string& op,
bool update_costs) {
Status s;
if (update_costs && cg_mu_) {
RunMetadata metadata;
s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata);
mutex_lock l(*cg_mu_);
cost_graph_->Swap(metadata.mutable_cost_graph());
} else {
s = sess->Run({}, {}, {op}, nullptr);
}
return s;
}
} // namespace tensorflow

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
@ -58,9 +59,16 @@ class QueueRunner : public RunnerInterface {
/// Starts the queue runner with the given session.
Status Start(Session* sess);
/// Starts the queue runner with the given session and sets the run arguments
/// for sess->Run. It also collects and stores the cost model.
Status StartAndCollectCostGraph(Session* sess,
const RunOptions* run_options = nullptr);
/// Starts the queue runner with the given session, and wait for up to the
/// specified time (in milliseconds) for the queues to start to fill up.
Status Start(Session* sess, int wait_for_ms);
Status StartAndCollectCostGraph(Session* session, int wait_for_ms,
const RunOptions* run_options = nullptr);
/// Requests to stop and runs the cancel op. It would be called in a separate
/// thread when coordinator is set. If there is no coordinator it should be
@ -74,8 +82,11 @@ class QueueRunner : public RunnerInterface {
/// Returns the latest status.
Status GetStatus();
// Returns the stored cost model.
Status ExportCostGraph(CostGraphDef* cost_graph) const override;
private:
QueueRunner() : coord_(nullptr), stopped_(false) {}
QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {}
// Initializes the instance with the QueueRunnerDef proto.
Status Init(const QueueRunnerDef& queue_runner_def);
@ -94,6 +105,10 @@ class QueueRunner : public RunnerInterface {
bool IsRunning() const override { return !stopped_; }
void SetRunArgumentsAndCostGraph(const RunOptions* run_options);
Status RealRun(Session* sess, const string& op, bool update_costs);
string queue_name_;
std::vector<string> enqueue_op_names_;
string close_op_name_;
@ -114,6 +129,10 @@ class QueueRunner : public RunnerInterface {
mutex cb_mu_;
std::vector<std::function<void(Status)>> callbacks_;
mutable std::unique_ptr<mutex> cg_mu_;
std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
RunOptions run_options_;
};
} // namespace tensorflow

View File

@ -44,6 +44,7 @@ using ops::FIFOQueue;
using ops::QueueClose;
using ops::QueueDequeue;
using ops::QueueEnqueue;
using ops::RandomNormal;
using ops::Square;
using ops::Variable;
@ -84,7 +85,7 @@ QueueRunnerDef BuildQueueRunnerDef(
const std::string& close_op, const std::string& cancel_op,
const std::vector<Code>& queue_closed_error_codes) {
QueueRunnerDef queue_runner_def;
*queue_runner_def.mutable_queue_name() = kQueueName;
*queue_runner_def.mutable_queue_name() = queue_name;
for (const std::string& enqueue_op : enqueue_ops) {
*queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op;
}
@ -293,7 +294,7 @@ TEST(QueueRunnerTest, StartTimeout) {
// This will timeout since queue0 is not fed and queue1 is fetching data from
// queue0.
EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED);
session->Close();
TF_EXPECT_OK(session->Close());
}
TEST(QueueRunnerTest, TestCoordinatorStop) {
@ -317,8 +318,8 @@ TEST(QueueRunnerTest, TestCoordinatorStop) {
TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1));
TF_CHECK_OK(qr1->Start(session.get()));
coord.RegisterRunner(std::move(qr0));
coord.RegisterRunner(std::move(qr1));
TF_EXPECT_OK(coord.RegisterRunner(std::move(qr0)));
TF_EXPECT_OK(coord.RegisterRunner(std::move(qr1)));
std::vector<Tensor> dq;
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
@ -340,9 +341,70 @@ TEST(QueueRunnerTest, CallbackCalledOnError) {
bool error_caught = false;
qr->AddErrorCallback([&error_caught](const Status&) { error_caught = true; });
TF_EXPECT_OK(qr->Start(session.get()));
qr->Join();
EXPECT_FALSE(qr->Join().ok());
EXPECT_TRUE(error_caught);
}
TEST(QueueRunnerTest, RunMetaDataTest) {
Scope root = Scope::NewRootScope();
auto q0 = FIFOQueue(root.WithOpName(kQueueName), {DataType::DT_FLOAT});
Output rnd = RandomNormal(root.WithOpName("rnd"), {1, 1}, DataType::DT_FLOAT);
Output square = Square(root.WithOpName(kSquareOpName), rnd);
auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {square});
auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
QueueClose::CancelPendingEnqueues(true));
auto dequeue0 =
QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_FLOAT});
GraphDef graph_def;
TF_EXPECT_OK(root.ToGraphDef(&graph_def));
for (auto& node : *graph_def.mutable_node()) {
node.set_device("/cpu:0");
}
SessionOptions sess_options;
sess_options.config.mutable_graph_options()->set_build_cost_model(1);
std::unique_ptr<Session> session(NewSession(sess_options));
TF_CHECK_OK(session->Create(graph_def));
QueueRunnerDef queue_runner_def =
BuildQueueRunnerDef(kQueueName, {kEnqueueOp0}, kCloseOp0, kCancelOp0, {});
std::unique_ptr<QueueRunner> qr;
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
RunOptions run_options;
TF_CHECK_OK(qr->StartAndCollectCostGraph(session.get(), &run_options));
// Make sure there was at least one element enqueued in q0: this prevents a
// race condition where we close the queue before it was populated.
std::vector<Tensor> dq0;
TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0));
// Second call to run dequeue op is to make sure the cost graph has been
// stored.
TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0));
CostGraphDef cost_graph;
TF_CHECK_OK(qr->ExportCostGraph(&cost_graph));
EXPECT_TRUE(cost_graph.node_size() > 0);
qr->Stop(session.get());
}
TEST(QueueRunnerTest, NoRunMetaDataTest) {
GraphDef graph_def = BuildSimpleGraph();
auto session = BuildSessionAndInitVariable(graph_def);
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
std::unique_ptr<QueueRunner> qr;
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
TF_CHECK_OK(qr->Start(session.get()));
TF_EXPECT_OK(qr->Join());
CostGraphDef cost_graph;
EXPECT_EQ(qr->ExportCostGraph(&cost_graph).code(),
error::FAILED_PRECONDITION);
}
} // namespace
} // namespace tensorflow

View File

@ -227,7 +227,7 @@ int main(int argc, char* argv[]) {
argv[dst++] = f;
}
argv[dst++] = nullptr;
argc = unknown_flags.size() + 1;
argc = static_cast<int>(unknown_flags.size() + 1);
tensorflow::port::InitMain(argv[0], &argc, &argv);
tensorflow::example::ConcurrentSessions(opts);
}

View File

@ -20,6 +20,7 @@ cc_library(
cc_test(
name = "runtime_test",
size = "small",
srcs = ["runtime_test.cc"],
deps = [
":runtime",
@ -73,7 +74,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:core_cpu",
@ -88,6 +89,7 @@ cc_library(
cc_test(
name = "codegen_test",
size = "small",
srcs = ["codegen_test.cc"],
data = ["codegen_test_h.golden"],
deps = [
@ -101,6 +103,7 @@ cc_test(
cc_test(
name = "tfcompile_util_test",
size = "small",
srcs = ["tfcompile_util_test.cc"],
deps = [
":tfcompile_lib",
@ -123,9 +126,16 @@ cc_library(
deps = [
":tfcompile_lib",
":tfcompile_proto",
"//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags",
"//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags",
"//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags",
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags",
"//tensorflow/compiler/xla/legacy_flags:llvm_util_flags",
"//tensorflow/compiler/xla/legacy_flags:service_flags",
"//tensorflow/compiler/xla/legacy_flags:util_flags",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",

View File

@ -40,7 +40,7 @@ namespace benchmark {
// the implementation without pulling in all of the Env dependencies.
static double NowMicros() {
struct timeval tv;
gettimeofday(&tv, NULL);
gettimeofday(&tv, nullptr);
return static_cast<uint64>(tv.tv_sec) * 1000000 + tv.tv_usec;
}

View File

@ -152,8 +152,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
string RewriteWithName(const string& name, string code,
const std::vector<std::pair<string, string>>& rewrites) {
str_util::ReplaceAllPairs(&code, rewrites);
str_util::ReplaceAll(&code, "{{NAME}}", name);
return code;
return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true);
}
// Generate methods for args (inputs).
@ -366,7 +365,7 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config,
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace Eigen { class ThreadPoolDevice; }
namespace Eigen { struct ThreadPoolDevice; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void {{ENTRY}}(

View File

@ -15,7 +15,7 @@
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace Eigen { class ThreadPoolDevice; }
namespace Eigen { struct ThreadPoolDevice; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void entry_point(

View File

@ -25,8 +25,9 @@ limitations under the License.
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -199,17 +200,17 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config,
for (const Fetch& fetch : config.fetch()) {
missing_fetches.insert(TensorIdToString(fetch.id()));
}
for (const Node* n : graph->nodes()) {
for (const Node* n : graph->op_nodes()) {
if (n->type_string() == kArgOp) {
string feed_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id));
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
if (missing_feeds.erase(feed_id) == 0) {
return errors::Aborted(kArgOp,
" node found with unknown feed id: ", feed_id);
}
} else if (n->type_string() == kRetvalOp) {
string fetch_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id));
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
if (missing_fetches.erase(fetch_id) == 0) {
return errors::Aborted(kRetvalOp,
" node found with unknown fetch id: ", fetch_id);
@ -233,7 +234,7 @@ Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
for (Node* n : graph.nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
auto insert_result = indexed_arg_nodes.insert({index, n});
if (!insert_result.second) {
const Node* dup = insert_result.first->second;
@ -262,10 +263,10 @@ Status CreateXlaArgs(const Graph& graph,
TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
for (const Node* node : arg_nodes) {
XlaCompiler::Argument arg;
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &arg.parameter));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name));
arg.kind = XlaCompiler::Argument::kParameter;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
xla_args->push_back(arg);
}
return Status::OK();
@ -273,11 +274,11 @@ Status CreateXlaArgs(const Graph& graph,
// Converts the TensorFlow graph into an XLA computation, by executing the
// graph symbolically, with each op building up the XLA HLO.
Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
const FunctionLibraryDefinition* flib_def,
Status ConvertGraphToXla(xla::CompileOnlyClient* client,
std::unique_ptr<Graph> graph,
xla::Computation* computation, bool* has_context_arg) {
// Create a device and context to convert the graph into an XLA computation.
XlaOpRegistry::RegisterJitKernels();
XlaOpRegistry::RegisterCompilationKernels();
// Populate the context with args from the graph.
for (Node* node : graph->nodes()) {
node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
@ -288,19 +289,19 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
// Compile the graph into an XLA computation.
XlaCompiler::Options compiler_options;
compiler_options.client = client;
compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
DeviceType device_type(DEVICE_CPU_XLA_JIT);
compiler_options.device_type = &device_type;
compiler_options.flib_def = &graph->flib_def();
compiler_options.graph_def_version = graph->versions().producer();
compiler_options.allow_cpu_custom_calls = true;
XlaCompiler compiler(compiler_options);
std::unique_ptr<FunctionLibraryRuntime> flib_run(NewFunctionLibraryRuntime(
compiler.device_mgr(), Env::Default(), compiler.device(),
graph->versions().producer(), flib_def, OptimizerOptions()));
XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(compiler.CompileGraph("tfcompile", std::move(graph),
flib_run.get(), xla_args,
false /* use_tuple_arg */, &result));
TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
"tfcompile", std::move(graph),
xla_args, &result));
*has_context_arg = result.requires_runtime_context;
*computation = std::move(result.computation);
*computation = std::move(*result.computation);
int num_const_results = 0;
for (int i = 0; i < result.outputs.size(); ++i) {
@ -334,7 +335,8 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
}
// Compiles the XLA computation into executable code.
Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
Status CompileXla(xla::CompileOnlyClient* client,
const xla::Computation& computation,
const xla::cpu::CpuAotCompilationOptions& aot_opts,
CompileResult* compile_result) {
// Retrieves arg and result layouts from the computation.
@ -348,10 +350,11 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
compile_result->program_shape = *pshape_or.ValueOrDie();
xla::ProgramShape* pshape = &compile_result->program_shape;
std::vector<const xla::Shape*> arg_layouts;
arg_layouts.reserve(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
}
xla::LocalClient::AheadOfTimeComputationInstance instance;
xla::CompileOnlyClient::AotComputationInstance instance;
instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts);
instance.result_layout = &pshape->result();
@ -366,17 +369,17 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
std::move(aot_or.ValueOrDie().back()));
compile_result->entry_point = aot_opts.entry_point_name();
compile_result->pointer_size =
xla::LocalClient::PointerSizeForTriple(aot_opts.triple());
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
return Status::OK();
}
} // namespace
Status InitGraph(const GraphDef& graph_def, const Config& config,
const MainFlags& flags, const FunctionLibraryDefinition* flib,
std::unique_ptr<Graph>* graph) {
const MainFlags& flags, std::unique_ptr<Graph>* graph) {
TF_RETURN_IF_ERROR(ValidateConfig(config));
std::unique_ptr<Graph> g(new Graph(flib));
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
std::unique_ptr<Graph> g(new Graph(flib_def));
GraphDef copy_def(graph_def);
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&copy_def, *g->op_registry(),
0 /*node_offset*/));
@ -388,7 +391,6 @@ Status InitGraph(const GraphDef& graph_def, const Config& config,
}
Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
const FunctionLibraryDefinition* flib,
CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the
// computation.
@ -396,11 +398,11 @@ Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
namespace gpu = perftools::gputools;
gpu::Platform* cpu_platform =
gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
xla::LocalClient* client =
xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie();
xla::CompileOnlyClient* client =
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
xla::Computation computation;
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), flib,
&computation,
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation,
&compile_result->has_context_arg));
if (!flags.debug_dir.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,

View File

@ -56,8 +56,7 @@ extern const char* const kDebugNameAttr;
// compute the outputs. If dump_graphs is true, graph rewrites will be dumped
// for debugging.
Status InitGraph(const GraphDef& graph_def, const Config& config,
const MainFlags& flags, const FunctionLibraryDefinition* flib,
std::unique_ptr<Graph>* graph);
const MainFlags& flags, std::unique_ptr<Graph>* graph);
// CompileResult describes the output of CompileGraph, where the object file
// data and meta-information is available in aot.
@ -83,7 +82,6 @@ struct CompileResult {
//
// The XLA compilation options are specified in the flags.
Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
const FunctionLibraryDefinition* flib,
CompileResult* result);
} // namespace tfcompile

View File

@ -31,6 +31,8 @@ namespace {
inline void* aligned_malloc(size_t size, int minimum_alignment) {
#if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN)
return memalign(minimum_alignment, size);
#elif defined(COMPILER_MSVC)
return _aligned_malloc(size, minimum_alignment);
#else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN
void* ptr = nullptr;
// posix_memalign requires that the requested alignment be at least
@ -45,7 +47,13 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) {
#endif
}
inline void aligned_free(void* aligned_memory) { free(aligned_memory); }
inline void aligned_free(void* aligned_memory) {
#if defined(COMPILER_MSVC)
_aligned_free(aligned_memory);
#else
free(aligned_memory);
#endif
}
size_t align_to(size_t n, size_t align) {
return (((n - 1) / align) + 1) * align;

View File

@ -43,14 +43,16 @@ genrule(
testonly = 1,
outs = [
"test_graph_tfadd.pb",
"test_graph_tfadd_with_ckpt.pb",
"test_graph_tfadd_with_ckpt.ckpt",
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt.pb",
"test_graph_tfadd_with_ckpt_saver.ckpt",
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.saver",
"test_graph_tffunction.pb",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
"test_graph_tfsplits.pb",
],
cmd = "$(location :make_test_graphs) --out_dir $(@D)",
tags = ["manual"],
@ -114,6 +116,24 @@ tf_library(
tags = ["manual"],
)
tf_library(
name = "test_graph_tffunction",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
cpp_class = "FunctionComp",
graph = "test_graph_tffunction.pb",
tags = ["manual"],
)
tf_library(
name = "test_graph_tfsplits",
testonly = 1,
config = "test_graph_tfsplits.config.pbtxt",
cpp_class = "SplitsComp",
graph = "test_graph_tfsplits.pb",
tags = ["manual"],
)
cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
@ -122,9 +142,11 @@ cc_test(
":test_graph_tfadd",
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
":test_graph_tfmatmulandadd",
":test_graph_tfsplits",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",

View File

@ -25,6 +25,7 @@ from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -71,7 +72,7 @@ def tfadd_with_ckpt_saver(out_dir):
saver.save(sess, ckpt_file)
# Without the SaverDef, the restore op won't be named correctly.
saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % out_dir
with open(saver_file, 'w') as f:
with open(saver_file, 'wb') as f:
f.write(saver.as_saver_def().SerializeToString())
@ -95,13 +96,41 @@ def tfmatmulandadd(_):
math_ops.add(x, y, name='x_y_sum')
def tffunction(_):
@function.Defun(dtypes.int32, dtypes.int32)
def test_func(a, b):
return a + b
x = constant_op.constant([1], name='x_const')
y = constant_op.constant([2], name='y_const')
test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg
def tfsplits(_):
"""A more complex graph, including splits."""
x = array_ops.placeholder(dtypes.float32, shape=[2, 2], name='x')
y = array_ops.placeholder(dtypes.float32, shape=[2, 2], name='y')
for _ in range(3):
x0, x1 = array_ops.split(x, 2, 0)
y0, y1 = array_ops.split(y, 2, 0)
x0 += 1
y0 += 1
z = math_ops.matmul(x, y, name='x_y_prod')
a = array_ops.concat([x0, y1], axis=0, name='concat_x0_y1')
b = array_ops.concat([y0, x1], axis=0, name='concat_y0_x1')
x = math_ops.matmul(a, b, name='a_b')
y = math_ops.add(x, z)
array_ops.identity(y, name='result')
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
with g.as_default():
build_graph(out_dir)
filename = '%s/test_graph_%s.pb' % (out_dir, build_graph.__name__)
with open(filename, 'w') as f:
with open(filename, 'wb') as f:
f.write(g.as_graph_def().SerializeToString())
@ -112,6 +141,8 @@ def main(_):
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
if __name__ == '__main__':
@ -121,7 +152,6 @@ if __name__ == '__main__':
'--out_dir',
type=str,
default='',
help='Output directory for graphs, checkpoints and savers.'
)
help='Output directory for graphs, checkpoints and savers.')
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,16 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_const" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_const" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "func_call" }
}

View File

@ -0,0 +1,18 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x" }
shape {
dim { size: 2 }
dim { size: 2 }
}
}
feed {
id { node_name: "y" }
shape {
dim { size: 2 }
dim { size: 2 }
}
}
fetch {
id { node_name: "result" }
}

View File

@ -20,9 +20,11 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -376,6 +378,49 @@ TEST(TFCompileTest, MatMulAndAdd1) {
}
}
TEST(TFCompileTest, Function) {
// The function is equivalent to an addition
FunctionComp add_fn;
EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);
add_fn.arg0() = 1;
add_fn.arg1() = 2;
EXPECT_TRUE(add_fn.Run());
EXPECT_EQ(add_fn.error_msg(), "");
EXPECT_EQ(add_fn.result0(), 3);
EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
}
TEST(TFCompileTest, Splits) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
SplitsComp fn;
fn.set_thread_pool(&device);
// x = [[1, 2], [3, 4]]
fn.arg0(0, 0) = 1;
fn.arg0(0, 1) = 2;
fn.arg0(1, 0) = 3;
fn.arg0(1, 1) = 4;
// y = [[10, 20], [30, 40]]
fn.arg1(0, 0) = 10;
fn.arg1(0, 1) = 20;
fn.arg1(1, 0) = 30;
fn.arg1(1, 1) = 40;
EXPECT_TRUE(fn.Run());
EXPECT_EQ(fn.error_msg(), "");
const float expected[] = {7.86375557e+10, 1.34274679e+11, 1.92741717e+12,
3.29964742e+12};
EXPECT_NEAR(expected[0], fn.result0(0, 0), 1e4);
EXPECT_NEAR(expected[1], fn.result0(0, 1), 1e4);
EXPECT_NEAR(expected[2], fn.result0(1, 0), 1e4);
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}
} // namespace
} // namespace tfcompile
} // namespace tensorflow

View File

@ -279,7 +279,11 @@ def target_llvm_triple():
# TODO(toddw): Add target_triple for other targets. For details see:
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
return select({
"//tensorflow:android_armeabi": "armv5-none-android",
"//tensorflow:android_arm": "armv7-none-android",
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:darwin": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",
})

View File

@ -23,9 +23,16 @@ limitations under the License.
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/util_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -52,7 +59,8 @@ const char kUsageHeader[] =
"header file that gives access to the functionality in the object file.\n"
"A typical invocation looks like this:\n"
"\n"
" $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n"
" $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt "
"--cpp_class=\"mynamespace::MyComputation\"\n"
"\n";
Status ReadProtoFile(const string& kind, const string& fname,
@ -73,6 +81,9 @@ void ParseTensorId(const string& name, TensorId* id) {
Status Main(const MainFlags& flags) {
// Process config.
Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
@ -85,15 +96,16 @@ Status Main(const MainFlags& flags) {
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def));
std::unique_ptr<Graph> graph;
FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library());
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &flib, &graph));
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &graph));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph), flags, &flib, &compile_result));
TF_RETURN_IF_ERROR(CompileGraph(std::move(graph), flags, &compile_result));
// Write output files.
Env* env = Env::Default();
@ -101,6 +113,9 @@ Status Main(const MainFlags& flags) {
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object,
StringPiece(obj.data(), obj.size())));
HeaderOpts header_opts;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name,
&header_opts.namespaces));
string header;
@ -121,9 +136,16 @@ int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
xla::legacy_flags::AppendAliasAnalysisFlags(&flag_list);
xla::legacy_flags::AppendBufferAssignmentFlags(&flag_list);
xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list);
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list);
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::legacy_flags::AppendLlvmUtilFlags(&flag_list);
xla::legacy_flags::AppendServiceFlags(&flag_list);
xla::legacy_flags::AppendUtilFlags(&flag_list);
tensorflow::string usage = tensorflow::tfcompile::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list);
@ -131,12 +153,16 @@ int main(int argc, char** argv) {
QCHECK(parsed_flags_ok) << "\n" << usage;
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(argc == 1 && !flags.config.empty() &&
(flags.dump_fetch_nodes ||
(!flags.graph.empty() && !flags.entry_point.empty())))
<< "\n"
<< usage;
TF_QCHECK_OK(tensorflow::tfcompile::Main(flags));
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
"other than flags\n\n"
<< usage;
tensorflow::Status status = tensorflow::tfcompile::Main(flags);
if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
<< usage;
return 1;
} else {
TF_QCHECK_OK(status);
}
return 0;
}

View File

@ -24,7 +24,7 @@ namespace tensorflow {
namespace tfcompile {
namespace {
void ExpectErrorContains(Status status, StringPiece str) {
void ExpectErrorContains(const Status& status, StringPiece str) {
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
<< "expected error: " << status.error_message() << " to contain: " << str;

View File

@ -18,7 +18,23 @@ package(
default_visibility = [":internal"],
)
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# This target can be used by XLA device plugins to prevent circular
# dependencies, and provides access to all of the required headers
# for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
visibility = ["//visibility:public"],
deps = [
":xla_cpu_device",
":xla_cpu_jit",
":xla_gpu_device",
":xla_gpu_jit",
],
)
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
@ -29,6 +45,7 @@ cc_library(
":xla_cpu_jit",
":xla_gpu_device",
":xla_gpu_jit",
"//tensorflow/compiler/plugin",
],
alwayslink = 1,
)
@ -38,7 +55,7 @@ cc_library(
visibility = [":friends"],
deps = [
":jit_compilation_passes",
":xla_local_launch_op",
"//tensorflow/compiler/jit/kernels:xla_local_launch_op",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
],
@ -48,12 +65,12 @@ cc_library(
cc_library(
name = "xla_gpu_jit",
visibility = [":friends"],
deps = [
deps = if_cuda([
":jit_compilation_passes",
":xla_local_launch_op",
"//tensorflow/compiler/jit/kernels:xla_local_launch_op",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin",
],
]),
alwayslink = 1,
)
@ -64,8 +81,10 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/jit/kernels:xla_device_launch_op",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
],
@ -79,8 +98,10 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/jit/kernels:xla_device_launch_op",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
],
@ -105,26 +126,22 @@ cc_library(
srcs = [
"xla_device.cc",
"xla_device_context.cc",
"xla_device_launch_op.cc",
"xla_device_ops.cc",
],
hdrs = [
"xla_device.h",
"xla_device_context.h",
"xla_device_launch_op.h",
"xla_device_ops.h",
],
deps = [
":common",
":jit_compilation_passes",
":xla_compilation_cache",
":xla_local_launch_op",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
@ -132,9 +149,9 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core/kernels:assign_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:identity_op",
@ -142,7 +159,6 @@ cc_library(
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:variable_ops",
],
alwayslink = 1,
)
cc_library(
@ -155,13 +171,13 @@ cc_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
],
)
@ -175,27 +191,41 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "graph_to_functiondef",
srcs = ["graph_to_functiondef.cc"],
hdrs = ["graph_to_functiondef.h"],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "compilation_passes",
srcs = [
"build_xla_launch_ops_pass.cc",
"encapsulate_subgraphs_pass.cc",
"graph_to_functiondef.cc",
"mark_for_compilation_pass.cc",
],
hdrs = [
"build_xla_launch_ops_pass.h",
"encapsulate_subgraphs_pass.h",
"graph_to_functiondef.h",
"mark_for_compilation_pass.h",
],
deps = [
":common",
":parallel_check_op",
":xla_local_launch_op",
":graph_to_functiondef",
":union_find",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/kernels:parallel_check_op",
"//tensorflow/compiler/jit/kernels:xla_local_launch_op",
"//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:parallel_check_op",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:const_analysis",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
@ -208,6 +238,11 @@ cc_library(
],
)
cc_library(
name = "union_find",
hdrs = ["union_find.h"],
)
cc_test(
name = "compilation_passes_test",
size = "small",
@ -217,8 +252,9 @@ cc_test(
"mark_for_compilation_pass_test.cc",
],
deps = [
":common",
":compilation_passes",
":xla_local_launch_op",
":graph_to_functiondef",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
@ -226,48 +262,14 @@ cc_test(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "xla_local_launch_op",
srcs = ["xla_local_launch_op.cc"],
hdrs = ["xla_local_launch_op.h"],
deps = [
":common",
":xla_compilation_cache",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:tensorflow_opensource",
],
alwayslink = 1,
)
tf_kernel_library(
name = "parallel_check_op",
srcs = ["parallel_check_op.cc"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
alwayslink = 1,
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -16,14 +16,13 @@ limitations under the License.
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/kernels/xla_local_launch_op.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/xla_local_launch_op.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -32,7 +31,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
@ -40,14 +38,16 @@ namespace tensorflow {
static Status BuildLaunchNode(
const string& nodename, const string& function_name,
const AttrValueMap& function_attr, const string& device_name,
const DataTypeVector& constant_dtypes, const DataTypeVector& arg_dtypes,
const DataTypeVector& result_dtypes, Graph* graph, Node** node) {
const DataTypeVector& constant_dtypes, int num_resources,
const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes,
Graph* graph, Node** node) {
NodeDef def;
def.set_name(graph->NewName(nodename));
def.set_op("_XlaLaunch");
def.set_device(device_name);
AddNodeAttr("Tconstants", constant_dtypes, &def);
AddNodeAttr("Targs", arg_dtypes, &def);
AddNodeAttr("Nresources", num_resources, &def);
AddNodeAttr("Tresults", result_dtypes, &def);
NameAttrList function;
function.set_name(function_name);
@ -62,25 +62,32 @@ static Status BuildLaunchNode(
static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
int num_constant_args;
int num_constant_args, num_resource_args;
TF_RETURN_IF_ERROR(
GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args));
GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args));
TF_RETURN_IF_ERROR(
GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args));
if (num_constant_args < 0 || num_constant_args > node->input_types().size()) {
if (num_constant_args < 0 || num_resource_args < 0 ||
num_constant_args + num_resource_args > node->num_inputs()) {
return errors::InvalidArgument(
"Invalid number of constant arguments to XLA kernel");
"Invalid number of constant/resource arguments to XLA kernel.");
}
const int num_nonconst_args =
node->num_inputs() - num_constant_args - num_resource_args;
DataTypeVector const_dtypes(node->input_types().begin(),
node->input_types().begin() + num_constant_args);
DataTypeVector arg_dtypes(node->input_types().begin() + num_constant_args,
node->input_types().end());
DataTypeVector arg_dtypes(
node->input_types().begin() + num_constant_args,
node->input_types().begin() + num_constant_args + num_nonconst_args);
// Build a _XlaLaunch operator to execute the function body.
Node* launch_node;
TF_RETURN_IF_ERROR(
BuildLaunchNode(graph->NewName(node->name()), node->type_string(),
node->def().attr(), node->def().device(), const_dtypes,
arg_dtypes, node->output_types(), graph, &launch_node));
TF_RETURN_IF_ERROR(BuildLaunchNode(
graph->NewName(node->name()), node->type_string(), node->def().attr(),
node->requested_device(), const_dtypes, num_resource_args, arg_dtypes,
node->output_types(), graph, &launch_node));
launch_node->set_assigned_device_name(node->assigned_device_name());
// Copy incoming edges to the launch node.
@ -116,9 +123,9 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
for (Node* n : graph->nodes()) {
for (Node* n : graph->op_nodes()) {
// In all cases, only try to compile computational nodes.
if (!n->IsOp() || n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
continue;
}
@ -128,6 +135,11 @@ Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
}
}
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph,
options.flib_def);
}
return Status::OK();
}
@ -151,7 +163,7 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
}
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterJitKernels();
XlaOpRegistry::RegisterCompilationKernels();
if (!IsCompilable(flr, ndef)) {
// ndef is calling a function that XLA can't compile.
return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
@ -159,7 +171,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
FunctionLibraryRuntime::Handle handle;
// If ndef is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle));
TF_RETURN_IF_ERROR(
flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
const FunctionBody* fbody = flr->GetFunctionBody(handle);
CHECK(fbody); // Can't be nullptr since we just instantiated it.
std::vector<bool> const_args(fbody->arg_types.size());
@ -179,6 +192,7 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
launch_def.set_op("_XlaLaunch");
launch_def.set_device(flr->device()->name());
AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
AddNodeAttr("Nresources", 0, &launch_def);
AddNodeAttr("Targs", fbody->arg_types, &launch_def);
AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
NameAttrList func;

View File

@ -18,5 +18,6 @@ limitations under the License.
namespace tensorflow {
const char* const kXlaCompileAttr = "_XlaCompile";
const char* const kXlaScopeAttr = "_XlaScope";
} // namespace tensorflow

View File

@ -23,6 +23,7 @@ namespace tensorflow {
// Name of attribute used to tag operators for compilation with XLA
extern const char* const kXlaCompileAttr; // "_XlaCompile"
extern const char* const kXlaScopeAttr; // "_XlaScope"
} // namespace tensorflow

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -46,6 +45,7 @@ namespace tensorflow {
const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
namespace {
@ -87,9 +87,12 @@ class Encapsulator {
// Build a FunctionDef for each subgraph, and add it 'library'. The values of
// the 'group_attribute' annotations become the function names.
// If 'reuse_existing_functions' is set, use an existing function with the
// same name, if any.
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
// function conversion.
Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
bool reuse_existing_functions,
FunctionLibraryDefinition* library);
// Write a copy of the input graph to 'graph_out', where the subgraphs are
@ -109,8 +112,8 @@ class Encapsulator {
// returned by _Retval nodes.
std::unique_ptr<Graph> graph;
// Which device are these nodes on? Used both to check that all nodes
// are assigned to the same device, and to assign a device to the call node.
// Which device are these nodes on? Used to assign a device to the call
// node.
string device;
// NodeDef for the function call node.
@ -161,7 +164,7 @@ static const char* const kRetValOp = "_Retval";
// none.
string Encapsulator::GetFunctionNameAttr(Node const* node) const {
string attr;
if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) {
if (!GetNodeAttr(node->attrs(), group_attribute_, &attr).ok()) {
attr.clear();
}
return attr;
@ -174,8 +177,7 @@ Status Encapsulator::SplitIntoSubgraphs() {
std::unordered_map<Node*, Node*> node_images;
// Copy all marked nodes to a subgraph. Do nothing for unmarked nodes.
for (Node* node : graph_in_->nodes()) {
if (node->IsSource() || node->IsSink()) continue;
for (Node* node : graph_in_->op_nodes()) {
string func_id = GetFunctionNameAttr(node);
if (func_id.empty()) continue;
@ -189,16 +191,10 @@ Status Encapsulator::SplitIntoSubgraphs() {
image->ClearAttr(group_attribute_);
node_images[node] = image;
// Check the device matches any existing device.
string device = node->assigned_device_name().empty()
? node->def().device()
: node->assigned_device_name();
if (subgraph.device.empty()) {
subgraph.device = device;
} else if (subgraph.device != device) {
s.Update(errors::InvalidArgument(
"Mismatched devices for nodes to be grouped by Encapsulator"));
subgraph.device = node->assigned_device_name().empty()
? node->requested_device()
: node->assigned_device_name();
}
}
@ -235,9 +231,16 @@ Status Encapsulator::SplitIntoSubgraphs() {
// Create a new _Retval node
DataType dtype = edge->src()->output_type(edge->src_output());
if (IsRefType(dtype)) {
return errors::InvalidArgument(
"Ref Tensors (e.g., Variables) are not supported: tensor ",
edge->src()->name(), ":", edge->src_output());
}
NodeDef ret_def;
ret_def.set_op(kRetValOp);
ret_def.set_name(src_subgraph.graph->NewName("output"));
ret_def.set_name(strings::StrCat(edge->src()->name(), "_",
edge->src_output(), "_retval"));
AddNodeAttr("T", dtype, &ret_def);
AddNodeAttr("index", ret_index, &ret_def);
Node* ret = src_subgraph.graph->AddNode(ret_def, &s);
@ -262,8 +265,16 @@ Status Encapsulator::SplitIntoSubgraphs() {
// This is the first time we have seen this tensor. Create an _Arg node.
DataType dtype = edge->dst()->input_type(edge->dst_input());
if (IsRefType(dtype)) {
return errors::InvalidArgument(
"Ref Tensors (e.g., Variables) are not supported: tensor ",
edge->src()->name(), ":", edge->src_output());
}
NodeDef arg_def;
NodeDefBuilder builder(dst_subgraph.graph->NewName("input"), kArgOp);
NodeDefBuilder builder(strings::StrCat(edge->src()->name(), "_",
edge->src_output(), "_arg"),
kArgOp);
builder.Attr("T", dtype);
builder.Attr("index", arg_index);
s = builder.Finalize(&arg_def);
@ -290,11 +301,11 @@ Status Encapsulator::SplitIntoSubgraphs() {
}
Status Encapsulator::BuildFunctionDefs(
const RewriteSubgraphFn& rewrite_subgraph_fn,
const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
FunctionLibraryDefinition* library) {
// For each subgraph, build a FunctionDef.
for (auto& subgraph_entry : subgraphs_) {
const string& name = subgraph_entry.first;
string name = subgraph_entry.first;
Subgraph& subgraph = subgraph_entry.second;
subgraph.call_node_def.set_op(name);
@ -331,6 +342,8 @@ Status Encapsulator::BuildFunctionDefs(
for (auto& result : subgraph.results) {
result.second = output_permutation[result.second];
}
name = subgraph.call_node_def.op();
}
FunctionDef fdef;
@ -345,7 +358,9 @@ Status Encapsulator::BuildFunctionDefs(
strings::StrCat("encapsulate_fdef_", name), fdef);
}
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
if (!reuse_existing_functions || library->Find(name) == nullptr) {
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
}
}
return Status::OK();
}
@ -422,8 +437,7 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking,
std::unordered_map<const Node*, Node*> node_images;
// Copy all unmarked nodes to the output graph.
for (Node* node : graph_in_->nodes()) {
if (node->IsSource() || node->IsSink()) continue;
for (Node* node : graph_in_->op_nodes()) {
string func_id = GetFunctionNameAttr(node);
// Don't copy nodes that going to be encapsulated, unless parallel checking
@ -544,14 +558,16 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking,
Status EncapsulateSubgraphsInFunctions(
string group_attribute, const Graph& graph_in,
const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking,
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
FunctionLibraryDefinition* library) {
Status s;
Encapsulator encapsulator(std::move(group_attribute), &graph_in);
s = encapsulator.SplitIntoSubgraphs();
if (!s.ok()) return s;
s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, library);
s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn,
reuse_existing_functions, library);
if (!s.ok()) return s;
std::unique_ptr<Graph> out(new Graph(library));
@ -563,14 +579,29 @@ Status EncapsulateSubgraphsInFunctions(
return s;
}
// Finds the types of the _Arg nodes, indexed by position.
static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
for (Node* n : graph.op_nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
if (index < 0 || index >= types->size()) {
return errors::InvalidArgument("Invalid argument number");
}
(*types)[index] = n->output_type(0);
}
}
return Status::OK();
}
// Renumber the indices of _Arg nodes in a graph, according to
// 'permutation' that maps old indices to new indices.
static Status RenumberArguments(Graph* graph,
const std::vector<int>& permutation) {
for (Node* n : graph->nodes()) {
for (Node* n : graph->op_nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
if (index < 0 || index >= permutation.size()) {
return errors::InvalidArgument("Invalid argument number");
}
@ -604,19 +635,40 @@ Status EncapsulateSubgraphsPass::Run(
// Optimize the subgraph.
OptimizeGraph(flr.get(), subgraph);
std::vector<bool> const_args(input_permutation->size());
const int num_args = input_permutation->size();
std::vector<bool> const_args(num_args);
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
DataTypeVector arg_types(num_args);
TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
// Compute a permutation of the arguments such that the constant arguments
// are first.
const int num_consts =
std::count(const_args.begin(), const_args.end(), true);
const int num_resources =
std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
const int num_nonconsts = num_args - num_resources - num_consts;
if (num_nonconsts < 0) {
return errors::Internal("num_nonconsts should be >= 0, was ",
num_nonconsts);
}
int const_pos = 0;
int arg_pos = num_consts;
for (int i = 0; i < const_args.size(); ++i) {
int resource_pos = num_consts + num_nonconsts;
for (int i = 0; i < num_args; ++i) {
if (const_args[i]) {
if (arg_types[i] == DT_RESOURCE) {
return errors::Internal(
"Resource arguments cannot be constant (argument ", i, ")");
}
(*input_permutation)[i] = const_pos;
++const_pos;
} else if (arg_types[i] == DT_RESOURCE) {
(*input_permutation)[i] = resource_pos;
++resource_pos;
} else {
(*input_permutation)[i] = arg_pos;
++arg_pos;
@ -631,12 +683,14 @@ Status EncapsulateSubgraphsPass::Run(
AddNodeAttr(kXlaCompiledKernelAttr, true, node);
AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
return Status::OK();
};
TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
kXlaClusterAttr, **options.graph, rewrite_subgraph,
flags->tf_xla_parallel_checking, &graph_out, library));
flags->tf_xla_parallel_checking, /*reuse_existing_functions=*/false,
&graph_out, library));
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
@ -650,7 +704,7 @@ Status EncapsulateSubgraphsPass::Run(
bool IsXlaCompiledKernel(const Node& node) {
bool is_compiled = false;
bool has_compilation_attr =
GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
is_compiled;
return has_compilation_attr ? is_compiled : false;
}

View File

@ -34,6 +34,8 @@ namespace tensorflow {
// 'input_permutation' and 'output_permutation' are initialized to the identity
// permutation. 'nodedef' is the NodeDef for the call to the function under
// construction, provided to allow additional attributes to be set.
// The rewrite may also change the NodeDef's operator name, and that
// name will be used as the name of the generated function.
typedef std::function<Status(
std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
std::vector<int>* output_permutation, NodeDef* node_def)>
@ -53,6 +55,9 @@ typedef std::function<Status(
// output graph, together with a "ParallelCheck" operator, that verifies that
// the original and encapsulated subgraphs produce similar results.
//
// If 'reuse_existing_functions' is set, use an existing function with the
// same name, if any.
//
// TODO(phawkins): currently, some information in control edges
// is not preserved. Suppose you have A and B in the main
// graph, C and D in a subgraph. B and C have control deps from A, D has control
@ -61,7 +66,8 @@ typedef std::function<Status(
Status EncapsulateSubgraphsInFunctions(
string group_attribute, const Graph& graph_in,
const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking,
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
FunctionLibraryDefinition* library);
// The attribute that marks function calls produced by the encapsulate
// subgraphs pass and that should in turn be compiled via _XlaLaunch operators.
@ -70,12 +76,22 @@ extern const char* const kXlaCompiledKernelAttr;
// Does `node` have the kXlaCompiledKernelAttr attribute?
bool IsXlaCompiledKernel(const Node& node);
// Functions produce by the EncapsulateSubgraphs pass have their arguments
// ordered such that compile-time constant arguments are first in the argument
// order. The functions are annotated with the following attribute giving the
// number of constant arguments.
// Functions produced by the EncapsulateSubgraphs pass have their arguments in
// the order:
// 1) compile-time constant arguments, in host memory,
// 2) other arguments, in device memory.
// 3) resource variable arguments, in host memory. Note that only the resource
// Tensor itself is in host memory; the underlying value may be in device
// memory.
// The functions are annotated with the following attributes that describe how
// many constant and resource arguments there are:
// Name of the attribute containing the number of constant arguments.
extern const char* const kXlaNumConstantArgsAttr;
// Name of the attribute containing the number of resource variable arguments.
extern const char* const kXlaNumResourceArgsAttr;
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;

View File

@ -13,16 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <utility>
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/graph/equal_graph_def.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
namespace {
@ -76,7 +78,7 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
#define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \
do { \
string diff; \
EXPECT_TRUE(EqualFunctionDefLibrary(actual, expected, &diff)) \
EXPECT_TRUE(EqualFunctionDefLibrary(expected, actual, &diff)) \
<< diff << "\nActual: " << actual.DebugString(); \
} while (false)
@ -101,15 +103,15 @@ Node* Input(const GraphDefBuilder::Options& opts) {
}
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
return ops::UnaryOp("UnaryTest", a, opts);
return ops::UnaryOp("UnaryTest", std::move(a), opts);
}
Node* Binary(ops::NodeOut a, ops::NodeOut b,
const GraphDefBuilder::Options& opts) {
return ops::BinaryOp("BinaryTest", a, b, opts);
return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
}
Node* AddNLike(std::vector<ops::NodeOut> inputs,
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest",
@ -127,7 +129,7 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
opts.op_registry());
node_builder.Input(a).Attr("index", index);
node_builder.Input(std::move(a)).Attr("index", index);
return opts.FinalizeBuilder(&node_builder);
}
@ -144,8 +146,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
std::unique_ptr<Graph> graph_out;
s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
/* rewrite_subgraph_fn= */ {},
/* parallel_checking= */ false,
/*rewrite_subgraph_fn=*/{},
/*parallel_checking=*/false,
/*reuse_existing_functions=*/false,
&graph_out, lib_def.get());
if (!s.ok()) return s;
@ -168,7 +171,7 @@ TEST(EncapsulateSubgraphsTest, NoFunctions) {
GraphDef graphdef_in;
FunctionDefLibrary library_in;
builder.ToGraphDef(&graphdef_in);
TF_EXPECT_OK(builder.ToGraphDef(&graphdef_in));
*library_in.add_function() = test::function::XTimesTwo();
GraphDef graphdef_out = graphdef_in;
@ -195,7 +198,7 @@ TEST(EncapsulateSubgraphsTest, OneFunction) {
Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr(
"_encapsulate", "F1"));
Binary(a, d, b1.opts().WithName("E"));
b1.ToGraphDef(&graphdef);
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
@ -205,12 +208,12 @@ TEST(EncapsulateSubgraphsTest, OneFunction) {
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"input__0:float", "input__1:float"}, {"output__2:float"}, {},
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"c_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"input__0"}},
{{"c"}, "BinaryTest", {"input__1", "C:o:0"}, {}, {"C"}},
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
},
{{"output__2", "c:o:0"}});
{{"c_0_retval", "c:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
@ -224,7 +227,7 @@ TEST(EncapsulateSubgraphsTest, OneFunction) {
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Binary(a, call, b2.opts().WithName("E"));
b2.ToGraphDef(&graphdef_expected);
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
// If there are no marked nodes, funcification should be a no-op.
@ -251,7 +254,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) {
Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr(
"_encapsulate", "F2"));
Binary(a, d, b1.opts().WithName("E"));
b1.ToGraphDef(&graphdef);
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
@ -261,17 +264,17 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) {
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"input__0:float"}, {"output__1:float"}, {},
"F1", {"a_0_arg:float"}, {"c_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"input__0"}},
{{"C"}, "UnaryTest", {"a_0_arg"}},
},
{{"output__1", "C:o:0"}});
{{"c_0_retval", "C:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
"F2", {"input__0:float", "input__1:float"}, {"output__2:float"}, {},
"F2", {"b_0_arg:float", "c_0_arg:float"}, {"d_0_retval:float"}, {},
{
{{"D"}, "BinaryTest", {"input__0", "input__1"}},
{{"D"}, "BinaryTest", {"b_0_arg", "c_0_arg"}},
},
{{"output__2", "D:o:0"}});
{{"d_0_retval", "D:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
@ -290,7 +293,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) {
Node* call2 = b2.opts().FinalizeBuilder(&nb2);
Binary(a, call2, b2.opts().WithName("E"));
b2.ToGraphDef(&graphdef_expected);
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
// If there are no marked nodes, funcification should be a no-op.
@ -340,7 +343,8 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
/*parallel_checking=*/false, &graph, &library));
/*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph,
&library));
std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
EXPECT_EQ(expected_nodes, GraphNodes(*graph));
@ -371,7 +375,8 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) {
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
/*parallel_checking=*/true, &graph, &library));
/*parallel_checking=*/true, /*reuse_existing_functions=*/false, &graph,
&library));
std::vector<string> expected_nodes = {
"add1", "add2", "cluster1", "cluster1_parallel_check/_0",

View File

@ -120,14 +120,12 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
std::unordered_map<string, string> return_values;
NodeNameMapping node_names;
for (Node const* node : graph.nodes()) {
if (!node->IsOp()) continue;
for (Node const* node : graph.op_nodes()) {
if (node->type_string() == kArgOp) {
int index;
DataType type;
GetNodeAttr(node->def(), "T", &type);
GetNodeAttr(node->def(), "index", &index);
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
while (fdef->signature().input_arg_size() <= index) {
fdef->mutable_signature()->add_input_arg();
}
@ -143,8 +141,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
if (node->type_string() == kRetValOp) {
int index;
DataType type;
GetNodeAttr(node->def(), "T", &type);
GetNodeAttr(node->def(), "index", &index);
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
while (fdef->signature().output_arg_size() <= index) {
fdef->mutable_signature()->add_output_arg();
}
@ -161,9 +159,11 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
}
NodeDef* node_def = fdef->add_node_def();
node_def->CopyFrom(node->def());
*node_def = node->def();
if (!node->assigned_device_name().empty()) {
node_def->set_device(node->assigned_device_name());
}
node_def->set_name(node_names.Uniquify(node->name()));
node_def->clear_device();
// Reset input names based on graph rather than the NodeDef.
node_def->clear_input();
@ -185,7 +185,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
}
// Add regular inputs
for (int i = 0; i < in_edges.size(); ++i) {
for (std::vector<const Edge*>::size_type i = 0; i < in_edges.size(); ++i) {
const Edge* edge = in_edges[i];
if (edge == nullptr) {
return errors::InvalidArgument(
@ -204,8 +204,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
// Populate tensor_renaming.
NameRangeMap output_ranges;
TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr,
&output_ranges));
TF_RETURN_IF_ERROR(
NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
for (const auto& output : output_ranges) {
for (int i = output.second.first; i < output.second.second; ++i) {
const string tensor_name = strings::StrCat(

View File

@ -19,11 +19,11 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/graph/equal_graph_def.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
namespace {
@ -54,7 +54,7 @@ TEST(GraphToFunctionDefTest, Basics) {
auto h = ops::_Retval(root.WithOpName("H"), g, 0);
GraphDef graph_def;
root.ToGraphDef(&graph_def);
TF_EXPECT_OK(root.ToGraphDef(&graph_def));
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphConstructorOptions options;

View File

@ -76,7 +76,7 @@ struct GraphCycles::Rep {
GraphCycles::GraphCycles() : rep_(new Rep) {}
GraphCycles::~GraphCycles() {
for (int i = 0; i < rep_->nodes_.size(); i++) {
for (Vec<Node*>::size_type i = 0; i < rep_->nodes_.size(); i++) {
delete rep_->nodes_[i];
}
delete rep_;
@ -85,7 +85,7 @@ GraphCycles::~GraphCycles() {
bool GraphCycles::CheckInvariants() const {
Rep* r = rep_;
NodeSet ranks; // Set of ranks seen so far.
for (int32 x = 0; x < r->nodes_.size(); x++) {
for (Vec<Node*>::size_type x = 0; x < r->nodes_.size(); x++) {
Node* nx = r->nodes_[x];
if (nx->visited) {
LOG(FATAL) << "Did not clear visited marker on node " << x;
@ -108,7 +108,7 @@ int32 GraphCycles::NewNode() {
if (rep_->free_nodes_.empty()) {
Node* n = new Node;
n->visited = false;
n->data = NULL;
n->data = nullptr;
n->rank = rep_->nodes_.size();
rep_->nodes_.push_back(n);
return n->rank;
@ -116,7 +116,7 @@ int32 GraphCycles::NewNode() {
// Preserve preceding rank since the set of ranks in use must be
// a permutation of [0,rep_->nodes_.size()-1].
int32 r = rep_->free_nodes_.back();
rep_->nodes_[r]->data = NULL;
rep_->nodes_[r]->data = nullptr;
rep_->free_nodes_.pop_back();
return r;
}
@ -259,7 +259,7 @@ static void Reorder(GraphCycles::Rep* r) {
r->deltaf_.end(), r->merged_.begin());
// Assign the ranks in order to the collected list.
for (int32 i = 0; i < r->list_.size(); i++) {
for (Vec<int32>::size_type i = 0; i < r->list_.size(); i++) {
r->nodes_[r->list_[i]]->rank = r->merged_[i];
}
}
@ -277,7 +277,7 @@ static void Sort(const Vec<Node*>& nodes, Vec<int32>* delta) {
}
static void MoveToList(GraphCycles::Rep* r, Vec<int32>* src, Vec<int32>* dst) {
for (int32 i = 0; i < src->size(); i++) {
for (Vec<int32>::size_type i = 0; i < src->size(); i++) {
int32 w = (*src)[i];
(*src)[i] = r->nodes_[w]->rank; // Replace src entry with its rank
r->nodes_[w]->visited = false; // Prepare for future DFS calls
@ -286,7 +286,7 @@ static void MoveToList(GraphCycles::Rep* r, Vec<int32>* src, Vec<int32>* dst) {
}
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32>& nodes) {
for (int32 i = 0; i < nodes.size(); i++) {
for (Vec<int32>::size_type i = 0; i < nodes.size(); i++) {
r->nodes_[nodes[i]]->visited = false;
}
}
@ -332,7 +332,7 @@ int GraphCycles::FindPath(int32 x, int32 y, int max_path_len,
}
bool GraphCycles::IsReachable(int32 x, int32 y) const {
return FindPath(x, y, 0, NULL) > 0;
return FindPath(x, y, 0, nullptr) > 0;
}
bool GraphCycles::IsReachableNonConst(int32 x, int32 y) {

View File

@ -230,7 +230,7 @@ TEST(GraphCycles, RandomizedTest) {
int new_node = graph_cycles.NewNode();
ASSERT_NE(-1, new_node);
VLOG(1) << "adding node " << new_node;
ASSERT_EQ(0, graph_cycles.GetNodeData(new_node));
ASSERT_EQ(nullptr, graph_cycles.GetNodeData(new_node));
graph_cycles.SetNodeData(
new_node, reinterpret_cast<void *>(
static_cast<intptr_t>(new_node + kDataOffset)));
@ -243,7 +243,7 @@ TEST(GraphCycles, RandomizedTest) {
break;
case 1: // Remove a node
if (nodes.size() > 0) {
if (!nodes.empty()) {
int node_index = RandomNode(&rnd, &nodes);
int node = nodes[node_index];
nodes[node_index] = nodes.back();
@ -263,7 +263,7 @@ TEST(GraphCycles, RandomizedTest) {
break;
case 2: // Add an edge
if (nodes.size() > 0) {
if (!nodes.empty()) {
int from = RandomNode(&rnd, &nodes);
int to = RandomNode(&rnd, &nodes);
if (EdgeIndex(&edges, nodes[from], nodes[to]) == -1) {
@ -282,7 +282,7 @@ TEST(GraphCycles, RandomizedTest) {
break;
case 3: // Remove an edge
if (edges.size() > 0) {
if (!edges.empty()) {
int i = RandomEdge(&rnd, &edges);
int from = edges[i].from;
int to = edges[i].to;
@ -296,7 +296,7 @@ TEST(GraphCycles, RandomizedTest) {
break;
case 4: // Check a path
if (nodes.size() > 0) {
if (!nodes.empty()) {
int from = RandomNode(&rnd, &nodes);
int to = RandomNode(&rnd, &nodes);
int32 path[2 * kMaxNodes];
@ -343,7 +343,7 @@ TEST(GraphCycles, RandomizedTest) {
ASSERT_NE(-1, new_node);
VLOG(1) << "adding node " << new_node;
ASSERT_GE(new_node, 0);
ASSERT_EQ(0, graph_cycles.GetNodeData(new_node));
ASSERT_EQ(nullptr, graph_cycles.GetNodeData(new_node));
graph_cycles.SetNodeData(
new_node, reinterpret_cast<void *>(
static_cast<intptr_t>(new_node + kDataOffset)));

View File

@ -0,0 +1,74 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//tensorflow/compiler/tf2xla:internal",
],
)
cc_library(
name = "xla_local_launch_op",
srcs = ["xla_local_launch_op.cc"],
hdrs = ["xla_local_launch_op.h"],
deps = [
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:xla_compilation_cache",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_local_runtime_context",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:tensorflow_opensource",
],
alwayslink = 1,
)
cc_library(
name = "xla_device_launch_op",
srcs = ["xla_device_launch_op.cc"],
hdrs = ["xla_device_launch_op.h"],
deps = [
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:xla_compilation_cache",
"//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core/kernels:variable_ops",
],
)
cc_library(
name = "parallel_check_op",
srcs = ["parallel_check_op.cc"],
visibility = ["//tensorflow/compiler/jit:friends"],
deps = [
"//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
alwayslink = 1,
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,144 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace {
// Inputs 2*N tensors, outputs the first N inputs.
// Logs errors if input tensor i and i + N are not (near) identical
// in any position.
class ParallelCheckOp : public OpKernel {
public:
explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
template <typename T>
int CompareTensors(DataType dtype, const char* v0, const char* v1,
int64 num_elts, int input_idx) {
int failed = 0;
const T* p0 = reinterpret_cast<const T*>(v0);
const T* p1 = reinterpret_cast<const T*>(v1);
double rtol;
legacy_flags::ParallelCheckOpFlags* flags =
legacy_flags::GetParallelCheckOpFlags();
if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(),
&rtol)) {
LOG(ERROR) << "can't convert parallel_check_rtol "
<< flags->parallel_check_rtol << " to double";
}
double atol;
if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(),
&atol)) {
LOG(ERROR) << "can't convert parallel_check_atol "
<< flags->parallel_check_atol << " to double";
}
for (int i = 0; i < num_elts; ++i) {
bool ok = (p0[i] == p1[i]);
VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i];
if (!ok) {
if (std::is_same<T, float>::value || std::is_same<T, double>::value) {
float tolerance =
std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i])));
T diff = p0[i] - p1[i];
if (diff < 0) diff = 0 - diff;
ok = (diff <= tolerance);
}
if (ok) continue;
LOG(ERROR) << "Op " << def().name() << " fails equality at output "
<< input_idx << " type " << DataTypeString(dtype)
<< " element " << i << ": std_val=" << p0[i]
<< " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]);
if (++failed > 10) break;
}
}
return failed;
}
void Compute(OpKernelContext* ctx) override {
VLOG(1) << "Compute " << def().name();
const int num_pairs = ctx->num_inputs() / 2;
for (int i = 0; i < num_pairs; ++i) {
CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs));
Tensor t0 = ctx->input(i);
Tensor t1 = ctx->input(i + num_pairs);
int64 num_elts = t0.NumElements();
CHECK_EQ(num_elts, t1.NumElements());
// Compare inputs elementwise for near-exact equality.
const char* v0 = t0.tensor_data().data();
const char* v1 = t1.tensor_data().data();
int failed = 0;
switch (ctx->input_dtype(i)) {
case DT_INT32:
failed =
CompareTensors<int32>(ctx->input_dtype(i), v0, v1, num_elts, i);
break;
case DT_INT64:
failed =
CompareTensors<int64>(ctx->input_dtype(i), v0, v1, num_elts, i);
break;
case DT_FLOAT:
failed =
CompareTensors<float>(ctx->input_dtype(i), v0, v1, num_elts, i);
break;
case DT_DOUBLE:
failed =
CompareTensors<double>(ctx->input_dtype(i), v0, v1, num_elts, i);
break;
case DT_BOOL:
failed =
CompareTensors<bool>(ctx->input_dtype(i), v0, v1, num_elts, i);
break;
default:
LOG(FATAL) << "unimpl: " << ctx->input_dtype(i);
}
if (failed > 0) {
LOG(ERROR) << "check failed for " << def().name() << " output " << i
<< " num_elts: " << num_elts;
legacy_flags::ParallelCheckOpFlags* flags =
legacy_flags::GetParallelCheckOpFlags();
if (flags->parallel_check_failfast) {
LOG(QFATAL) << "failfast on first parallel-check failure";
}
} else {
VLOG(1) << "check passed for " << def().name() << " output " << i
<< " num_elts: " << num_elts;
}
// Propagate the std value.
if (IsRefType(ctx->input_dtype(i))) {
ctx->forward_ref_input_to_ref_output(i, i);
} else {
ctx->set_output(i, ctx->input(i));
}
}
}
TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp);
};
REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU),
ParallelCheckOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,253 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
namespace {
Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** cache) {
XlaDevice::Metadata* metadata;
Status s = rm->Lookup<XlaDevice::Metadata>(rm->default_container(),
"xla_metadata", &metadata);
if (!s.ok()) {
return s;
}
core::ScopedUnref metadata_ref(metadata);
*cache =
new XlaCompilationCache(metadata->client(), metadata->jit_device_type());
return Status::OK();
}
} // namespace
XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
const NameAttrList* func;
OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
function_ = *func;
VLOG(1) << "XlaDeviceLaunch created function="
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
DataTypeVector constant_types;
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
num_constant_args_ = constant_types.size();
OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
}
std::vector<OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
int num_variables) {
std::vector<OptionalTensor> snapshot(num_variables);
int first_variable = ctx->num_inputs() - num_variables;
for (int i = 0; i < num_variables; ++i) {
Var* variable = nullptr;
ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
if (LookupResource(ctx, handle, &variable).ok()) {
mutex_lock lock(*variable->mu());
snapshot[i].name = handle.name();
snapshot[i].present = true;
snapshot[i].value = *variable->tensor();
}
}
return snapshot;
}
void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "XlaDeviceLaunch::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
ResourceMgr* rm = ctx->resource_manager();
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
XlaCompilationCache* cache;
OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_compiler", &cache,
[rm](XlaCompilationCache** cache) {
return BuildCompilationCache(rm, cache);
}));
// Holds the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
std::vector<OptionalTensor> variables =
SnapshotResourceVariables(ctx, num_resource_args_);
XlaCompiler::Options options;
options.client = cache->client();
options.device_type = &cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = false;
options.local_executable_has_hybrid_result = false;
const XlaCompiler::CompilationResult* kernel;
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_,
variables, ctx, &kernel, nullptr));
VLOG(1) << "XLA compilation complete...";
OP_REQUIRES(ctx, ctx->num_outputs() == kernel->outputs.size(),
errors::Internal("Unexpected number of outputs"));
// Runs the computation, if any. There might not be a computation if all
// outputs were compile-time constants.
std::vector<std::unique_ptr<xla::GlobalData>> outputs;
if (!kernel->computation->IsNull()) {
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
// Builds the inputs to the computation.
std::vector<std::shared_ptr<xla::GlobalData>> arg_handles(
kernel->input_mapping.size());
std::vector<xla::GlobalData*> arg_ptrs(kernel->input_mapping.size());
// Adds the argument tensors.
const int first_variable_arg = ctx->num_inputs() - num_resource_args_;
for (int i = 0; i < kernel->input_mapping.size(); ++i) {
int op_input_num = kernel->input_mapping[i];
if (op_input_num >= first_variable_arg) {
arg_handles[i] = XlaTransferManager::GetTensorGlobalData(
variables[op_input_num - first_variable_arg].value);
} else {
arg_handles[i] =
XlaTransferManager::GetTensorGlobalData(ctx->input(op_input_num));
}
arg_ptrs[i] = arg_handles[i].get();
}
// Execute the computation.
xla::ExecutionProfile profile;
xla::ExecutionOptions execution_options;
*execution_options.mutable_shape_with_output_layout() =
kernel->xla_output_shape;
Env* env = Env::Default();
auto start_time = env->NowMicros();
VLOG(1) << "Executing XLA Computation...";
auto result = cache->client()->Execute(*kernel->computation, arg_ptrs,
&execution_options, &profile);
auto elapsed = env->NowMicros() - start_time;
OP_REQUIRES(ctx, result.ok(), result.status());
VLOG(1) << "Elapsed time: " << elapsed << "us";
VLOG(1) << "ExecutionProfile: " << profile.DebugString();
if (xla::ShapeUtil::IsTuple(kernel->xla_output_shape)) {
auto outputs_or_error =
cache->client()->DeconstructTuple(*result.ValueOrDie());
OP_REQUIRES(ctx, outputs_or_error.ok(), outputs_or_error.status());
outputs = outputs_or_error.ConsumeValueOrDie();
} else {
outputs.push_back(result.ConsumeValueOrDie());
}
}
XlaDeviceContext* device_context = ctx->op_device_context<XlaDeviceContext>();
// Copy XLA outputs to the operator's outputs.
VLOG(2) << "Setting operator output";
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
Tensor* output;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(i, kernel->outputs[i].shape, &output));
if (kernel->outputs[i].is_constant) {
// TODO(phawkins): mark constant _XlaLaunch outputs as HostMemory and
// remove the copy from this code.
Status status;
device_context->CopyCPUTensorToDevice(
&kernel->outputs[i].constant_value, nullptr, output,
[&status](const Status& s) { status = s; });
if (!status.ok()) {
ctx->SetStatus(status);
return;
}
} else {
CHECK_LT(output_num, outputs.size());
XlaTransferManager::SetTensorGlobalData(
std::shared_ptr<xla::GlobalData>(std::move(outputs[output_num])),
output);
++output_num;
}
}
// Apply variable updates, if any.
VLOG(2) << "Applying variable updates";
for (int i = 0; i < kernel->variable_updates.size(); ++i) {
const XlaCompiler::VariableUpdate& write = kernel->variable_updates[i];
OP_REQUIRES(ctx,
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
errors::Internal("Invalid input index for variable write."));
// This code is very close to being a clone of AssignVariableOp, but the
// key difference is that the contents of an XLA device tensor cannot be
// copied safely; instead we must use
// XlaTransferManager::SetTensorGlobalData.
Var* variable = nullptr;
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not
// a Tensor.
OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, write.input_index),
&variable, [this, ctx, &write](Var** ptr) {
*ptr = new Var(write.type);
PersistentTensor unused;
Tensor* tmp;
TF_RETURN_IF_ERROR(ctx->allocate_persistent(
write.type, write.shape, &unused, &tmp));
*(*ptr)->tensor() = *tmp;
return Status::OK();
}));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu());
OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
errors::Internal("Mismatched type in variable write"));
if (!variable->tensor()->shape().IsSameSize(write.shape)) {
PersistentTensor unused;
Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->allocate_persistent(write.type, write.shape,
&unused, &tmp));
*variable->tensor() = *tmp;
}
XlaTransferManager::SetTensorGlobalData(
std::shared_ptr<xla::GlobalData>(std::move(outputs[output_num])),
variable->tensor());
++output_num;
}
VLOG(1) << "Done";
}
XlaDeviceLaunchOp::~XlaDeviceLaunchOp() {
VLOG(1) << "XlaDeviceLaunch destroyed";
}
} // namespace tensorflow

View File

@ -0,0 +1,65 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_DEVICE_LAUNCH_OP_H_
#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_DEVICE_LAUNCH_OP_H_
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// Takes a snapshot of the values of resource variable arguments, which are
// the last `num_variables` arguments. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
std::vector<OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
int num_variables);
// The XlaDeviceLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaDeviceLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
// Once all inputs are present, and their shapes are known, the op can
// use a 'TlaJit' to compile and execute code which is specific
// to the shapes of input Tensors.
class XlaDeviceLaunchOp : public OpKernel {
public:
explicit XlaDeviceLaunchOp(OpKernelConstruction* ctx);
~XlaDeviceLaunchOp() override;
void Compute(OpKernelContext* ctx) override;
private:
NameAttrList function_;
// Number of compile-time constant arguments.
int num_constant_args_;
// Number of resource variable arguments.
int num_resource_args_;
Tensor dummy_tensor_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceLaunchOp);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_DEVICE_LAUNCH_OP_H_

View File

@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_local_launch_op.h"
#include "tensorflow/compiler/jit/kernels/xla_local_launch_op.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -37,20 +38,9 @@ namespace gpu = perftools::gputools;
namespace tensorflow {
REGISTER_OP("_XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
.Input("args: Targs")
.Attr("Targs: list(type) >= 0")
.Output("results: Tresults")
.Attr("Tresults: list(type) >= 0")
.Attr("function: func")
// XLA random-number generation ops are stateful.
// TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch.
.SetIsStateful()
.Doc("XLA Launch Op. For use by the XLA JIT only.");
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
// see comment on `AllowsAsynchronousDeallocation()`.
class XlaAllocator : public xla::DeviceMemoryAllocator {
public:
XlaAllocator(const perftools::gputools::Platform* platform,
@ -66,6 +56,15 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype,
const TensorShape& shape, Tensor* tensor) const;
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
// before GPU execution takes place. Tensorflow uses the ordering of the main
// compute stream to enforce a happens-before relationship between a memory
// allocation and code that reuses the same memory. If Tensorflow adds
// support for multiple GPU streams or allocators with different ordering
// requirements, this code may need to change.
// (This attribute has no effect on CPU.)
bool AllowsAsynchronousDeallocation() const override { return true; }
private:
OpKernelContext* const op_context_;
@ -143,45 +142,51 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
DataTypeVector constant_types;
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
num_constant_args_ = constant_types.size();
int num_resource_args;
OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args));
OP_REQUIRES(ctx, num_resource_args == 0,
errors::Unimplemented(
"XlaLocalLaunchOp does not support resource variables"));
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id_ = gpu::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
platform_id_ = gpu::cuda::kCudaPlatformId;
} else {
ctx->SetStatus(
errors::InvalidArgument("Unknown device type for local _XlaLaunch"));
return;
}
}
Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) {
gpu::Platform::Id platform_id;
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id = gpu::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
platform_id = gpu::cuda::kCudaPlatformId;
} else {
return errors::InvalidArgument("Unknown device type for local _XlaLaunch");
}
auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id);
Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
XlaCompilationCache** cache) {
auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_);
if (!platform.ok()) {
return StreamExecutorUtil::ConvertStatus(platform.status());
}
auto client =
xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie());
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
client_options.set_intra_op_parallelism_threads(
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
}
const string* compiler_device;
if (!XlaOpRegistry::GetJitDevice(device_type_.type(), &compiler_device,
/*requires_jit=*/nullptr)) {
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(),
&registration)) {
return errors::InvalidArgument("No JIT device registered for ",
device_type_.type());
}
XlaCompiler::Options options;
options.device_type = DeviceType(*compiler_device);
options.client = client.ValueOrDie();
options.allow_cpu_custom_calls = (platform_id == gpu::host::kHostPlatformId);
options.local_executable_has_hybrid_result = true;
*compiler = new XlaCompilationCache(options);
*cache = new XlaCompilationCache(
client.ValueOrDie(), DeviceType(registration->compilation_device_name));
return Status::OK();
}
void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "XlaLocalLaunchOp::Compute "
<< Canonicalize(function_.name(), function_.attr());
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
ResourceMgr* rm = ctx->resource_manager();
@ -190,25 +195,31 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
gpu::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
XlaCompilationCache* compiler;
OP_REQUIRES_OK(ctx,
rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_compiler", &compiler,
[this](XlaCompilationCache** compiler) {
return BuildCompilationCache(compiler);
}));
XlaCompilationCache* cache;
OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
[this, ctx](XlaCompilationCache** cache) {
return BuildCompilationCache(ctx, cache);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
// this is more obviously correct.)
core::ScopedUnref compiler_ref(compiler);
core::ScopedUnref cache_ref(cache);
xla::LocalClient* client = static_cast<xla::LocalClient*>(compiler->client());
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
XlaCompiler::Options options;
options.client = client;
options.device_type = &cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
options.local_executable_has_hybrid_result = true;
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
OP_REQUIRES_OK(ctx,
compiler->Compile(function_, num_constant_args_, ctx, &kernel,
&executable));
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, {},
ctx, &kernel, &executable));
VLOG(1) << "Executing XLA Computation...";
@ -218,7 +229,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
std::unique_ptr<xla::ShapedBuffer> output;
bool output_is_tuple;
if (!kernel->computation.IsNull()) {
if (!kernel->computation->IsNull()) {
// Build xla::ShapedBuffers that point directly to the Tensor buffers.
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
arg_buffers.reserve(kernel->xla_input_shapes.size() + 1);
@ -227,8 +238,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
// Pass remaining parameters.
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
int arg_num = kernel->xla_input_shapes[i].first;
const xla::Shape& shape = kernel->xla_input_shapes[i].second;
int arg_num = kernel->input_mapping[i];
const xla::Shape& shape = kernel->xla_input_shapes[i];
gpu::DeviceMemoryBase dmem(
const_cast<char*>(ctx->input(arg_num).tensor_data().data()),
ctx->input(arg_num).tensor_data().size());
@ -316,10 +327,9 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
}
Tensor output_tensor;
// Looks up the owning Tensor by buffer address.
OP_REQUIRES_OK(
ctx,
xla_allocator.MakeTensorFromBuffer(
buffer, ctx->expected_output_dtype(i), shape, &output_tensor));
OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer(
buffer, ctx->expected_output_dtype(i), shape,
&output_tensor));
ctx->set_output(i, output_tensor);
++output_num;
}

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
#define TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_
#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/core/framework/allocator.h"
@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/stream_executor_util.h"
namespace tensorflow {
@ -31,8 +32,9 @@ namespace tensorflow {
// Once all inputs are present, and their shapes are known, the op can
// use a 'XlaCompilationCache' to compile and execute code which is specific
// to the shapes of input Tensors.
// XlaLocalLaunchOp uses xla::LocalClient::ExecuteLocally and passes
// arguments into/out of XLA in device memory.
// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
// memory.
class XlaLocalLaunchOp : public OpKernel {
public:
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
@ -42,14 +44,18 @@ class XlaLocalLaunchOp : public OpKernel {
private:
// Builds a XlaCompilationCache class suitable for the current device.
Status BuildCompilationCache(XlaCompilationCache** compiler);
Status BuildCompilationCache(OpKernelContext* ctx,
XlaCompilationCache** compiler);
DeviceType device_type_;
NameAttrList function_;
int num_constant_args_;
perftools::gputools::Platform::Id platform_id_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_

View File

@ -24,8 +24,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
@ -50,22 +51,24 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
}
// Make sure we don't recurse infinitely on recursive functions.
const int kMaxRecursionDepth = 5;
const int kMaxRecursionDepth = 10;
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
int depth, FunctionLibraryRuntime* lib_runtime);
bool IsCompilableCall(const NodeDef& call_def,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime);
// Tests whether 'while_def' is a completely compilable loop.
// Tests whether 'while_node' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a
// while loop to be compilable.
bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type,
int depth, FunctionLibraryRuntime* lib_runtime) {
VLOG(2) << "Loop marking: " << while_def.op();
bool IsCompilableWhile(const Node& while_node,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
VLOG(2) << "Loop marking: " << while_node.type_string();
const NameAttrList* name_attr;
NodeDef call;
Status status;
status = GetNodeAttr(while_def, "cond", &name_attr);
status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
if (!status.ok()) {
VLOG(2) << "Missing 'cond' attribute on While node.";
return false;
@ -78,7 +81,7 @@ bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type,
VLOG(2) << "Can't compile loop condition: " << cond_func;
return false;
}
status = GetNodeAttr(while_def, "body", &name_attr);
status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
if (!status.ok()) {
VLOG(2) << "Missing 'body' attribute on While node.";
return false;
@ -98,8 +101,9 @@ bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type,
// Tests whether 'call_def' is a call to a completely compilable function.
// Every operator in the function must be compilable for a function to be
// compilable.
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
int depth, FunctionLibraryRuntime* lib_runtime) {
bool IsCompilableCall(const NodeDef& call_def,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
VLOG(2) << "Function marking: " << call_def.op();
if (depth > kMaxRecursionDepth) {
@ -109,21 +113,32 @@ bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
FunctionLibraryRuntime::Handle handle;
Status status =
lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle);
lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle);
if (!status.ok()) {
VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
return false;
}
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
CHECK(fbody);
const FunctionDef& fdef = fbody->fdef;
bool noinline = false;
if (GetNodeAttr(AttrSlice(&fdef.attr()), "_noinline", &noinline).ok() &&
noinline) {
// The underlying mechanism that calls non-inlined functions uses
// LocalExecutor, which interacts poorly with the LocalExecutor used by
// tf2xla to translate the TF graph into XLA. So we avoid this for now.
//
// TODO(b/36139787): Create a mechanism to set inlining hints.
VLOG(2) << "Can't compile noinline function: " << fdef.DebugString();
return false;
}
for (Node* node : fbody->graph->nodes()) {
if (node->IsSource() || node->IsSink()) continue;
if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue;
if (node->def().op() == "While") {
for (Node* node : fbody->graph->op_nodes()) {
if (node->type_string() == "_Arg" || node->type_string() == "_Retval")
continue;
if (node->type_string() == "While") {
// Handle functional While loop (not in open source build).
return IsCompilableWhile(node->def(), jit_device_type, depth + 1,
lib_runtime);
return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime);
}
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, depth + 1,
@ -147,6 +162,12 @@ Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
return Status::OK();
}
// Does `node` have a DT_RESOURCE typed argument?
bool HasResourceArgument(const Node& node) {
return std::find(node.input_types().begin(), node.input_types().end(),
DT_RESOURCE) != node.input_types().end();
}
Status FindCompilationCandidates(
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
@ -155,28 +176,30 @@ Status FindCompilationCandidates(
std::unique_ptr<FunctionLibraryRuntime> lib_runtime(NewFunctionLibraryRuntime(
nullptr, env, nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts));
for (Node* node : graph.nodes()) {
if (node->IsSource() || node->IsSink()) continue;
for (Node* node : graph.op_nodes()) {
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceTypeOfDevice(node->assigned_device_name(), &device_type));
if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue;
const string* jit_device_name;
CHECK(XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device_name,
/*requires_jit=*/nullptr));
DeviceType jit_device_type(*jit_device_name);
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration));
DeviceType jit_device_type(registration->compilation_device_name);
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) {
VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
<< ": " << node->def().op();
<< ": " << node->type_string();
continue;
}
if (node->def().op() == "While" &&
!IsCompilableWhile(node->def(), jit_device_type, 0,
lib_runtime.get())) {
if (!registration->compile_resource_ops && HasResourceArgument(*node)) {
VLOG(2) << "Compilation rejected node: resource argument " << node->name()
<< ": " << node->type_string();
continue;
}
if (node->type_string() == "While" &&
!IsCompilableWhile(*node, jit_device_type, 0, lib_runtime.get())) {
continue;
}
candidates->insert(node);
@ -184,85 +207,27 @@ Status FindCompilationCandidates(
return Status::OK();
}
// Union-Find data structure used to compute clusters. We use our own
// implementation because we want one key feature: when merging clusters, we
// need to know which value becomes the representative of the merged clusters.
// We use the representatives to name nodes in a cycle detection graph, and we
// need to control which node is named.
// TODO(phawkins): consider merging this code with union-find implementations
// in Tensorflow, e.g., in SimplePlacer.
class Cluster {
public:
Cluster();
int Size() { return FindRoot()->size_; }
// Merges this cluster with 'other'. This cluster's representative becomes
// the representative of the merged cluster; the representative of 'other'
// is ignored.
void Merge(Cluster* other);
// Each cluster has an associated integer 'representative', initialized to -1
// by default.
int GetRepresentative() { return FindRoot()->representative_; }
void SetRepresentative(int representative) {
FindRoot()->representative_ = representative;
}
private:
// Finds the root element of the cluster. Performs path compression.
Cluster* FindRoot();
int representative_;
int rank_;
int size_; // Size of the cluster.
Cluster* parent_;
struct Cluster {
// Identifies the node that represents this cluster in the cycle detection
// graph.
int representative = -1;
};
Cluster::Cluster()
: representative_(-1), rank_(0), size_(1), parent_(nullptr) {}
void Cluster::Merge(Cluster* other) {
Cluster* a = FindRoot();
Cluster* b = other->FindRoot();
if (a == b) return;
if (a->rank_ > b->rank_) {
b->parent_ = a;
a->size_ += b->size_;
return;
}
a->parent_ = b;
if (a->rank_ == b->rank_) {
b->rank_++;
}
b->representative_ = a->representative_;
b->size_ += a->size_;
}
Cluster* Cluster::FindRoot() {
if (!parent_) return this;
// Path compression: update intermediate nodes to point to the root of the
// equivalence class.
parent_ = parent_->FindRoot();
return parent_;
}
} // anonymous namespace
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
Device* device = flr->device();
const string* jit_device_name;
CHECK(XlaOpRegistry::GetJitDevice(device->device_type(), &jit_device_name,
/*requires_jit=*/nullptr));
DeviceType jit_device_type(*jit_device_name);
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
&registration));
DeviceType jit_device_type(registration->compilation_device_name);
return IsCompilableCall(ndef, jit_device_type, 0, flr);
}
Status MarkForCompilationPass::Run(
const GraphOptimizationPassOptions& options) {
// TODO(phawkins): precompute the "GetJitDevice" properties each device ahead
// of time.
// TODO(phawkins): precompute the "GetCompilationDevice" properties of each
// device ahead of time.
OptimizerOptions::GlobalJitLevel global_jit_level =
options.session_options->config.graph_options()
.optimizer_options()
@ -283,25 +248,24 @@ Status MarkForCompilationPass::Run(
const FunctionLibraryDefinition* fld = options.flib_def;
auto is_compilable = [global_jit_level, fld](const Node* node,
const DeviceType& device_type) {
const string* jit_device;
bool requires_jit;
if (!XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device,
&requires_jit)) {
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
&registration)) {
return false;
}
// If this device requires a JIT, we must say yes.
if (requires_jit) return true;
if (registration->requires_compilation) return true;
// If there is a _XlaCompile annotation, use its value.
bool compile = false;
Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile);
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
if (status.ok()) return compile;
status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile);
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
if (status.ok()) return compile;
// Otherwise use the value of global_jit_level.
return global_jit_level > 0;
return registration->enable_jit_by_default && global_jit_level > 0;
};
return RunImpl(options, is_compilable);
}
@ -323,7 +287,7 @@ Status MarkForCompilationPass::RunImpl(
VLOG(1) << "MarkForCompilationPass::Run";
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterJitKernels();
XlaOpRegistry::RegisterCompilationKernels();
Graph* graph = options.graph->get();
@ -411,10 +375,11 @@ Status MarkForCompilationPass::RunImpl(
// Each compilation candidate belongs to a cluster. The cluster's
// representative
// names the node in the 'cycles' graph that represents the cluster.
std::vector<Cluster> clusters(graph->num_node_ids());
std::deque<Cluster*> worklist;
std::vector<UnionFind<Cluster>> clusters(graph->num_node_ids());
std::deque<UnionFind<Cluster>*> worklist;
for (Node* node : compilation_candidates) {
clusters[node->id()].SetRepresentative(node->id());
Cluster& cluster = clusters[node->id()].Get();
cluster.representative = node->id();
worklist.push_back(&clusters[node->id()]);
}
@ -424,15 +389,19 @@ Status MarkForCompilationPass::RunImpl(
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle.
while (!worklist.empty()) {
int from = worklist.front()->GetRepresentative();
int from = worklist.front()->Get().representative;
worklist.pop_front();
Node* node_from = graph->FindNodeId(from);
if (node_from->IsControlFlow()) {
// Control flow nodes aren't compilation candidates and should never
// appear.
return errors::Internal("Found control flow node in clustering worklist");
return errors::Internal(
"Found control flow node in clustering worklist: ",
node_from->type_string());
}
string from_scope;
string to_scope;
for (int to : cycles.Successors(from)) {
if (to >= graph->num_node_ids()) {
// Node is a "frame" node that is present only in the cycle detection
@ -440,10 +409,27 @@ Status MarkForCompilationPass::RunImpl(
continue;
}
Node* node_to = graph->FindNodeId(to);
if (compilation_candidates.find(node_to) == compilation_candidates.cend())
if (compilation_candidates.find(node_to) ==
compilation_candidates.cend()) {
continue;
if (node_from->assigned_device_name() != node_to->assigned_device_name())
}
if (node_from->assigned_device_name() !=
node_to->assigned_device_name()) {
continue;
}
// Look for an _XlaScope on both nodes. If both nodes have a
// scope and the scopes do not match, do not cluster along this
// edge. If even one of the nodes lacks an _XlaScope attribute,
// then it is treated as a "bridge" and a cluster may be created
// along it. We may want to restrict this behavior to require
// all nodes marked with _XlaCompile=true to also have a
// _XlaScope property set (and raise an error otherwise); but
// for now we don't do this.
if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() &&
GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() &&
from_scope != to_scope) {
continue;
}
// Ops that consume shapes cannot be the root of a cluster. This is an
// optimization.
@ -476,7 +462,7 @@ Status MarkForCompilationPass::RunImpl(
// Count the number of elements in each cluster.
std::vector<int> cluster_sizes(graph->num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].GetRepresentative();
int cluster = clusters[n->id()].Get().representative;
cluster_sizes[cluster]++;
}
@ -490,32 +476,30 @@ Status MarkForCompilationPass::RunImpl(
// if compilation is enabled, otherwise there will be no such candidates).
const int min_cluster_size = flags->tf_xla_min_cluster_size;
for (Node* n : compilation_candidates) {
int cluster = clusters[n->id()].GetRepresentative();
int cluster = clusters[n->id()].Get().representative;
// Compile if the user marked this node _XlaCompile=true
bool compile_attr = false;
bool marked_for_compilation = false;
if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) {
if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) {
marked_for_compilation = compile_attr;
} else if (options.flib_def
->GetAttr(n->def(), kXlaCompileAttr, &compile_attr)
} else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr)
.ok()) {
marked_for_compilation = compile_attr;
}
// Compile if this operator is placed on a device that requires
// compilation.
bool requires_jit = false;
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceTypeOfDevice(n->assigned_device_name(), &device_type));
XlaOpRegistry::GetJitDevice(device_type.type(),
/*jit_device_name=*/nullptr, &requires_jit);
const XlaOpRegistry::DeviceRegistration* registration;
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration);
// Or compile if this is a cluster of >= min_cluster_size compilable
// operators.
if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation ||
requires_jit) {
registration->requires_compilation) {
string& name = cluster_names[cluster];
if (name.empty()) {
name = strings::StrCat("cluster_", cluster_sequence_num++);

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
@ -56,7 +57,7 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
std::unordered_map<string, string> ids;
for (Node* node : graph.nodes()) {
string cluster;
if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) {
if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) {
CHECK(!cluster.empty());
ids[node->name()] = cluster;
}
@ -77,7 +78,7 @@ TEST(XlaCompilationTest, Chains) {
ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
builder.ToGraph(graph.get());
TF_EXPECT_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph);
@ -102,7 +103,7 @@ TEST(XlaCompilationTest, UncompilableCycles) {
Node* b =
ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
builder.ToGraph(graph.get());
TF_EXPECT_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph);
@ -122,7 +123,7 @@ TEST(XlaCompilationTest, CompilableCycles) {
.WithAttr("value", Tensor()));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
builder.ToGraph(graph.get());
TF_EXPECT_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph);
@ -145,7 +146,7 @@ TEST(XlaCompilationTest, UnsupportedTypes) {
.WithAttr("value", Tensor(DT_COMPLEX64, TensorShape())));
Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
builder.ToGraph(graph.get());
TF_EXPECT_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph);
@ -174,7 +175,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) {
concat_builder.Input(dim).Input({a, a}).Attr("N", 2);
builder.opts().FinalizeBuilder(&concat_builder);
builder.ToGraph(graph.get());
TF_EXPECT_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph);
@ -183,13 +184,20 @@ TEST(XlaCompilationTest, ConcatWithConstArg) {
}
TEST(XlaCompilationTest, FunctionCalls) {
FunctionDefLibrary flib;
*flib.add_function() = FunctionDefHelper::Define(
FunctionDef compilable = FunctionDefHelper::Define(
"CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
{{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
*flib.add_function() =
FunctionDef uncompilable =
FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
{}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
FunctionDef noinline = compilable;
noinline.mutable_signature()->set_name("NoInlineFn");
AddAttr("_noinline", bool(true), noinline.mutable_attr());
FunctionDefLibrary flib;
*flib.add_function() = compilable;
*flib.add_function() = uncompilable;
*flib.add_function() = noinline;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
std::unique_ptr<Graph> graph(new Graph(&flib_def));
@ -201,7 +209,8 @@ TEST(XlaCompilationTest, FunctionCalls) {
Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
builder.ToGraph(graph.get());
ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
TF_EXPECT_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph, &flib_def);
@ -212,6 +221,7 @@ TEST(XlaCompilationTest, FunctionCalls) {
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
EXPECT_TRUE(clusters.find("E") == clusters.cend());
}
// Metadata-only operators such as Shape/Rank/Size may not be the root of a
@ -231,8 +241,8 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
ops::UnaryOp("Shape", d, builder.opts().WithName("C"));
builder.ToGraph(graph.get());
ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
TF_EXPECT_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
@ -318,7 +328,7 @@ TEST(XlaCompilationTest, SymbolicGradients) {
d_builder.Input({c, c});
builder.opts().FinalizeBuilder(&d_builder);
builder.ToGraph(graph.get());
TF_EXPECT_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph);
@ -344,7 +354,7 @@ TEST(XlaCompilationTest, Loops) {
auto d = ops::Add(root.WithOpName("D"), c, exit);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
root.ToGraph(graph.get());
TF_EXPECT_OK(root.ToGraph(graph.get()));
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
@ -354,5 +364,96 @@ TEST(XlaCompilationTest, Loops) {
EXPECT_EQ(0, clusters.size());
}
TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor())
.WithAttr(kXlaScopeAttr, "ScopeA"));
Node* b = ops::UnaryOp(
"Relu", a,
builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
ops::BinaryOp(
"MatMul", a, b,
builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
TF_CHECK_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
// The computation is: C = A + relu(A)
// where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
// In this case, we cannot fuse anything, and there are no clusters.
EXPECT_EQ(0, clusters.size());
}
TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor())
.WithAttr(kXlaScopeAttr, "Scope1"));
Node* b = ops::UnaryOp(
"Relu", a,
builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "Scope1"));
Node* c = ops::BinaryOp(
"MatMul", a, b,
builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "Scope2"));
ops::BinaryOp(
"Add", b, c,
builder.opts().WithName("D").WithAttr(kXlaScopeAttr, "Scope2"));
TF_CHECK_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
// The computation is: D = relu(A) + (A @ relu(A))
// where A and relu(A) are in Scope1, and the @, + ops are in Scope2.
// In this case, we can fuse the A and relu(A), and we can fuse the
// second half of the operations; there are two clusters.
EXPECT_EQ(4, clusters.size());
EXPECT_EQ(clusters["A"], clusters["B"]);
EXPECT_NE(clusters["A"], clusters["C"]);
EXPECT_EQ(clusters["C"], clusters["D"]);
}
TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor())
.WithAttr(kXlaScopeAttr, "ScopeA"));
Node* b = ops::UnaryOp(
"Relu", a,
builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
TF_CHECK_OK(builder.ToGraph(graph.get()));
}
MarkForCompilation(&graph);
auto clusters = GetClusters(*graph);
// The computation is: C = A @ relu(A)
// where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
// In this case, we cannot fuse anything.
EXPECT_EQ(2, clusters.size());
EXPECT_NE(clusters["A"], clusters["B"]);
EXPECT_EQ(clusters["B"], clusters["C"]);
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,45 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//tensorflow/compiler/tf2xla:internal",
],
)
cc_library(
name = "xla_ops",
srcs = [
"xla_ops.cc",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
cc_library(
name = "parallel_check_op",
srcs = ["parallel_check_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,30 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("ParallelCheck")
.Attr("T: list(type) >= 0")
.Input("expected: T")
.Input("actual: T")
.Output("result: T")
.Doc(R"doc(
Op that compares two sets of inputs for near-identity, and propagates the first.
Inequality is logged to ERROR log.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,35 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("_XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
.Input("args: Targs")
.Attr("Targs: list(type) >= 0")
.Input("resources: Nresources * resource")
.Attr("Nresources: int >= 0")
.Output("results: Tresults")
.Attr("Tresults: list(type) >= 0")
.Attr("function: func")
// XLA random-number generation ops are stateful.
// TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch.
.SetIsStateful()
.Doc("XLA Launch Op. For use by the XLA JIT only.");
} // namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More