Merge branch 'master' into google-upstream-gpuprim

This commit is contained in:
ekuznetsov139 2020-03-16 12:50:58 -07:00 committed by GitHub
commit d25e90e5cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3922 changed files with 156243 additions and 80613 deletions
.bazelrc.bazelversion
.github/ISSUE_TEMPLATE
.gitignore.pylintrcREADME.mdWORKSPACEconfigure.py
tensorflow

View File

@ -46,7 +46,6 @@
# sycl_asan:
# sycl_trisycl:
# mkl: Enable full mkl support.
# mkl_open_source_only: Enable MKL support only using open source MKL libraries.
# tensorrt: Enable Tensorrt support.
# ngraph: Enable ngraph support.
# numa: Enable numa using hwloc.
@ -137,15 +136,9 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
# environment variable "TF_MKL_ROOT" every time before build.
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version.
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
@ -222,6 +215,11 @@ build --define=grpc_no_ares=true
# archives in -whole_archive -no_whole_archive.
build --noincompatible_remove_legacy_whole_archive
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
# https://github.com/tensorflow/community/pull/179
build --noincompatible_prohibit_aapt1
# Modular TF build options
build:dynamic_kernels --define=dynamic_loaded_kernels=true
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
@ -242,6 +240,7 @@ build:windows --copt=/w
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
# _USE_MATH_DEFINES is defined.
build:windows --copt=/D_USE_MATH_DEFINES
build:windows --host_copt=/D_USE_MATH_DEFINES
# Default paths for TF_SYSTEM_LIBS
build:linux --define=PREFIX=/usr
@ -314,22 +313,26 @@ build:xla --define=with_xla_support=true
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
# Options when using remote execution
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
# Flag to enable remote config
common --experimental_repo_remote_exec
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
build:rbe --auth_enabled=true
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
build:rbe --google_default_credentials
build:rbe --bes_backend=buildeventservice.googleapis.com
build:rbe --bes_best_effort=false
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
build:rbe --bes_timeout=600s
build:rbe --define=EXECUTOR=remote
build:rbe --distinct_host_configuration=false
build:rbe --flaky_test_attempts=3
build:rbe --jobs=200
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
build:rbe --remote_timeout=3600
build:rbe --spawn_strategy=remote,worker,standalone,local
test:rbe --test_env=USER=anon
build:rbe --distinct_host_configuration=false
# Attempt to minimize the amount of data transfer between bazel and the remote
# workers:
build:rbe --remote_download_toplevel
build:rbe_linux --config=rbe
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
@ -354,13 +357,14 @@ build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --config=rbe_linux
build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010-gpu"
build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0"
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
@ -377,18 +381,17 @@ build:rbe_linux_py2 --python_path="/usr/bin/python2"
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
build:rbe_linux_py3 --config=rbe_linux
build:rbe_linux_py3 --repo_env=PYTHON_BIN_PATH="/usr/bin/python3"
build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
@ -396,9 +399,7 @@ build:rbe_win --define=override_eigen_strong_inline=true
build:rbe_win --jobs=500
build:rbe_win_py37 --config=rbe
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python"
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
build:rbe_win_py38 --config=rbe
@ -416,7 +417,6 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
build:tensorflow_testing_rbe_win --config=rbe_win
# END TF REMOTE BUILD EXECUTION OPTIONS
# Default options should come above this line

View File

@ -1 +1 @@
1.2.1
2.0.0

44
.github/ISSUE_TEMPLATE/00-bug-issue.md vendored Normal file
View File

@ -0,0 +1,44 @@
---
name: Bug Issue
about: Use this template for reporting a bug
labels: 'type:bug'
---
<em>Please make sure that this is a bug. As per our
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

View File

@ -1,35 +0,0 @@
---
name: Bug/Performance Issue
about: Use this template for reporting a bug or a performance issue.
---
<em>Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
**Other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

View File

@ -1,7 +1,7 @@
--------------------------------------------------------------------------------
name: Build/Installation Issue about: Use this template for build/installation
issues labels: 'type:build/install'
---
name: Build/Installation Issue
about: Use this template for build/installation issues
labels: 'type:build/install'
---

View File

@ -1,10 +1,11 @@
---
name: Documentation Issue
about: Use this template for documentation related
about: Use this template for documentation related issues
labels: 'type:docs'
---
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
policy, we only address code/doc bugs, performance issues, feature requests, and
build/installation issues on GitHub.

View File

@ -1,6 +1,7 @@
---
name: Feature Request
about: Use this template for raising a feature request
labels: 'type:feature'
---

View File

@ -1,11 +1,10 @@
--------------------------------------------------------------------------------
name: TensorFlow Lite Op Request about: Use this template for reporting ops you
are using or missing. labels: 'comp:lite'
---
name: TensorFlow Lite Op Request
about: Use this template for reporting Lite ops you are using or missing
labels: 'comp:lite'
---
**System information**
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
@ -18,8 +17,14 @@ are using or missing. labels: 'comp:lite'
# Copy and paste here
```
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Also, please include a link to a GraphDef or the model if possible.
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
Include any logs or source code that would be helpful to diagnose the problem.
If including tracebacks, please include the full traceback. Large logs and files
should be attached.

View File

@ -1,7 +1,7 @@
--------------------------------------------------------------------------------
name: Other Issues about: Use this template for any other non-support related
issues labels: 'type:others'
---
name: Other Issues
about: Use this template for any other non-support related issues
labels: 'type:others'
---

View File

@ -1,6 +1,7 @@
---
name: TensorFlow Lite New Converter Issue
about: Use this template for reporting issues during model conversion to TFLite.
about: Use this template for reporting issues during model conversion to TFLite
labels: 'TFLiteConverter'
---
@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite.
**Command used to run the converter or code if youre using the Python API**
If possible, please share a link to Colab/Jupyter/any notebook.
```
# Copy and paste here the exact command

View File

@ -0,0 +1,45 @@
---
name: Performance Issue
about: Use this template for reporting a performance issue
labels: 'type:performance'
---
<em>Please make sure that this is an issue related to performance of TensorFlow.
As per our
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

5
.gitignore vendored
View File

@ -22,6 +22,7 @@ tensorflow/contrib/cmake/_build/
/tensorflow/python/framework/fast_tensor_util.cpp
/tensorflow/lite/gen/**
/tensorflow/lite/tools/make/downloads/**
/tensorflow/lite/tools/make/gen/**
/api_init_files_list.txt
/estimator_api_init_files_list.txt
*.whl
@ -37,7 +38,9 @@ gradleBuild
*.pbxproj
*.xcworkspace
/*.podspec
/tensorflow/lite/**/[ios|objc|swift]*/BUILD
/tensorflow/lite/**/ios/BUILD
/tensorflow/lite/**/objc/BUILD
/tensorflow/lite/**/swift/BUILD
/tensorflow/lite/examples/ios/simple/data/*.tflite
/tensorflow/lite/examples/ios/simple/data/*.txt
Podfile.lock

1
.pylintrc Symbolic link
View File

@ -0,0 +1 @@
tensorflow/tools/ci_build/pylintrc

View File

@ -70,7 +70,7 @@ $ python
3
>>> hello = tf.constant('Hello, TensorFlow!')
>>> hello.numpy()
'Hello, TensorFlow!'
b'Hello, TensorFlow!'
```
For more examples, see the
@ -130,18 +130,20 @@ Build Type | Status
## Resources
* [TensorFlow.org](https://www.tensorflow.org)
* [TensorFlow tutorials](https://www.tensorflow.org/tutorials/)
* [TensorFlow official models](https://github.com/tensorflow/models/tree/master/official)
* [TensorFlow examples](https://github.com/tensorflow/examples)
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
* [TensorFlow Examples](https://github.com/tensorflow/examples)
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [TensorFlow blog](https://blog.tensorflow.org)
* [TensorFlow Blog](https://blog.tensorflow.org)
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow)
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
* [TensorFlow roadmap](https://www.tensorflow.org/community/roadmap)
* [TensorFlow white papers](https://www.tensorflow.org/about/bib)
* [TensorBoard visualization toolkit](https://github.com/tensorflow/tensorboard)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)
Learn more about the
[TensorFlow community](https://www.tensorflow.org/community) and how to

View File

@ -113,3 +113,28 @@ http_archive(
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
],
)
# Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps()
load(
"@build_bazel_rules_apple//apple:repositories.bzl",
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
bazel_version_repository(name = "bazel_version")

View File

@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '1.2.1'
_TF_MAX_BAZEL_VERSION = '1.2.1'
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '2.0.0'
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -1390,9 +1390,8 @@ def main():
else:
environ_cp['TF_CONFIGURE_IOS'] = '0'
xla_enabled_by_default = is_linux() or is_macos()
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
xla_enabled_by_default, 'xla')
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
write_to_bazelrc('build --config=xla')
set_action_env_var(
environ_cp,

View File

@ -187,6 +187,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "fuchsia",
values = {"cpu": "fuchsia"},
visibility = ["//visibility:public"],
)
config_setting(
name = "ios_x86_64",
values = {
@ -448,19 +454,66 @@ config_setting(
visibility = ["//visibility:public"],
)
# Specifies via a config setting if this is a mobile build or not, makes
# it easier to combine settings later.
selects.config_setting_group(
name = "mobile",
match_any = [
":android",
":chromiumos",
":emscripten",
":ios",
],
)
config_setting(
name = "lite_protos_legacy",
values = {"define": "TENSORFLOW_PROTOS=lite"},
visibility = ["//visibility:private"],
)
config_setting(
name = "full_protos",
values = {"define": "TENSORFLOW_PROTOS=full"},
visibility = ["//visibility:public"],
)
selects.config_setting_group(
name = "lite_protos",
match_any = [":lite_protos_legacy"],
)
selects.config_setting_group(
name = "mobile_lite_protos",
match_all = [
":lite_protos",
":mobile",
],
)
selects.config_setting_group(
name = "mobile_full_protos",
match_all = [
":full_protos",
":mobile",
],
)
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
package_group(
name = "internal",
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
],
)
@ -494,8 +547,8 @@ cc_library(
name = "grpc",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc_unsecure"],
"//conditions:default": ["@grpc"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
}),
)
@ -503,8 +556,8 @@ cc_library(
name = "grpc++",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc++_unsecure"],
"//conditions:default": ["@grpc//:grpc++"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
}),
)
@ -909,7 +962,6 @@ py_library(
"//conditions:default": [":tf_python_api_gen_v1"],
}) + [
":root_init_gen",
":virtual_root_init_gen",
"//tensorflow/python/keras/api:keras_python_api_gen",
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v1",
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v2",

View File

@ -35,9 +35,11 @@ import inspect as _inspect
import logging as _logging
import os as _os
import site as _site
import six as _six
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# API IMPORTS PLACEHOLDER
@ -69,13 +71,13 @@ except ImportError:
_logging.warning(
"Limited tf.summary API due to missing TensorBoard installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from .python.keras.api._v2 import keras
@ -85,6 +87,13 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v2 import estimator
# pylint: enable=g-import-not-at-top
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top

View File

@ -22,12 +22,14 @@ import distutils as _distutils
import inspect as _inspect
import os as _os
import site as _site
import six as _six
import sys as _sys
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.platform import tf_logging as _logging
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# API IMPORTS PLACEHOLDER
@ -64,13 +66,14 @@ elif _tf_api_dir not in __path__:
# reexport_tf_summary can get compat from sys.modules. Only needed if using
# lazy loading.
_current_module.compat.v2 # pylint: disable=pointless-statement
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from .python.keras.api._v1 import keras
@ -80,6 +83,13 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v1 import estimator
# pylint: enable=g-import-not-at-top
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
_CONTRIB_WARNING = """

View File

@ -154,7 +154,10 @@ tf_cuda_library(
"c_api.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
visibility = [
"//tensorflow/c:__subpackages__",
"//third_party/llvm/llvm-project:__subpackages__",
],
deps = [
":c_api_internal",
":tf_attrtype",
@ -242,7 +245,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:framework",
@ -536,6 +539,7 @@ tf_cuda_cc_test(
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:math",
"//tensorflow/core/platform:resource_loader",
],
)
@ -697,4 +701,5 @@ tf_cuda_library(
# TODO(b/74620627): remove when _USE_C_SHAPES is removed
"//tensorflow/python:cpp_shape_inference_proto_cc",
],
alwayslink = 1,
)

View File

@ -774,7 +774,7 @@ extern "C" {
static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
const char* op_type,
const char* oper_name)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
return new TF_OperationDescription(graph, op_type, oper_name);
}
@ -1032,7 +1032,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
Node* ret = nullptr;
if (desc->graph->name_map.count(desc->node_builder.node_name())) {
@ -1706,7 +1706,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
const TF_ImportGraphDefOptions* opts,
TF_ImportGraphDefResults* tf_results,
TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
const int last_node_id = graph->graph.num_node_ids();
tensorflow::ImportGraphDefResults results;
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
@ -816,12 +817,15 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items;
NodeDef node_def;
node_def.set_name(tfe_op->operation.Name());
node_def.set_op(tfe_op->operation.Name());
node_def.set_name(tfe_op->operation->Name());
node_def.set_op(tfe_op->operation->Name());
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
tensorflow::down_cast<tensorflow::OperationInterface*>(
tfe_op->operation.get())
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());
const tensorflow::OpRegistrationData* op_reg_data;
status->status =

View File

@ -51,7 +51,7 @@ Status ProcessInputs(
const TF_Graph* fn_body, const char* fn_name, int ninputs,
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
Node* node = &inputs[i].oper->node;
@ -87,7 +87,7 @@ Status ProcessInputs(
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
int noutputs, const TF_Output* outputs,
std::vector<OutputTensor>* output_tensors)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
Node* node = &outputs[i].oper->node;
@ -111,7 +111,7 @@ Status ComputeBodyNodes(
const TF_Operation* const* opers,
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
std::vector<const Node*>* body_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
if (num_opers == -1) {
for (const Node* node : fn_body->graph.op_nodes()) {
const auto& iter = input_nodes.find(node);

View File

@ -71,14 +71,14 @@ struct TF_Graph {
TF_Graph();
tensorflow::mutex mu;
tensorflow::Graph graph GUARDED_BY(mu);
tensorflow::Graph graph TF_GUARDED_BY(mu);
// Runs shape inference.
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);
// Maps from name of an operation to the Node* in 'graph'.
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
GUARDED_BY(mu);
TF_GUARDED_BY(mu);
// The keys of this map are all the active sessions using this graph. Each
// value records whether the graph has been mutated since the corresponding
@ -94,8 +94,8 @@ struct TF_Graph {
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
// status, this should be reverted when possible.
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
GUARDED_BY(mu);
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
TF_GUARDED_BY(mu);
bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph
// Used to link graphs contained in TF_WhileParams to the parent graph that
// will eventually contain the full while loop.
@ -123,7 +123,7 @@ struct TF_Session {
tensorflow::Session* session;
TF_Graph* const graph;
tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu);
int last_num_graph_nodes;
// If true, TF_SessionRun and similar methods will call
@ -169,9 +169,9 @@ struct TF_ApiDefMap {
}
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock);
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
bool update_docs_called GUARDED_BY(lock);
bool update_docs_called TF_GUARDED_BY(lock);
tensorflow::mutex lock;
};
@ -210,10 +210,10 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
LOCKS_EXCLUDED(session->graph->mu, session->mu);
TF_LOCKS_EXCLUDED(session->graph->mu, session->mu);
std::string getTF_OutputDebugString(TF_Output node);

View File

@ -45,6 +45,8 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@ -193,8 +195,9 @@ TEST(CAPI, LibraryLoadFunctions) {
{
// Load the library.
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
string lib_path = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
@ -1350,9 +1353,9 @@ TEST_F(CApiColocationTest, ClearViaProto) {
TEST(CAPI, SavedModel) {
// Load the saved model.
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123"));
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
TF_Buffer* metagraph = TF_NewBuffer();
@ -1426,9 +1429,9 @@ TEST(CAPI, SavedModel) {
}
TEST(CAPI, SavedModelNullArgsAreValid) {
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123"));
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Status* s = TF_NewStatus();
const char* tags[] = {tensorflow::kSavedModelTagServe};

View File

@ -28,6 +28,8 @@ tf_cuda_library(
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.h",
"operation_interface.cc",
"operation_interface.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api.h"],
@ -56,6 +58,7 @@ tf_cuda_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:errors",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
@ -92,6 +95,8 @@ filegroup(
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
"dlpack.h",
"operation_interface.h",
"tensor_handle_interface.h",
],
visibility = [
@ -104,6 +109,7 @@ tf_cuda_library(
name = "c_api_internal",
srcs = [
"c_api_experimental.h",
"operation_interface.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api_internal.h"],
@ -128,6 +134,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/container:fixed_array",
],
)
@ -199,6 +206,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform:casts",
"@com_google_absl//absl/strings",
],
)
@ -256,8 +264,6 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/rpc:profiler_server",
"//tensorflow/core/profiler/rpc/client:capture_profile",
"//tensorflow/core:gpu_runtime",
],
alwayslink = 1,
@ -323,10 +329,34 @@ filegroup(
srcs = [
"c_api.h",
"c_api_experimental.h",
"dlpack.h",
],
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
hdrs = ["dlpack.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
visibility = ["//tensorflow:__subpackages__"],
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@dlpack",
],
alwayslink = 1,
)
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
# right now, remove this public rule when no longer needed (it should be
# replaced by TF Lite)
@ -340,6 +370,7 @@ filegroup(
exclude = [
"c_api_experimental.cc",
"*test*",
"*dlpack*",
],
),
visibility = ["//visibility:public"],

View File

@ -27,7 +27,6 @@ limitations under the License.
// clang-format on
#include "absl/algorithm/container.h"
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
@ -95,14 +94,6 @@ using tensorflow::string;
namespace {
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
const tensorflow::OpDef* op_def = op->operation.OpDef();
if (op_def) return op_def;
status->status =
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
return op_def;
}
bool IsCPU(
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
if (VariantDeviceIsCustom(variant)) {
@ -883,12 +874,12 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#endif // !IS_MOBILE_PLATFORM
}
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
TF_Status* status) {
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
status->status = ctx->context->ClearRemoteExecutors();
status->status = ctx->context->SyncExecutors();
#endif // !IS_MOBILE_PLATFORM
}
@ -1125,9 +1116,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
return retval;
} else {
tensorflow::Tensor tensor;
if (IsCPU(handle_->device())) {
if (IsCPU(handle_->device()) || handle_->HasLocalMirror(nullptr)) {
const tensorflow::Tensor* src = nullptr;
*status = handle_->Tensor(&src);
if (handle_->HasLocalMirror(nullptr)) {
*status = handle_->TensorFromDevice(nullptr, &src);
} else {
*status = handle_->Tensor(&src);
}
if (!status->ok()) return nullptr;
tensor = *src;
} else {
@ -1135,6 +1130,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
CHECK_NE(ctx, nullptr);
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
if (!status->ok()) return nullptr;
if (handle_->ImplicitMirroring()) {
*status = handle_->AddEmptyLocalMirror(nullptr);
if (!status->ok()) return nullptr;
Tensor mirror = tensor;
*status = handle_->SetTensor(std::move(mirror), nullptr);
if (!status->ok()) return nullptr;
}
}
return tensorflow::TF_TensorFromTensor(tensor, status);
}
@ -1199,18 +1201,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
!tensorflow::DataTypeCanUseMemcpy(
static_cast<tensorflow::DataType>(dtype))) {
status->status = tensorflow::errors::InvalidArgument(
"Trying to create a tensor with a pointer to non-pod memory.");
deallocator(data, len, deallocator_arg);
return nullptr;
}
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
// the device?
TF_ManagedBuffer* buf =
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
/*owns_memory=*/false);
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf);
@ -1218,10 +1213,10 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorHandle* ret_handle;
if (custom_device == nullptr) {
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, context, &ret_handle);
std::move(t), device, device, context, &ret_handle);
} else {
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, custom_device, context, &ret_handle);
std::move(t), custom_device, context, &ret_handle);
}
if (!status->status.ok()) {
return nullptr;
@ -1261,9 +1256,8 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
std::unique_ptr<TFE_Op> new_op(
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
status->status =
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)});
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) {
new_op.reset();
}
@ -1273,49 +1267,51 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
void TFE_DeleteOp(TFE_Op* op) { delete op; }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = op->operation.SetDeviceName(device_name);
status->status = op->operation->SetDeviceName(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device = (op->operation.Device() == nullptr)
? op->operation.EagerContext().HostCPU()
: op->operation.Device();
return device->name().c_str();
return op->operation->DeviceName().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
op->operation.SetUseXla(enable);
#ifndef TENSORFLOW_EAGER_USE_XLA
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Status s = op->operation->SetUseXla(enable);
if (!s.ok()) {
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
}
#else
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
"built with XLA support.";
#endif // TENSORFLOW_EAGER_USE_XLA
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
tensorflow::TensorHandle* h =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
input->handle.get())
->Handle();
op->operation.AddInput(h);
status->status = op->operation.MaybeInferSingleInputAttrs(h);
status->status = op->operation->AddInput(input->handle);
}
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
num_inputs);
for (int i = 0; i < num_inputs; ++i) {
op->operation.AddInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
inputs[i]->handle.get())
->Handle());
handles[i].reset(inputs[i]->handle->Copy());
}
status->status = op->operation.InferInputListAttrs(num_inputs);
status->status = op->operation->AddInputList(handles);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret = TF_ATTR_INT;
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
attr_name, &ret, is_list);
const tensorflow::AttrTypeMap* attr_types_;
bool is_function;
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
&attr_types_, &is_function);
if (!status->status.ok()) {
return ret;
}
status->status =
tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
return ret;
}
@ -1336,221 +1332,169 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
size_t length) {
op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::StringPiece(static_cast<const char*>(value), length));
auto s = op->operation->SetAttrString(
attr_name, static_cast<const char*>(value), length);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
auto s = op->operation->SetAttrInt(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
op->operation.MutableAttrs()->Set(attr_name, value);
auto s = op->operation->SetAttrFloat(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
op->operation.MutableAttrs()->Set(attr_name,
static_cast<tensorflow::DataType>(value));
auto s = op->operation->SetAttrType(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
const int num_dims, TF_Status* out_status) {
if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Value specified for `", attr_name, "` has ", num_dims,
" dimensions which is over the limit of ",
tensorflow::TensorShape::MaxDimensions(), ".")
.c_str());
return;
}
tensorflow::TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
op->operation.MutableAttrs()->Set(attr_name, proto);
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
}
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
const TFE_Op* value) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
func->set_name(value->operation.Name());
value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
op->operation.MutableAttrs()->Set(attr_name, attr_value);
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
const char* data, size_t length) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
func->set_name(data, length);
op->operation.MutableAttrs()->Set(attr_name, attr_value);
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
TF_Status* status) {
tensorflow::Tensor t;
status->status = TF_TensorToTensor(tensor, &t);
if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
status->status = op->operation->SetAttrTensor(attr_name, tensor);
}
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values) {
std::vector<tensorflow::StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
lengths[i]);
auto s =
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
op->operation.MutableAttrs()->Set(attr_name, v);
}
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
auto s = op->operation->SetAttrTypeList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
}
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status) {
std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
new tensorflow::TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Value specified for `", attr_name, "` has ", num_dims_i,
" dimensions which is over the limit of ",
tensorflow::TensorShape::MaxDimensions(), ".")
.c_str());
return;
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
proto.get(), num_values));
out_status->status =
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) {
std::unique_ptr<tensorflow::NameAttrList[]> funcs(
new tensorflow::NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
funcs[i].set_name(value[i]->operation.Name());
value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
auto s = op->operation->SetAttrFunctionList(attr_name, value, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
funcs.get(), num_values));
}
void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
const void* proto, size_t proto_len,
TF_Status* status) {
tensorflow::AttrValue attr_value;
if (!attr_value.ParseFromArray(proto, proto_len)) {
status->status =
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
return;
}
if (op == nullptr || op->operation == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Got a null or uninitialized `op` argument");
return;
}
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
operation->MutableAttrs()->Set(attr_name, attr_value);
}
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
"' not found");
return -1;
}
return iter->second.second - iter->second.first;
int ret = -1;
status->status = op->operation->InputLength(input_name, &ret);
return ret;
}
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument(
"Output '", output_name, "' not found");
return -1;
}
return iter->second.second - iter->second.first;
int ret = -1;
status->status = op->operation->OutputLength(output_name, &ret);
return ret;
}
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
VLOG(1) << "Calling TFE_Execute() on op " << op;
status->status = tensorflow::EagerExecute(&op->operation,
handle_retvals.data(), num_retvals);
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
*num_retvals);
status->status = op->operation->Execute(&handles, num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
retvals[i] = new TFE_TensorHandle{std::move(handles[i])};
}
}
@ -1678,6 +1622,31 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
*attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (auto attribute : m) {
destination->Set(attribute.first, attribute.second);
}
}
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
TF_Status* status) {
tensorflow::NameAttrList name_and_attrs;
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
name_and_attrs.set_name(attrs->name);
status->status = MessageToBuffer(name_and_attrs, buf);
}
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
@ -1741,8 +1710,9 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
namespace {
class CustomDeviceAPI : public tensorflow::CustomDevice {
public:
CustomDeviceAPI(TFE_CustomDevice device, void* info, string name)
: device_(device), info_(info), name_(name) {}
CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
string name)
: context_(context), device_(device), info_(info), name_(name) {}
~CustomDeviceAPI() override { device_.delete_device(info_); }
@ -1756,7 +1726,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
TF_Status status;
TFE_TensorHandle* result_handle =
device_.copy_tensor_to_device(&tensor_handle, &status, info_);
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
@ -1775,7 +1745,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
TFE_TensorHandle tensor_handle{
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
&tensor_handle, target_device_name.c_str(), &status, info_);
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
@ -1797,10 +1767,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
op->Inputs()[i])});
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
// TODO(allenl): figure out how to get attrs from EagerOperation
TF_Status status;
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
num_retvals, outputs.data(), &status, info_);
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
&attributes, num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
@ -1818,6 +1788,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
}
private:
TFE_Context* context_;
TFE_CustomDevice device_;
void* info_;
string name_;
@ -1825,8 +1796,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
} // namespace
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info) {
const char* device_name, void* device_info,
TF_Status* status) {
auto custom_device =
std::make_unique<CustomDeviceAPI>(device, device_info, device_name);
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
status->status =
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
}

View File

@ -25,34 +25,20 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
status->status = op_to_reset->operation.Reset(
op_or_function_name, raw_device_name, false, nullptr);
status->status =
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
}
}
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle());
}
void TFE_StartProfilerServer(int port) {
// Release child thread intentionally. The child thread can be terminated by
// terminating the main thread.
tensorflow::StartProfilerServer(port).release();
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true);
}
@ -61,46 +47,6 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(false);
}
bool TFE_ProfilerClientStartTracing(const char* service_addr,
const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms,
int num_tracing_attempts,
TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return false;
}
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
include_dataset_ops, duration_ms,
num_tracing_attempts);
tensorflow::Set_TF_Status_from_Status(status, s);
return s.ok();
}
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
string content;
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data;
result->length = content.length();
result->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
tensorflow::Set_TF_Status_from_Status(status, s);
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);
@ -568,8 +514,7 @@ void TFE_DeleteCancellationManager(
void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
op->operation.SetCancellationManager(
&cancellation_manager->cancellation_manager);
status->status = op->operation->SetCancellationManager(cancellation_manager);
}
TFE_Executor* TFE_NewExecutor(bool is_async) {
@ -617,3 +562,22 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
h->handle->EnableImplicitMirroring();
status->status = tensorflow::Status::OK();
}
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
auto* function_def = ctx->context->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name);
return;
}
string str = function_def->SerializeAsString();
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
status->status = tensorflow::Status::OK();
}

View File

@ -34,18 +34,6 @@ TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* raw_device_name,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);
// Start a profiler grpc server which listens to specified port. It will start
// the server on its own thread. It can be shutdown by terminating tensorflow.
// It can be used in both Eager mode and graph mode. Creating multiple profiler
// server is allowed. The service defined in
// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
// Enables only graph collection in RunMetadata on the functions executed from
// this context.
TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
@ -54,29 +42,6 @@ TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
// this context.
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// profiling and save the result into logdir which can be visualized by
// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set
// include_dataset_opts to false to profile longer traces. It will block the
// caller thread until receives tracing result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
TF_Status* status);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// monitoring and return the result in a string. It will block the
// caller thread until receiving the monitoring result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
const char* service_addr, int duration_ms, int monitoring_level,
bool display_timestamp, TF_Buffer* result, TF_Status* status);
// TODO(fishx): Move these monitoring APIs into a separate file.
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.
@ -417,9 +382,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
const char* worker_name,
TF_Status* status);
// Clear pending streaming requests and error statuses on remote executors.
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
TF_Status* status);
// Sync pending nodes in local executors (including the context default executor
// and thread executors) and streaming requests to remote executors, and get the
// combined status.
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status);
// If the TensorHandle is copied to another device as part of an op execution,
// the copy is destroyed after the op has executed. Enabling implicit mirroring
@ -456,26 +423,63 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
#define TFE_CUSTOM_DEVICE_VERSION 0
// APIs for generically dealing with op attributes (e.g. when forwarding them
// through custom device implementations).
//
// TODO(allenl): Currently these are black boxes, but we should have some way to
// inspect values. This would let people e.g. copy over most attributes and then
// modify some based on their values.
// A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a struct with a reference to information about attributes of `op`.
//
// The `attrs` struct does not own any memory, and `op` must outlive it.
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
// containing the op name and a map of its attributes.
TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
TF_Buffer* buf,
TF_Status* status);
// Set an op's attribute from a serialized AttrValue protocol buffer.
//
// Analogous to TF_SetAttrValueProto for building graph operations.
TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
const char* attr_name,
const void* proto,
size_t proto_len,
TF_Status* status);
#define TFE_CUSTOM_DEVICE_VERSION 2
// Struct to be filled in
typedef struct TFE_CustomDevice {
int version = TFE_CUSTOM_DEVICE_VERSION;
// Method to copy a tensor to the custom device.
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor,
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status,
void* device_info) = nullptr;
// Method to copy a tensor from the custom device to a target device.
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor,
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info);
// Method to execute an operation.
// TODO(allenl) figure out a generic way of passing attrs here
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
void (*execute)(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device.
@ -501,11 +505,26 @@ typedef struct TFE_CustomDevice {
// devices, so executing tf.functions which contain operations placed on custom
// devices will fail.
//
// `device_name` must not name an existing physical or custom device. It must
// follow the format:
//
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
//
// If the device is successfully registered, `status` is set to TF_OK. Otherwise
// the device is not usable. In case of a bad status, `device.delete_device` is
// still called on `device_info` (i.e. the caller does not retain ownership).
//
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info);
const char* device_name, void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,
TF_Buffer* buf,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */

View File

@ -27,12 +27,12 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
@ -89,7 +89,7 @@ struct TFE_TensorDebugInfo {
};
struct TFE_Op {
tensorflow::EagerOperation operation;
std::unique_ptr<AbstractOperationInterface> operation;
};
struct TFE_MonitoringCounterCell {
@ -236,4 +236,17 @@ struct TFE_Executor {
tensorflow::EagerExecutor* unowned_executor;
};
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
const char* op_name)
: name(op_name), attributes(value) {}
const char* name;
const tensorflow::AttrBuilder* attributes;
};
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
@ -127,7 +128,7 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
void TestRemoteExecuteSilentCopies(bool async) {
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -166,10 +167,14 @@ void TestRemoteExecuteSilentCopies(bool async) {
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Handles are on task0 (local), and task2, but op is on task1.
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
TFE_OpSetDevice(matmul, task1_name, status);
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
}
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
@ -177,6 +182,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!async) {
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h1_task2->handle.get())
->Handle();
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
// The input handles should never change since they have been mirrored.
ASSERT_EQ(op->GetInput(1), remote_arg);
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -213,9 +229,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
worker_server2.release();
}
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(false, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
TestRemoteExecuteSilentCopies(true, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(false, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(true, false);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <string.h>
#include <string>
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
@ -367,7 +369,7 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
void TensorHandleSilentCopy(bool async,
TFE_ContextDevicePlacementPolicy global_policy,
TFE_ContextDevicePlacementPolicy thread_policy,
bool mirror, bool cpu_op) {
bool cpu_op) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -390,12 +392,6 @@ void TensorHandleSilentCopy(bool async,
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
if (mirror) {
TFE_TensorHandleEnableImplicitMirroring(hcpu, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandleEnableImplicitMirroring(hgpu, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
}
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
if (cpu_op) {
@ -419,21 +415,13 @@ void TensorHandleSilentCopy(bool async,
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
hgpu->handle.get())
->Handle();
if (mirror) {
// The input handles should never change since they have been mirrored.
ASSERT_EQ(matmul->operation.Inputs()[0], arg0);
ASSERT_EQ(matmul->operation.Inputs()[1], arg1);
} else {
if (cpu_op) {
ASSERT_EQ(matmul->operation.Inputs()[0], arg0);
// The GPU handle should be replaced with a CPU copy
ASSERT_NE(matmul->operation.Inputs()[1], arg1);
} else {
// The CPU handle should be replaced with a GPU copy
ASSERT_NE(matmul->operation.Inputs()[0], arg0);
ASSERT_EQ(matmul->operation.Inputs()[1], arg1);
}
}
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
// The input handles should never change since they have been mirrored.
EXPECT_EQ(op->GetInput(0), arg0);
EXPECT_EQ(op->GetInput(1), arg1);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
@ -450,27 +438,19 @@ void TensorHandleSilentCopy(bool async,
}
TEST(CAPI, TensorHandleSilentCopy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, false, false);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, false, false);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT, false, false);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT, false, false);
}
TEST(CAPI, TensorHandleMirrorCopy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, true, false);
}
TEST(CAPI, TensorHandleMirrorCopyCpu) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, true, true);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
void SetAndGetOpDevices(bool async) {
@ -606,6 +586,91 @@ TEST(CAPI, TensorHandleDevices) {
TFE_DeleteContext(ctx);
}
void ExecuteAdd(bool async, bool forward_input) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
// If a GPU exists, copy the handle to GPU so that we can exercise
// unprotecting a mirror.
std::string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* n_gpu =
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
TFE_DeleteTensorHandle(n);
n = n_gpu;
}
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
// Store pointer to raw buffer for validation of forwarding behaviour.
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
void* orig_ptr = TF_TensorData(orig);
TF_DeleteTensor(orig);
TFE_Op* add_op = AddOp(ctx, n, m);
std::string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (forward_input) {
TFE_DeleteTensorHandle(n);
}
int num_retvals = 1;
if (async) {
// Enqueue dummy ops so we backlog async execution & actually test async.
for (int i = 0; i < 10000; ++i) {
TFE_TensorHandle* dummy = nullptr;
TFE_Execute(add_op, &dummy, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(dummy);
}
}
TFE_TensorHandle* retval = nullptr;
TFE_Execute(add_op, &retval, &num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (!forward_input) {
TFE_DeleteTensorHandle(n);
}
TFE_DeleteOp(add_op);
TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
if (forward_input || async) {
EXPECT_EQ(orig_ptr, TF_TensorData(t));
} else {
EXPECT_NE(orig_ptr, TF_TensorData(t));
}
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float result[100 * 100] = {0};
EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < 100 * 100; ++i) {
EXPECT_EQ(2.0f, result[i]);
}
TF_DeleteStatus(status);
}
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1244,6 +1309,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
TFE_DeleteTensorHandle(h_shares_tensor);
}
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
tensorflow::AttrValueMap attr_values;
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
->Attrs()
.FillAttrValueMap(&attr_values);
return attr_values;
}
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1260,8 +1333,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
TFE_OpAddInput(minOp, axis, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
minOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
@ -1300,8 +1372,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
TFE_OpAddInputList(concatOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
@ -1341,8 +1412,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
TFE_OpAddInputList(assertOp, data, 3, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
assertOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
@ -1378,16 +1448,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->operation.OpDef());
CHECK(concatOp->operation->OpDef());
TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->operation.OpDef())
EXPECT_FALSE(concatOp->operation->OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
EXPECT_EQ(attr_values.find("T"), attr_values.end());
EXPECT_EQ(attr_values.find("N"), attr_values.end());
@ -1474,4 +1543,88 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpGetAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
TFE_OpAddAttrs(copy_op, &attributes);
unsigned char is_list = 0;
ASSERT_EQ(TF_ATTR_TYPE,
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_ATTR_SHAPE,
TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
copy_op->operation.get());
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(var_op);
TFE_DeleteOp(copy_op);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpAttrsSerialize) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
TF_Buffer* serialized_attr_values = TF_NewBuffer();
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::NameAttrList name_and_attrs;
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
serialized_attr_values->length));
ASSERT_EQ("VarHandleOp", name_and_attrs.name());
ASSERT_EQ(tensorflow::DT_INT64,
name_and_attrs.attr().find("dtype")->second.type());
TF_DeleteBuffer(serialized_attr_values);
TFE_Op* second_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
string serialized_dtype;
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
&serialized_dtype));
TFE_OpSetAttrValueProto(
second_var_op, "dtype",
reinterpret_cast<const void*>(serialized_dtype.c_str()),
serialized_dtype.length(), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
second_var_op->operation.get());
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(var_op);
TFE_DeleteOp(second_var_op);
TFE_DeleteContext(ctx);
}
} // namespace

View File

@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() {
return th;
}
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, b, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();

View File

@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
// Return a tensor handle containing a 3x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle3X2();
// Return an add op multiplying `a` by `b`.
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
// Return a matmul op multiplying `a` by `b`.
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);

View File

@ -21,16 +21,18 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
TFE_Context* ctx;
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
@ -45,7 +47,7 @@ void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* ctx, const tensorflow::string& logging_device_name,
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
@ -55,23 +57,25 @@ TFE_TensorHandle* MakeLoggedTensorHandle(
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor,
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, dev->ctx, dev->underlying_device.c_str(), status);
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst),
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
@ -80,13 +84,15 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
return nullptr;
}
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
@ -112,9 +118,10 @@ void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
@ -122,18 +129,19 @@ void DeleteLoggingDevice(void* device_info) {
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag) {
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->ctx = context;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device);
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
@ -144,13 +152,16 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived);
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
ASSERT_TRUE(arrived);
ASSERT_FALSE(executed);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
@ -160,6 +171,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
TFE_DeleteTensorHandle(hcpu);
@ -167,4 +179,220 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
TFE_DeleteContext(context);
}
TEST(CUSTOM_DEVICE, ResetOperation) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string(custom_device_name));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpReset(reused_op.get(), "Identity",
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, MakeVariable) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto value_cleaner = tensorflow::gtl::MakeCleanup(
[var_value]() { TFE_DeleteTensorHandle(var_value); });
ASSERT_EQ(tensorflow::string(name),
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
<< "Execution should fail because the variable is being used on the "
"wrong device.";
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}
} // namespace

View File

@ -0,0 +1,334 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // TF:dlpack
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
// Managing context for the DLManagedTensor, will manage the lifetime of
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
// original framework of destruction, and this context will be deleted also.
struct TfDlManagedTensorCtx {
TensorReference reference;
std::vector<int64_t> shape;
std::vector<int64_t> strides;
DLManagedTensor tensor;
explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
};
// Gets tensor from eager tensor handle.
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support remote tensor");
return nullptr;
}
const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
return tensor;
}
// Deleter for DLManagedTensor
void DLManagedTensorDeleter(DLManagedTensor* arg) {
TfDlManagedTensorCtx* owner =
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
owner->reference.Unref();
delete owner;
}
// Converts TF_DATAType to DLPack data type.
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = TF_DataTypeSize(data_type) * 8;
switch (data_type) {
case TF_DataType::TF_HALF:
case TF_DataType::TF_FLOAT:
case TF_DataType::TF_DOUBLE:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_INT8:
case TF_DataType::TF_INT16:
case TF_DataType::TF_INT32:
case TF_DataType::TF_INT64:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_BOOL:
case TF_DataType::TF_UINT8:
case TF_DataType::TF_UINT16:
case TF_DataType::TF_UINT32:
case TF_DataType::TF_UINT64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_BFLOAT16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
DataType_Name(static_cast<DataType>(data_type)),
" is not supported by dlpack");
break;
}
return dtype;
}
// Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type;
int device_id = 0;
if (parsed_name.has_id) {
device_id = parsed_name.id;
}
ctx.device_id = device_id;
if (device_type == "CPU") {
ctx.device_type = DLDeviceType::kDLCPU;
} else if (device_type == "GPU") {
ctx.device_type = DLDeviceType::kDLGPU;
} else {
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Device Type for dlpack");
}
return ctx;
}
// Converts DLContext to TF device name.
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
TF_Status* status) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return "CPU:0";
case DLDeviceType::kDLGPU:
return absl::StrCat("GPU:", ctx.device_id);
default:
return absl::nullopt;
}
}
// Converts DLPack data type to TF_DATATYPE.
Status TfDataTypeFormDlDataType(const DLDataType& dtype,
TF_DataType* tf_dtype) {
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
case 8:
*tf_dtype = TF_DataType::TF_UINT8;
return Status::OK();
case 16:
*tf_dtype = TF_DataType::TF_UINT16;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_UINT32;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_UINT64;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
dtype.bits);
}
return Status::OK();
case DLDataTypeCode::kDLInt:
switch (dtype.bits) {
case 8:
*tf_dtype = TF_DataType::TF_INT8;
return Status::OK();
case 16:
*tf_dtype = TF_DataType::TF_INT16;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_INT32;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_INT64;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
dtype.bits);
}
return Status::OK();
case DLDataTypeCode::kDLFloat:
switch (dtype.bits) {
case 16:
*tf_dtype = TF_DataType::TF_HALF;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_FLOAT;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_DOUBLE;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
dtype.bits);
}
break;
case DLDataTypeCode::kDLBfloat:
switch (dtype.bits) {
case 16:
*tf_dtype = TF_DataType::TF_BFLOAT16;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument(
"Unsupported BFloat bits: ", dtype.bits);
}
break;
default:
return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
dtype.code);
}
}
// Wraps the deleter function of DLManagedTensor to match the function signature
// TFE_NewTensorHandleFromDeviceMemory.
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
}
// Checks whether the stride array matches the layout of compact, row-majored
// data.
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) {
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
return false;
}
for (int i = ndim - 2; i >= 0; --i) {
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
return false;
}
}
return true;
}
} // namespace
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
if (dlMTensor->deleter != nullptr) {
dlMTensor->deleter(dlMTensor);
}
}
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
stride_arr->resize(ndim, 1);
for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i);
}
for (int i = ndim - 2; i >= 0; --i) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some frameworks didn't handle the strides
// argument properly.
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor);
}
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name =
DeviceNameFromDlContext(dl_tensor->ctx, status);
if (!device_name.has_value()) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported Device Type");
return nullptr;
}
TF_DataType dtype;
Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
if (!s.ok()) {
status->status = std::move(s);
return nullptr;
}
int num_dims = dl_tensor->ndim;
const int64_t* dims = dl_tensor->shape;
void* data = dl_tensor->data;
size_t total_bytes = dl_tensor->dtype.bits / 8;
for (int i = 0; i < num_dims; i++) {
total_bytes *= dims[i];
}
if (dl_tensor->strides != nullptr &&
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
num_dims)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid strides array from DLPack");
return nullptr;
}
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
return handle;
}
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_DLPACK_H_
#define TENSORFLOW_C_EAGER_DLPACK_H_
#include "tensorflow/c/eager/c_api.h"
namespace tensorflow {
// PyCapsule name for DLPack Tensor
const char* const kDlTensorCapsuleName = "dltensor";
// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
// void* for further PyCapsule construction.
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
TF_Status* status);
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status);
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_DLPACK_H_

View File

@ -0,0 +1,312 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/operation_interface.h"
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
OperationInterface::OperationInterface(TFE_Context* ctx)
: operation_(ctx->context) {}
const string& OperationInterface::DeviceName() const {
absl::variant<Device*, CustomDevice*> variant_device =
(operation_.Device() == kVariantDeviceNull)
? operation_.EagerContext().HostCPU()
: operation_.Device();
return absl::visit([](auto* d) -> const string& { return d->name(); },
variant_device);
}
Status OperationInterface::SetDeviceName(const char* name) {
return operation_.SetDeviceName(name);
}
Status OperationInterface::SetAttrString(const char* attr_name,
const char* data, size_t length) {
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
return Status::OK();
}
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
return Status::OK();
}
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
operation_.MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
operation_.MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status OperationInterface::SetAttrType(const char* attr_name,
TF_DataType value) {
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
return Status::OK();
}
Status OperationInterface::SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
operation_.MutableAttrs()->Set(attr_name, proto);
return Status::OK();
}
Status OperationInterface::SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) {
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(value->Name());
OperationInterface* value_operation =
tensorflow::down_cast<OperationInterface*>(value.get());
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
const char* data,
size_t length) {
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(data, length);
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
Status OperationInterface::SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) {
Tensor t;
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
operation_.MutableAttrs()->Set(attr_name, t);
return Status::OK();
}
Status OperationInterface::SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
operation_.MutableAttrs()->Set(attr_name, v);
return Status::OK();
}
Status OperationInterface::SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const float>(values, num_values));
return Status::OK();
}
Status OperationInterface::SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const DataType>(
reinterpret_cast<const DataType*>(values), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) {
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
auto value_operation =
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
funcs[i].set_name(value_operation->operation_.Name());
value_operation->operation_.Attrs().FillAttrValueMap(
funcs[i].mutable_attr());
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
return Status::OK();
}
const OpDef* OperationInterface::GetOpDef(Status* status) {
const tensorflow::OpDef* op_def = operation_.OpDef();
if (op_def) return op_def;
*status = OpDefForOp(Name(), &op_def);
return op_def;
}
Status OperationInterface::InputLength(const char* input_name, int* length) {
Status status;
const tensorflow::OpDef* op_def = GetOpDef(&status);
if (!status.ok()) {
return status;
}
AttrValueMap attrs;
operation_.Attrs().FillAttrValueMap(&attrs);
NameRangeMap name_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
return errors::InvalidArgument("Input '", input_name, "' not found");
}
*length = iter->second.second - iter->second.first;
return Status::OK();
}
Status OperationInterface::OutputLength(const char* output_name, int* length) {
Status status;
const tensorflow::OpDef* op_def = GetOpDef(&status);
if (!status.ok()) {
return status;
}
AttrValueMap attrs;
operation_.Attrs().FillAttrValueMap(&attrs);
NameRangeMap name_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
return errors::InvalidArgument("Output '", output_name, "' not found");
}
*length = iter->second.second - iter->second.first;
return Status::OK();
}
Status OperationInterface::AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
operation_.AddInput(h);
return operation_.MaybeInferSingleInputAttrs(h);
}
Status OperationInterface::AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) {
for (auto& input : inputs) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
operation_.AddInput(h);
}
return operation_.InferInputListAttrs(inputs.size());
}
Status OperationInterface::Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) {
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
TF_RETURN_IF_ERROR(
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
for (int i = 0; i < *num_retvals; ++i) {
retvals->at(i).reset(
new tensorflow::TensorHandleInterface(handle_retvals[i]));
}
return Status::OK();
}
Status OperationInterface::SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
operation_.SetCancellationManager(
&cancellation_manager->cancellation_manager);
return Status::OK();
}
Status OperationInterface::SetUseXla(bool enable) {
operation_.SetUseXla(enable);
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,188 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#include <memory>
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
// Abstract interface to an operation.
class AbstractOperationInterface {
public:
virtual ~AbstractOperationInterface() {}
virtual void Clear() = 0;
virtual tensorflow::Status Reset(const char* op,
const char* raw_device_name) = 0;
virtual const tensorflow::string& Name() const = 0;
virtual const tensorflow::string& DeviceName() const = 0;
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
virtual tensorflow::Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
virtual tensorflow::Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) = 0;
virtual tensorflow::Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual tensorflow::Status SetAttrString(const char* attr_name,
const char* data, size_t length) = 0;
virtual tensorflow::Status SetAttrInt(const char* attr_name,
int64_t value) = 0;
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
float value) = 0;
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual tensorflow::Status SetAttrType(const char* attr_name,
TF_DataType value) = 0;
virtual tensorflow::Status SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) = 0;
virtual tensorflow::Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
const char* value,
size_t length) = 0;
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) = 0;
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) = 0;
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) = 0;
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) = 0;
virtual tensorflow::Status InputLength(const char* input_name,
int* length) = 0;
virtual tensorflow::Status OutputLength(const char* output_name,
int* length) = 0;
// Experimental
virtual tensorflow::Status SetUseXla(bool enable) {
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
}
virtual tensorflow::Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
return tensorflow::errors::Unimplemented(
"SetCancellationManager not implemented");
}
};
namespace tensorflow {
class OpDef;
class OperationInterface : public AbstractOperationInterface {
public:
explicit OperationInterface(TFE_Context* ctx);
~OperationInterface() override{};
void Clear() override { operation_.Clear(); }
Status Reset(const char* op, const char* raw_device_name) override {
return operation_.Reset(op, raw_device_name, false, nullptr);
}
const string& Name() const override { return operation_.Name(); }
const string& DeviceName() const override;
Status SetDeviceName(const char* name) override;
Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) override;
Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) override;
const tensorflow::OpDef* OpDef() const override {
return operation_.OpDef();
};
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrInt(const char* attr_name, int64_t value) override;
Status SetAttrFloat(const char* attr_name, float value) override;
Status SetAttrBool(const char* attr_name, bool value) override;
Status SetAttrType(const char* attr_name, TF_DataType value) override;
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override;
Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) override;
Status SetAttrFunctionName(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override;
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override;
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
int num_values) override;
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override;
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override;
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
int num_values) override;
Status InputLength(const char* input_name, int* length) override;
Status OutputLength(const char* output_name, int* length) override;
Status SetUseXla(bool enable) override;
Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) override;
// TODO(gjn): Remove once TFE_InferShapes is removed
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
private:
const tensorflow::OpDef* GetOpDef(Status* status);
EagerOperation operation_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_

View File

@ -27,14 +27,10 @@ namespace {
class DummyDevice : public DeviceBase {
public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
private:
bool save_;
};
void TestBitcastOp(Tensor* input_tensor, DataType out_type,
@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type,
ASSERT_TRUE(status.ok()) << status.ToString();
OpKernelContext::Params params;
DummyDevice dummy_device(nullptr, false);
DummyDevice dummy_device(nullptr);
params.device = &dummy_device;
params.op_kernel = kernel.get();
gtl::InlinedVector<TensorValue, 4> inputs;

View File

@ -155,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
class DummyDevice : public DeviceBase {
public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
private:
bool save_;
};
TEST(TestKernel, TestInputAndOutputCount) {
@ -223,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
{
OpKernelContext::Params p;
DummyDevice dummy_device(nullptr, false);
DummyDevice dummy_device(nullptr);
p.device = &dummy_device;
p.step_id = 43;

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include <memory>
#include <vector>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
@ -64,25 +65,41 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
}
} // namespace tensorflow
namespace {
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
const int64_t* dims, int num_dims, size_t len) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
}
} // namespace
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
int num_dims, size_t len) {
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
tensorflow::cpu_allocator());
return TF_NewTensor(dtype, dims, num_dims, data, len,
tensorflow::deallocate_buffer,
tensorflow::cpu_allocator());
TF_ManagedBuffer* buf =
new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer,
tensorflow::cpu_allocator(), /*owns_memory=*/true);
return CreateTensor(buf, dtype, dims, num_dims, len);
}
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
TF_ManagedBuffer* buf = nullptr;
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
tensorflow::DataTypeCanUseMemcpy(
@ -97,24 +114,17 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
// Other types have the same representation, so copy only if it is safe to
// do so.
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
len, tensorflow::deallocate_buffer, nullptr);
len, tensorflow::deallocate_buffer, nullptr,
/*owns_memory=*/true);
std::memcpy(buf->data(), data, len);
// Free the original buffer.
deallocator(data, len, deallocator_arg);
} else {
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
/*owns_memory=*/false);
}
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
return CreateTensor(buf, dtype, dims, num_dims, len);
}
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {

View File

@ -58,9 +58,9 @@ extern "C" {
// start_offset: array[uint64]
// data: byte[...]
//
// The string length (as a varint), followed by the contents of the string
// is encoded at data[start_offset[i]]]. TF_StringEncode and TF_StringDecode
// facilitate this encoding.
// The string length (as a varint, start_offset[i + 1] - start_offset[i]),
// followed by the contents of the string is encoded at data[start_offset[i]].
// TF_StringEncode and TF_StringDecode facilitate this encoding.
typedef struct TF_Tensor TF_Tensor;

View File

@ -38,11 +38,12 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
public:
TF_ManagedBuffer(void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg)
void* deallocator_arg, bool owns_memory)
: TensorBuffer(data),
len_(len),
deallocator_(deallocator),
deallocator_arg_(deallocator_arg) {}
deallocator_arg_(deallocator_arg),
owns_memory_(owns_memory) {}
~TF_ManagedBuffer() override {
(*deallocator_)(data(), len_, deallocator_arg_);
@ -57,13 +58,13 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
}
// Prevents input forwarding from mutating this buffer.
bool OwnsMemory() const override { return false; }
bool OwnsMemory() const override { return owns_memory_; }
private:
const size_t len_;
void (*const deallocator_)(void* data, size_t len, void* arg);
void* const deallocator_arg_;
bool owns_memory_;
};
namespace tensorflow {

View File

@ -41,7 +41,7 @@ class ClientSession::Impl {
std::shared_ptr<Graph> graph_;
mutable mutex mu_;
mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0;
};
ClientSession::ClientSession(const Scope& scope, const string& target)

View File

@ -68,6 +68,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/platform:resource_loader",
],
)
@ -224,3 +225,15 @@ filegroup(
"testdata/VarsAndArithmeticObjectGraph/**",
]),
)
exports_files(
glob([
"testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**",
"testdata/half_plus_two_v2/**",
"testdata/x_plus_y_v2_debuginfo/**",
"testdata/CyclicModule/**",
"testdata/VarsAndArithmeticObjectGraph/**",
]),
)

View File

@ -21,15 +21,22 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestDataPbTxt[] =
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123";
string TestDataPbTxt() {
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two_pbtxt", "00000123");
}
string TestDataSharded() {
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123");
}
class ReaderTest : public ::testing::Test {
protected:
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
TEST_F(ReaderTest, TagMatch) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
TEST_F(ReaderTest, NoTagMatch) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
&meta_graph_def);
EXPECT_FALSE(st.ok());
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
TEST_F(ReaderTest, NoTagMatchMultiple) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
Status st = ReadMetaGraphDefFromSavedModel(
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
EXPECT_FALSE(st.ok());
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
TEST_F(ReaderTest, PbtxtFormat) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
TEST_F(ReaderTest, InvalidExportPath) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
const string export_dir = GetDataDependencyFilepath("missing-path");
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def);
EXPECT_FALSE(st.ok());

View File

@ -114,14 +114,14 @@ class Coordinator {
condition_variable wait_for_stop_;
mutex mu_;
bool should_stop_ GUARDED_BY(mu_);
bool should_stop_ TF_GUARDED_BY(mu_);
mutex status_lock_;
Status status_ GUARDED_BY(status_lock_);
Status status_ TF_GUARDED_BY(status_lock_);
mutable mutex runners_lock_;
std::vector<std::unique_ptr<RunnerInterface>> runners_
GUARDED_BY(runners_lock_);
TF_GUARDED_BY(runners_lock_);
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
};

View File

@ -119,8 +119,8 @@ class QueueRunner : public RunnerInterface {
std::unique_ptr<thread::ThreadPool> thread_pool_;
mutex mu_;
int runs_ = 0;
Status status_ GUARDED_BY(mu_);
Status enqueue_status_ GUARDED_BY(mu_);
Status status_ TF_GUARDED_BY(mu_);
Status enqueue_status_ TF_GUARDED_BY(mu_);
std::unique_ptr<BlockingCounter> counter_;
Coordinator* coord_;
@ -131,7 +131,7 @@ class QueueRunner : public RunnerInterface {
std::vector<std::function<void(Status)>> callbacks_;
mutable std::unique_ptr<mutex> cg_mu_;
std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
std::unique_ptr<CostGraphDef> cost_graph_ TF_GUARDED_BY(cg_mu_);
RunOptions run_options_;
};

View File

@ -20,9 +20,11 @@ from __future__ import print_function as _print_function
import logging as _logging
import os as _os
import six as _six
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# pylint: disable=g-bad-import-order
@ -36,20 +38,19 @@ try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
# Make sure we get the correct summary module with lazy loading
setattr(_current_module, "summary", summary)
except ImportError:
_logging.warning(
"Limited tf.compat.v2.summary API due to missing TensorBoard "
"installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from tensorflow.python.keras.api._v2 import keras
@ -59,6 +60,13 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v2 import estimator
# pylint: enable=g-import-not-at-top
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
#

View File

@ -20,8 +20,10 @@ from __future__ import print_function as _print_function
import os as _os
import sys as _sys
import six as _six
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# pylint: disable=g-bad-import-order
@ -31,13 +33,14 @@ from tensorflow.python.tools import module_util as _module_util
# Hook external TensorFlow modules.
_current_module = _sys.modules[__name__]
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from tensorflow.python.keras.api._v1 import keras
@ -47,6 +50,14 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v1 import estimator
# pylint: enable=g-import-not-at-top
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
_current_module.app.flags = flags # pylint: disable=undefined-variable

View File

@ -37,6 +37,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
@ -64,6 +65,7 @@ cc_library(
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"//tensorflow/core:regexp_internal",
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
@ -84,6 +86,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support", # fixdeps: keep
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/str_split.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/compiler.h"
@ -288,8 +289,8 @@ Status GenVariableMethods(const tf2xla::Config& config,
}
// Generates code implementing {Arg,Result}Names(), where T is one of
// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
// literal in the array, with nullptr terminating the array.
// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
// string literal in the array, with nullptr terminating the array.
template <typename T>
string GenNameToIndexCode(const T& entries, bool generate) {
// No need for a static array if we're not supposed to generate the data.
@ -419,6 +420,16 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
// Generate metadata.
const string arg_names_code =
GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
auto variable_copy = config.variable();
for (auto& var : variable_copy) {
if (var.name().empty()) {
var.set_name(var.node_name());
}
}
const string variable_names_code =
GenNameToIndexCode(variable_copy, opts.gen_name_to_index);
const string result_names_code =
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
const string include_xla_data_proto =
@ -507,6 +518,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = {{ARG_NUM}};
// Number of variables for the compiled computation.
static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
// Byte size of each argument buffer. There are kNumArgs entries.
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
@ -522,8 +536,10 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
set_static_data_num_buffers(data, kNumBuffers);
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
set_static_data_num_args(data, kNumArgs);
set_static_data_num_variables(data, kNumVariables);
set_static_data_result_index(data, kResultIndex);
set_static_data_arg_names(data, StaticArgNames());
set_static_data_variable_names(data, StaticVariableNames());
set_static_data_result_names(data, StaticResultNames());
set_static_data_program_shape(data, StaticProgramShape());
set_static_data_hlo_profile_printer_data(
@ -626,6 +642,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
// Array of names of each positional variable, terminated by nullptr.
static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}}
// Array of names of each positional result, terminated by nullptr.
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
@ -654,6 +673,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
{"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
@ -673,6 +693,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim},
{"{{VARIABLE_NAMES_CODE}}", variable_names_code},
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
{"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/codegen.h"
#include <algorithm>
#include <string>
#include <vector>
@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -139,14 +141,23 @@ TEST_F(ParseCppClassTest, ParseFail) {
static void CompareWithGoldenFile(
const string& tensorflow_relative_golden_file_name,
const string& expected_contents) {
const string& expected_contents, bool ignore_cr) {
// Get rid of all CR characters, we may be running under windows.
string sanitized_expected_contents(expected_contents);
if (ignore_cr) {
sanitized_expected_contents.erase(
std::remove(sanitized_expected_contents.begin(),
sanitized_expected_contents.end(), '\r'),
sanitized_expected_contents.end());
}
// To update the golden file, flip update_golden to true and run the
// following:
// bazel test --test_strategy=local \
// third_party/tensorflow/compiler/aot:codegen_test
const bool update_golden = false;
const string golden_file_name = io::JoinPath(
testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name);
string golden_file_name =
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
if (update_golden) {
TF_EXPECT_OK(
@ -156,6 +167,11 @@ static void CompareWithGoldenFile(
string golden_file_contents;
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
&golden_file_contents));
if (ignore_cr) {
golden_file_contents.erase(std::remove(golden_file_contents.begin(),
golden_file_contents.end(), '\r'),
golden_file_contents.end());
}
EXPECT_EQ(golden_file_contents, expected_contents);
}
@ -201,10 +217,16 @@ TEST(CodegenTest, Golden) {
{},
{BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
BufferInfo::MakeTempBuffer(2),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
5, {}));
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
11, {}));
compile_result.program_shape =
xla::ShapeUtil::MakeProgramShape(
{
@ -229,14 +251,18 @@ TEST(CodegenTest, Golden) {
// The other fields in metadata_result are tested as part of the generated
// header test.
CompareWithGoldenFile("compiler/aot/codegen_test_o.golden",
metadata_result.object_file_data);
// This specific golden test checks a binary file. It can potentially run into
// issues due to ABIs not being stable, but has not so far.
// If we see any ABI issues, we should reconsider this specific test case.
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
metadata_result.object_file_data, false);
string header;
TF_ASSERT_OK(
GenerateHeader(opts, config, compile_result, metadata_result, &header));
CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header);
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
true);
}
} // namespace
} // namespace tfcompile

View File

@ -55,14 +55,17 @@ namespace bar {
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
//
// Memory stats:
// arg bytes total: 104
// arg bytes aligned: 192
// arg bytes total: 392
// arg bytes aligned: 576
// temp bytes total: 126
// temp bytes aligned: 320
// temp bytes aligned: 512
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
public:
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = 2;
static constexpr size_t kNumArgs = 5;
// Number of variables for the compiled computation.
static constexpr size_t kNumVariables = 3;
// Byte size of each argument buffer. There are kNumArgs entries.
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
@ -79,8 +82,10 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
set_static_data_num_buffers(data, kNumBuffers);
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
set_static_data_num_args(data, kNumArgs);
set_static_data_num_variables(data, kNumVariables);
set_static_data_result_index(data, kResultIndex);
set_static_data_arg_names(data, StaticArgNames());
set_static_data_variable_names(data, StaticVariableNames());
set_static_data_result_names(data, StaticResultNames());
set_static_data_program_shape(data, StaticProgramShape());
set_static_data_hlo_profile_printer_data(
@ -295,16 +300,22 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
private:
// Number of buffers for the compiled computation.
static constexpr size_t kNumBuffers = 6;
static constexpr size_t kNumBuffers = 12;
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::xla::cpu_function_runtime::BufferInfo
kBufferInfos[kNumBuffers] = {
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
};
return kBufferInfos;
@ -312,13 +323,13 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
1, 3
1, 3, 5, 7, 9
};
return kArgIndexToBufferIndex;
}
// The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = 5;
static constexpr size_t kResultIndex = 11;
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {
@ -326,6 +337,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return kNames;
}
// Array of names of each positional variable, terminated by nullptr.
static const char** StaticVariableNames() {
static const char* kNames[] = {"myvar_readonly", "myvar", "myvar2", nullptr};
return kNames;
}
// Array of names of each positional result, terminated by nullptr.
static const char** StaticResultNames() {
static const char* kNames[] = {"myfetch", nullptr};

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@ -39,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@ -105,14 +107,18 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
.ValueOrDie();
xla::XlaComputation computation;
if (flags.mlir_components == "Bridge") {
TF_RETURN_IF_ERROR(
ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
graph_def, config, &computation, flags.debug_info,
flags.debug_info_path_begin_marker));
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
client, &computation));
} else {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
if (flags.quantize) {
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
}
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
computation.Snapshot());
@ -166,6 +172,23 @@ static void InitializeTargets() {
LLVMInitializeX86AsmPrinter();
}
// Replaces {{tag.type tag.name}} in the error message with tag_name.
// TODO(bixia): We currently only handlge tag.type == "node".
//
// In the error message, a graph node is represented as {{tag.type, tag.name}},
// to allow a Python debugger to insert source information about the graph node.
// For example, a Python add expression may be represented as
// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
// in tensorflow/python/framework/error_interpolation.py for more detail.
static std::string InterpolateErrorMessage(std::string message) {
// See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
// Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
return message;
}
Status Main(const MainFlags& flags) {
absl::call_once(targets_init, &InitializeTargets);
@ -192,8 +215,13 @@ Status Main(const MainFlags& flags) {
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph_def), config, flags, &compile_result));
Status status =
CompileGraph(std::move(graph_def), config, flags, &compile_result);
if (!status.ok()) {
return Status(status.code(),
InterpolateErrorMessage(status.error_message()));
}
// Write output files.
Env* env = Env::Default();

View File

@ -24,6 +24,13 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
"be in the human-readable proto text format, otherwise it is expected "
"to be in the proto binary format."},
{"debug_info", &flags->debug_info,
"Graph debug info file. If the file ends in '.pbtxt' it is expected to "
"be in the human-readable proto text format, otherwise it is expected "
"to be in the proto binary format."},
{"debug_info_path_begin_marker", &flags->debug_info_path_begin_marker,
"If not none, only keep the file path in the debug information after the"
" marker. The default value is empty"},
{"config", &flags->config,
"Input file containing Config proto. If the file ends in '.pbtxt' it "
"is expected to be in the human-readable proto text format, otherwise "
@ -70,6 +77,8 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"Output session module proto."},
{"mlir_components", &flags->mlir_components,
"The MLIR components to enable. Currently only Bridge is supported."},
{"quantize", &flags->quantize,
"If set, quantization will be applied before HLO code generation."},
{"gen_name_to_index", &flags->gen_name_to_index,
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
{"gen_program_shape", &flags->gen_program_shape,

View File

@ -28,6 +28,8 @@ namespace tfcompile {
struct MainFlags {
string graph;
string debug_info;
string debug_info_path_begin_marker;
string config;
bool dump_fetch_nodes = false;
string target_triple;
@ -40,6 +42,7 @@ struct MainFlags {
string out_header;
string out_session_module;
string mlir_components;
bool quantize = false;
// C++ codegen options
bool gen_name_to_index = false;

View File

@ -1,11 +1,37 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)
glob_lit_tests(
data = [":filecheck_test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"test_error_message.lit.pbtxt": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
},
test_file_exts = ["lit.pbtxt"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "filecheck_test_utilities",
testonly = True,
srcs = [
"test_error_message.lit.pbtxt.config.pbtxt",
"test_error_message.lit.pbtxt.debug.pbtxt",
"test_error_message.lit.pbtxt.fake_py.debug",
],
data = [
"//tensorflow/compiler/aot:tfcompile",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)
# We disable some tfcompile tests in the open source build with the
# "manual" tag to avoid making our OSS users build LLVM twice
# (once for host and once for target).
@ -60,6 +86,7 @@ genrule(
testonly = 1,
outs = [
"test_graph_tfadd.pb",
"test_debuginfo_tfadd.pb",
"test_graph_tfadd_with_ckpt.ckpt",
"test_graph_tfadd_with_ckpt.pb",
"test_graph_tfadd_with_ckpt_saver.ckpt",
@ -317,6 +344,7 @@ tf_library(
testonly = 1,
config = "test_graph_tfadd.config.pbtxt",
cpp_class = "AddComp",
debug_info = "test_debuginfo_tfadd.pb",
graph = "test_graph_tfadd.pb",
include_standard_runtime_deps = False,
mlir_components = "Bridge",

View File

@ -30,6 +30,7 @@ from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -184,7 +185,22 @@ def tfvariable_sequential_updates(_):
array_ops.identity(updates, name='result')
def write_graph(build_graph, out_dir):
def export_debug_info(exported_graph):
"""Exports debug information from a graph.
Args:
exported_graph: A Graph that has been created by tracing a saveable view.
Returns:
Corresponding GraphDebugInfo with traces for all ops in exported_graph.
"""
exported_operations = []
for op in exported_graph.get_operations():
exported_operations.append(('', op))
return error_interpolation.create_graph_debug_info_def(exported_operations)
def write_graph(build_graph, out_dir, debug_info=False):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
with g.as_default():
@ -193,10 +209,19 @@ def write_graph(build_graph, out_dir):
with open(filename, 'wb') as f:
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
if debug_info:
filename_debuginfo = os.path.join(
out_dir, 'test_debuginfo_%s.pb' % build_graph.__name__)
test_debuginfo = export_debug_info(g)
with open(filename_debuginfo, 'wb') as f:
f.write(
six.ensure_binary(
test_debuginfo.SerializeToString(deterministic=True)))
def main(_):
control_flow_util.enable_control_flow_v2()
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd, FLAGS.out_dir, debug_info=True)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)

View File

@ -0,0 +1,69 @@
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
# Checks the error message produced by tfcompile with mlir_component
# Checks that source debug information is used in the output error message and
# the node x_y_sum = Add
# CHECK: INVALID ARGUMENTS: Dimensions must be equal, but are 2 and 3 for 'x_y_sum = Add[T=DT_INT32](aot_feed_0/x, aot_feed_0/y)'
# CHECK: math_ops.add(x, y, name='x_y_sum')
# CHECK: build_graph(out_dir)
# Checks the error message produced by tfcompile without mlir_component
# OLD: INVALID ARGUMENTS: Incompatible shapes: [2] vs. [3]
# OLD: x_y_sum
node: {
name: "x"
op: "Placeholder"
attr: {
key: "shape"
value: {
shape: {
dim: {
size: -1
}
}
}
}
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "y"
op: "Placeholder"
attr: {
key: "shape"
value: {
shape: {
dim: {
size: -1
}
}
}
}
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "x_y_sum"
op: "Add"
input: "x"
input: "y"
attr: {
key: "T"
value: {
type: DT_INT32
}
}
}
versions: {
producer: 321
}

View File

@ -0,0 +1,16 @@
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x" }
shape {
dim { size: 2 }
}
}
feed {
id { node_name: "y" }
shape {
dim { size: 3 }
}
}
fetch {
id { node_name: "x_y_sum" }
}

View File

@ -0,0 +1,28 @@
files: "org_tensorflow/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug"
traces: {
key: "x@"
value: {
file_line_cols: {
line: 1
}
}
}
traces: {
key: "x_y_sum@"
value: {
file_line_cols: {
line: 3
}
file_line_cols: {
line: 4
}
}
}
traces: {
key: "y@"
value: {
file_line_cols: {
line: 2
}
}
}

View File

@ -0,0 +1,4 @@
x = value
y = value
math_ops.add(x, y, name='x_y_sum')
build_graph(out_dir)

View File

@ -26,6 +26,7 @@ def tf_library(
name,
graph,
config,
debug_info = None,
freeze_checkpoint = None,
freeze_saver = None,
cpp_class = None,
@ -191,12 +192,15 @@ def tf_library(
mlir_flag = "--mlir_components=" + mlir_components
srcs = [tfcompile_graph, config]
debug_info_flag = ""
if debug_info:
srcs.append(debug_info)
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
native.genrule(
name = ("gen_" + name),
srcs = [
tfcompile_graph,
config,
],
srcs = srcs,
outs = [
header_file,
metadata_object_file,
@ -206,6 +210,7 @@ def tf_library(
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
debug_info_flag +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
@ -237,10 +242,7 @@ def tf_library(
session_module_pb = name + "_session_module.pb"
native.genrule(
name = (name + "_session_module"),
srcs = [
tfcompile_graph,
config,
],
srcs = srcs,
outs = [
session_module_pb,
],
@ -248,6 +250,7 @@ def tf_library(
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
debug_info_flag +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
@ -407,5 +410,6 @@ def target_llvm_triple():
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:macos": "x86_64-none-darwin",
"//tensorflow:windows": "x86_64-none-windows",
"//conditions:default": "x86_64-pc-linux",
})

View File

@ -65,6 +65,7 @@ int main(int argc, char** argv) {
flags.out_metadata_object = "out_helper.o";
flags.out_header = "out.h";
flags.entry_point = "entry";
flags.debug_info_path_begin_marker = "";
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
@ -81,12 +82,10 @@ int main(int argc, char** argv) {
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
"other than flags\n\n"
<< usage;
"other than flags. See --help.\n\n";
tensorflow::Status status = tensorflow::tfcompile::Main(flags);
if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
<< usage;
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n";
return 1;
} else {
TF_QCHECK_OK(status);

View File

@ -14,6 +14,10 @@ package_group(
includes = [
"//tensorflow/compiler/tf2xla:internal",
],
packages = [
"//tensorflow/compiler/tests/...",
"//tensorflow/python/...",
],
)
package_group(
@ -180,6 +184,7 @@ XLA_DEVICE_DEPS = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",

View File

@ -108,7 +108,7 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
"(LRN, LRNGrad)."
" BN: TF FusedBatchNorm* operations."
" FUSIBLE: All TF operations that XLA can fuse (All the above). "
"You can also put any TF operation name, e.g. 'FUSIBLE,Matmul'."),
"You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
Flag("tf_xla_clustering_debug",
&mark_for_compilation_flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),

View File

@ -20,6 +20,7 @@ XLA_OPS_DEPS = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -41,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/profiler/lib/traceme.h"
@ -170,8 +172,9 @@ class XlaExecutableClosureStore {
private:
mutex mutex_;
int64 key_counter_ GUARDED_BY(mutex_);
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
int64 key_counter_ TF_GUARDED_BY(mutex_);
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
TF_GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
};
@ -206,12 +209,14 @@ se::DeviceMemoryAllocator* GetAllocator(
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
const std::vector<int>& constants,
const std::vector<int>& resources,
const NameAttrList& function)
const NameAttrList& function,
bool has_ref_vars)
: OpKernel(ctx),
constants_(constants),
resources_(resources),
function_(function),
platform_info_(PlatformInfoFromContext(ctx)) {}
platform_info_(PlatformInfoFromContext(ctx)),
has_ref_vars_(has_ref_vars) {}
static Status BuildCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
@ -350,8 +355,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
{
Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/true, platform_info_, resources_,
constants_, /*lazy=*/false, &client, &variables, &kernel, &executable);
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
&executable);
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
// Suggest auto jit if the failure was with GPU or CPU.
@ -384,6 +390,18 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
xla::ThenExecuteFunction then_execute;
if (ctx->op_device_context()) {
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
Status status = ctx->op_device_context()->ThenExecute(
down_cast<Device*>(ctx->device()), stream, std::move(fn));
if (!status.ok()) {
// This should never happen.
LOG(ERROR) << "ThenExecute failed " << status;
}
};
run_options.set_then_execute_function(&then_execute);
}
Env* env = Env::Default();
auto start_time = env->NowMicros();
@ -462,7 +480,7 @@ bool HasRefVars(OpKernelConstruction* ctx) {
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
: XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
FunctionAttr(ctx)) {}
FunctionAttr(ctx), /*has_ref_vars=*/true) {}
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
VLOG(1) << "XlaLocalLaunchOp destroyed";
@ -592,6 +610,18 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
xla::ThenExecuteFunction then_execute;
if (ctx->op_device_context()) {
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
Status status = ctx->op_device_context()->ThenExecute(
down_cast<Device*>(ctx->device()), stream, std::move(fn));
if (!status.ok()) {
// This should never happen.
LOG(ERROR) << "ThenExecute failed " << status;
}
};
run_options.set_then_execute_function(&then_execute);
}
Env* env = Env::Default();
auto start_time = env->NowMicros();

View File

@ -95,12 +95,15 @@ class XlaPlatformInfo {
// in the GraphDef.
// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
// this kernel when asked to create a kernel for an XLA-compiled function.
//
// `has_ref_vars`: whether the input computation can have reference variables.
// TODO(cheshire): instead derive this information from the input graph.
class XlaLocalLaunchBase : public OpKernel {
public:
XlaLocalLaunchBase(OpKernelConstruction* ctx,
const std::vector<int>& constants,
const std::vector<int>& resources,
const NameAttrList& function);
const NameAttrList& function, bool has_ref_vars);
XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
~XlaLocalLaunchBase() override = default;
@ -115,6 +118,8 @@ class XlaLocalLaunchBase : public OpKernel {
const NameAttrList function_;
const XlaPlatformInfo platform_info_;
bool has_ref_vars_;
};
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
@ -160,7 +165,8 @@ class XlaCompileOp : public OpKernel {
// error when compiling the cluster this _XlaCompile is supposed to compile.
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
// on any future calls to _XlaCompile.
bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false;
bool cannot_compile_cluster_ TF_GUARDED_BY(cannot_compile_cluster_mu_) =
false;
mutex cannot_compile_cluster_mu_;
};

View File

@ -963,6 +963,22 @@ absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
return absl::nullopt;
}
// Returns true iff the attribute `attr_name` is attached to either the node or
// to it's callee.
static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
const char* attr_name) {
bool out = false;
bool attr_value;
if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
out |= attr_value;
}
if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
out |= attr_value;
}
return out;
}
Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
@ -1016,16 +1032,9 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
resource_var_operation_node_id = node->id();
}
bool is_xla_compile_attr_true = false;
bool xla_compile_attr;
if (TryGetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr)) {
is_xla_compile_attr_true |= xla_compile_attr;
}
if (flib_def_->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr).ok()) {
is_xla_compile_attr_true |= xla_compile_attr;
}
bool is_xla_compile_attr_true =
GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr);
DeviceSet devices;
devices.Insert(device);
@ -1874,6 +1883,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"EmptyTensorList",
"ExtractImagePatches",
"Igamma",
"IgammaGradA",
"RandomGammaGrad",
"Igammac",
"FFT",
"FFT2D",
@ -1900,6 +1911,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"LinSpace",
"ListDiff",
"LogMatrixDeterminant",
"LowerBound",
"MatMul",
"MatrixBandPart",
"MatrixDiag",
@ -1996,6 +2008,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"StatelessRandomNormal",
"StatelessRandomUniform",
"StatelessRandomUniformInt",
"StatelessRandomUniformFullInt",
"StatelessTruncatedNormal",
"StatelessWhile",
"Svd",
@ -2025,6 +2038,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorScatterUpdate",
"TridiagonalSolve",
"TruncatedNormal",
"UpperBound",
"UnsortedSegmentMax",
"UnsortedSegmentMin",
"UnsortedSegmentProd",

View File

@ -18,13 +18,15 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace {
// The list of all registered `XlaActivityListener`s.
struct XlaActivityListenerList {
absl::Mutex mutex;
std::vector<std::unique_ptr<XlaActivityListener>> listeners GUARDED_BY(mutex);
std::vector<std::unique_ptr<XlaActivityListener>> listeners
TF_GUARDED_BY(mutex);
};
void FlushAllListeners();

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/util.h"
@ -33,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/hash/hash.h"
@ -202,6 +204,52 @@ static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) {
execution_count < kMinExecutionsPerCompile * compile_count;
}
// Creates a simple graph using the specified op as the only op apart from the
// arg and retval nodes.
static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
absl::Span<const DataType> result_types) {
// TODO(b/74182462): We implement this by creating a new dummy Graph including
// _Arg nodes, and let CompileGraph walk it. This could be optimized.
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Status status;
// First create the actual node we care about computing.
Node* main_node = graph->AddNode(node_def, &status);
TF_RETURN_IF_ERROR(status);
// Create dummy _Arg nodes. Link these to `node` and also via a control
// dependency edge to the _SOURCE node.
for (int64 i = 0; i < args.size(); ++i) {
Node* node;
string arg_name = absl::StrCat("_arg", i);
Status status =
NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
.ControlInput(graph->source_node())
.Attr("T", args[i].kind == XlaCompiler::Argument::kResource
? DT_RESOURCE
: args[i].type)
.Attr("index", i)
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
graph->AddEdge(node, 0, main_node, i);
}
// Similarly with return values, create dummy _Retval nodes fed by `node`.
for (int64 i = 0; i < result_types.size(); ++i) {
Node* node;
string retval_name = absl::StrCat("_retval", i);
Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
.Input(main_node, i)
.Attr("T", result_types[i])
.Attr("index", i)
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
}
FixupSourceAndSinkEdges(graph.get());
return graph;
}
Status XlaCompilationCache::CompileSingleOp(
const XlaCompiler::Options& options,
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
@ -222,8 +270,11 @@ Status XlaCompilationCache::CompileSingleOp(
for (int i = 0; i < result_dtypes.size(); ++i) {
result_dtypes[i] = ctx->expected_output_dtype(i);
}
return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(),
args, result_dtypes, result);
const NodeDef& node_def = ctx->op_kernel().def();
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,

View File

@ -151,19 +151,19 @@ class XlaCompilationCache : public ResourceBase {
int64 request_count = 0;
// Did compilation succeed?
Status compilation_status GUARDED_BY(mu);
Status compilation_status TF_GUARDED_BY(mu);
// Output of the XlaCompiler.
XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu);
XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu);
// The XLA executable compiled from <computation>. May be null if no
// executable has been built.
std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu);
};
mutex compile_cache_mu_;
absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
GUARDED_BY(compile_cache_mu_);
TF_GUARDED_BY(compile_cache_mu_);
struct ClusterCompileStats {
// Number of times the cluster has been (re-)compiled.
@ -185,7 +185,7 @@ class XlaCompilationCache : public ResourceBase {
// Maps cluster names to compilation statistics for said cluster.
absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
GUARDED_BY(cluster_compile_stats_mu_);
TF_GUARDED_BY(cluster_compile_stats_mu_);
// The number of times a lazy compilation must be requested for a specific
// signature before we attempt to compile it.

View File

@ -83,7 +83,7 @@ class XlaDeviceAllocatorState {
std::unordered_map<std::pair<const xla::Backend*, int>,
std::unique_ptr<XlaDeviceAllocator>,
hash<std::pair<const xla::Backend*, int>>>
allocators_ GUARDED_BY(allocator_mutex_);
allocators_ TF_GUARDED_BY(allocator_mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
};

View File

@ -137,7 +137,7 @@ class XlaDevice : public LocalDevice {
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override
LOCKS_EXCLUDED(mu_);
TF_LOCKS_EXCLUDED(mu_);
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
@ -145,18 +145,18 @@ class XlaDevice : public LocalDevice {
void Sync(const DoneCallback& done) override;
Status TryGetDeviceContext(DeviceContext** out_context) override
LOCKS_EXCLUDED(mu_);
TF_LOCKS_EXCLUDED(mu_);
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override LOCKS_EXCLUDED(mu_);
Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_);
// Allocate tensor on fast memory space. This is only applied to the new TPU
// hardware which has faster read/write memory. If the hardware doesn't
// have such memory space, we fallback to the ordinary memory space.
Status MakeFastMemTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) LOCKS_EXCLUDED(mu_);
Tensor* tensor) TF_LOCKS_EXCLUDED(mu_);
const Metadata& metadata() { return xla_metadata_; }
@ -166,34 +166,35 @@ class XlaDevice : public LocalDevice {
//
// TODO(b/111859745): The Eager context needs to call this method to recover
// from failures.
Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_);
Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_);
// Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
Status UseGpuDeviceInfo() TF_LOCKS_EXCLUDED(mu_);
// Instructs this XlaDevice to return 'sync_on_completion' for
// AllowsSyncOnCompletion().
void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
void SetAllowsSyncOnCompletion(bool sync_on_completion)
TF_LOCKS_EXCLUDED(mu_);
bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_);
// Installs an error handling callback when RefreshStatus sees !status.ok().
void SetHandleDeviceErrorCallback(std::function<Status()> callback);
Status RefreshStatus() override LOCKS_EXCLUDED(mu_);
Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_);
private:
xla::StatusOr<xla::LocalClient*> GetOrCreateClient() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Return a pair of device context, the second one is fast_mem device context.
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
GetDeviceContextLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
@ -218,13 +219,13 @@ class XlaDevice : public LocalDevice {
// Intra-op threads to spawn (from SessionOptions).
const int intra_op_parallelism_threads_;
// Memory allocator associated with this device.
Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr; // Not owned.
// Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_);
// If false, only stream_ is valid and all computation and transfers use
// stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_device stream or borrowing a stream
@ -232,36 +233,36 @@ class XlaDevice : public LocalDevice {
const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_);
// If use_multiple_streams_, transfers between different devices are performed
// using these streams.
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
GUARDED_BY(mu_);
TF_GUARDED_BY(mu_);
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// The device context accessed by all users of the XlaDevice, set by calls to
// EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
// also filled in to that struct. XlaDeviceContext is a ref-counted object.
XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr;
XlaDeviceContext* device_context_ TF_GUARDED_BY(mu_) = nullptr;
// The device context will allocate memory on fast memory space on TPU.
// XlaDeviceContext is a ref-counted object.
XlaDeviceContext* fast_mem_device_context_ GUARDED_BY(mu_) = nullptr;
XlaDeviceContext* fast_mem_device_context_ TF_GUARDED_BY(mu_) = nullptr;
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
bool use_gpu_device_info_ TF_GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ TF_GUARDED_BY(mu_);
// Thread pool used for running closures
std::unique_ptr<thread::ThreadPool> thread_pool_;
// True if the device allows XlaDevice::Sync to be called on completion
// regardless of status.
bool sync_on_completion_ GUARDED_BY(mu_) = true;
bool sync_on_completion_ TF_GUARDED_BY(mu_) = true;
// A callback that will be invoked when RefreshStatus sees a status error.
std::function<Status()> device_error_callback_ GUARDED_BY(mu_);
std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_);
// Set of devices to use. This controls which of the devices on the given
// platform will have resources allocated. For GPUs this will be

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/stream_executor/platform/port.h"

View File

@ -117,7 +117,7 @@ class XlaDeviceContext : public DeviceContext {
bool use_fast_mem_;
absl::Mutex mu_;
int next_stream_ GUARDED_BY(mu_) = 0;
int next_stream_ TF_GUARDED_BY(mu_) = 0;
};
} // namespace tensorflow

View File

@ -20,15 +20,17 @@ limitations under the License.
namespace tensorflow {
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const {
return CanCreateXlaKernel(node_def);
bool XlaKernelCreator::CanCreateKernel(
const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const {
return CanCreateXlaKernel(props->node_def);
}
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) const {
return CreateXlaKernel(flr, node_def, kernel);
Status XlaKernelCreator::CreateKernel(
FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const {
return CreateXlaKernel(flr, props->node_def, kernel);
}
namespace {

View File

@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
// true if 'node_def' is a call to a compilable function defined in 'flr',
// with the kXlaCompileAttr set.
bool CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const override;
bool CanCreateKernel(
const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const override;
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
Status CreateKernel(FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const override;
};

View File

@ -30,10 +30,12 @@ limitations under the License.
namespace tensorflow {
NodeDef ToNodeDef(const string& text) {
std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
NodeDef node_def;
DataTypeVector dummy;
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
return node_def;
return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
dummy);
}
// Create a FunctionDef that takes one resource and one regular param
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
Init({fdef});
XlaKernelCreator xla_kernel_creator;
NodeDef callsite =
ToNodeDef(R"pb(
auto callsite =
ToNodeProperties(R"pb(
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
)pb");
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
(*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
// Note: need to set attribute on the created node.
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
Init({fdef});
XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
Status status =
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
}
@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
Init({fdef});
XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
Status status =
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
}

View File

@ -104,7 +104,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
/*compile_time_const_nodes=*/nullptr, flr));
for (int i = 0; i < const_args.size(); ++i) {
for (size_t i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
constant_arg_indices->push_back(i);
}
@ -113,7 +113,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
// There can be hundreds of resource variables. Reserve the space for them.
// We don't reserve for constants above as they are usually few.
resource_arg_indices->reserve(arg_types.size());
for (int i = 0; i < arg_types.size(); ++i) {
for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i] == DT_RESOURCE) {
resource_arg_indices->push_back(i);
}
@ -177,7 +177,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(&constant_arg_indices);
SinglePassSearch resources_search(&resource_arg_indices);
for (int i = 0; i < fbody->arg_types.size(); ++i) {
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
@ -207,7 +207,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
for (int i = 0; i < fbody->ret_types.size(); ++i) {
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == DT_RESOURCE) {
output_memory_types[i] = HOST_MEMORY;
}
@ -218,15 +218,17 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
input_memory_types, fbody->ret_types, output_memory_types,
flr->graph_def_version(), &s);
auto props = std::make_shared<NodeProperties>(
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()),
flr, dev->resource_manager(), props,
input_memory_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function);
&construction, constant_arg_indices, resource_arg_indices, function,
/*has_ref_vars=*/false);
return s;
}
} // namespace tensorflow

View File

@ -18,7 +18,6 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
#include "absl/base/thread_annotations.h"
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -30,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace tensorflow {
@ -102,7 +102,7 @@ class VariableInfo {
// `variables` is allowed to contain instances that don't track a resource
// variable (i.e. variables[i].var() can be null for some i).
Status LockVariables(absl::Span<VariableInfo> variables)
EXCLUSIVE_LOCK_FUNCTION();
TF_EXCLUSIVE_LOCK_FUNCTION();
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
// ShapedBuffers suitable for passing to an XLA computation.

View File

@ -122,7 +122,7 @@ class XlaTensor {
std::shared_ptr<se::Event> definition_event_;
// A list of all streams for which the tensor's content is defined for any
// newly enqueued command.
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ TF_GUARDED_BY(mu_);
mutex mu_;
};

View File

@ -44,11 +44,9 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AffineDialectRegistration",
"@llvm-project//mlir:LoopDialectRegistration",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Support",
"@llvm-project//mlir/test:TestTransforms",
],
@ -76,12 +74,15 @@ cc_library(
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/xla:lhlo",
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",
@ -102,11 +103,45 @@ cc_library(
],
)
cc_library(
name = "mlir_graph_optimization_pass",
srcs = ["mlir_graph_optimization_pass.cc"],
hdrs = ["mlir_graph_optimization_pass.h"],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:device_util",
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
alwayslink = 1,
)
cc_library(
name = "mlir_graph_optimization_pass_registration",
srcs = [
"mlir_graph_optimization_pass_registration.cc",
],
deps = [
":mlir_graph_optimization_pass",
"//tensorflow/core:core_cpu",
],
alwayslink = 1,
)
tf_cc_binary(
name = "tf-opt",
deps = [
":tf_mlir_opt_main",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
],
)
@ -116,8 +151,10 @@ tf_cc_binary(
srcs = ["tf_mlir_translate_main.cc"],
deps = [
":init_mlir",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
@ -129,6 +166,7 @@ tf_cc_binary(
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TranslateClParser",

View File

@ -0,0 +1,3 @@
# TensorFlow MLIR
These are the docs for: https://www.tensorflow.org/mlir

View File

@ -0,0 +1,26 @@
upper_tabs:
# Tabs left of dropdown menu
- include: /_upper_tabs_left.yaml
- include: /api_docs/_upper_tabs_api.yaml
# Dropdown menu
- name: Resources
path: /resources
is_default: true
menu:
- include: /resources/_menu_toc.yaml
lower_tabs:
# Subsite tabs
other:
- name: Guide
contents:
- title: Overview
path: /mlir/overview
- heading: Dialects
- title: Overview
path: /mlir/dialects
- title: TensorFlow
path: /mlir/tf_ops
- title: TensorFlow Lite
path: /mlir/tfl_ops
- include: /_upper_tabs_right.yaml

View File

@ -0,0 +1,54 @@
book_path: /mlir/_book.yaml
project_path: /mlir/_project.yaml
description: <!--no description-->
landing_page:
custom_css_path: /site-assets/css/style.css
rows:
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
items:
- description: >
The <a href="https://mlir.llvm.org/" class="external">MLIR</a> project defines a common
intermediate representation (IR) that unifies the infrastructure required to execute high
performance machine learning models in TensorFlow and similar ML frameworks. This project
will include the application of HPC techniques, along with integration of
search algorithms like reinforcement learning. MLIR aims to reduce the
cost to bring up new hardware, and improve usability for existing
TensorFlow users.
- code_block: |
<pre class = "prettyprint">
// Syntactically similar to LLVM:
func @testFunction(%arg0: i32) {
%x = call @thingToCall(%arg0) : (i32) -> i32
br ^bb1
^bb1:
%y = addi %x, %x : i32
return %y : i32
}
</pre>
- classname: devsite-landing-row-cards
items:
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
youtube_id: qzljG6DKgic
buttons:
- label: Watch the video
path: https://www.youtube.com/watch?v=qzljG6DKgic
- heading: "A new intermediate representation and compiler framework"
image_path: /resources/images/tf-logo-card-16x9.png
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
buttons:
- label: Read on TensorFlow blog
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
- heading: MLIR on GitHub
image_path: /resources/images/github-card-16x9.png
path: https://github.com/llvm/llvm-project/tree/master/mlir
buttons:
- label: View on GitHub
path: https://github.com/llvm/llvm-project/tree/master/mlir
- heading: TensorFlow MLIR on GitHub
image_path: /resources/images/github-card-16x9.png
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
buttons:
- label: View on GitHub
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir

View File

@ -0,0 +1,37 @@
# MLIR dialects
## Overview
To separate different hardware and software targets, MLIR has “dialects”,
including:
* TensorFlow IR, which represents all things possible in TensorFlow graphs.
* XLA HLO IR, which is designed to take advantage of XLAs compilation
abilities (with output to, among other things, TPUs).
* An experimental affine dialect, which focuses on
[polyhedral representations](https://en.wikipedia.org/wiki/Polytope_model)
and optimizations.
* LLVM IR, which has a 1:1 mapping between it and LLVMs own representation,
allowing MLIR to emit GPU and CPU code through LLVM.
* TensorFlow Lite, which will translate to running code on mobile platforms.
Each dialect consists of a set of defined operations which have invariants
placed on them, like: “This is a binary operator, and the inputs and outputs
have the same types.”
## Adding to MLIR
MLIR has no fixed/built-in list of globally known operations (no “intrinsics”).
Dialects can define entirely custom types, which is how MLIR can model things
like the LLVM IR type system (which has first class aggregates), domain
abstractions important for ML-optimized accelerators like quantized types, and
even the Swift or Clang type systems (which are built around Swift/Clang
declaration nodes) in the future.
If you want to connect a new low-level compiler, you would create a new dialect
and the lowerings between the TensorFlow Graph dialect and your dialect.
This smooths the path for hardware and compiler makers. You can even target
dialects at different levels in the same model; the higher-level optimizers
will respect the unfamiliar parts of the IR and wait for a lower level to handle
it.

File diff suppressed because one or more lines are too long

After

(image error) Size: 148 KiB

View File

@ -0,0 +1,36 @@
# MLIR
## Overview
MLIR, or Multi-Level Intermediate Representation, is a representation format
and library of compiler utilities that sits between the model representation
and low-level compilers/executors that generate hardware-specific code.
MLIR is, at its heart, a flexible infrastructure for modern optimizing
compilers. This means it consists of a specification for intermediate
representations (IR) and a code toolkit to perform transformations on that
representation. (In compiler parlance, as you move from higher-level
representations to lower-level representations, these transformations can be
called “lowerings”)
MLIR is highly influenced by [LLVM](https://llvm.org/) and unabashedly reuses
many great ideas from it. It has a flexible type system, and allows
representing, analyzing and transforming graphs combining multiple levels of
abstraction in the same compilation unit. These abstractions include TensorFlow
operations, nested polyhedral loop regions, and even LLVM instructions and fixed
hardware operations and types.
We expect MLIR to be of interest to many groups, including:
* Compiler researchers and implementers looking to optimize performance and
memory consumption of machine learning models
* Hardware makers looking for a way to connect their hardware to TensorFlow,
such as TPUs, portable neural hardware in phones, and other custom ASICs
* People writing language bindings that want to take advantage of optimizing
compilers and hardware acceleration.
The TensorFlow ecosystem contains a number of compilers and optimizers that
operate at multiple levels of the software and hardware stack. We expect the
gradual adoption of MLIR to simplify every aspect of this stack.
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>

View File

@ -48,10 +48,11 @@ def _run_lit_test(name, data, size, tags, driver, features):
" the driver parameter when running this test. If you require" +
" custom driver support, please file an issue to request it.")
# Disable tests on windows for now, to enable testing rest of all xla and mlir.
native.py_test(
name = name,
srcs = ["@llvm-project//llvm:lit"],
tags = tags,
tags = tags + ["no_windows"],
args = [
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
] + features,

View File

@ -30,6 +30,7 @@ filegroup(
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
],
)
@ -208,6 +209,7 @@ cc_library(
"ir/tfl_ops.h.inc",
"ir/tfl_ops_interface.cc.inc",
"ir/tfl_ops_interface.h.inc",
"runtime_verifiers.inc",
"utils/attribute_utils.cc",
],
hdrs = [
@ -222,15 +224,18 @@ cc_library(
":validators",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
@ -302,15 +307,15 @@ cc_library(
"transforms/optimize_functional_ops.cc",
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
"transforms/runtime_type_verify.cc",
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
"transforms/unroll_batch_matmul.cc",
"transforms/while_loop_outline.cc",
],
hdrs = [
"ir/tfl_ops_interface.h.inc",
"transforms/dilated_conv.h",
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
],
deps = [
":common",
@ -323,6 +328,8 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
@ -459,9 +466,9 @@ cc_library(
)
tf_native_cc_binary(
name = "operator-converter-gen",
name = "converter-gen",
srcs = [
"operator_converter_gen.cc",
"converter_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
@ -471,14 +478,18 @@ tf_native_cc_binary(
)
gentbl(
name = "operator_converter_inc",
name = "converter_inc",
tbl_outs = [
(
"", # This driver has no options.
"--gen-operator-converters",
"operator_converters.inc",
),
(
"--gen-runtime-verifiers",
"runtime_verifiers.inc",
),
],
tblgen = ":operator-converter-gen",
tblgen = ":converter-gen",
td_file = "ir/tfl_ops.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
@ -508,6 +519,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils",
],
@ -561,6 +573,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -571,7 +584,7 @@ cc_library(
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/versioning:op_version",
"//tensorflow/lite/tools/versioning",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
@ -581,8 +594,6 @@ cc_library(
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:StandardDialectRegistration",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
@ -594,6 +605,7 @@ tf_cc_binary(
name = "flatbuffer_translate",
deps = [
":flatbuffer_translate_lib",
"@llvm-project//mlir:LoopOpsTransforms",
"@llvm-project//mlir:MlirTranslateMain",
],
)
@ -643,12 +655,14 @@ tf_cc_binary(
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)
@ -687,16 +701,16 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Transforms",
],
)
@ -716,6 +730,7 @@ cc_library(
":tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
@ -725,12 +740,10 @@ cc_library(
"//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],

View File

@ -1,9 +1,9 @@
# Experimental code for the new TF-Lite convertor, and MLIR dialects and utilities for TensorFlow Lite.
# The new [MLIR](https://github.com/llvm/llvm-project/tree/master/mlir) based
TensorFlow to TensorFlow Lite converter
This directory contains:
1. Experimental code for the new TF-Lite convertor.
2. Code for the TF-lite dialect [MLIR](https://github.com/tensorflow/mlir).
1. MLIR dialects, transformation passes and utilities for TensorFlow Lite.
## API:
@ -11,7 +11,8 @@ The API for converting TensorFlow models to TensorFlow Lite will be through
`tf.lite.TFLiteConverter`. All the conversion code is open sourced, and
the API will be integrated soon.
### The conversion process from TensorFlow to TensorFlow Lite includes the following major passes:
### The conversion process from TensorFlow to TensorFlow Lite includes the
following major passes:
- Import from GraphDef, in .pb or .pbtxt format, into MLIR.
- Raise to Control-flow-graph. Converts TF Control Flow dialect to TF dialect.
@ -28,3 +29,6 @@ TensorFlow Lite models).
- The Export pass writes out TensorFlow Lite FlatBuffer format. This pass
operates on MLIR TensorFlow Lite dialect and is simple/direct translation.
See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
for the full list of MLIR passes for conversion from TensorFlow to
TensorFlow Lite.

View File

@ -34,8 +34,9 @@ struct PassConfig {
quant_specs(std::move(specs)),
skip_control_dialect(false),
form_clusters(false),
inline_functions(true),
unfold_batch_matmul(true) {}
unfold_batch_matmul(true),
legalize_tf_while(true),
shape_inference(false) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops.
@ -55,12 +56,15 @@ struct PassConfig {
// are formed by grouping consecutive ops of the same device, under a
// `tf_device.launch` op.
bool form_clusters;
// Inline function calls within the main function in the MLIR module, prior
// to legalization to TFLite.
bool inline_functions;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.
bool unfold_batch_matmul;
// Whether to legalize TF While to TFL While.
// Note: This is staging step and will be removed.
// TODO(b/137395003): Remove post switching legalization.
bool legalize_tf_while;
// Whether to do shape inference.
bool shape_inference;
};
} // namespace TFL

View File

@ -28,6 +28,9 @@ limitations under the License.
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
#include "mlir/TableGen/Format.h" // TF:llvm-project
#include "mlir/TableGen/Operator.h" // TF:llvm-project
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
using llvm::DefInit;
using llvm::dyn_cast;
@ -41,6 +44,19 @@ using llvm::SmallVector;
using llvm::StringInit;
using llvm::StringRef;
enum ActionType {
OpConv,
RuntimeVerify,
};
// NOLINTNEXTLINE
llvm::cl::opt<ActionType> action(
llvm::cl::desc("Action to perform:"),
llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters",
"Generate operator converters"),
clEnumValN(RuntimeVerify, "gen-runtime-verifiers",
"Generate TFLite runtime verifiers")));
// Returns the associated option name for the given op definition.
static inline std::string GetOperatorOptionName(const Record &def) {
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
@ -103,6 +119,12 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper,
// conversion generation and so the simplicity was chosen over the
// flexibility.
StringRef arg_name = arg_values->getArgNameStr(i);
// Skip any "intermiadiateXXX" attribute as they are specially handled
// in the exporter. They are special because though they are attributes
// in the MLIR they are expressed as tensors in the flatbuffer instead
// of option.
if (op_name == "LSTMOp" && arg_name.take_back(12) == "intermediate")
continue;
os << formatv(
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
arg_name, mlir::tblgen::Attribute(arg_def).getAttrDefName());
@ -148,17 +170,24 @@ static void EmitOperatorBuilders(const std::vector<Record *> &defs,
for (const auto *def : defs) {
StringRef op_name = def->getName().drop_front(4);
const bool has_intermediates = op_name == "LSTMOp";
// Signature
os << "static flatbuffers::Offset<tflite::Operator> "
<< GetOperatorBuilderName(def->getName()) << "(mlir::TFL::" << op_name
<< " tflOp, uint32_t opcode_index, "
<< "const std::vector<int32_t>& operands,"
<< "const std::vector<int32_t>& results,"
<< (has_intermediates ? "const std::vector<int32_t>& intermediate_index,"
: "")
<< "flatbuffers::FlatBufferBuilder *fbb) {\n";
// Inputs & outputs
os << " auto inputs = fbb->CreateVector(operands);\n"
" auto outputs = fbb->CreateVector(results);\n\n";
// Intermediates for LSTM.
if (has_intermediates) {
os << " auto intermediates = fbb->CreateVector(intermediate_index);\n";
}
// Build the FlatBuffer operator
os << " return tflite::CreateOperator(\n"
@ -175,9 +204,9 @@ static void EmitOperatorBuilders(const std::vector<Record *> &defs,
// Only builtin ops' builders are auto-generated. custom_options are only
// used by custom or flex ops and those ops are handled manually.
os << " /*custom_options=*/0, "
"tflite::CustomOptionsFormat_FLEXBUFFERS,\n"
" /*mutating_variable_inputs=*/0);\n"
"}\n\n";
<< "tflite::CustomOptionsFormat_FLEXBUFFERS,\n"
<< " /*mutating_variable_inputs=*/0"
<< (has_intermediates ? ", intermediates" : "") << ");\n}\n\n";
}
}
@ -228,6 +257,7 @@ static void EmitGetBuiltinOpCode(const std::vector<Record *> &defs,
// uint32_t opcode_index,
// const std::vector<int32_t>& operands,
// const std::vector<int32_t>& results,
// const std::vector<int32_t>& intermediates,
// flatbuffers::FlatBufferBuilder *fbb);
static void EmitBuildOperator(const std::vector<Record *> &defs,
raw_ostream *ostream) {
@ -239,6 +269,7 @@ static void EmitBuildOperator(const std::vector<Record *> &defs,
"uint32_t opcode_index, "
"const std::vector<int32_t>& operands,"
"const std::vector<int32_t>& results,"
"const std::vector<int32_t>& intermediates,"
"flatbuffers::FlatBufferBuilder *fbb) {\n";
for (const auto *def : defs) {
@ -248,7 +279,8 @@ static void EmitBuildOperator(const std::vector<Record *> &defs,
os << " if (auto tflOp = llvm::dyn_cast<mlir::TFL::" << op_name
<< ">(op))\n"
<< " return " << GetOperatorBuilderName(def->getName())
<< "(tflOp, opcode_index, operands, results, fbb);\n";
<< "(tflOp, opcode_index, operands, results, "
<< (op_name == "LSTMOp" ? "intermediates, " : "") << "fbb);\n";
}
os << " return llvm::None;\n"
@ -291,6 +323,10 @@ static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper,
if (!arg_def) continue;
if (arg_def->getDef()->isSubClassOf(attr_type)) {
StringRef arg_name = arg_values->getArgNameStr(i);
// Already handle this case in flatbuffer_import.cc.
if (option_name == "LSTMOptions" &&
arg_name.take_back(12) == "intermediate")
continue;
StringRef attr_type = mlir::tblgen::Attribute(arg_def).getAttrDefName();
os << formatv(
" attributes.emplace_back(builder.getNamedAttr(\"{0}\","
@ -342,8 +378,101 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
return false;
}
static void GenOperandResultVerifier(raw_ostream &os,
llvm::ArrayRef<llvm::Init *> values,
StringRef valueKind) {
mlir::tblgen::FmtContext fctx;
bool first = true;
for (auto static_value : llvm::enumerate(values)) {
auto *definit = llvm::cast<llvm::DefInit>(static_value.value());
auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate");
if (!val) continue;
// Create code block on first type to verify.
if (first) {
os << " {\n";
os << " unsigned index = " << static_value.index() << ";\n";
first = false;
}
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
auto desc =
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
// Emit a loop to check all the dynamic values in the pack.
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
static_value.index());
os << " (void)v;\n"
<< " if (!("
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
<< formatv(
" return op->emitOpError(\"{0} #\") << index "
"<< \" must be {1}, but got \" << v.getType();\n",
valueKind, desc)
<< " }\n" // if
<< " ++index;\n"
<< " }\n"; // for
}
// Emit closing brace if needed.
if (!first) os << " }\n";
}
// NOLINTNEXTLINE
static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os);
// Retrieve all the definitions derived from TFL_Op and sort by record name.
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
llvm::sort(defs, LessRecord());
// Iterate through all the ops defined.
for (const auto *def : defs) {
mlir::tblgen::Operator op(*def);
if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue;
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
auto &value = op.getOperand(i);
// Skip from from first variadic operands for now. Else getOperand index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
}
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
auto &value = op.getResult(i);
// Skip from from first variadic results for now. Else getResult index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
}
}
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
"operand");
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
"result");
os << " return mlir::success();\n}\n";
}
return false;
}
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
return TableGenMain(argv[0], &OperatorWritersMain);
if (action == ActionType::OpConv)
return TableGenMain(argv[0], &OperatorWritersMain);
return TableGenMain(argv[0], &RuntimeVerifierWriterMain);
}

View File

@ -46,7 +46,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
@ -76,6 +76,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -124,6 +125,20 @@ static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
llvm::cl::init(false));
// NOLINTNEXTLINE
static opt<std::string> input_arrays_flag(
"input-arrays",
llvm::cl::desc(
"List of input tensors, if different from the default inputs"),
llvm::cl::init(""));
// NOLINTNEXTLINE
static opt<std::string> output_arrays_flag(
"output-arrays",
llvm::cl::desc(
"List of output tensors, if different from the default outputs"),
llvm::cl::init(""));
namespace {
bool IsScalar(const TensorT& tensor) {
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
@ -532,6 +547,7 @@ bool IsCustomOp(const std::string& op_name) {
// TODO(krzysd) Handle function calls
StatusOr<Operation*> ConvertOp(
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
const std::vector<mlir::TensorType>& intermediate_types,
Value optional_arg_marker, const std::vector<std::string>& op_names,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
@ -590,6 +606,33 @@ StatusOr<Operation*> ConvertOp(
op_state.addTypes({type});
}
if (op_name == "tfl.lstm") {
// TODO(b/147587779): add the right region if region is empty.
op_state.addRegion();
if (!op.intermediates.empty()) {
if (op.intermediates.size() != 5) {
auto err = errors::InvalidArgument(
"operator has intermediate tensors but the number of them is not "
"five.");
return emitError(loc, err.ToString()), err;
}
// Create intermediate value
const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
"input_to_input_intermediate", "input_to_forget_intermediate",
"input_to_cell_intermediate", "input_to_output_intermediate",
"effective_hidden_scale_intermediate"};
for (auto type_and_name :
llvm::zip(intermediate_types, kIntermediateNames)) {
mlir::TypeAttr type_attr =
mlir::TypeAttr::get(std::get<0>(type_and_name));
auto named_attr =
builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
op_state.addAttribute(named_attr.first, named_attr.second);
}
}
}
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
if (IsCustomOp(op_name)) {
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
@ -610,43 +653,30 @@ StatusOr<Operation*> ConvertOp(
return builder.createOperation(op_state);
}
// Returns the output tensor indices for the given subgraph. If
// ordered_output_arrays is provided, then return the tensor indices in
// ordered_output_arrays.
StatusOr<llvm::SmallVector<int32_t, 4>> GetOutputTensorIndices(
const tflite::SubGraphT& subgraph, Location base_loc,
const std::vector<std::string>& ordered_output_arrays) {
if (ordered_output_arrays.empty()) {
return llvm::SmallVector<int32_t, 4>(subgraph.outputs.begin(),
subgraph.outputs.end());
// Returns indices of the given tensors in the subgraph. Returns error if a
// tensor name cannot be found in the subgraph.
StatusOr<std::vector<int>> GetTensorIndices(
const tflite::SubGraphT& subgraph,
const std::vector<std::string>& tensor_names) {
absl::flat_hash_map<std::string, int> name_to_index;
for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
}
llvm::SmallVector<int32_t, 4> outputs;
outputs.resize(ordered_output_arrays.size());
absl::flat_hash_map<std::string, int> output_order_map;
for (auto output : llvm::enumerate(ordered_output_arrays)) {
output_order_map[output.value()] = output.index();
}
std::vector<int> indices;
indices.reserve(tensor_names.size());
int tensor_index = 0;
int found_output_tensors = 0;
for (const auto& tensor : subgraph.tensors) {
auto found = output_order_map.find(tensor->name);
if (found != output_order_map.end()) {
const int output_index = found->second;
outputs[output_index] = tensor_index;
++found_output_tensors;
for (const auto& name : tensor_names) {
auto found = name_to_index.find(name);
if (found != name_to_index.end()) {
indices.push_back(found->second);
} else {
return errors::InvalidArgument("could not find tensor in subgraph: ",
name);
}
++tensor_index;
}
if (found_output_tensors != ordered_output_arrays.size()) {
auto err = errors::InvalidArgument(
"cannot find all nodes in ordered_output_arrays");
return emitError(base_loc, err.ToString()), err;
}
return outputs;
return indices;
}
// Given a list of tensor indices, returns a string of concatenated tensor names
@ -661,15 +691,18 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
}
// Given a list of output indices, traverses the subgraph and returns the set of
// ops that are ancestors of the output tensors.
// Traverses the subgraph from output_indices to input_indices and returns the
// set of ops that are visited.
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
ArrayRef<int32_t> output_indices) {
// Create a map from tensor index to defining op.
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
for (const auto& op : subgraph.operators) {
for (int32_t output : op->outputs) {
defining_op[output] = op.get();
if (!llvm::is_contained(input_indices, output)) {
defining_op[output] = op.get();
}
}
}
@ -718,18 +751,40 @@ StatusOr<FuncOp> ConvertSubgraph(
const std::vector<std::string>& op_names,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
Location base_loc, Builder builder,
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
Location base_loc, Builder builder, bool is_entry_point,
bool use_external_constant,
const std::vector<std::string>& ordered_input_arrays,
const std::vector<std::string>& ordered_output_arrays,
bool experimental_prune_unreachable_nodes_unconditionally) {
llvm::SmallVector<mlir::Type, 2> ret_types;
llvm::SmallVector<mlir::Type, 4> input_types;
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
// Construct function type
for (auto input : subgraph.inputs) {
auto& tensor = *subgraph.tensors.at(input);
std::vector<int> func_inputs = subgraph.inputs;
if (is_entry_point && !ordered_input_arrays.empty()) {
if (!experimental_prune_unreachable_nodes_unconditionally) {
// TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
return errors::InvalidArgument(
"input-arrays should be used with experimental pruning flag");
}
TF_ASSIGN_OR_RETURN(func_inputs,
GetTensorIndices(subgraph, ordered_input_arrays));
}
// Add state variables to inputs.
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
func_inputs.end());
for (int i = 0; i < subgraph.tensors.size(); i++) {
auto& tensor = *subgraph.tensors.at(i);
if (tensor.is_variable && !input_index_set.contains(i)) {
func_inputs.emplace_back(i);
input_index_set.insert(i);
}
}
for (auto input_or_variable : func_inputs) {
auto& tensor = *subgraph.tensors.at(input_or_variable);
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
// but we cannot differentiate scalars from unranked tensors.
// Here we reverse the default assumption that shape = [] means unranked.
@ -753,9 +808,11 @@ StatusOr<FuncOp> ConvertSubgraph(
}
}
TF_ASSIGN_OR_RETURN(
auto func_outputs,
GetOutputTensorIndices(subgraph, base_loc, ordered_output_arrays));
std::vector<int> func_outputs = subgraph.outputs;
if (is_entry_point && !ordered_output_arrays.empty()) {
TF_ASSIGN_OR_RETURN(func_outputs,
GetTensorIndices(subgraph, ordered_output_arrays));
}
for (auto output : func_outputs) {
bool is_constant = !is_op_output[output];
@ -782,8 +839,8 @@ StatusOr<FuncOp> ConvertSubgraph(
Value maybe_optional_arg_marker = nullptr;
// Get or construct MLIR values for each input
for (int i = 0, e = subgraph.inputs.size(); i < e; i++) {
auto input_tensor = subgraph.inputs[i];
for (int i = 0, e = func_inputs.size(); i < e; i++) {
auto input_tensor = func_inputs[i];
const auto& tensor = *subgraph.tensors.at(input_tensor);
auto loc = TensorLoc(tensor, builder, base_loc);
if (vals_map[input_tensor]) {
@ -806,9 +863,9 @@ StatusOr<FuncOp> ConvertSubgraph(
// Set tf.entry_function attribute
if (is_entry_point) {
llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
if (!subgraph.inputs.empty()) {
if (!func_inputs.empty()) {
attributes.push_back(BuildTFEntryFunctionAttribute(
subgraph, &builder, "inputs", subgraph.inputs));
subgraph, &builder, "inputs", func_inputs));
}
if (!func_outputs.empty()) {
attributes.push_back(BuildTFEntryFunctionAttribute(
@ -820,7 +877,7 @@ StatusOr<FuncOp> ConvertSubgraph(
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
if (experimental_prune_unreachable_nodes_unconditionally) {
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
PruneSubgraph(subgraph, func_outputs));
PruneSubgraph(subgraph, func_inputs, func_outputs));
}
// Construct MLIR operators from TFLite operators
@ -859,6 +916,18 @@ StatusOr<FuncOp> ConvertSubgraph(
}
}
// Intermediate tensors for tfl.lstm are used to carry quantization range
// in their types, so we only need and extract their types.
std::vector<mlir::TensorType> intermediate_types;
intermediate_types.reserve(5);
for (auto intermediate : op->intermediates) {
TF_ASSIGN_OR_RETURN(
auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
intermediate_types.emplace_back(type);
}
// The NameLoc corresponding to the name of the first output tensor
auto op_loc =
op->outputs.empty()
@ -868,8 +937,8 @@ StatusOr<FuncOp> ConvertSubgraph(
// to a valid Value
TF_ASSIGN_OR_RETURN(
auto* mlir_op,
ConvertOp(*op, vals_map, maybe_optional_arg_marker, op_names,
func_names, subgraph.tensors, op_loc, op_builder));
ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
op_names, func_names, subgraph.tensors, op_loc, op_builder));
// Add the results to the value maps. There are two cases: 1. the result
// tensor does not have min/max values, the original op result is used
@ -931,8 +1000,9 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
OwningModuleRef tflite::FlatBufferToMlir(
absl::string_view buffer, MLIRContext* context, Location base_loc,
const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant,
const std::vector<std::string>& ordered_input_arrays,
const std::vector<std::string>& ordered_output_arrays,
bool experimental_prune_unreachable_nodes_unconditionally) {
auto model_ptr =
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
@ -971,33 +1041,25 @@ OwningModuleRef tflite::FlatBufferToMlir(
builder.getStringAttr(model->description));
}
if (!ordered_output_arrays.empty() && model->subgraphs.size() > 1) {
// TODO(b/141485522): support more than one subgraph.
return emitError(base_loc,
"ordered_output_arrays does not support more than one "
"subgraph yet"),
nullptr;
}
for (auto e : llvm::enumerate(model->subgraphs)) {
auto& subgraph = e.value();
std::string name = SubgraphName(e.index(), *subgraph);
auto func_or_error = ConvertSubgraph(
*subgraph, name, operator_names, func_names, model->buffers, base_loc,
// Only the entry point needs pseudo_input_ops
builder,
// TODO(b/131175224,b/132239787) Support multiple entry points
builder, ordered_output_arrays,
/*is_entry_point=*/e.index() == 0,
/*use_external_constant=*/use_external_constant,
/*use_external_constant=*/use_external_constant, ordered_input_arrays,
ordered_output_arrays,
experimental_prune_unreachable_nodes_unconditionally);
if (!func_or_error.ok()) {
return emitError(base_loc, "could not translate function ")
<< subgraph->name,
<< subgraph->name << ": "
<< func_or_error.status().error_message(),
nullptr;
}
module.push_back(func_or_error.ConsumeValueOrDie());
}
// TFLite subgraphs do not necessarily have names,
return OwningModuleRef(module);
}
@ -1012,17 +1074,24 @@ static OwningModuleRef FlatBufferFileToMlirTrans(
auto loc =
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
// Parses output_arrays_order from command line option.
// Parses input/output names from command line options.
std::vector<std::string> inputs;
std::vector<std::string> outputs;
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) {
// Use output parser since we only have tensor names.
if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) {
return emitError(loc, "parsing input array info failed ")
<< input_arrays_flag,
nullptr;
}
if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) {
return emitError(loc, "parsing output array info failed ")
<< output_arrays_string,
<< output_arrays_flag,
nullptr;
}
return tflite::FlatBufferToMlir(
absl::string_view(input->getBufferStart(), input->getBufferSize()),
context, loc, outputs, use_external_constant,
context, loc, use_external_constant, inputs, outputs,
experimental_prune_unreachable_nodes_unconditionally);
}

View File

@ -35,9 +35,9 @@ namespace tflite {
// are not ancestors of the output nodes will be pruned.
mlir::OwningModuleRef FlatBufferToMlir(
absl::string_view buffer, mlir::MLIRContext* context,
mlir::Location base_loc,
const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant = false,
mlir::Location base_loc, bool use_external_constant = false,
const std::vector<std::string>& ordered_input_arrays = {},
const std::vector<std::string>& ordered_output_arrays = {},
bool experimental_prune_unreachable_nodes_unconditionally = false);
} // namespace tflite

View File

@ -44,6 +44,7 @@ llvm::Optional<tflite::BuiltinOperator> GetBuiltinOpCode(Operation *mlir_op);
llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
Operation *mlir_op, uint32_t opcode_index,
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
const std::vector<int32_t> &intermediates,
flatbuffers::FlatBufferBuilder *fbb);
// Populates the array of mlir::NamedAttributes corresponding to the given

Some files were not shown because too many files have changed in this diff Show More