Merge branch 'master' into google-upstream-gpuprim
This commit is contained in:
commit
d25e90e5cf
.bazelrc.bazelversion
.github/ISSUE_TEMPLATE
00-bug-issue.md00-bug-performance-issue.md10-build-installation-issue.md20-documentation-issue.md30-feature-request.md40-tflite-op-request.md50-other-issues.md60-tflite-converter-issue.md80-performance-issue.md
.gitignore.pylintrcREADME.mdWORKSPACEconfigure.pytensorflow
BUILDapi_template.__init__.pyapi_template_v1.__init__.py
c
BUILDc_api.ccc_api_experimental.ccc_api_function.ccc_api_internal.hc_api_test.cc
eager
BUILDc_api.ccc_api_experimental.ccc_api_experimental.hc_api_internal.hc_api_remote_test.ccc_api_test.ccc_api_test_util.ccc_api_test_util.hcustom_device_test.ccdlpack.ccdlpack.hoperation_interface.ccoperation_interface.h
kernels
kernels_test.cctf_tensor.cctf_tensor.htf_tensor_internal.hcc
compat_template.__init__.pycompat_template_v1.__init__.pycompiler
aot
BUILDcodegen.cccodegen_test.cccodegen_test_h.goldencompile.ccflags.ccflags.h
tests
BUILDmake_test_graphs.pytest_error_message.lit.pbtxttest_error_message.lit.pbtxt.config.pbtxttest_error_message.lit.pbtxt.debug.pbtxttest_error_message.lit.pbtxt.fake_py.debug
tfcompile.bzltfcompile_main.ccjit
BUILDflags.cc
kernels
mark_for_compilation_pass.ccxla_activity_listener.ccxla_compilation_cache.ccxla_compilation_cache.hxla_device.ccxla_device.hxla_device_context.ccxla_device_context.hxla_kernel_creator.ccxla_kernel_creator.hxla_kernel_creator_test.ccxla_kernel_creator_util.ccxla_launch_util.hxla_tensor.hmlir
66
.bazelrc
66
.bazelrc
@ -46,7 +46,6 @@
|
||||
# sycl_asan:
|
||||
# sycl_trisycl:
|
||||
# mkl: Enable full mkl support.
|
||||
# mkl_open_source_only: Enable MKL support only using open source MKL libraries.
|
||||
# tensorrt: Enable Tensorrt support.
|
||||
# ngraph: Enable ngraph support.
|
||||
# numa: Enable numa using hwloc.
|
||||
@ -137,15 +136,9 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
|
||||
# environment variable "TF_MKL_ROOT" every time before build.
|
||||
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl -c opt
|
||||
|
||||
# This config option is used to enable MKL-DNN open source library only,
|
||||
# without depending on MKL binary version.
|
||||
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
|
||||
build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
@ -222,6 +215,11 @@ build --define=grpc_no_ares=true
|
||||
# archives in -whole_archive -no_whole_archive.
|
||||
build --noincompatible_remove_legacy_whole_archive
|
||||
|
||||
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
|
||||
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
|
||||
# https://github.com/tensorflow/community/pull/179
|
||||
build --noincompatible_prohibit_aapt1
|
||||
|
||||
# Modular TF build options
|
||||
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
||||
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
||||
@ -242,6 +240,7 @@ build:windows --copt=/w
|
||||
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
|
||||
# _USE_MATH_DEFINES is defined.
|
||||
build:windows --copt=/D_USE_MATH_DEFINES
|
||||
build:windows --host_copt=/D_USE_MATH_DEFINES
|
||||
|
||||
# Default paths for TF_SYSTEM_LIBS
|
||||
build:linux --define=PREFIX=/usr
|
||||
@ -314,22 +313,26 @@ build:xla --define=with_xla_support=true
|
||||
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
||||
# Options when using remote execution
|
||||
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
||||
|
||||
# Flag to enable remote config
|
||||
common --experimental_repo_remote_exec
|
||||
|
||||
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
|
||||
build:rbe --auth_enabled=true
|
||||
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
|
||||
build:rbe --google_default_credentials
|
||||
build:rbe --bes_backend=buildeventservice.googleapis.com
|
||||
build:rbe --bes_best_effort=false
|
||||
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
|
||||
build:rbe --bes_timeout=600s
|
||||
build:rbe --define=EXECUTOR=remote
|
||||
build:rbe --distinct_host_configuration=false
|
||||
build:rbe --flaky_test_attempts=3
|
||||
build:rbe --jobs=200
|
||||
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
|
||||
build:rbe --remote_timeout=3600
|
||||
build:rbe --spawn_strategy=remote,worker,standalone,local
|
||||
test:rbe --test_env=USER=anon
|
||||
|
||||
build:rbe --distinct_host_configuration=false
|
||||
# Attempt to minimize the amount of data transfer between bazel and the remote
|
||||
# workers:
|
||||
build:rbe --remote_download_toplevel
|
||||
|
||||
build:rbe_linux --config=rbe
|
||||
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
|
||||
@ -354,13 +357,14 @@ build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe
|
||||
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux
|
||||
build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
|
||||
build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010-gpu"
|
||||
build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
||||
build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0"
|
||||
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
|
||||
@ -377,18 +381,17 @@ build:rbe_linux_py2 --python_path="/usr/bin/python2"
|
||||
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
|
||||
|
||||
build:rbe_linux_py3 --config=rbe_linux
|
||||
build:rbe_linux_py3 --repo_env=PYTHON_BIN_PATH="/usr/bin/python3"
|
||||
build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
|
||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
|
||||
|
||||
build:rbe_win --config=rbe
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
|
||||
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:cc-toolchain-x64_windows"
|
||||
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
|
||||
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
||||
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
|
||||
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain"
|
||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows"
|
||||
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
|
||||
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
|
||||
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
|
||||
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
|
||||
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
|
||||
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
|
||||
|
||||
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
|
||||
@ -396,9 +399,7 @@ build:rbe_win --define=override_eigen_strong_inline=true
|
||||
build:rbe_win --jobs=500
|
||||
|
||||
build:rbe_win_py37 --config=rbe
|
||||
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
|
||||
build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
|
||||
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
|
||||
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python"
|
||||
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
||||
|
||||
build:rbe_win_py38 --config=rbe
|
||||
@ -416,7 +417,6 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
|
||||
|
||||
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
|
||||
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
|
||||
build:tensorflow_testing_rbe_win --config=rbe_win
|
||||
# END TF REMOTE BUILD EXECUTION OPTIONS
|
||||
|
||||
# Default options should come above this line
|
||||
|
@ -1 +1 @@
|
||||
1.2.1
|
||||
2.0.0
|
||||
|
44
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
Normal file
44
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
Normal file
@ -0,0 +1,44 @@
|
||||
---
|
||||
name: Bug Issue
|
||||
about: Use this template for reporting a bug
|
||||
labels: 'type:bug'
|
||||
|
||||
---
|
||||
|
||||
<em>Please make sure that this is a bug. As per our
|
||||
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
|
||||
we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:bug_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
**Other info / logs** Include any logs or source code that would be helpful to
|
||||
diagnose the problem. If including tracebacks, please include the full
|
||||
traceback. Large logs and files should be attached.
|
@ -1,35 +0,0 @@
|
||||
---
|
||||
name: Bug/Performance Issue
|
||||
about: Use this template for reporting a bug or a performance issue.
|
||||
|
||||
---
|
||||
|
||||
<em>Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or binary):
|
||||
- TensorFlow version (use command below):
|
||||
- Python version:
|
||||
- Bazel version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from source):
|
||||
- CUDA/cuDNN version:
|
||||
- GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
|
||||
|
||||
**Other info / logs**
|
||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
@ -1,7 +1,7 @@
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
name: Build/Installation Issue about: Use this template for build/installation
|
||||
issues labels: 'type:build/install'
|
||||
---
|
||||
name: Build/Installation Issue
|
||||
about: Use this template for build/installation issues
|
||||
labels: 'type:build/install'
|
||||
|
||||
---
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
---
|
||||
name: Documentation Issue
|
||||
about: Use this template for documentation related
|
||||
about: Use this template for documentation related issues
|
||||
labels: 'type:docs'
|
||||
|
||||
---
|
||||
|
||||
|
||||
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
|
||||
policy, we only address code/doc bugs, performance issues, feature requests, and
|
||||
build/installation issues on GitHub.
|
||||
|
1
.github/ISSUE_TEMPLATE/30-feature-request.md
vendored
1
.github/ISSUE_TEMPLATE/30-feature-request.md
vendored
@ -1,6 +1,7 @@
|
||||
---
|
||||
name: Feature Request
|
||||
about: Use this template for raising a feature request
|
||||
labels: 'type:feature'
|
||||
|
||||
---
|
||||
|
||||
|
17
.github/ISSUE_TEMPLATE/40-tflite-op-request.md
vendored
17
.github/ISSUE_TEMPLATE/40-tflite-op-request.md
vendored
@ -1,11 +1,10 @@
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
name: TensorFlow Lite Op Request about: Use this template for reporting ops you
|
||||
are using or missing. labels: 'comp:lite'
|
||||
---
|
||||
name: TensorFlow Lite Op Request
|
||||
about: Use this template for reporting Lite ops you are using or missing
|
||||
labels: 'comp:lite'
|
||||
|
||||
---
|
||||
|
||||
|
||||
**System information**
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- TensorFlow installed from (source or binary):
|
||||
@ -18,8 +17,14 @@ are using or missing. labels: 'comp:lite'
|
||||
# Copy and paste here
|
||||
```
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
Also, please include a link to a GraphDef or the model if possible.
|
||||
|
||||
**Any other info / logs**
|
||||
|
||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
||||
Include any logs or source code that would be helpful to diagnose the problem.
|
||||
If including tracebacks, please include the full traceback. Large logs and files
|
||||
should be attached.
|
||||
|
8
.github/ISSUE_TEMPLATE/50-other-issues.md
vendored
8
.github/ISSUE_TEMPLATE/50-other-issues.md
vendored
@ -1,7 +1,7 @@
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
name: Other Issues about: Use this template for any other non-support related
|
||||
issues labels: 'type:others'
|
||||
---
|
||||
name: Other Issues
|
||||
about: Use this template for any other non-support related issues
|
||||
labels: 'type:others'
|
||||
|
||||
---
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
---
|
||||
name: TensorFlow Lite New Converter Issue
|
||||
about: Use this template for reporting issues during model conversion to TFLite.
|
||||
about: Use this template for reporting issues during model conversion to TFLite
|
||||
labels: 'TFLiteConverter'
|
||||
|
||||
---
|
||||
|
||||
@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite.
|
||||
|
||||
|
||||
**Command used to run the converter or code if you’re using the Python API**
|
||||
If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
```
|
||||
# Copy and paste here the exact command
|
||||
|
45
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
Normal file
45
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
---
|
||||
name: Performance Issue
|
||||
about: Use this template for reporting a performance issue
|
||||
labels: 'type:performance'
|
||||
|
||||
---
|
||||
|
||||
<em>Please make sure that this is an issue related to performance of TensorFlow.
|
||||
As per our
|
||||
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
|
||||
we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:performance_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
**Other info / logs** Include any logs or source code that would be helpful to
|
||||
diagnose the problem. If including tracebacks, please include the full
|
||||
traceback. Large logs and files should be attached.
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -22,6 +22,7 @@ tensorflow/contrib/cmake/_build/
|
||||
/tensorflow/python/framework/fast_tensor_util.cpp
|
||||
/tensorflow/lite/gen/**
|
||||
/tensorflow/lite/tools/make/downloads/**
|
||||
/tensorflow/lite/tools/make/gen/**
|
||||
/api_init_files_list.txt
|
||||
/estimator_api_init_files_list.txt
|
||||
*.whl
|
||||
@ -37,7 +38,9 @@ gradleBuild
|
||||
*.pbxproj
|
||||
*.xcworkspace
|
||||
/*.podspec
|
||||
/tensorflow/lite/**/[ios|objc|swift]*/BUILD
|
||||
/tensorflow/lite/**/ios/BUILD
|
||||
/tensorflow/lite/**/objc/BUILD
|
||||
/tensorflow/lite/**/swift/BUILD
|
||||
/tensorflow/lite/examples/ios/simple/data/*.tflite
|
||||
/tensorflow/lite/examples/ios/simple/data/*.txt
|
||||
Podfile.lock
|
||||
|
18
README.md
18
README.md
@ -70,7 +70,7 @@ $ python
|
||||
3
|
||||
>>> hello = tf.constant('Hello, TensorFlow!')
|
||||
>>> hello.numpy()
|
||||
'Hello, TensorFlow!'
|
||||
b'Hello, TensorFlow!'
|
||||
```
|
||||
|
||||
For more examples, see the
|
||||
@ -130,18 +130,20 @@ Build Type | Status
|
||||
## Resources
|
||||
|
||||
* [TensorFlow.org](https://www.tensorflow.org)
|
||||
* [TensorFlow tutorials](https://www.tensorflow.org/tutorials/)
|
||||
* [TensorFlow official models](https://github.com/tensorflow/models/tree/master/official)
|
||||
* [TensorFlow examples](https://github.com/tensorflow/examples)
|
||||
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
|
||||
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
|
||||
* [TensorFlow Examples](https://github.com/tensorflow/examples)
|
||||
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [TensorFlow blog](https://blog.tensorflow.org)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
|
||||
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
||||
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
|
||||
* [TensorFlow roadmap](https://www.tensorflow.org/community/roadmap)
|
||||
* [TensorFlow white papers](https://www.tensorflow.org/about/bib)
|
||||
* [TensorBoard visualization toolkit](https://github.com/tensorflow/tensorboard)
|
||||
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
|
||||
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
|
||||
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)
|
||||
|
||||
Learn more about the
|
||||
[TensorFlow community](https://www.tensorflow.org/community) and how to
|
||||
|
25
WORKSPACE
25
WORKSPACE
@ -113,3 +113,28 @@ http_archive(
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||
],
|
||||
)
|
||||
|
||||
# Required for dependency @com_github_grpc_grpc
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
||||
grpc_deps()
|
||||
|
||||
load(
|
||||
"@build_bazel_rules_apple//apple:repositories.bzl",
|
||||
"apple_rules_dependencies",
|
||||
)
|
||||
|
||||
apple_rules_dependencies()
|
||||
|
||||
load(
|
||||
"@build_bazel_apple_support//lib:repositories.bzl",
|
||||
"apple_support_dependencies",
|
||||
)
|
||||
|
||||
apple_support_dependencies()
|
||||
|
||||
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
|
||||
|
||||
bazel_version_repository(name = "bazel_version")
|
||||
|
||||
|
@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '1.2.1'
|
||||
_TF_MAX_BAZEL_VERSION = '1.2.1'
|
||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '2.0.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||
@ -1390,9 +1390,8 @@ def main():
|
||||
else:
|
||||
environ_cp['TF_CONFIGURE_IOS'] = '0'
|
||||
|
||||
xla_enabled_by_default = is_linux() or is_macos()
|
||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
||||
xla_enabled_by_default, 'xla')
|
||||
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
|
||||
write_to_bazelrc('build --config=xla')
|
||||
|
||||
set_action_env_var(
|
||||
environ_cp,
|
||||
|
@ -187,6 +187,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "fuchsia",
|
||||
values = {"cpu": "fuchsia"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "ios_x86_64",
|
||||
values = {
|
||||
@ -448,19 +454,66 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Specifies via a config setting if this is a mobile build or not, makes
|
||||
# it easier to combine settings later.
|
||||
selects.config_setting_group(
|
||||
name = "mobile",
|
||||
match_any = [
|
||||
":android",
|
||||
":chromiumos",
|
||||
":emscripten",
|
||||
":ios",
|
||||
],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "lite_protos_legacy",
|
||||
values = {"define": "TENSORFLOW_PROTOS=lite"},
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "full_protos",
|
||||
values = {"define": "TENSORFLOW_PROTOS=full"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "lite_protos",
|
||||
match_any = [":lite_protos_legacy"],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "mobile_lite_protos",
|
||||
match_all = [
|
||||
":lite_protos",
|
||||
":mobile",
|
||||
],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "mobile_full_protos",
|
||||
match_all = [
|
||||
":full_protos",
|
||||
":mobile",
|
||||
],
|
||||
)
|
||||
|
||||
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
||||
# Instead, please use public APIs or public build rules TF provides.
|
||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/...",
|
||||
"//learning/brain/swift/x10/...",
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
],
|
||||
)
|
||||
|
||||
@ -494,8 +547,8 @@ cc_library(
|
||||
name = "grpc",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":linux_s390x": ["@grpc//:grpc_unsecure"],
|
||||
"//conditions:default": ["@grpc"],
|
||||
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
|
||||
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
|
||||
}),
|
||||
)
|
||||
|
||||
@ -503,8 +556,8 @@ cc_library(
|
||||
name = "grpc++",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":linux_s390x": ["@grpc//:grpc++_unsecure"],
|
||||
"//conditions:default": ["@grpc//:grpc++"],
|
||||
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
|
||||
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
|
||||
}),
|
||||
)
|
||||
|
||||
@ -909,7 +962,6 @@ py_library(
|
||||
"//conditions:default": [":tf_python_api_gen_v1"],
|
||||
}) + [
|
||||
":root_init_gen",
|
||||
":virtual_root_init_gen",
|
||||
"//tensorflow/python/keras/api:keras_python_api_gen",
|
||||
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v1",
|
||||
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v2",
|
||||
|
@ -35,9 +35,11 @@ import inspect as _inspect
|
||||
import logging as _logging
|
||||
import os as _os
|
||||
import site as _site
|
||||
import six as _six
|
||||
import sys as _sys
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
@ -69,13 +71,13 @@ except ImportError:
|
||||
_logging.warning(
|
||||
"Limited tf.summary API due to missing TensorBoard installation.")
|
||||
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
try:
|
||||
from .python.keras.api._v2 import keras
|
||||
@ -85,6 +87,13 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if not _six.PY2:
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
# Enable TF2 behaviors
|
||||
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
||||
|
@ -22,12 +22,14 @@ import distutils as _distutils
|
||||
import inspect as _inspect
|
||||
import os as _os
|
||||
import site as _site
|
||||
import six as _six
|
||||
import sys as _sys
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.platform import tf_logging as _logging
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
@ -64,13 +66,14 @@ elif _tf_api_dir not in __path__:
|
||||
# reexport_tf_summary can get compat from sys.modules. Only needed if using
|
||||
# lazy loading.
|
||||
_current_module.compat.v2 # pylint: disable=pointless-statement
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
try:
|
||||
from .python.keras.api._v1 import keras
|
||||
@ -80,6 +83,13 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if not _six.PY2:
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
|
||||
_CONTRIB_WARNING = """
|
||||
|
@ -154,7 +154,10 @@ tf_cuda_library(
|
||||
"c_api.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
visibility = [
|
||||
"//tensorflow/c:__subpackages__",
|
||||
"//third_party/llvm/llvm-project:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":c_api_internal",
|
||||
":tf_attrtype",
|
||||
@ -242,7 +245,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
@ -536,6 +539,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/kernels:array",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:math",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
],
|
||||
)
|
||||
|
||||
@ -697,4 +701,5 @@ tf_cuda_library(
|
||||
# TODO(b/74620627): remove when _USE_C_SHAPES is removed
|
||||
"//tensorflow/python:cpp_shape_inference_proto_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -774,7 +774,7 @@ extern "C" {
|
||||
static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
|
||||
const char* op_type,
|
||||
const char* oper_name)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||
return new TF_OperationDescription(graph, op_type, oper_name);
|
||||
}
|
||||
|
||||
@ -1032,7 +1032,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
|
||||
|
||||
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
|
||||
TF_Status* status)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
|
||||
Node* ret = nullptr;
|
||||
|
||||
if (desc->graph->name_map.count(desc->node_builder.node_name())) {
|
||||
@ -1706,7 +1706,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
|
||||
const TF_ImportGraphDefOptions* opts,
|
||||
TF_ImportGraphDefResults* tf_results,
|
||||
TF_Status* status)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||
const int last_node_id = graph->graph.num_node_ids();
|
||||
tensorflow::ImportGraphDefResults results;
|
||||
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
@ -816,12 +817,15 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
node_def.set_name(tfe_op->operation.Name());
|
||||
node_def.set_op(tfe_op->operation.Name());
|
||||
node_def.set_name(tfe_op->operation->Name());
|
||||
node_def.set_op(tfe_op->operation->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
node_def.add_input("dummy_input");
|
||||
}
|
||||
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
|
||||
tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
tfe_op->operation.get())
|
||||
->Attrs()
|
||||
.FillAttrValueMap(node_def.mutable_attr());
|
||||
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
status->status =
|
||||
|
@ -51,7 +51,7 @@ Status ProcessInputs(
|
||||
const TF_Graph* fn_body, const char* fn_name, int ninputs,
|
||||
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
|
||||
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
input_tensors->reserve(ninputs);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
Node* node = &inputs[i].oper->node;
|
||||
@ -87,7 +87,7 @@ Status ProcessInputs(
|
||||
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
|
||||
int noutputs, const TF_Output* outputs,
|
||||
std::vector<OutputTensor>* output_tensors)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
output_tensors->reserve(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
Node* node = &outputs[i].oper->node;
|
||||
@ -111,7 +111,7 @@ Status ComputeBodyNodes(
|
||||
const TF_Operation* const* opers,
|
||||
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
|
||||
std::vector<const Node*>* body_nodes)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
if (num_opers == -1) {
|
||||
for (const Node* node : fn_body->graph.op_nodes()) {
|
||||
const auto& iter = input_nodes.find(node);
|
||||
|
@ -71,14 +71,14 @@ struct TF_Graph {
|
||||
TF_Graph();
|
||||
|
||||
tensorflow::mutex mu;
|
||||
tensorflow::Graph graph GUARDED_BY(mu);
|
||||
tensorflow::Graph graph TF_GUARDED_BY(mu);
|
||||
|
||||
// Runs shape inference.
|
||||
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
|
||||
tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);
|
||||
|
||||
// Maps from name of an operation to the Node* in 'graph'.
|
||||
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
|
||||
GUARDED_BY(mu);
|
||||
TF_GUARDED_BY(mu);
|
||||
|
||||
// The keys of this map are all the active sessions using this graph. Each
|
||||
// value records whether the graph has been mutated since the corresponding
|
||||
@ -94,8 +94,8 @@ struct TF_Graph {
|
||||
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
|
||||
// status, this should be reverted when possible.
|
||||
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
|
||||
GUARDED_BY(mu);
|
||||
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
|
||||
TF_GUARDED_BY(mu);
|
||||
bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph
|
||||
|
||||
// Used to link graphs contained in TF_WhileParams to the parent graph that
|
||||
// will eventually contain the full while loop.
|
||||
@ -123,7 +123,7 @@ struct TF_Session {
|
||||
tensorflow::Session* session;
|
||||
TF_Graph* const graph;
|
||||
|
||||
tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
|
||||
tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu);
|
||||
int last_num_graph_nodes;
|
||||
|
||||
// If true, TF_SessionRun and similar methods will call
|
||||
@ -169,9 +169,9 @@ struct TF_ApiDefMap {
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
|
||||
tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock);
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
bool update_docs_called GUARDED_BY(lock);
|
||||
bool update_docs_called TF_GUARDED_BY(lock);
|
||||
tensorflow::mutex lock;
|
||||
};
|
||||
|
||||
@ -210,10 +210,10 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
|
||||
|
||||
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
|
||||
const char* mutation_type)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
|
||||
|
||||
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
|
||||
LOCKS_EXCLUDED(session->graph->mu, session->mu);
|
||||
TF_LOCKS_EXCLUDED(session->graph->mu, session->mu);
|
||||
|
||||
std::string getTF_OutputDebugString(TF_Output node);
|
||||
|
||||
|
@ -45,6 +45,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
@ -193,8 +195,9 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
{
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Library* lib =
|
||||
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
|
||||
string lib_path = tensorflow::GetDataDependencyFilepath(
|
||||
tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
|
||||
TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
|
||||
TF_Code code = TF_GetCode(status);
|
||||
string status_msg(TF_Message(status));
|
||||
TF_DeleteStatus(status);
|
||||
@ -1350,9 +1353,9 @@ TEST_F(CApiColocationTest, ClearViaProto) {
|
||||
|
||||
TEST(CAPI, SavedModel) {
|
||||
// Load the saved model.
|
||||
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
|
||||
const string saved_model_dir = tensorflow::io::JoinPath(
|
||||
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
|
||||
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
|
||||
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two", "00000123"));
|
||||
TF_SessionOptions* opt = TF_NewSessionOptions();
|
||||
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
|
||||
TF_Buffer* metagraph = TF_NewBuffer();
|
||||
@ -1426,9 +1429,9 @@ TEST(CAPI, SavedModel) {
|
||||
}
|
||||
|
||||
TEST(CAPI, SavedModelNullArgsAreValid) {
|
||||
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
|
||||
const string saved_model_dir = tensorflow::io::JoinPath(
|
||||
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
|
||||
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
|
||||
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two", "00000123"));
|
||||
TF_SessionOptions* opt = TF_NewSessionOptions();
|
||||
TF_Status* s = TF_NewStatus();
|
||||
const char* tags[] = {tensorflow::kSavedModelTagServe};
|
||||
|
@ -28,6 +28,8 @@ tf_cuda_library(
|
||||
"c_api_debug.cc",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"operation_interface.cc",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
@ -56,6 +58,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
@ -92,6 +95,8 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
visibility = [
|
||||
@ -104,6 +109,7 @@ tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
@ -128,6 +134,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
],
|
||||
)
|
||||
|
||||
@ -199,6 +206,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -256,8 +264,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/profiler/rpc:profiler_server",
|
||||
"//tensorflow/core/profiler/rpc/client:capture_profile",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -323,10 +329,34 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api.h",
|
||||
"c_api_experimental.h",
|
||||
"dlpack.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dlpack",
|
||||
srcs = ["dlpack.cc"],
|
||||
hdrs = ["dlpack.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@dlpack",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
||||
# right now, remove this public rule when no longer needed (it should be
|
||||
# replaced by TF Lite)
|
||||
@ -340,6 +370,7 @@ filegroup(
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
),
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
@ -95,14 +94,6 @@ using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = op->operation.OpDef();
|
||||
if (op_def) return op_def;
|
||||
status->status =
|
||||
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
|
||||
return op_def;
|
||||
}
|
||||
|
||||
bool IsCPU(
|
||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
|
||||
if (VariantDeviceIsCustom(variant)) {
|
||||
@ -883,12 +874,12 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
status->status = ctx->context->ClearRemoteExecutors();
|
||||
status->status = ctx->context->SyncExecutors();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -1125,9 +1116,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
return retval;
|
||||
} else {
|
||||
tensorflow::Tensor tensor;
|
||||
if (IsCPU(handle_->device())) {
|
||||
if (IsCPU(handle_->device()) || handle_->HasLocalMirror(nullptr)) {
|
||||
const tensorflow::Tensor* src = nullptr;
|
||||
*status = handle_->Tensor(&src);
|
||||
if (handle_->HasLocalMirror(nullptr)) {
|
||||
*status = handle_->TensorFromDevice(nullptr, &src);
|
||||
} else {
|
||||
*status = handle_->Tensor(&src);
|
||||
}
|
||||
if (!status->ok()) return nullptr;
|
||||
tensor = *src;
|
||||
} else {
|
||||
@ -1135,6 +1130,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
CHECK_NE(ctx, nullptr);
|
||||
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->ok()) return nullptr;
|
||||
if (handle_->ImplicitMirroring()) {
|
||||
*status = handle_->AddEmptyLocalMirror(nullptr);
|
||||
if (!status->ok()) return nullptr;
|
||||
Tensor mirror = tensor;
|
||||
*status = handle_->SetTensor(std::move(mirror), nullptr);
|
||||
if (!status->ok()) return nullptr;
|
||||
}
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||
}
|
||||
@ -1199,18 +1201,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
|
||||
!tensorflow::DataTypeCanUseMemcpy(
|
||||
static_cast<tensorflow::DataType>(dtype))) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to create a tensor with a pointer to non-pod memory.");
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
}
|
||||
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
||||
// the device?
|
||||
TF_ManagedBuffer* buf =
|
||||
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
||||
/*owns_memory=*/false);
|
||||
|
||||
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
@ -1218,10 +1213,10 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
tensorflow::TensorHandle* ret_handle;
|
||||
if (custom_device == nullptr) {
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, device, context, &ret_handle);
|
||||
std::move(t), device, device, context, &ret_handle);
|
||||
} else {
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, custom_device, context, &ret_handle);
|
||||
std::move(t), custom_device, context, &ret_handle);
|
||||
}
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
@ -1261,9 +1256,8 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op> new_op(
|
||||
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
|
||||
status->status =
|
||||
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
|
||||
new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)});
|
||||
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
new_op.reset();
|
||||
}
|
||||
@ -1273,49 +1267,51 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
status->status = op->operation.SetDeviceName(device_name);
|
||||
status->status = op->operation->SetDeviceName(device_name);
|
||||
}
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
tensorflow::Device* device = (op->operation.Device() == nullptr)
|
||||
? op->operation.EagerContext().HostCPU()
|
||||
: op->operation.Device();
|
||||
return device->name().c_str();
|
||||
return op->operation->DeviceName().c_str();
|
||||
}
|
||||
|
||||
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
op->operation.SetUseXla(enable);
|
||||
#ifndef TENSORFLOW_EAGER_USE_XLA
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Status s = op->operation->SetUseXla(enable);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
|
||||
}
|
||||
#else
|
||||
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
|
||||
"built with XLA support.";
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
tensorflow::TensorHandle* h =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
input->handle.get())
|
||||
->Handle();
|
||||
op->operation.AddInput(h);
|
||||
status->status = op->operation.MaybeInferSingleInputAttrs(h);
|
||||
status->status = op->operation->AddInput(input->handle);
|
||||
}
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||
num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
op->operation.AddInput(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
inputs[i]->handle.get())
|
||||
->Handle());
|
||||
handles[i].reset(inputs[i]->handle->Copy());
|
||||
}
|
||||
status->status = op->operation.InferInputListAttrs(num_inputs);
|
||||
status->status = op->operation->AddInputList(handles);
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
unsigned char* is_list, TF_Status* status) {
|
||||
TF_AttrType ret = TF_ATTR_INT;
|
||||
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
|
||||
attr_name, &ret, is_list);
|
||||
const tensorflow::AttrTypeMap* attr_types_;
|
||||
bool is_function;
|
||||
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
|
||||
&attr_types_, &is_function);
|
||||
if (!status->status.ok()) {
|
||||
return ret;
|
||||
}
|
||||
status->status =
|
||||
tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -1336,221 +1332,169 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
|
||||
|
||||
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
||||
size_t length) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name,
|
||||
tensorflow::StringPiece(static_cast<const char*>(value), length));
|
||||
auto s = op->operation->SetAttrString(
|
||||
attr_name, static_cast<const char*>(value), length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
|
||||
auto s = op->operation->SetAttrInt(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name, value);
|
||||
auto s = op->operation->SetAttrFloat(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
|
||||
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name,
|
||||
static_cast<tensorflow::DataType>(value));
|
||||
auto s = op->operation->SetAttrType(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
||||
const int num_dims, TF_Status* out_status) {
|
||||
if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"Value specified for `", attr_name, "` has ", num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
tensorflow::TensorShape::MaxDimensions(), ".")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(attr_name, proto);
|
||||
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op* value) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
tensorflow::NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(value->operation.Name());
|
||||
value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
|
||||
op->operation.MutableAttrs()->Set(attr_name, attr_value);
|
||||
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
||||
const char* data, size_t length) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
tensorflow::NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(data, length);
|
||||
op->operation.MutableAttrs()->Set(attr_name, attr_value);
|
||||
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
||||
TF_Status* status) {
|
||||
tensorflow::Tensor t;
|
||||
status->status = TF_TensorToTensor(tensor, &t);
|
||||
if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
|
||||
status->status = op->operation->SetAttrTensor(attr_name, tensor);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
const void* const* values, const size_t* lengths,
|
||||
int num_values) {
|
||||
std::vector<tensorflow::StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
|
||||
lengths[i]);
|
||||
auto s =
|
||||
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(attr_name, v);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
|
||||
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name,
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
||||
auto s = op->operation->SetAttrTypeList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
const unsigned char* values, int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, TF_Status* out_status) {
|
||||
std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
|
||||
new tensorflow::TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"Value specified for `", attr_name, "` has ", num_dims_i,
|
||||
" dimensions which is over the limit of ",
|
||||
tensorflow::TensorShape::MaxDimensions(), ".")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
|
||||
proto.get(), num_values));
|
||||
out_status->status =
|
||||
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op** value, int num_values) {
|
||||
std::unique_ptr<tensorflow::NameAttrList[]> funcs(
|
||||
new tensorflow::NameAttrList[num_values]);
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
funcs[i].set_name(value[i]->operation.Name());
|
||||
value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
|
||||
auto s = op->operation->SetAttrFunctionList(attr_name, value, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
|
||||
funcs.get(), num_values));
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
|
||||
const void* proto, size_t proto_len,
|
||||
TF_Status* status) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!attr_value.ParseFromArray(proto, proto_len)) {
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
|
||||
return;
|
||||
}
|
||||
if (op == nullptr || op->operation == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Got a null or uninitialized `op` argument");
|
||||
return;
|
||||
}
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
operation->MutableAttrs()->Set(attr_name, attr_value);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
||||
const char* input_name,
|
||||
TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = GetOpDef(op, status);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
tensorflow::AttrValueMap attrs;
|
||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
||||
tensorflow::NameRangeMap name_ranges;
|
||||
status->status = tensorflow::NameRangesForNode(
|
||||
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
auto iter = name_ranges.find(input_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
|
||||
"' not found");
|
||||
return -1;
|
||||
}
|
||||
return iter->second.second - iter->second.first;
|
||||
int ret = -1;
|
||||
status->status = op->operation->InputLength(input_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
const char* output_name,
|
||||
TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = GetOpDef(op, status);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
tensorflow::AttrValueMap attrs;
|
||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
||||
tensorflow::NameRangeMap name_ranges;
|
||||
status->status = tensorflow::NameRangesForNode(
|
||||
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
auto iter = name_ranges.find(output_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Output '", output_name, "' not found");
|
||||
return -1;
|
||||
}
|
||||
return iter->second.second - iter->second.first;
|
||||
int ret = -1;
|
||||
status->status = op->operation->OutputLength(output_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||
status->status = tensorflow::EagerExecute(&op->operation,
|
||||
handle_retvals.data(), num_retvals);
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||
*num_retvals);
|
||||
status->status = op->operation->Execute(&handles, num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
|
||||
retvals[i] = new TFE_TensorHandle{std::move(handles[i])};
|
||||
}
|
||||
}
|
||||
|
||||
@ -1678,6 +1622,31 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
|
||||
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
||||
|
||||
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
*attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
attrs->attributes->FillAttrValueMap(&m);
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
for (auto attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::NameAttrList name_and_attrs;
|
||||
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||
name_and_attrs.set_name(attrs->name);
|
||||
status->status = MessageToBuffer(name_and_attrs, buf);
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
@ -1741,8 +1710,9 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
namespace {
|
||||
class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
public:
|
||||
CustomDeviceAPI(TFE_CustomDevice device, void* info, string name)
|
||||
: device_(device), info_(info), name_(name) {}
|
||||
CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
|
||||
string name)
|
||||
: context_(context), device_(device), info_(info), name_(name) {}
|
||||
|
||||
~CustomDeviceAPI() override { device_.delete_device(info_); }
|
||||
|
||||
@ -1756,7 +1726,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
|
||||
TF_Status status;
|
||||
TFE_TensorHandle* result_handle =
|
||||
device_.copy_tensor_to_device(&tensor_handle, &status, info_);
|
||||
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
@ -1775,7 +1745,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
TFE_TensorHandle tensor_handle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
&tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
@ -1797,10 +1767,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
op->Inputs()[i])});
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
// TODO(allenl): figure out how to get attrs from EagerOperation
|
||||
TF_Status status;
|
||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
num_retvals, outputs.data(), &status, info_);
|
||||
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
|
||||
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
@ -1818,6 +1788,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
}
|
||||
|
||||
private:
|
||||
TFE_Context* context_;
|
||||
TFE_CustomDevice device_;
|
||||
void* info_;
|
||||
string name_;
|
||||
@ -1825,8 +1796,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
} // namespace
|
||||
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info) {
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status) {
|
||||
auto custom_device =
|
||||
std::make_unique<CustomDeviceAPI>(device, device_info, device_name);
|
||||
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
|
||||
status->status =
|
||||
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
@ -25,34 +25,20 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
||||
#include "tensorflow/core/profiler/rpc/profiler_server.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
status->status = op_to_reset->operation.Reset(
|
||||
op_or_function_name, raw_device_name, false, nullptr);
|
||||
status->status =
|
||||
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
op->operation.ConsumeInput(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle());
|
||||
}
|
||||
|
||||
void TFE_StartProfilerServer(int port) {
|
||||
// Release child thread intentionally. The child thread can be terminated by
|
||||
// terminating the main thread.
|
||||
tensorflow::StartProfilerServer(port).release();
|
||||
}
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
@ -61,46 +47,6 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
||||
const char* logdir, const char* worker_list,
|
||||
bool include_dataset_ops, int duration_ms,
|
||||
int num_tracing_attempts,
|
||||
TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return false;
|
||||
}
|
||||
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
|
||||
include_dataset_ops, duration_ms,
|
||||
num_tracing_attempts);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return s.ok();
|
||||
}
|
||||
|
||||
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
|
||||
int monitoring_level, bool display_timestamp,
|
||||
TF_Buffer* result, TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return;
|
||||
}
|
||||
string content;
|
||||
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
|
||||
display_timestamp, &content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
||||
result->data = data;
|
||||
result->length = content.length();
|
||||
result->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
|
||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||
int64_t value) {
|
||||
cell->cell.IncrementBy(value);
|
||||
@ -568,8 +514,7 @@ void TFE_DeleteCancellationManager(
|
||||
void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||
TFE_CancellationManager* cancellation_manager,
|
||||
TF_Status* status) {
|
||||
op->operation.SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
status->status = op->operation->SetCancellationManager(cancellation_manager);
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_NewExecutor(bool is_async) {
|
||||
@ -617,3 +562,22 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
|
||||
h->handle->EnableImplicitMirroring();
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
auto* function_def = ctx->context->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"Unable to find FunctionDef with name: ", function_name);
|
||||
return;
|
||||
}
|
||||
string str = function_def->SerializeAsString();
|
||||
void* data = tensorflow::port::Malloc(str.length());
|
||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = str.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
@ -34,18 +34,6 @@ TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
|
||||
const char* raw_device_name,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
// Start a profiler grpc server which listens to specified port. It will start
|
||||
// the server on its own thread. It can be shutdown by terminating tensorflow.
|
||||
// It can be used in both Eager mode and graph mode. Creating multiple profiler
|
||||
// server is allowed. The service defined in
|
||||
// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
|
||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
|
||||
// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
|
||||
|
||||
// Enables only graph collection in RunMetadata on the functions executed from
|
||||
// this context.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
|
||||
@ -54,29 +42,6 @@ TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
|
||||
// this context.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
|
||||
|
||||
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
||||
// profiling and save the result into logdir which can be visualized by
|
||||
// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set
|
||||
// include_dataset_opts to false to profile longer traces. It will block the
|
||||
// caller thread until receives tracing result.
|
||||
// This API is designed for TensorBoard, for end user, please use
|
||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||
TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
|
||||
const char* service_addr, const char* logdir, const char* worker_list,
|
||||
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
|
||||
TF_Status* status);
|
||||
|
||||
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
||||
// monitoring and return the result in a string. It will block the
|
||||
// caller thread until receiving the monitoring result.
|
||||
// This API is designed for TensorBoard, for end user, please use
|
||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
|
||||
const char* service_addr, int duration_ms, int monitoring_level,
|
||||
bool display_timestamp, TF_Buffer* result, TF_Status* status);
|
||||
|
||||
// TODO(fishx): Move these monitoring APIs into a separate file.
|
||||
// -----------------------------------------------------------------------------
|
||||
// Monitoring Counter APIs.
|
||||
@ -417,9 +382,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
const char* worker_name,
|
||||
TF_Status* status);
|
||||
|
||||
// Clear pending streaming requests and error statuses on remote executors.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
// Sync pending nodes in local executors (including the context default executor
|
||||
// and thread executors) and streaming requests to remote executors, and get the
|
||||
// combined status.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// If the TensorHandle is copied to another device as part of an op execution,
|
||||
// the copy is destroyed after the op has executed. Enabling implicit mirroring
|
||||
@ -456,26 +423,63 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
TF_Buffer* buf);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 0
|
||||
// APIs for generically dealing with op attributes (e.g. when forwarding them
|
||||
// through custom device implementations).
|
||||
//
|
||||
// TODO(allenl): Currently these are black boxes, but we should have some way to
|
||||
// inspect values. This would let people e.g. copy over most attributes and then
|
||||
// modify some based on their values.
|
||||
|
||||
// A reference to an op's name -> attribute mapping
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
// Fetch a struct with a reference to information about attributes of `op`.
|
||||
//
|
||||
// The `attrs` struct does not own any memory, and `op` must outlive it.
|
||||
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||
|
||||
// Add attributes in `attrs` to `op`.
|
||||
//
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
|
||||
|
||||
// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
|
||||
// containing the op name and a map of its attributes.
|
||||
TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
// Set an op's attribute from a serialized AttrValue protocol buffer.
|
||||
//
|
||||
// Analogous to TF_SetAttrValueProto for building graph operations.
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
|
||||
const char* attr_name,
|
||||
const void* proto,
|
||||
size_t proto_len,
|
||||
TF_Status* status);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 2
|
||||
|
||||
// Struct to be filled in
|
||||
typedef struct TFE_CustomDevice {
|
||||
int version = TFE_CUSTOM_DEVICE_VERSION;
|
||||
// Method to copy a tensor to the custom device.
|
||||
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor,
|
||||
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status,
|
||||
void* device_info) = nullptr;
|
||||
|
||||
// Method to copy a tensor from the custom device to a target device.
|
||||
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor,
|
||||
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info);
|
||||
|
||||
// Method to execute an operation.
|
||||
// TODO(allenl) figure out a generic way of passing attrs here
|
||||
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, int* num_outputs,
|
||||
void (*execute)(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
||||
|
||||
// Method to delete a device.
|
||||
@ -501,11 +505,26 @@ typedef struct TFE_CustomDevice {
|
||||
// devices, so executing tf.functions which contain operations placed on custom
|
||||
// devices will fail.
|
||||
//
|
||||
// `device_name` must not name an existing physical or custom device. It must
|
||||
// follow the format:
|
||||
//
|
||||
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
|
||||
//
|
||||
// If the device is successfully registered, `status` is set to TF_OK. Otherwise
|
||||
// the device is not usable. In case of a bad status, `device.delete_device` is
|
||||
// still called on `device_info` (i.e. the caller does not retain ownership).
|
||||
//
|
||||
// This API is highly experimental, and in particular is expected to change when
|
||||
// it starts supporting operations with attributes and when tf.function support
|
||||
// is added.
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info);
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||
const char* function_name,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
|
@ -27,12 +27,12 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -89,7 +89,7 @@ struct TFE_TensorDebugInfo {
|
||||
};
|
||||
|
||||
struct TFE_Op {
|
||||
tensorflow::EagerOperation operation;
|
||||
std::unique_ptr<AbstractOperationInterface> operation;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
@ -236,4 +236,17 @@ struct TFE_Executor {
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||
// that sometimes do not require serialization.
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
|
||||
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
|
||||
const char* op_name)
|
||||
: name(op_name), attributes(value) {}
|
||||
|
||||
const char* name;
|
||||
const tensorflow::AttrBuilder* attributes;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
@ -127,7 +128,7 @@ void TestRemoteExecute(bool async) {
|
||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async) {
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -166,10 +167,14 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
auto* h1_task2 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
}
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
@ -177,6 +182,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!async) {
|
||||
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h1_task2->handle.get())
|
||||
->Handle();
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->GetInput(1), remote_arg);
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -213,9 +229,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||
TestRemoteExecuteSilentCopies(false, true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true);
|
||||
TestRemoteExecuteSilentCopies(true, true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||
TestRemoteExecuteSilentCopies(false, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, false);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
@ -367,7 +369,7 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
|
||||
void TensorHandleSilentCopy(bool async,
|
||||
TFE_ContextDevicePlacementPolicy global_policy,
|
||||
TFE_ContextDevicePlacementPolicy thread_policy,
|
||||
bool mirror, bool cpu_op) {
|
||||
bool cpu_op) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -390,12 +392,6 @@ void TensorHandleSilentCopy(bool async,
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
if (mirror) {
|
||||
TFE_TensorHandleEnableImplicitMirroring(hcpu, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandleEnableImplicitMirroring(hgpu, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
if (cpu_op) {
|
||||
@ -419,21 +415,13 @@ void TensorHandleSilentCopy(bool async,
|
||||
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
hgpu->handle.get())
|
||||
->Handle();
|
||||
if (mirror) {
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(matmul->operation.Inputs()[0], arg0);
|
||||
ASSERT_EQ(matmul->operation.Inputs()[1], arg1);
|
||||
} else {
|
||||
if (cpu_op) {
|
||||
ASSERT_EQ(matmul->operation.Inputs()[0], arg0);
|
||||
// The GPU handle should be replaced with a CPU copy
|
||||
ASSERT_NE(matmul->operation.Inputs()[1], arg1);
|
||||
} else {
|
||||
// The CPU handle should be replaced with a GPU copy
|
||||
ASSERT_NE(matmul->operation.Inputs()[0], arg0);
|
||||
ASSERT_EQ(matmul->operation.Inputs()[1], arg1);
|
||||
}
|
||||
}
|
||||
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
|
||||
// The input handles should never change since they have been mirrored.
|
||||
EXPECT_EQ(op->GetInput(0), arg0);
|
||||
EXPECT_EQ(op->GetInput(1), arg1);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
@ -450,27 +438,19 @@ void TensorHandleSilentCopy(bool async,
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyAsync) {
|
||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
|
||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleMirrorCopy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, true, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleMirrorCopyCpu) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, true, true);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
}
|
||||
|
||||
void SetAndGetOpDevices(bool async) {
|
||||
@ -606,6 +586,91 @@ TEST(CAPI, TensorHandleDevices) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
void ExecuteAdd(bool async, bool forward_input) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
|
||||
// If a GPU exists, copy the handle to GPU so that we can exercise
|
||||
// unprotecting a mirror.
|
||||
std::string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_TensorHandle* n_gpu =
|
||||
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
|
||||
TFE_DeleteTensorHandle(n);
|
||||
n = n_gpu;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
|
||||
|
||||
// Store pointer to raw buffer for validation of forwarding behaviour.
|
||||
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
|
||||
void* orig_ptr = TF_TensorData(orig);
|
||||
TF_DeleteTensor(orig);
|
||||
|
||||
TFE_Op* add_op = AddOp(ctx, n, m);
|
||||
std::string cpu_device_name;
|
||||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
if (forward_input) {
|
||||
TFE_DeleteTensorHandle(n);
|
||||
}
|
||||
|
||||
int num_retvals = 1;
|
||||
|
||||
if (async) {
|
||||
// Enqueue dummy ops so we backlog async execution & actually test async.
|
||||
for (int i = 0; i < 10000; ++i) {
|
||||
TFE_TensorHandle* dummy = nullptr;
|
||||
TFE_Execute(add_op, &dummy, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(dummy);
|
||||
}
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retval = nullptr;
|
||||
TFE_Execute(add_op, &retval, &num_retvals, status);
|
||||
EXPECT_EQ(1, num_retvals);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
if (!forward_input) {
|
||||
TFE_DeleteTensorHandle(n);
|
||||
}
|
||||
TFE_DeleteOp(add_op);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
|
||||
if (forward_input || async) {
|
||||
EXPECT_EQ(orig_ptr, TF_TensorData(t));
|
||||
} else {
|
||||
EXPECT_NE(orig_ptr, TF_TensorData(t));
|
||||
}
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
float result[100 * 100] = {0};
|
||||
EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
|
||||
memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
for (int i = 0; i < 100 * 100; ++i) {
|
||||
EXPECT_EQ(2.0f, result[i]);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
|
||||
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
|
||||
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
|
||||
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
|
||||
|
||||
void Execute_MatMul_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -1244,6 +1309,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||
TFE_DeleteTensorHandle(h_shares_tensor);
|
||||
}
|
||||
|
||||
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
|
||||
->Attrs()
|
||||
.FillAttrValueMap(&attr_values);
|
||||
return attr_values;
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -1260,8 +1333,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
||||
TFE_OpAddInput(minOp, axis, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
minOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
|
||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||
EXPECT_NE(attr_found, attr_values.cend());
|
||||
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
||||
@ -1300,8 +1372,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
|
||||
TFE_OpAddInputList(concatOp, inputs, 2, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
|
||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||
EXPECT_NE(attr_found, attr_values.cend());
|
||||
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
||||
@ -1341,8 +1412,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
|
||||
TFE_OpAddInputList(assertOp, data, 3, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
assertOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
|
||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||
EXPECT_NE(attr_found, attr_values.cend());
|
||||
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
|
||||
@ -1378,16 +1448,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInput(concatOp, dim, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK(concatOp->operation.OpDef());
|
||||
CHECK(concatOp->operation->OpDef());
|
||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FALSE(concatOp->operation.OpDef())
|
||||
EXPECT_FALSE(concatOp->operation->OpDef())
|
||||
<< "Inference context is still present";
|
||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
|
||||
EXPECT_EQ(attr_values.find("T"), attr_values.end());
|
||||
EXPECT_EQ(attr_values.find("N"), attr_values.end());
|
||||
|
||||
@ -1474,4 +1543,88 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
|
||||
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddAttrs(copy_op, &attributes);
|
||||
unsigned char is_list = 0;
|
||||
ASSERT_EQ(TF_ATTR_TYPE,
|
||||
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_ATTR_SHAPE,
|
||||
TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
copy_op->operation.get());
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(var_op);
|
||||
TFE_DeleteOp(copy_op);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
|
||||
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
tensorflow::NameAttrList name_and_attrs;
|
||||
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
|
||||
serialized_attr_values->length));
|
||||
ASSERT_EQ("VarHandleOp", name_and_attrs.name());
|
||||
ASSERT_EQ(tensorflow::DT_INT64,
|
||||
name_and_attrs.attr().find("dtype")->second.type());
|
||||
TF_DeleteBuffer(serialized_attr_values);
|
||||
|
||||
TFE_Op* second_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
|
||||
string serialized_dtype;
|
||||
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
|
||||
&serialized_dtype));
|
||||
TFE_OpSetAttrValueProto(
|
||||
second_var_op, "dtype",
|
||||
reinterpret_cast<const void*>(serialized_dtype.c_str()),
|
||||
serialized_dtype.length(), status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
second_var_op->operation.get());
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(var_op);
|
||||
TFE_DeleteOp(second_var_op);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, a, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, b, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
|
||||
|
||||
return op;
|
||||
}
|
||||
|
||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
|
@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
|
||||
// Return a tensor handle containing a 3x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
||||
|
||||
// Return an add op multiplying `a` by `b`.
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
// Return a matmul op multiplying `a` by `b`.
|
||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
|
@ -21,16 +21,18 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
TFE_Context* ctx;
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
@ -45,7 +47,7 @@ void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* ctx, const tensorflow::string& logging_device_name,
|
||||
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
@ -55,23 +57,25 @@ TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor,
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, dev->ctx, dev->underlying_device.c_str(), status);
|
||||
tensor, context, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst),
|
||||
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
@ -80,13 +84,15 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, int* num_outputs,
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
@ -112,9 +118,10 @@ void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
|
||||
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
@ -122,18 +129,19 @@ void DeleteLoggingDevice(void* device_info) {
|
||||
}
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag) {
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->ctx = context;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device);
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
@ -144,13 +152,16 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived);
|
||||
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
ASSERT_FALSE(arrived);
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_FALSE(executed);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
||||
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
|
||||
@ -160,6 +171,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
@ -167,4 +179,220 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
TFE_DeleteContext(context);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, ResetOperation) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* custom_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
|
||||
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
|
||||
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
|
||||
tensorflow::string(custom_device_name));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpReset(reused_op.get(), "Identity",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
|
||||
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto value_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_value]() { TFE_DeleteTensorHandle(var_value); });
|
||||
ASSERT_EQ(tensorflow::string(name),
|
||||
tensorflow::string(
|
||||
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||
TFE_TensorHandle* var_value_unpacked =
|
||||
reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||
->tensor;
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||
TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
|
||||
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
|
||||
<< "Execution should fail because the variable is being used on the "
|
||||
"wrong device.";
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
RegisterLoggingDevice(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
334
tensorflow/c/eager/dlpack.cc
Normal file
334
tensorflow/c/eager/dlpack.cc
Normal file
@ -0,0 +1,334 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/dlpack.h"
|
||||
|
||||
#include "include/dlpack/dlpack.h" // TF:dlpack
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Managing context for the DLManagedTensor, will manage the lifetime of
|
||||
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
|
||||
// original framework of destruction, and this context will be deleted also.
|
||||
struct TfDlManagedTensorCtx {
|
||||
TensorReference reference;
|
||||
std::vector<int64_t> shape;
|
||||
std::vector<int64_t> strides;
|
||||
DLManagedTensor tensor;
|
||||
|
||||
explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
|
||||
};
|
||||
|
||||
// Gets tensor from eager tensor handle.
|
||||
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"DLPack doesn't support remote tensor");
|
||||
return nullptr;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = handle->Tensor(&tensor);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Deleter for DLManagedTensor
|
||||
void DLManagedTensorDeleter(DLManagedTensor* arg) {
|
||||
TfDlManagedTensorCtx* owner =
|
||||
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
|
||||
owner->reference.Unref();
|
||||
delete owner;
|
||||
}
|
||||
|
||||
// Converts TF_DATAType to DLPack data type.
|
||||
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
|
||||
DLDataType dtype;
|
||||
dtype.lanes = 1;
|
||||
dtype.bits = TF_DataTypeSize(data_type) * 8;
|
||||
switch (data_type) {
|
||||
case TF_DataType::TF_HALF:
|
||||
case TF_DataType::TF_FLOAT:
|
||||
case TF_DataType::TF_DOUBLE:
|
||||
dtype.code = DLDataTypeCode::kDLFloat;
|
||||
break;
|
||||
case TF_DataType::TF_INT8:
|
||||
case TF_DataType::TF_INT16:
|
||||
case TF_DataType::TF_INT32:
|
||||
case TF_DataType::TF_INT64:
|
||||
dtype.code = DLDataTypeCode::kDLInt;
|
||||
break;
|
||||
case TF_DataType::TF_BOOL:
|
||||
case TF_DataType::TF_UINT8:
|
||||
case TF_DataType::TF_UINT16:
|
||||
case TF_DataType::TF_UINT32:
|
||||
case TF_DataType::TF_UINT64:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case TF_DataType::TF_BFLOAT16:
|
||||
dtype.code = DLDataTypeCode::kDLBfloat;
|
||||
break;
|
||||
default:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
DataType_Name(static_cast<DataType>(data_type)),
|
||||
" is not supported by dlpack");
|
||||
break;
|
||||
}
|
||||
return dtype;
|
||||
}
|
||||
|
||||
// Gets DLPack's DLContext from eager tensor handle.
|
||||
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||
DLContext ctx;
|
||||
const char* device_name = h->handle->DeviceName(&status->status);
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||
std::string device_type = parsed_name.type;
|
||||
int device_id = 0;
|
||||
if (parsed_name.has_id) {
|
||||
device_id = parsed_name.id;
|
||||
}
|
||||
|
||||
ctx.device_id = device_id;
|
||||
if (device_type == "CPU") {
|
||||
ctx.device_type = DLDeviceType::kDLCPU;
|
||||
} else if (device_type == "GPU") {
|
||||
ctx.device_type = DLDeviceType::kDLGPU;
|
||||
} else {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unsupported Device Type for dlpack");
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// Converts DLContext to TF device name.
|
||||
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
|
||||
TF_Status* status) {
|
||||
switch (ctx.device_type) {
|
||||
case DLDeviceType::kDLCPU:
|
||||
return "CPU:0";
|
||||
case DLDeviceType::kDLGPU:
|
||||
return absl::StrCat("GPU:", ctx.device_id);
|
||||
default:
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// Converts DLPack data type to TF_DATATYPE.
|
||||
Status TfDataTypeFormDlDataType(const DLDataType& dtype,
|
||||
TF_DataType* tf_dtype) {
|
||||
switch (dtype.code) {
|
||||
case DLDataTypeCode::kDLUInt:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
*tf_dtype = TF_DataType::TF_UINT8;
|
||||
return Status::OK();
|
||||
case 16:
|
||||
*tf_dtype = TF_DataType::TF_UINT16;
|
||||
return Status::OK();
|
||||
case 32:
|
||||
*tf_dtype = TF_DataType::TF_UINT32;
|
||||
return Status::OK();
|
||||
case 64:
|
||||
*tf_dtype = TF_DataType::TF_UINT64;
|
||||
return Status::OK();
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
|
||||
dtype.bits);
|
||||
}
|
||||
return Status::OK();
|
||||
case DLDataTypeCode::kDLInt:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
*tf_dtype = TF_DataType::TF_INT8;
|
||||
return Status::OK();
|
||||
case 16:
|
||||
*tf_dtype = TF_DataType::TF_INT16;
|
||||
return Status::OK();
|
||||
case 32:
|
||||
*tf_dtype = TF_DataType::TF_INT32;
|
||||
return Status::OK();
|
||||
case 64:
|
||||
*tf_dtype = TF_DataType::TF_INT64;
|
||||
return Status::OK();
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
|
||||
dtype.bits);
|
||||
}
|
||||
return Status::OK();
|
||||
case DLDataTypeCode::kDLFloat:
|
||||
switch (dtype.bits) {
|
||||
case 16:
|
||||
*tf_dtype = TF_DataType::TF_HALF;
|
||||
return Status::OK();
|
||||
case 32:
|
||||
*tf_dtype = TF_DataType::TF_FLOAT;
|
||||
return Status::OK();
|
||||
case 64:
|
||||
*tf_dtype = TF_DataType::TF_DOUBLE;
|
||||
return Status::OK();
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
|
||||
dtype.bits);
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLBfloat:
|
||||
switch (dtype.bits) {
|
||||
case 16:
|
||||
*tf_dtype = TF_DataType::TF_BFLOAT16;
|
||||
return Status::OK();
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unsupported BFloat bits: ", dtype.bits);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
|
||||
dtype.code);
|
||||
}
|
||||
}
|
||||
|
||||
// Wraps the deleter function of DLManagedTensor to match the function signature
|
||||
// TFE_NewTensorHandleFromDeviceMemory.
|
||||
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
|
||||
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
||||
}
|
||||
|
||||
// Checks whether the stride array matches the layout of compact, row-majored
|
||||
// data.
|
||||
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
|
||||
int ndim) {
|
||||
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
|
||||
return false;
|
||||
}
|
||||
for (int i = ndim - 2; i >= 0; --i) {
|
||||
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
||||
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
|
||||
if (dlMTensor->deleter != nullptr) {
|
||||
dlMTensor->deleter(dlMTensor);
|
||||
}
|
||||
}
|
||||
|
||||
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||
|
||||
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
|
||||
tf_dlm_tensor_ctx->reference = tensor_ref;
|
||||
|
||||
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
||||
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
||||
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
||||
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
|
||||
int ndim = tensor->dims();
|
||||
dlm_tensor->dl_tensor.ndim = ndim;
|
||||
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
|
||||
|
||||
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||
shape_arr->resize(ndim);
|
||||
stride_arr->resize(ndim, 1);
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
(*shape_arr)[i] = tensor->dim_size(i);
|
||||
}
|
||||
for (int i = ndim - 2; i >= 0; --i) {
|
||||
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||
}
|
||||
|
||||
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
||||
// There are two ways to represent compact row-major data
|
||||
// 1) nullptr indicates tensor is compact and row-majored.
|
||||
// 2) fill in the strides array as the real case for compact row-major data.
|
||||
// Here we choose option 2, since some frameworks didn't handle the strides
|
||||
// argument properly.
|
||||
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
||||
dlm_tensor->dl_tensor.byte_offset =
|
||||
0; // TF doesn't handle the strides and byte_offsets here
|
||||
return static_cast<void*>(dlm_tensor);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
||||
DLTensor* dl_tensor = &dlmt->dl_tensor;
|
||||
absl::optional<std::string> device_name =
|
||||
DeviceNameFromDlContext(dl_tensor->ctx, status);
|
||||
if (!device_name.has_value()) {
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument("Unsupported Device Type");
|
||||
return nullptr;
|
||||
}
|
||||
TF_DataType dtype;
|
||||
Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
|
||||
if (!s.ok()) {
|
||||
status->status = std::move(s);
|
||||
return nullptr;
|
||||
}
|
||||
int num_dims = dl_tensor->ndim;
|
||||
const int64_t* dims = dl_tensor->shape;
|
||||
void* data = dl_tensor->data;
|
||||
|
||||
size_t total_bytes = dl_tensor->dtype.bits / 8;
|
||||
for (int i = 0; i < num_dims; i++) {
|
||||
total_bytes *= dims[i];
|
||||
}
|
||||
|
||||
if (dl_tensor->strides != nullptr &&
|
||||
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
|
||||
num_dims)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Invalid strides array from DLPack");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
39
tensorflow/c/eager/dlpack.h
Normal file
39
tensorflow/c/eager/dlpack.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_DLPACK_H_
|
||||
#define TENSORFLOW_C_EAGER_DLPACK_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// PyCapsule name for DLPack Tensor
|
||||
const char* const kDlTensorCapsuleName = "dltensor";
|
||||
|
||||
// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
|
||||
// void* for further PyCapsule construction.
|
||||
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
||||
TF_Status* status);
|
||||
|
||||
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
|
||||
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_DLPACK_H_
|
312
tensorflow/c/eager/operation_interface.cc
Normal file
312
tensorflow/c/eager/operation_interface.cc
Normal file
@ -0,0 +1,312 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
OperationInterface::OperationInterface(TFE_Context* ctx)
|
||||
: operation_(ctx->context) {}
|
||||
|
||||
const string& OperationInterface::DeviceName() const {
|
||||
absl::variant<Device*, CustomDevice*> variant_device =
|
||||
(operation_.Device() == kVariantDeviceNull)
|
||||
? operation_.EagerContext().HostCPU()
|
||||
: operation_.Device();
|
||||
return absl::visit([](auto* d) -> const string& { return d->name(); },
|
||||
variant_device);
|
||||
}
|
||||
|
||||
Status OperationInterface::SetDeviceName(const char* name) {
|
||||
return operation_.SetDeviceName(name);
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrString(const char* attr_name,
|
||||
const char* data, size_t length) {
|
||||
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrType(const char* attr_name,
|
||||
TF_DataType value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims,
|
||||
const int num_dims) {
|
||||
if (num_dims > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||
num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), ".");
|
||||
}
|
||||
|
||||
TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
|
||||
operation_.MutableAttrs()->Set(attr_name, proto);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) {
|
||||
AttrValue attr_value;
|
||||
NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(value->Name());
|
||||
OperationInterface* value_operation =
|
||||
tensorflow::down_cast<OperationInterface*>(value.get());
|
||||
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
|
||||
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
|
||||
const char* data,
|
||||
size_t length) {
|
||||
AttrValue attr_value;
|
||||
NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(data, length);
|
||||
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrTensor(const char* attr_name,
|
||||
TF_Tensor* tensor) {
|
||||
Tensor t;
|
||||
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
|
||||
operation_.MutableAttrs()->Set(attr_name, t);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) {
|
||||
std::vector<StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
operation_.MutableAttrs()->Set(attr_name, v);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFloatList(const char* attr_name,
|
||||
const float* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const float>(values, num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const DataType>(
|
||||
reinterpret_cast<const DataType*>(values), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims,
|
||||
int num_values) {
|
||||
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||
num_dims_i, " dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), "."));
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value,
|
||||
int num_values) {
|
||||
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
auto value_operation =
|
||||
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
|
||||
funcs[i].set_name(value_operation->operation_.Name());
|
||||
value_operation->operation_.Attrs().FillAttrValueMap(
|
||||
funcs[i].mutable_attr());
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const OpDef* OperationInterface::GetOpDef(Status* status) {
|
||||
const tensorflow::OpDef* op_def = operation_.OpDef();
|
||||
if (op_def) return op_def;
|
||||
*status = OpDefForOp(Name(), &op_def);
|
||||
return op_def;
|
||||
}
|
||||
|
||||
Status OperationInterface::InputLength(const char* input_name, int* length) {
|
||||
Status status;
|
||||
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
AttrValueMap attrs;
|
||||
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||
NameRangeMap name_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
|
||||
auto iter = name_ranges.find(input_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
return errors::InvalidArgument("Input '", input_name, "' not found");
|
||||
}
|
||||
*length = iter->second.second - iter->second.first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::OutputLength(const char* output_name, int* length) {
|
||||
Status status;
|
||||
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
AttrValueMap attrs;
|
||||
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||
NameRangeMap name_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
|
||||
auto iter = name_ranges.find(output_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
return errors::InvalidArgument("Output '", output_name, "' not found");
|
||||
}
|
||||
*length = iter->second.second - iter->second.first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
|
||||
TensorHandle* h =
|
||||
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||
operation_.AddInput(h);
|
||||
return operation_.MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
|
||||
Status OperationInterface::AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) {
|
||||
for (auto& input : inputs) {
|
||||
TensorHandle* h =
|
||||
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||
operation_.AddInput(h);
|
||||
}
|
||||
return operation_.InferInputListAttrs(inputs.size());
|
||||
}
|
||||
|
||||
Status OperationInterface::Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) {
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
TF_RETURN_IF_ERROR(
|
||||
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals->at(i).reset(
|
||||
new tensorflow::TensorHandleInterface(handle_retvals[i]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
operation_.SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetUseXla(bool enable) {
|
||||
operation_.SetUseXla(enable);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
188
tensorflow/c/eager/operation_interface.h
Normal file
188
tensorflow/c/eager/operation_interface.h
Normal file
@ -0,0 +1,188 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class AbstractOperationInterface {
|
||||
public:
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
|
||||
virtual void Clear() = 0;
|
||||
virtual tensorflow::Status Reset(const char* op,
|
||||
const char* raw_device_name) = 0;
|
||||
|
||||
virtual const tensorflow::string& Name() const = 0;
|
||||
virtual const tensorflow::string& DeviceName() const = 0;
|
||||
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual tensorflow::Status AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
|
||||
virtual tensorflow::Status AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) = 0;
|
||||
virtual tensorflow::Status Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) = 0;
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual tensorflow::Status SetAttrString(const char* attr_name,
|
||||
const char* data, size_t length) = 0;
|
||||
virtual tensorflow::Status SetAttrInt(const char* attr_name,
|
||||
int64_t value) = 0;
|
||||
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
|
||||
float value) = 0;
|
||||
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
|
||||
virtual tensorflow::Status SetAttrType(const char* attr_name,
|
||||
TF_DataType value) = 0;
|
||||
virtual tensorflow::Status SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual tensorflow::Status SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
|
||||
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
|
||||
const char* value,
|
||||
size_t length) = 0;
|
||||
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
|
||||
TF_Tensor* tensor) = 0;
|
||||
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
|
||||
const float* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value,
|
||||
int num_values) = 0;
|
||||
|
||||
virtual tensorflow::Status InputLength(const char* input_name,
|
||||
int* length) = 0;
|
||||
virtual tensorflow::Status OutputLength(const char* output_name,
|
||||
int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual tensorflow::Status SetUseXla(bool enable) {
|
||||
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
|
||||
}
|
||||
virtual tensorflow::Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetCancellationManager not implemented");
|
||||
}
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpDef;
|
||||
|
||||
class OperationInterface : public AbstractOperationInterface {
|
||||
public:
|
||||
explicit OperationInterface(TFE_Context* ctx);
|
||||
~OperationInterface() override{};
|
||||
|
||||
void Clear() override { operation_.Clear(); }
|
||||
Status Reset(const char* op, const char* raw_device_name) override {
|
||||
return operation_.Reset(op, raw_device_name, false, nullptr);
|
||||
}
|
||||
|
||||
const string& Name() const override { return operation_.Name(); }
|
||||
const string& DeviceName() const override;
|
||||
Status SetDeviceName(const char* name) override;
|
||||
|
||||
Status AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
|
||||
Status AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) override;
|
||||
Status Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) override;
|
||||
const tensorflow::OpDef* OpDef() const override {
|
||||
return operation_.OpDef();
|
||||
};
|
||||
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||
Status SetAttrType(const char* attr_name, TF_DataType value) override;
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override;
|
||||
Status SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override;
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override;
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override;
|
||||
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
|
||||
int num_values) override;
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override;
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override;
|
||||
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
|
||||
int num_values) override;
|
||||
|
||||
Status InputLength(const char* input_name, int* length) override;
|
||||
Status OutputLength(const char* output_name, int* length) override;
|
||||
|
||||
Status SetUseXla(bool enable) override;
|
||||
Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) override;
|
||||
|
||||
// TODO(gjn): Remove once TFE_InferShapes is removed
|
||||
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
|
||||
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
|
||||
|
||||
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
|
||||
|
||||
private:
|
||||
const tensorflow::OpDef* GetOpDef(Status* status);
|
||||
EagerOperation operation_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
@ -27,14 +27,10 @@ namespace {
|
||||
|
||||
class DummyDevice : public DeviceBase {
|
||||
public:
|
||||
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
|
||||
bool RequiresRecordingAccessedTensors() const override { return save_; }
|
||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
||||
return cpu_allocator();
|
||||
}
|
||||
|
||||
private:
|
||||
bool save_;
|
||||
};
|
||||
|
||||
void TestBitcastOp(Tensor* input_tensor, DataType out_type,
|
||||
@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type,
|
||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||
|
||||
OpKernelContext::Params params;
|
||||
DummyDevice dummy_device(nullptr, false);
|
||||
DummyDevice dummy_device(nullptr);
|
||||
params.device = &dummy_device;
|
||||
params.op_kernel = kernel.get();
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
|
@ -155,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
|
||||
|
||||
class DummyDevice : public DeviceBase {
|
||||
public:
|
||||
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
|
||||
bool RequiresRecordingAccessedTensors() const override { return save_; }
|
||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
||||
return cpu_allocator();
|
||||
}
|
||||
|
||||
private:
|
||||
bool save_;
|
||||
};
|
||||
|
||||
TEST(TestKernel, TestInputAndOutputCount) {
|
||||
@ -223,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
|
||||
|
||||
{
|
||||
OpKernelContext::Params p;
|
||||
DummyDevice dummy_device(nullptr, false);
|
||||
DummyDevice dummy_device(nullptr);
|
||||
p.device = &dummy_device;
|
||||
p.step_id = 43;
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
@ -64,25 +65,41 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
|
||||
const int64_t* dims, int num_dims, size_t len) {
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||
tensorflow::TensorInterface ret(
|
||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf));
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
|
||||
int num_dims, size_t len) {
|
||||
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
|
||||
tensorflow::cpu_allocator());
|
||||
return TF_NewTensor(dtype, dims, num_dims, data, len,
|
||||
tensorflow::deallocate_buffer,
|
||||
tensorflow::cpu_allocator());
|
||||
TF_ManagedBuffer* buf =
|
||||
new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer,
|
||||
tensorflow::cpu_allocator(), /*owns_memory=*/true);
|
||||
return CreateTensor(buf, dtype, dims, num_dims, len);
|
||||
}
|
||||
|
||||
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg) {
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
TF_ManagedBuffer* buf = nullptr;
|
||||
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
|
||||
tensorflow::DataTypeCanUseMemcpy(
|
||||
@ -97,24 +114,17 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
// Other types have the same representation, so copy only if it is safe to
|
||||
// do so.
|
||||
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
|
||||
len, tensorflow::deallocate_buffer, nullptr);
|
||||
len, tensorflow::deallocate_buffer, nullptr,
|
||||
/*owns_memory=*/true);
|
||||
std::memcpy(buf->data(), data, len);
|
||||
// Free the original buffer.
|
||||
deallocator(data, len, deallocator_arg);
|
||||
} else {
|
||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
||||
/*owns_memory=*/false);
|
||||
}
|
||||
|
||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||
tensorflow::TensorInterface ret(
|
||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf));
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||
return CreateTensor(buf, dtype, dims, num_dims, len);
|
||||
}
|
||||
|
||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
||||
|
@ -58,9 +58,9 @@ extern "C" {
|
||||
// start_offset: array[uint64]
|
||||
// data: byte[...]
|
||||
//
|
||||
// The string length (as a varint), followed by the contents of the string
|
||||
// is encoded at data[start_offset[i]]]. TF_StringEncode and TF_StringDecode
|
||||
// facilitate this encoding.
|
||||
// The string length (as a varint, start_offset[i + 1] - start_offset[i]),
|
||||
// followed by the contents of the string is encoded at data[start_offset[i]].
|
||||
// TF_StringEncode and TF_StringDecode facilitate this encoding.
|
||||
|
||||
typedef struct TF_Tensor TF_Tensor;
|
||||
|
||||
|
@ -38,11 +38,12 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
public:
|
||||
TF_ManagedBuffer(void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg)
|
||||
void* deallocator_arg, bool owns_memory)
|
||||
: TensorBuffer(data),
|
||||
len_(len),
|
||||
deallocator_(deallocator),
|
||||
deallocator_arg_(deallocator_arg) {}
|
||||
deallocator_arg_(deallocator_arg),
|
||||
owns_memory_(owns_memory) {}
|
||||
|
||||
~TF_ManagedBuffer() override {
|
||||
(*deallocator_)(data(), len_, deallocator_arg_);
|
||||
@ -57,13 +58,13 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
|
||||
}
|
||||
|
||||
// Prevents input forwarding from mutating this buffer.
|
||||
bool OwnsMemory() const override { return false; }
|
||||
bool OwnsMemory() const override { return owns_memory_; }
|
||||
|
||||
private:
|
||||
const size_t len_;
|
||||
void (*const deallocator_)(void* data, size_t len, void* arg);
|
||||
void* const deallocator_arg_;
|
||||
bool owns_memory_;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -41,7 +41,7 @@ class ClientSession::Impl {
|
||||
std::shared_ptr<Graph> graph_;
|
||||
|
||||
mutable mutex mu_;
|
||||
mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
|
||||
mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
ClientSession::ClientSession(const Scope& scope, const string& target)
|
||||
|
@ -68,6 +68,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
],
|
||||
)
|
||||
|
||||
@ -224,3 +225,15 @@ filegroup(
|
||||
"testdata/VarsAndArithmeticObjectGraph/**",
|
||||
]),
|
||||
)
|
||||
|
||||
exports_files(
|
||||
glob([
|
||||
"testdata/half_plus_two_pbtxt/**",
|
||||
"testdata/half_plus_two_main_op/**",
|
||||
"testdata/half_plus_two/**",
|
||||
"testdata/half_plus_two_v2/**",
|
||||
"testdata/x_plus_y_v2_debuginfo/**",
|
||||
"testdata/CyclicModule/**",
|
||||
"testdata/VarsAndArithmeticObjectGraph/**",
|
||||
]),
|
||||
)
|
||||
|
@ -21,15 +21,22 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestDataPbTxt[] =
|
||||
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
|
||||
constexpr char kTestDataSharded[] =
|
||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
||||
string TestDataPbTxt() {
|
||||
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two_pbtxt", "00000123");
|
||||
}
|
||||
|
||||
string TestDataSharded() {
|
||||
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two", "00000123");
|
||||
}
|
||||
|
||||
class ReaderTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
|
||||
TEST_F(ReaderTest, TagMatch) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
|
||||
TEST_F(ReaderTest, NoTagMatch) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
|
||||
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
Status st = ReadMetaGraphDefFromSavedModel(
|
||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||
TEST_F(ReaderTest, PbtxtFormat) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
|
||||
TEST_F(ReaderTest, InvalidExportPath) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
||||
const string export_dir = GetDataDependencyFilepath("missing-path");
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
|
@ -114,14 +114,14 @@ class Coordinator {
|
||||
condition_variable wait_for_stop_;
|
||||
|
||||
mutex mu_;
|
||||
bool should_stop_ GUARDED_BY(mu_);
|
||||
bool should_stop_ TF_GUARDED_BY(mu_);
|
||||
|
||||
mutex status_lock_;
|
||||
Status status_ GUARDED_BY(status_lock_);
|
||||
Status status_ TF_GUARDED_BY(status_lock_);
|
||||
|
||||
mutable mutex runners_lock_;
|
||||
std::vector<std::unique_ptr<RunnerInterface>> runners_
|
||||
GUARDED_BY(runners_lock_);
|
||||
TF_GUARDED_BY(runners_lock_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
|
||||
};
|
||||
|
@ -119,8 +119,8 @@ class QueueRunner : public RunnerInterface {
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
mutex mu_;
|
||||
int runs_ = 0;
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
Status enqueue_status_ GUARDED_BY(mu_);
|
||||
Status status_ TF_GUARDED_BY(mu_);
|
||||
Status enqueue_status_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<BlockingCounter> counter_;
|
||||
|
||||
Coordinator* coord_;
|
||||
@ -131,7 +131,7 @@ class QueueRunner : public RunnerInterface {
|
||||
std::vector<std::function<void(Status)>> callbacks_;
|
||||
|
||||
mutable std::unique_ptr<mutex> cg_mu_;
|
||||
std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
|
||||
std::unique_ptr<CostGraphDef> cost_graph_ TF_GUARDED_BY(cg_mu_);
|
||||
RunOptions run_options_;
|
||||
};
|
||||
|
||||
|
@ -20,9 +20,11 @@ from __future__ import print_function as _print_function
|
||||
|
||||
import logging as _logging
|
||||
import os as _os
|
||||
import six as _six
|
||||
import sys as _sys
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
@ -36,20 +38,19 @@ try:
|
||||
from tensorboard.summary._tf import summary
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||
# Make sure we get the correct summary module with lazy loading
|
||||
setattr(_current_module, "summary", summary)
|
||||
except ImportError:
|
||||
_logging.warning(
|
||||
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
||||
"installation.")
|
||||
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v2 import keras
|
||||
@ -59,6 +60,13 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if not _six.PY2:
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
|
||||
#
|
||||
|
@ -20,8 +20,10 @@ from __future__ import print_function as _print_function
|
||||
|
||||
import os as _os
|
||||
import sys as _sys
|
||||
import six as _six
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
@ -31,13 +33,14 @@ from tensorflow.python.tools import module_util as _module_util
|
||||
|
||||
# Hook external TensorFlow modules.
|
||||
_current_module = _sys.modules[__name__]
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v1 import keras
|
||||
@ -47,6 +50,14 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if not _six.PY2:
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||
|
@ -37,6 +37,7 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
@ -64,6 +65,7 @@ cc_library(
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
"//tensorflow/core:regexp_internal",
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
@ -84,6 +86,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support", # fixdeps: keep
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
@ -288,8 +289,8 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
||||
}
|
||||
|
||||
// Generates code implementing {Arg,Result}Names(), where T is one of
|
||||
// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
|
||||
// literal in the array, with nullptr terminating the array.
|
||||
// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
|
||||
// string literal in the array, with nullptr terminating the array.
|
||||
template <typename T>
|
||||
string GenNameToIndexCode(const T& entries, bool generate) {
|
||||
// No need for a static array if we're not supposed to generate the data.
|
||||
@ -419,6 +420,16 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
// Generate metadata.
|
||||
const string arg_names_code =
|
||||
GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
|
||||
|
||||
auto variable_copy = config.variable();
|
||||
for (auto& var : variable_copy) {
|
||||
if (var.name().empty()) {
|
||||
var.set_name(var.node_name());
|
||||
}
|
||||
}
|
||||
const string variable_names_code =
|
||||
GenNameToIndexCode(variable_copy, opts.gen_name_to_index);
|
||||
|
||||
const string result_names_code =
|
||||
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
||||
const string include_xla_data_proto =
|
||||
@ -507,6 +518,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
// Number of input arguments for the compiled computation.
|
||||
static constexpr size_t kNumArgs = {{ARG_NUM}};
|
||||
|
||||
// Number of variables for the compiled computation.
|
||||
static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
|
||||
|
||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
||||
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
|
||||
@ -522,8 +536,10 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
set_static_data_num_buffers(data, kNumBuffers);
|
||||
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
||||
set_static_data_num_args(data, kNumArgs);
|
||||
set_static_data_num_variables(data, kNumVariables);
|
||||
set_static_data_result_index(data, kResultIndex);
|
||||
set_static_data_arg_names(data, StaticArgNames());
|
||||
set_static_data_variable_names(data, StaticVariableNames());
|
||||
set_static_data_result_names(data, StaticResultNames());
|
||||
set_static_data_program_shape(data, StaticProgramShape());
|
||||
set_static_data_hlo_profile_printer_data(
|
||||
@ -626,6 +642,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
// Array of names of each positional argument, terminated by nullptr.
|
||||
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
|
||||
|
||||
// Array of names of each positional variable, terminated by nullptr.
|
||||
static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}}
|
||||
|
||||
// Array of names of each positional result, terminated by nullptr.
|
||||
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
||||
|
||||
@ -654,6 +673,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
|
||||
{"{{ARG_NAMES_CODE}}", arg_names_code},
|
||||
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
|
||||
{"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
|
||||
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
|
||||
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
|
||||
{"{{CLASS}}", opts.class_name},
|
||||
@ -673,6 +693,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
|
||||
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
|
||||
metadata_result.program_shape_access_shim},
|
||||
{"{{VARIABLE_NAMES_CODE}}", variable_names_code},
|
||||
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
||||
{"{{RESULT_NAMES_CODE}}", result_names_code},
|
||||
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -29,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -139,14 +141,23 @@ TEST_F(ParseCppClassTest, ParseFail) {
|
||||
|
||||
static void CompareWithGoldenFile(
|
||||
const string& tensorflow_relative_golden_file_name,
|
||||
const string& expected_contents) {
|
||||
const string& expected_contents, bool ignore_cr) {
|
||||
// Get rid of all CR characters, we may be running under windows.
|
||||
string sanitized_expected_contents(expected_contents);
|
||||
if (ignore_cr) {
|
||||
sanitized_expected_contents.erase(
|
||||
std::remove(sanitized_expected_contents.begin(),
|
||||
sanitized_expected_contents.end(), '\r'),
|
||||
sanitized_expected_contents.end());
|
||||
}
|
||||
|
||||
// To update the golden file, flip update_golden to true and run the
|
||||
// following:
|
||||
// bazel test --test_strategy=local \
|
||||
// third_party/tensorflow/compiler/aot:codegen_test
|
||||
const bool update_golden = false;
|
||||
const string golden_file_name = io::JoinPath(
|
||||
testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name);
|
||||
string golden_file_name =
|
||||
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
||||
|
||||
if (update_golden) {
|
||||
TF_EXPECT_OK(
|
||||
@ -156,6 +167,11 @@ static void CompareWithGoldenFile(
|
||||
string golden_file_contents;
|
||||
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
||||
&golden_file_contents));
|
||||
if (ignore_cr) {
|
||||
golden_file_contents.erase(std::remove(golden_file_contents.begin(),
|
||||
golden_file_contents.end(), '\r'),
|
||||
golden_file_contents.end());
|
||||
}
|
||||
EXPECT_EQ(golden_file_contents, expected_contents);
|
||||
}
|
||||
|
||||
@ -201,10 +217,16 @@ TEST(CodegenTest, Golden) {
|
||||
{},
|
||||
{BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
|
||||
BufferInfo::MakeTempBuffer(2),
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
||||
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
|
||||
5, {}));
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
|
||||
BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
|
||||
11, {}));
|
||||
compile_result.program_shape =
|
||||
xla::ShapeUtil::MakeProgramShape(
|
||||
{
|
||||
@ -229,14 +251,18 @@ TEST(CodegenTest, Golden) {
|
||||
// The other fields in metadata_result are tested as part of the generated
|
||||
// header test.
|
||||
|
||||
CompareWithGoldenFile("compiler/aot/codegen_test_o.golden",
|
||||
metadata_result.object_file_data);
|
||||
// This specific golden test checks a binary file. It can potentially run into
|
||||
// issues due to ABIs not being stable, but has not so far.
|
||||
// If we see any ABI issues, we should reconsider this specific test case.
|
||||
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
|
||||
metadata_result.object_file_data, false);
|
||||
|
||||
string header;
|
||||
TF_ASSERT_OK(
|
||||
GenerateHeader(opts, config, compile_result, metadata_result, &header));
|
||||
|
||||
CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header);
|
||||
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
|
||||
true);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
|
@ -55,14 +55,17 @@ namespace bar {
|
||||
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
|
||||
//
|
||||
// Memory stats:
|
||||
// arg bytes total: 104
|
||||
// arg bytes aligned: 192
|
||||
// arg bytes total: 392
|
||||
// arg bytes aligned: 576
|
||||
// temp bytes total: 126
|
||||
// temp bytes aligned: 320
|
||||
// temp bytes aligned: 512
|
||||
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
public:
|
||||
// Number of input arguments for the compiled computation.
|
||||
static constexpr size_t kNumArgs = 2;
|
||||
static constexpr size_t kNumArgs = 5;
|
||||
|
||||
// Number of variables for the compiled computation.
|
||||
static constexpr size_t kNumVariables = 3;
|
||||
|
||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
||||
@ -79,8 +82,10 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
set_static_data_num_buffers(data, kNumBuffers);
|
||||
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
||||
set_static_data_num_args(data, kNumArgs);
|
||||
set_static_data_num_variables(data, kNumVariables);
|
||||
set_static_data_result_index(data, kResultIndex);
|
||||
set_static_data_arg_names(data, StaticArgNames());
|
||||
set_static_data_variable_names(data, StaticVariableNames());
|
||||
set_static_data_result_names(data, StaticResultNames());
|
||||
set_static_data_program_shape(data, StaticProgramShape());
|
||||
set_static_data_hlo_profile_printer_data(
|
||||
@ -295,16 +300,22 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
|
||||
private:
|
||||
// Number of buffers for the compiled computation.
|
||||
static constexpr size_t kNumBuffers = 6;
|
||||
static constexpr size_t kNumBuffers = 12;
|
||||
|
||||
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
|
||||
static const ::xla::cpu_function_runtime::BufferInfo
|
||||
kBufferInfos[kNumBuffers] = {
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
|
||||
};
|
||||
return kBufferInfos;
|
||||
@ -312,13 +323,13 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
|
||||
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
|
||||
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
|
||||
1, 3
|
||||
1, 3, 5, 7, 9
|
||||
};
|
||||
return kArgIndexToBufferIndex;
|
||||
}
|
||||
|
||||
// The 0-based index of the result tuple in the temporary buffers.
|
||||
static constexpr size_t kResultIndex = 5;
|
||||
static constexpr size_t kResultIndex = 11;
|
||||
|
||||
// Array of names of each positional argument, terminated by nullptr.
|
||||
static const char** StaticArgNames() {
|
||||
@ -326,6 +337,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return kNames;
|
||||
}
|
||||
|
||||
// Array of names of each positional variable, terminated by nullptr.
|
||||
static const char** StaticVariableNames() {
|
||||
static const char* kNames[] = {"myvar_readonly", "myvar", "myvar2", nullptr};
|
||||
return kNames;
|
||||
}
|
||||
|
||||
// Array of names of each positional result, terminated by nullptr.
|
||||
static const char** StaticResultNames() {
|
||||
static const char* kNames[] = {"myfetch", nullptr};
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
@ -39,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -105,14 +107,18 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
.ValueOrDie();
|
||||
xla::XlaComputation computation;
|
||||
if (flags.mlir_components == "Bridge") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
|
||||
graph_def, config, &computation, flags.debug_info,
|
||||
flags.debug_info_path_begin_marker));
|
||||
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||
client, &computation));
|
||||
} else {
|
||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||
}
|
||||
if (flags.quantize) {
|
||||
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
|
||||
}
|
||||
if (!flags.out_session_module.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||
computation.Snapshot());
|
||||
@ -166,6 +172,23 @@ static void InitializeTargets() {
|
||||
LLVMInitializeX86AsmPrinter();
|
||||
}
|
||||
|
||||
// Replaces {{tag.type tag.name}} in the error message with tag_name.
|
||||
// TODO(bixia): We currently only handlge tag.type == "node".
|
||||
//
|
||||
// In the error message, a graph node is represented as {{tag.type, tag.name}},
|
||||
// to allow a Python debugger to insert source information about the graph node.
|
||||
// For example, a Python add expression may be represented as
|
||||
// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
|
||||
// in tensorflow/python/framework/error_interpolation.py for more detail.
|
||||
static std::string InterpolateErrorMessage(std::string message) {
|
||||
// See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
|
||||
// Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
|
||||
static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
|
||||
RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
absl::call_once(targets_init, &InitializeTargets);
|
||||
|
||||
@ -192,8 +215,13 @@ Status Main(const MainFlags& flags) {
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(std::move(graph_def), config, flags, &compile_result));
|
||||
|
||||
Status status =
|
||||
CompileGraph(std::move(graph_def), config, flags, &compile_result);
|
||||
if (!status.ok()) {
|
||||
return Status(status.code(),
|
||||
InterpolateErrorMessage(status.error_message()));
|
||||
}
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
|
@ -24,6 +24,13 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
||||
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
|
||||
"be in the human-readable proto text format, otherwise it is expected "
|
||||
"to be in the proto binary format."},
|
||||
{"debug_info", &flags->debug_info,
|
||||
"Graph debug info file. If the file ends in '.pbtxt' it is expected to "
|
||||
"be in the human-readable proto text format, otherwise it is expected "
|
||||
"to be in the proto binary format."},
|
||||
{"debug_info_path_begin_marker", &flags->debug_info_path_begin_marker,
|
||||
"If not none, only keep the file path in the debug information after the"
|
||||
" marker. The default value is empty"},
|
||||
{"config", &flags->config,
|
||||
"Input file containing Config proto. If the file ends in '.pbtxt' it "
|
||||
"is expected to be in the human-readable proto text format, otherwise "
|
||||
@ -70,6 +77,8 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
||||
"Output session module proto."},
|
||||
{"mlir_components", &flags->mlir_components,
|
||||
"The MLIR components to enable. Currently only Bridge is supported."},
|
||||
{"quantize", &flags->quantize,
|
||||
"If set, quantization will be applied before HLO code generation."},
|
||||
{"gen_name_to_index", &flags->gen_name_to_index,
|
||||
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
|
||||
{"gen_program_shape", &flags->gen_program_shape,
|
||||
|
@ -28,6 +28,8 @@ namespace tfcompile {
|
||||
|
||||
struct MainFlags {
|
||||
string graph;
|
||||
string debug_info;
|
||||
string debug_info_path_begin_marker;
|
||||
string config;
|
||||
bool dump_fetch_nodes = false;
|
||||
string target_triple;
|
||||
@ -40,6 +42,7 @@ struct MainFlags {
|
||||
string out_header;
|
||||
string out_session_module;
|
||||
string mlir_components;
|
||||
bool quantize = false;
|
||||
|
||||
// C++ codegen options
|
||||
bool gen_name_to_index = false;
|
||||
|
@ -1,11 +1,37 @@
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":filecheck_test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"test_error_message.lit.pbtxt": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
|
||||
},
|
||||
test_file_exts = ["lit.pbtxt"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "filecheck_test_utilities",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"test_error_message.lit.pbtxt.config.pbtxt",
|
||||
"test_error_message.lit.pbtxt.debug.pbtxt",
|
||||
"test_error_message.lit.pbtxt.fake_py.debug",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/compiler/aot:tfcompile",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
# We disable some tfcompile tests in the open source build with the
|
||||
# "manual" tag to avoid making our OSS users build LLVM twice
|
||||
# (once for host and once for target).
|
||||
@ -60,6 +86,7 @@ genrule(
|
||||
testonly = 1,
|
||||
outs = [
|
||||
"test_graph_tfadd.pb",
|
||||
"test_debuginfo_tfadd.pb",
|
||||
"test_graph_tfadd_with_ckpt.ckpt",
|
||||
"test_graph_tfadd_with_ckpt.pb",
|
||||
"test_graph_tfadd_with_ckpt_saver.ckpt",
|
||||
@ -317,6 +344,7 @@ tf_library(
|
||||
testonly = 1,
|
||||
config = "test_graph_tfadd.config.pbtxt",
|
||||
cpp_class = "AddComp",
|
||||
debug_info = "test_debuginfo_tfadd.pb",
|
||||
graph = "test_graph_tfadd.pb",
|
||||
include_standard_runtime_deps = False,
|
||||
mlir_components = "Bridge",
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import error_interpolation
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -184,7 +185,22 @@ def tfvariable_sequential_updates(_):
|
||||
array_ops.identity(updates, name='result')
|
||||
|
||||
|
||||
def write_graph(build_graph, out_dir):
|
||||
def export_debug_info(exported_graph):
|
||||
"""Exports debug information from a graph.
|
||||
|
||||
Args:
|
||||
exported_graph: A Graph that has been created by tracing a saveable view.
|
||||
|
||||
Returns:
|
||||
Corresponding GraphDebugInfo with traces for all ops in exported_graph.
|
||||
"""
|
||||
exported_operations = []
|
||||
for op in exported_graph.get_operations():
|
||||
exported_operations.append(('', op))
|
||||
return error_interpolation.create_graph_debug_info_def(exported_operations)
|
||||
|
||||
|
||||
def write_graph(build_graph, out_dir, debug_info=False):
|
||||
"""Build a graph using build_graph and write it out."""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
@ -193,10 +209,19 @@ def write_graph(build_graph, out_dir):
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
|
||||
|
||||
if debug_info:
|
||||
filename_debuginfo = os.path.join(
|
||||
out_dir, 'test_debuginfo_%s.pb' % build_graph.__name__)
|
||||
test_debuginfo = export_debug_info(g)
|
||||
with open(filename_debuginfo, 'wb') as f:
|
||||
f.write(
|
||||
six.ensure_binary(
|
||||
test_debuginfo.SerializeToString(deterministic=True)))
|
||||
|
||||
|
||||
def main(_):
|
||||
control_flow_util.enable_control_flow_v2()
|
||||
write_graph(tfadd, FLAGS.out_dir)
|
||||
write_graph(tfadd, FLAGS.out_dir, debug_info=True)
|
||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||
write_graph(tfassert_eq, FLAGS.out_dir)
|
||||
|
69
tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt
Normal file
69
tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt
Normal file
@ -0,0 +1,69 @@
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
|
||||
|
||||
# Checks the error message produced by tfcompile with mlir_component
|
||||
# Checks that source debug information is used in the output error message and
|
||||
# the node x_y_sum = Add
|
||||
# CHECK: INVALID ARGUMENTS: Dimensions must be equal, but are 2 and 3 for 'x_y_sum = Add[T=DT_INT32](aot_feed_0/x, aot_feed_0/y)'
|
||||
# CHECK: math_ops.add(x, y, name='x_y_sum')
|
||||
# CHECK: build_graph(out_dir)
|
||||
|
||||
# Checks the error message produced by tfcompile without mlir_component
|
||||
# OLD: INVALID ARGUMENTS: Incompatible shapes: [2] vs. [3]
|
||||
# OLD: x_y_sum
|
||||
|
||||
node: {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr: {
|
||||
key: "shape"
|
||||
value: {
|
||||
shape: {
|
||||
dim: {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "y"
|
||||
op: "Placeholder"
|
||||
attr: {
|
||||
key: "shape"
|
||||
value: {
|
||||
shape: {
|
||||
dim: {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "x_y_sum"
|
||||
op: "Add"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr: {
|
||||
key: "T"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions: {
|
||||
producer: 321
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y" }
|
||||
shape {
|
||||
dim { size: 3 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_sum" }
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
files: "org_tensorflow/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug"
|
||||
traces: {
|
||||
key: "x@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
traces: {
|
||||
key: "x_y_sum@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 3
|
||||
}
|
||||
file_line_cols: {
|
||||
line: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
traces: {
|
||||
key: "y@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 2
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
x = value
|
||||
y = value
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
build_graph(out_dir)
|
@ -26,6 +26,7 @@ def tf_library(
|
||||
name,
|
||||
graph,
|
||||
config,
|
||||
debug_info = None,
|
||||
freeze_checkpoint = None,
|
||||
freeze_saver = None,
|
||||
cpp_class = None,
|
||||
@ -191,12 +192,15 @@ def tf_library(
|
||||
|
||||
mlir_flag = "--mlir_components=" + mlir_components
|
||||
|
||||
srcs = [tfcompile_graph, config]
|
||||
debug_info_flag = ""
|
||||
if debug_info:
|
||||
srcs.append(debug_info)
|
||||
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
|
||||
|
||||
native.genrule(
|
||||
name = ("gen_" + name),
|
||||
srcs = [
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
srcs = srcs,
|
||||
outs = [
|
||||
header_file,
|
||||
metadata_object_file,
|
||||
@ -206,6 +210,7 @@ def tf_library(
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
debug_info_flag +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
@ -237,10 +242,7 @@ def tf_library(
|
||||
session_module_pb = name + "_session_module.pb"
|
||||
native.genrule(
|
||||
name = (name + "_session_module"),
|
||||
srcs = [
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
srcs = srcs,
|
||||
outs = [
|
||||
session_module_pb,
|
||||
],
|
||||
@ -248,6 +250,7 @@ def tf_library(
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
debug_info_flag +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
@ -407,5 +410,6 @@ def target_llvm_triple():
|
||||
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
|
||||
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
||||
"//tensorflow:macos": "x86_64-none-darwin",
|
||||
"//tensorflow:windows": "x86_64-none-windows",
|
||||
"//conditions:default": "x86_64-pc-linux",
|
||||
})
|
||||
|
@ -65,6 +65,7 @@ int main(int argc, char** argv) {
|
||||
flags.out_metadata_object = "out_helper.o";
|
||||
flags.out_header = "out.h";
|
||||
flags.entry_point = "entry";
|
||||
flags.debug_info_path_begin_marker = "";
|
||||
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
@ -81,12 +82,10 @@ int main(int argc, char** argv) {
|
||||
|
||||
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
||||
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
|
||||
"other than flags\n\n"
|
||||
<< usage;
|
||||
"other than flags. See --help.\n\n";
|
||||
tensorflow::Status status = tensorflow::tfcompile::Main(flags);
|
||||
if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
|
||||
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
|
||||
<< usage;
|
||||
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n";
|
||||
return 1;
|
||||
} else {
|
||||
TF_QCHECK_OK(status);
|
||||
|
@ -14,6 +14,10 @@ package_group(
|
||||
includes = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
packages = [
|
||||
"//tensorflow/compiler/tests/...",
|
||||
"//tensorflow/python/...",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(
|
||||
@ -180,6 +184,7 @@ XLA_DEVICE_DEPS = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:functional_ops_op_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
@ -108,7 +108,7 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
|
||||
"(LRN, LRNGrad)."
|
||||
" BN: TF FusedBatchNorm* operations."
|
||||
" FUSIBLE: All TF operations that XLA can fuse (All the above). "
|
||||
"You can also put any TF operation name, e.g. 'FUSIBLE,Matmul'."),
|
||||
"You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
|
||||
Flag("tf_xla_clustering_debug",
|
||||
&mark_for_compilation_flags->tf_xla_clustering_debug,
|
||||
"Dump graphs during XLA compilation."),
|
||||
|
@ -20,6 +20,7 @@ XLA_OPS_DEPS = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -41,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
@ -170,8 +172,9 @@ class XlaExecutableClosureStore {
|
||||
|
||||
private:
|
||||
mutex mutex_;
|
||||
int64 key_counter_ GUARDED_BY(mutex_);
|
||||
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
|
||||
int64 key_counter_ TF_GUARDED_BY(mutex_);
|
||||
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
|
||||
TF_GUARDED_BY(mutex_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
|
||||
};
|
||||
@ -206,12 +209,14 @@ se::DeviceMemoryAllocator* GetAllocator(
|
||||
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
const std::vector<int>& constants,
|
||||
const std::vector<int>& resources,
|
||||
const NameAttrList& function)
|
||||
const NameAttrList& function,
|
||||
bool has_ref_vars)
|
||||
: OpKernel(ctx),
|
||||
constants_(constants),
|
||||
resources_(resources),
|
||||
function_(function),
|
||||
platform_info_(PlatformInfoFromContext(ctx)) {}
|
||||
platform_info_(PlatformInfoFromContext(ctx)),
|
||||
has_ref_vars_(has_ref_vars) {}
|
||||
|
||||
static Status BuildCompilationCache(OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
@ -350,8 +355,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
|
||||
{
|
||||
Status s = CompileToLocalExecutable(
|
||||
ctx, function_, /*has_ref_vars=*/true, platform_info_, resources_,
|
||||
constants_, /*lazy=*/false, &client, &variables, &kernel, &executable);
|
||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
|
||||
&executable);
|
||||
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
|
||||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
|
||||
// Suggest auto jit if the failure was with GPU or CPU.
|
||||
@ -384,6 +390,18 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
xla::ThenExecuteFunction then_execute;
|
||||
if (ctx->op_device_context()) {
|
||||
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
|
||||
Status status = ctx->op_device_context()->ThenExecute(
|
||||
down_cast<Device*>(ctx->device()), stream, std::move(fn));
|
||||
if (!status.ok()) {
|
||||
// This should never happen.
|
||||
LOG(ERROR) << "ThenExecute failed " << status;
|
||||
}
|
||||
};
|
||||
run_options.set_then_execute_function(&then_execute);
|
||||
}
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
@ -462,7 +480,7 @@ bool HasRefVars(OpKernelConstruction* ctx) {
|
||||
|
||||
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
|
||||
: XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
|
||||
FunctionAttr(ctx)) {}
|
||||
FunctionAttr(ctx), /*has_ref_vars=*/true) {}
|
||||
|
||||
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
|
||||
VLOG(1) << "XlaLocalLaunchOp destroyed";
|
||||
@ -592,6 +610,18 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
xla::ThenExecuteFunction then_execute;
|
||||
if (ctx->op_device_context()) {
|
||||
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
|
||||
Status status = ctx->op_device_context()->ThenExecute(
|
||||
down_cast<Device*>(ctx->device()), stream, std::move(fn));
|
||||
if (!status.ok()) {
|
||||
// This should never happen.
|
||||
LOG(ERROR) << "ThenExecute failed " << status;
|
||||
}
|
||||
};
|
||||
run_options.set_then_execute_function(&then_execute);
|
||||
}
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
|
@ -95,12 +95,15 @@ class XlaPlatformInfo {
|
||||
// in the GraphDef.
|
||||
// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
|
||||
// this kernel when asked to create a kernel for an XLA-compiled function.
|
||||
//
|
||||
// `has_ref_vars`: whether the input computation can have reference variables.
|
||||
// TODO(cheshire): instead derive this information from the input graph.
|
||||
class XlaLocalLaunchBase : public OpKernel {
|
||||
public:
|
||||
XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
const std::vector<int>& constants,
|
||||
const std::vector<int>& resources,
|
||||
const NameAttrList& function);
|
||||
const NameAttrList& function, bool has_ref_vars);
|
||||
XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
|
||||
XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
|
||||
~XlaLocalLaunchBase() override = default;
|
||||
@ -115,6 +118,8 @@ class XlaLocalLaunchBase : public OpKernel {
|
||||
|
||||
const NameAttrList function_;
|
||||
const XlaPlatformInfo platform_info_;
|
||||
|
||||
bool has_ref_vars_;
|
||||
};
|
||||
|
||||
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
|
||||
@ -160,7 +165,8 @@ class XlaCompileOp : public OpKernel {
|
||||
// error when compiling the cluster this _XlaCompile is supposed to compile.
|
||||
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
|
||||
// on any future calls to _XlaCompile.
|
||||
bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false;
|
||||
bool cannot_compile_cluster_ TF_GUARDED_BY(cannot_compile_cluster_mu_) =
|
||||
false;
|
||||
|
||||
mutex cannot_compile_cluster_mu_;
|
||||
};
|
||||
|
@ -963,6 +963,22 @@ absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Returns true iff the attribute `attr_name` is attached to either the node or
|
||||
// to it's callee.
|
||||
static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
|
||||
const char* attr_name) {
|
||||
bool out = false;
|
||||
bool attr_value;
|
||||
if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
|
||||
out |= attr_value;
|
||||
}
|
||||
|
||||
if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
|
||||
out |= attr_value;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
|
||||
return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
|
||||
@ -1016,16 +1032,9 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
resource_var_operation_node_id = node->id();
|
||||
}
|
||||
|
||||
bool is_xla_compile_attr_true = false;
|
||||
|
||||
bool xla_compile_attr;
|
||||
if (TryGetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr)) {
|
||||
is_xla_compile_attr_true |= xla_compile_attr;
|
||||
}
|
||||
|
||||
if (flib_def_->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr).ok()) {
|
||||
is_xla_compile_attr_true |= xla_compile_attr;
|
||||
}
|
||||
bool is_xla_compile_attr_true =
|
||||
GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
|
||||
GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr);
|
||||
|
||||
DeviceSet devices;
|
||||
devices.Insert(device);
|
||||
@ -1874,6 +1883,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"EmptyTensorList",
|
||||
"ExtractImagePatches",
|
||||
"Igamma",
|
||||
"IgammaGradA",
|
||||
"RandomGammaGrad",
|
||||
"Igammac",
|
||||
"FFT",
|
||||
"FFT2D",
|
||||
@ -1900,6 +1911,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"LinSpace",
|
||||
"ListDiff",
|
||||
"LogMatrixDeterminant",
|
||||
"LowerBound",
|
||||
"MatMul",
|
||||
"MatrixBandPart",
|
||||
"MatrixDiag",
|
||||
@ -1996,6 +2008,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"StatelessRandomNormal",
|
||||
"StatelessRandomUniform",
|
||||
"StatelessRandomUniformInt",
|
||||
"StatelessRandomUniformFullInt",
|
||||
"StatelessTruncatedNormal",
|
||||
"StatelessWhile",
|
||||
"Svd",
|
||||
@ -2025,6 +2038,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorScatterUpdate",
|
||||
"TridiagonalSolve",
|
||||
"TruncatedNormal",
|
||||
"UpperBound",
|
||||
"UnsortedSegmentMax",
|
||||
"UnsortedSegmentMin",
|
||||
"UnsortedSegmentProd",
|
||||
|
@ -18,13 +18,15 @@ limitations under the License.
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// The list of all registered `XlaActivityListener`s.
|
||||
struct XlaActivityListenerList {
|
||||
absl::Mutex mutex;
|
||||
std::vector<std::unique_ptr<XlaActivityListener>> listeners GUARDED_BY(mutex);
|
||||
std::vector<std::unique_ptr<XlaActivityListener>> listeners
|
||||
TF_GUARDED_BY(mutex);
|
||||
};
|
||||
|
||||
void FlushAllListeners();
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -33,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
@ -202,6 +204,52 @@ static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) {
|
||||
execution_count < kMinExecutionsPerCompile * compile_count;
|
||||
}
|
||||
|
||||
// Creates a simple graph using the specified op as the only op apart from the
|
||||
// arg and retval nodes.
|
||||
static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
||||
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
|
||||
absl::Span<const DataType> result_types) {
|
||||
// TODO(b/74182462): We implement this by creating a new dummy Graph including
|
||||
// _Arg nodes, and let CompileGraph walk it. This could be optimized.
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
Status status;
|
||||
// First create the actual node we care about computing.
|
||||
Node* main_node = graph->AddNode(node_def, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
// Create dummy _Arg nodes. Link these to `node` and also via a control
|
||||
// dependency edge to the _SOURCE node.
|
||||
for (int64 i = 0; i < args.size(); ++i) {
|
||||
Node* node;
|
||||
string arg_name = absl::StrCat("_arg", i);
|
||||
Status status =
|
||||
NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
|
||||
.ControlInput(graph->source_node())
|
||||
.Attr("T", args[i].kind == XlaCompiler::Argument::kResource
|
||||
? DT_RESOURCE
|
||||
: args[i].type)
|
||||
.Attr("index", i)
|
||||
.Finalize(graph.get(), &node);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
graph->AddEdge(node, 0, main_node, i);
|
||||
}
|
||||
|
||||
// Similarly with return values, create dummy _Retval nodes fed by `node`.
|
||||
for (int64 i = 0; i < result_types.size(); ++i) {
|
||||
Node* node;
|
||||
string retval_name = absl::StrCat("_retval", i);
|
||||
Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
|
||||
.Input(main_node, i)
|
||||
.Attr("T", result_types[i])
|
||||
.Attr("index", i)
|
||||
.Finalize(graph.get(), &node);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
}
|
||||
FixupSourceAndSinkEdges(graph.get());
|
||||
return graph;
|
||||
}
|
||||
|
||||
Status XlaCompilationCache::CompileSingleOp(
|
||||
const XlaCompiler::Options& options,
|
||||
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
|
||||
@ -222,8 +270,11 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
for (int i = 0; i < result_dtypes.size(); ++i) {
|
||||
result_dtypes[i] = ctx->expected_output_dtype(i);
|
||||
}
|
||||
return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(),
|
||||
args, result_dtypes, result);
|
||||
|
||||
const NodeDef& node_def = ctx->op_kernel().def();
|
||||
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
|
||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||
std::move(graph), args, result);
|
||||
};
|
||||
return CompileImpl(options, name, args, compile_op,
|
||||
/*compile_threshold=*/absl::nullopt,
|
||||
|
@ -151,19 +151,19 @@ class XlaCompilationCache : public ResourceBase {
|
||||
int64 request_count = 0;
|
||||
|
||||
// Did compilation succeed?
|
||||
Status compilation_status GUARDED_BY(mu);
|
||||
Status compilation_status TF_GUARDED_BY(mu);
|
||||
|
||||
// Output of the XlaCompiler.
|
||||
XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu);
|
||||
XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu);
|
||||
|
||||
// The XLA executable compiled from <computation>. May be null if no
|
||||
// executable has been built.
|
||||
std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
|
||||
std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu);
|
||||
};
|
||||
|
||||
mutex compile_cache_mu_;
|
||||
absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
|
||||
GUARDED_BY(compile_cache_mu_);
|
||||
TF_GUARDED_BY(compile_cache_mu_);
|
||||
|
||||
struct ClusterCompileStats {
|
||||
// Number of times the cluster has been (re-)compiled.
|
||||
@ -185,7 +185,7 @@ class XlaCompilationCache : public ResourceBase {
|
||||
|
||||
// Maps cluster names to compilation statistics for said cluster.
|
||||
absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
|
||||
GUARDED_BY(cluster_compile_stats_mu_);
|
||||
TF_GUARDED_BY(cluster_compile_stats_mu_);
|
||||
|
||||
// The number of times a lazy compilation must be requested for a specific
|
||||
// signature before we attempt to compile it.
|
||||
|
@ -83,7 +83,7 @@ class XlaDeviceAllocatorState {
|
||||
std::unordered_map<std::pair<const xla::Backend*, int>,
|
||||
std::unique_ptr<XlaDeviceAllocator>,
|
||||
hash<std::pair<const xla::Backend*, int>>>
|
||||
allocators_ GUARDED_BY(allocator_mutex_);
|
||||
allocators_ TF_GUARDED_BY(allocator_mutex_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
|
||||
};
|
||||
|
@ -137,7 +137,7 @@ class XlaDevice : public LocalDevice {
|
||||
~XlaDevice() override;
|
||||
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
|
||||
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) override;
|
||||
@ -145,18 +145,18 @@ class XlaDevice : public LocalDevice {
|
||||
void Sync(const DoneCallback& done) override;
|
||||
|
||||
Status TryGetDeviceContext(DeviceContext** out_context) override
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) override LOCKS_EXCLUDED(mu_);
|
||||
Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Allocate tensor on fast memory space. This is only applied to the new TPU
|
||||
// hardware which has faster read/write memory. If the hardware doesn't
|
||||
// have such memory space, we fallback to the ordinary memory space.
|
||||
Status MakeFastMemTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) LOCKS_EXCLUDED(mu_);
|
||||
Tensor* tensor) TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
const Metadata& metadata() { return xla_metadata_; }
|
||||
|
||||
@ -166,34 +166,35 @@ class XlaDevice : public LocalDevice {
|
||||
//
|
||||
// TODO(b/111859745): The Eager context needs to call this method to recover
|
||||
// from failures.
|
||||
Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_);
|
||||
Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
|
||||
// information for GPU and TPU devices.
|
||||
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
|
||||
Status UseGpuDeviceInfo() TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Instructs this XlaDevice to return 'sync_on_completion' for
|
||||
// AllowsSyncOnCompletion().
|
||||
void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
|
||||
bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
|
||||
void SetAllowsSyncOnCompletion(bool sync_on_completion)
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Installs an error handling callback when RefreshStatus sees !status.ok().
|
||||
void SetHandleDeviceErrorCallback(std::function<Status()> callback);
|
||||
|
||||
Status RefreshStatus() override LOCKS_EXCLUDED(mu_);
|
||||
Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
private:
|
||||
xla::StatusOr<xla::LocalClient*> GetOrCreateClient() const;
|
||||
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
|
||||
std::shared_ptr<se::Stream>* stream,
|
||||
bool* stream_was_changed)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Return a pair of device context, the second one is fast_mem device context.
|
||||
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
|
||||
GetDeviceContextLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
static Status GetMetadataFromDevice(DeviceBase* device,
|
||||
const XlaDevice::Metadata** metadata);
|
||||
@ -218,13 +219,13 @@ class XlaDevice : public LocalDevice {
|
||||
// Intra-op threads to spawn (from SessionOptions).
|
||||
const int intra_op_parallelism_threads_;
|
||||
// Memory allocator associated with this device.
|
||||
Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
|
||||
Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr; // Not owned.
|
||||
|
||||
// Stream associated with this device. Operations enqueued on this
|
||||
// stream are executed on the device. Operations include data
|
||||
// copying back and forth between CPU and the device, and
|
||||
// computations enqueued by XLA.
|
||||
std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
|
||||
std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_);
|
||||
// If false, only stream_ is valid and all computation and transfers use
|
||||
// stream_. If true, computation is performed by stream_ and transfers are
|
||||
// performed by host_to_device/device_to_device stream or borrowing a stream
|
||||
@ -232,36 +233,36 @@ class XlaDevice : public LocalDevice {
|
||||
const bool use_multiple_streams_;
|
||||
// If use_multiple_streams_, host to device transfers are performed using this
|
||||
// stream.
|
||||
std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
|
||||
std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_);
|
||||
// If use_multiple_streams_, transfers between different devices are performed
|
||||
// using these streams.
|
||||
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
|
||||
GUARDED_BY(mu_);
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
|
||||
// The device context accessed by all users of the XlaDevice, set by calls to
|
||||
// EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
|
||||
// also filled in to that struct. XlaDeviceContext is a ref-counted object.
|
||||
XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr;
|
||||
XlaDeviceContext* device_context_ TF_GUARDED_BY(mu_) = nullptr;
|
||||
|
||||
// The device context will allocate memory on fast memory space on TPU.
|
||||
// XlaDeviceContext is a ref-counted object.
|
||||
XlaDeviceContext* fast_mem_device_context_ GUARDED_BY(mu_) = nullptr;
|
||||
XlaDeviceContext* fast_mem_device_context_ TF_GUARDED_BY(mu_) = nullptr;
|
||||
|
||||
// Holds extra information for GPU and TPU devices, e.g. the device context.
|
||||
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
|
||||
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
|
||||
bool use_gpu_device_info_ TF_GUARDED_BY(mu_) = false;
|
||||
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// Thread pool used for running closures
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
|
||||
// True if the device allows XlaDevice::Sync to be called on completion
|
||||
// regardless of status.
|
||||
bool sync_on_completion_ GUARDED_BY(mu_) = true;
|
||||
bool sync_on_completion_ TF_GUARDED_BY(mu_) = true;
|
||||
|
||||
// A callback that will be invoked when RefreshStatus sees a status error.
|
||||
std::function<Status()> device_error_callback_ GUARDED_BY(mu_);
|
||||
std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// Set of devices to use. This controls which of the devices on the given
|
||||
// platform will have resources allocated. For GPUs this will be
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
|
@ -117,7 +117,7 @@ class XlaDeviceContext : public DeviceContext {
|
||||
bool use_fast_mem_;
|
||||
|
||||
absl::Mutex mu_;
|
||||
int next_stream_ GUARDED_BY(mu_) = 0;
|
||||
int next_stream_ TF_GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,15 +20,17 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const {
|
||||
return CanCreateXlaKernel(node_def);
|
||||
bool XlaKernelCreator::CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const {
|
||||
return CanCreateXlaKernel(props->node_def);
|
||||
}
|
||||
|
||||
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
return CreateXlaKernel(flr, node_def, kernel);
|
||||
Status XlaKernelCreator::CreateKernel(
|
||||
FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
return CreateXlaKernel(flr, props->node_def, kernel);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
|
||||
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
||||
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
||||
// with the kXlaCompileAttr set.
|
||||
bool CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const override;
|
||||
bool CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const override;
|
||||
|
||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
Status CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const override;
|
||||
};
|
||||
|
||||
|
@ -30,10 +30,12 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
NodeDef ToNodeDef(const string& text) {
|
||||
std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
|
||||
NodeDef node_def;
|
||||
DataTypeVector dummy;
|
||||
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
|
||||
return node_def;
|
||||
return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
|
||||
dummy);
|
||||
}
|
||||
|
||||
// Create a FunctionDef that takes one resource and one regular param
|
||||
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
|
||||
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
NodeDef callsite =
|
||||
ToNodeDef(R"pb(
|
||||
auto callsite =
|
||||
ToNodeProperties(R"pb(
|
||||
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
|
||||
)pb");
|
||||
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
(*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
|
||||
|
||||
// Note: need to set attribute on the created node.
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
|
||||
@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
Status status =
|
||||
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
|
||||
}
|
||||
|
||||
@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
Status status =
|
||||
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
|
||||
}
|
||||
|
||||
|
@ -104,7 +104,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
||||
/*compile_time_const_nodes=*/nullptr, flr));
|
||||
|
||||
for (int i = 0; i < const_args.size(); ++i) {
|
||||
for (size_t i = 0; i < const_args.size(); ++i) {
|
||||
if (const_args[i]) {
|
||||
constant_arg_indices->push_back(i);
|
||||
}
|
||||
@ -113,7 +113,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
// There can be hundreds of resource variables. Reserve the space for them.
|
||||
// We don't reserve for constants above as they are usually few.
|
||||
resource_arg_indices->reserve(arg_types.size());
|
||||
for (int i = 0; i < arg_types.size(); ++i) {
|
||||
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||
if (arg_types[i] == DT_RESOURCE) {
|
||||
resource_arg_indices->push_back(i);
|
||||
}
|
||||
@ -177,7 +177,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
// 214 variables and a similar number of activations.
|
||||
SinglePassSearch constants_search(&constant_arg_indices);
|
||||
SinglePassSearch resources_search(&resource_arg_indices);
|
||||
for (int i = 0; i < fbody->arg_types.size(); ++i) {
|
||||
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
|
||||
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
||||
// Compile-time constants and resource handles are expected to be in
|
||||
// host memory.
|
||||
@ -207,7 +207,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||
// in device memory except for resources.
|
||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
||||
for (int i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
if (fbody->ret_types[i] == DT_RESOURCE) {
|
||||
output_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
@ -218,15 +218,17 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
|
||||
input_memory_types, fbody->ret_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
auto props = std::make_shared<NodeProperties>(
|
||||
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
|
||||
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()),
|
||||
flr, dev->resource_manager(), props,
|
||||
input_memory_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
&construction, constant_arg_indices, resource_arg_indices, function);
|
||||
&construction, constant_arg_indices, resource_arg_indices, function,
|
||||
/*has_ref_vars=*/false);
|
||||
return s;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -30,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -102,7 +102,7 @@ class VariableInfo {
|
||||
// `variables` is allowed to contain instances that don't track a resource
|
||||
// variable (i.e. variables[i].var() can be null for some i).
|
||||
Status LockVariables(absl::Span<VariableInfo> variables)
|
||||
EXCLUSIVE_LOCK_FUNCTION();
|
||||
TF_EXCLUSIVE_LOCK_FUNCTION();
|
||||
|
||||
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
||||
// ShapedBuffers suitable for passing to an XLA computation.
|
||||
|
@ -122,7 +122,7 @@ class XlaTensor {
|
||||
std::shared_ptr<se::Event> definition_event_;
|
||||
// A list of all streams for which the tensor's content is defined for any
|
||||
// newly enqueued command.
|
||||
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
|
||||
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ TF_GUARDED_BY(mu_);
|
||||
mutex mu_;
|
||||
};
|
||||
|
||||
|
@ -44,11 +44,9 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AffineDialectRegistration",
|
||||
"@llvm-project//mlir:LoopDialectRegistration",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir/test:TestTransforms",
|
||||
],
|
||||
@ -76,12 +74,15 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
@ -102,11 +103,45 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_graph_optimization_pass",
|
||||
srcs = ["mlir_graph_optimization_pass.cc"],
|
||||
hdrs = ["mlir_graph_optimization_pass.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_graph_optimization_pass_registration",
|
||||
srcs = [
|
||||
"mlir_graph_optimization_pass_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":mlir_graph_optimization_pass",
|
||||
"//tensorflow/core:core_cpu",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf-opt",
|
||||
deps = [
|
||||
":tf_mlir_opt_main",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
],
|
||||
)
|
||||
@ -116,8 +151,10 @@ tf_cc_binary(
|
||||
srcs = ["tf_mlir_translate_main.cc"],
|
||||
deps = [
|
||||
":init_mlir",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
||||
@ -129,6 +166,7 @@ tf_cc_binary(
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TranslateClParser",
|
||||
|
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# TensorFlow MLIR
|
||||
|
||||
These are the docs for: https://www.tensorflow.org/mlir
|
26
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
26
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
@ -0,0 +1,26 @@
|
||||
upper_tabs:
|
||||
# Tabs left of dropdown menu
|
||||
- include: /_upper_tabs_left.yaml
|
||||
- include: /api_docs/_upper_tabs_api.yaml
|
||||
# Dropdown menu
|
||||
- name: Resources
|
||||
path: /resources
|
||||
is_default: true
|
||||
menu:
|
||||
- include: /resources/_menu_toc.yaml
|
||||
lower_tabs:
|
||||
# Subsite tabs
|
||||
other:
|
||||
- name: Guide
|
||||
contents:
|
||||
- title: Overview
|
||||
path: /mlir/overview
|
||||
- heading: Dialects
|
||||
- title: Overview
|
||||
path: /mlir/dialects
|
||||
- title: TensorFlow
|
||||
path: /mlir/tf_ops
|
||||
- title: TensorFlow Lite
|
||||
path: /mlir/tfl_ops
|
||||
|
||||
- include: /_upper_tabs_right.yaml
|
54
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
54
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
@ -0,0 +1,54 @@
|
||||
book_path: /mlir/_book.yaml
|
||||
project_path: /mlir/_project.yaml
|
||||
description: <!--no description-->
|
||||
landing_page:
|
||||
custom_css_path: /site-assets/css/style.css
|
||||
rows:
|
||||
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
|
||||
items:
|
||||
- description: >
|
||||
The <a href="https://mlir.llvm.org/" class="external">MLIR</a> project defines a common
|
||||
intermediate representation (IR) that unifies the infrastructure required to execute high
|
||||
performance machine learning models in TensorFlow and similar ML frameworks. This project
|
||||
will include the application of HPC techniques, along with integration of
|
||||
search algorithms like reinforcement learning. MLIR aims to reduce the
|
||||
cost to bring up new hardware, and improve usability for existing
|
||||
TensorFlow users.
|
||||
|
||||
- code_block: |
|
||||
<pre class = "prettyprint">
|
||||
// Syntactically similar to LLVM:
|
||||
func @testFunction(%arg0: i32) {
|
||||
%x = call @thingToCall(%arg0) : (i32) -> i32
|
||||
br ^bb1
|
||||
^bb1:
|
||||
%y = addi %x, %x : i32
|
||||
return %y : i32
|
||||
}
|
||||
</pre>
|
||||
|
||||
- classname: devsite-landing-row-cards
|
||||
items:
|
||||
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
|
||||
youtube_id: qzljG6DKgic
|
||||
buttons:
|
||||
- label: Watch the video
|
||||
path: https://www.youtube.com/watch?v=qzljG6DKgic
|
||||
- heading: "A new intermediate representation and compiler framework"
|
||||
image_path: /resources/images/tf-logo-card-16x9.png
|
||||
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
|
||||
buttons:
|
||||
- label: Read on TensorFlow blog
|
||||
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
|
||||
- heading: MLIR on GitHub
|
||||
image_path: /resources/images/github-card-16x9.png
|
||||
path: https://github.com/llvm/llvm-project/tree/master/mlir
|
||||
buttons:
|
||||
- label: View on GitHub
|
||||
path: https://github.com/llvm/llvm-project/tree/master/mlir
|
||||
- heading: TensorFlow MLIR on GitHub
|
||||
image_path: /resources/images/github-card-16x9.png
|
||||
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
|
||||
buttons:
|
||||
- label: View on GitHub
|
||||
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
|
37
tensorflow/compiler/mlir/g3doc/dialects.md
Normal file
37
tensorflow/compiler/mlir/g3doc/dialects.md
Normal file
@ -0,0 +1,37 @@
|
||||
# MLIR dialects
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
To separate different hardware and software targets, MLIR has “dialects”,
|
||||
including:
|
||||
|
||||
* TensorFlow IR, which represents all things possible in TensorFlow graphs.
|
||||
* XLA HLO IR, which is designed to take advantage of XLA’s compilation
|
||||
abilities (with output to, among other things, TPUs).
|
||||
* An experimental affine dialect, which focuses on
|
||||
[polyhedral representations](https://en.wikipedia.org/wiki/Polytope_model)
|
||||
and optimizations.
|
||||
* LLVM IR, which has a 1:1 mapping between it and LLVM’s own representation,
|
||||
allowing MLIR to emit GPU and CPU code through LLVM.
|
||||
* TensorFlow Lite, which will translate to running code on mobile platforms.
|
||||
|
||||
Each dialect consists of a set of defined operations which have invariants
|
||||
placed on them, like: “This is a binary operator, and the inputs and outputs
|
||||
have the same types.”
|
||||
|
||||
## Adding to MLIR
|
||||
|
||||
MLIR has no fixed/built-in list of globally known operations (no “intrinsics”).
|
||||
Dialects can define entirely custom types, which is how MLIR can model things
|
||||
like the LLVM IR type system (which has first class aggregates), domain
|
||||
abstractions important for ML-optimized accelerators like quantized types, and
|
||||
even the Swift or Clang type systems (which are built around Swift/Clang
|
||||
declaration nodes) in the future.
|
||||
|
||||
If you want to connect a new low-level compiler, you would create a new dialect
|
||||
and the lowerings between the TensorFlow Graph dialect and your dialect.
|
||||
This smooths the path for hardware and compiler makers. You can even target
|
||||
dialects at different levels in the same model; the higher-level optimizers
|
||||
will respect the unfamiliar parts of the IR and wait for a lower level to handle
|
||||
it.
|
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
File diff suppressed because one or more lines are too long
After (image error) Size: 148 KiB |
36
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
36
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
@ -0,0 +1,36 @@
|
||||
# MLIR
|
||||
|
||||
## Overview
|
||||
|
||||
MLIR, or Multi-Level Intermediate Representation, is a representation format
|
||||
and library of compiler utilities that sits between the model representation
|
||||
and low-level compilers/executors that generate hardware-specific code.
|
||||
|
||||
MLIR is, at its heart, a flexible infrastructure for modern optimizing
|
||||
compilers. This means it consists of a specification for intermediate
|
||||
representations (IR) and a code toolkit to perform transformations on that
|
||||
representation. (In compiler parlance, as you move from higher-level
|
||||
representations to lower-level representations, these transformations can be
|
||||
called “lowerings”)
|
||||
|
||||
MLIR is highly influenced by [LLVM](https://llvm.org/) and unabashedly reuses
|
||||
many great ideas from it. It has a flexible type system, and allows
|
||||
representing, analyzing and transforming graphs combining multiple levels of
|
||||
abstraction in the same compilation unit. These abstractions include TensorFlow
|
||||
operations, nested polyhedral loop regions, and even LLVM instructions and fixed
|
||||
hardware operations and types.
|
||||
|
||||
We expect MLIR to be of interest to many groups, including:
|
||||
|
||||
* Compiler researchers and implementers looking to optimize performance and
|
||||
memory consumption of machine learning models
|
||||
* Hardware makers looking for a way to connect their hardware to TensorFlow,
|
||||
such as TPUs, portable neural hardware in phones, and other custom ASICs
|
||||
* People writing language bindings that want to take advantage of optimizing
|
||||
compilers and hardware acceleration.
|
||||
|
||||
The TensorFlow ecosystem contains a number of compilers and optimizers that
|
||||
operate at multiple levels of the software and hardware stack. We expect the
|
||||
gradual adoption of MLIR to simplify every aspect of this stack.
|
||||
|
||||
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>
|
@ -48,10 +48,11 @@ def _run_lit_test(name, data, size, tags, driver, features):
|
||||
" the driver parameter when running this test. If you require" +
|
||||
" custom driver support, please file an issue to request it.")
|
||||
|
||||
# Disable tests on windows for now, to enable testing rest of all xla and mlir.
|
||||
native.py_test(
|
||||
name = name,
|
||||
srcs = ["@llvm-project//llvm:lit"],
|
||||
tags = tags,
|
||||
tags = tags + ["no_windows"],
|
||||
args = [
|
||||
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
|
||||
] + features,
|
||||
|
@ -30,6 +30,7 @@ filegroup(
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
],
|
||||
)
|
||||
@ -208,6 +209,7 @@ cc_library(
|
||||
"ir/tfl_ops.h.inc",
|
||||
"ir/tfl_ops_interface.cc.inc",
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"runtime_verifiers.inc",
|
||||
"utils/attribute_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
@ -222,15 +224,18 @@ cc_library(
|
||||
":validators",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -302,15 +307,15 @@ cc_library(
|
||||
"transforms/optimize_functional_ops.cc",
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/runtime_type_verify.cc",
|
||||
"transforms/split_merged_operands.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
"transforms/while_loop_outline.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"transforms/dilated_conv.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
@ -323,6 +328,8 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
@ -459,9 +466,9 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "operator-converter-gen",
|
||||
name = "converter-gen",
|
||||
srcs = [
|
||||
"operator_converter_gen.cc",
|
||||
"converter_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
@ -471,14 +478,18 @@ tf_native_cc_binary(
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "operator_converter_inc",
|
||||
name = "converter_inc",
|
||||
tbl_outs = [
|
||||
(
|
||||
"", # This driver has no options.
|
||||
"--gen-operator-converters",
|
||||
"operator_converters.inc",
|
||||
),
|
||||
(
|
||||
"--gen-runtime-verifiers",
|
||||
"runtime_verifiers.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":operator-converter-gen",
|
||||
tblgen = ":converter-gen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
@ -508,6 +519,7 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
@ -561,6 +573,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -571,7 +584,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/tools/versioning:op_version",
|
||||
"//tensorflow/lite/tools/versioning",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -581,8 +594,6 @@ cc_library(
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:StandardDialectRegistration",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
@ -594,6 +605,7 @@ tf_cc_binary(
|
||||
name = "flatbuffer_translate",
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
"@llvm-project//mlir:LoopOpsTransforms",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
],
|
||||
)
|
||||
@ -643,12 +655,14 @@ tf_cc_binary(
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
@ -687,16 +701,16 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
@ -716,6 +730,7 @@ cc_library(
|
||||
":tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
@ -725,12 +740,10 @@ cc_library(
|
||||
"//tensorflow/lite/tools/optimize:quantize_weights",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
|
@ -1,9 +1,9 @@
|
||||
# Experimental code for the new TF-Lite convertor, and MLIR dialects and utilities for TensorFlow Lite.
|
||||
# The new [MLIR](https://github.com/llvm/llvm-project/tree/master/mlir) based
|
||||
TensorFlow to TensorFlow Lite converter
|
||||
|
||||
This directory contains:
|
||||
|
||||
1. Experimental code for the new TF-Lite convertor.
|
||||
2. Code for the TF-lite dialect [MLIR](https://github.com/tensorflow/mlir).
|
||||
1. MLIR dialects, transformation passes and utilities for TensorFlow Lite.
|
||||
|
||||
## API:
|
||||
|
||||
@ -11,7 +11,8 @@ The API for converting TensorFlow models to TensorFlow Lite will be through
|
||||
`tf.lite.TFLiteConverter`. All the conversion code is open sourced, and
|
||||
the API will be integrated soon.
|
||||
|
||||
### The conversion process from TensorFlow to TensorFlow Lite includes the following major passes:
|
||||
### The conversion process from TensorFlow to TensorFlow Lite includes the
|
||||
following major passes:
|
||||
|
||||
- Import from GraphDef, in .pb or .pbtxt format, into MLIR.
|
||||
- Raise to Control-flow-graph. Converts TF Control Flow dialect to TF dialect.
|
||||
@ -28,3 +29,6 @@ TensorFlow Lite models).
|
||||
- The Export pass writes out TensorFlow Lite FlatBuffer format. This pass
|
||||
operates on MLIR TensorFlow Lite dialect and is simple/direct translation.
|
||||
|
||||
See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
|
||||
for the full list of MLIR passes for conversion from TensorFlow to
|
||||
TensorFlow Lite.
|
||||
|
@ -34,8 +34,9 @@ struct PassConfig {
|
||||
quant_specs(std::move(specs)),
|
||||
skip_control_dialect(false),
|
||||
form_clusters(false),
|
||||
inline_functions(true),
|
||||
unfold_batch_matmul(true) {}
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true),
|
||||
shape_inference(false) {}
|
||||
|
||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||
// added, which produces TF Lite ops.
|
||||
@ -55,12 +56,15 @@ struct PassConfig {
|
||||
// are formed by grouping consecutive ops of the same device, under a
|
||||
// `tf_device.launch` op.
|
||||
bool form_clusters;
|
||||
// Inline function calls within the main function in the MLIR module, prior
|
||||
// to legalization to TFLite.
|
||||
bool inline_functions;
|
||||
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||
// of tfl.fully_connected ops.
|
||||
bool unfold_batch_matmul;
|
||||
// Whether to legalize TF While to TFL While.
|
||||
// Note: This is staging step and will be removed.
|
||||
// TODO(b/137395003): Remove post switching legalization.
|
||||
bool legalize_tf_while;
|
||||
// Whether to do shape inference.
|
||||
bool shape_inference;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
|
@ -28,6 +28,9 @@ limitations under the License.
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Format.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Operator.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
|
||||
|
||||
using llvm::DefInit;
|
||||
using llvm::dyn_cast;
|
||||
@ -41,6 +44,19 @@ using llvm::SmallVector;
|
||||
using llvm::StringInit;
|
||||
using llvm::StringRef;
|
||||
|
||||
enum ActionType {
|
||||
OpConv,
|
||||
RuntimeVerify,
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
llvm::cl::opt<ActionType> action(
|
||||
llvm::cl::desc("Action to perform:"),
|
||||
llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters",
|
||||
"Generate operator converters"),
|
||||
clEnumValN(RuntimeVerify, "gen-runtime-verifiers",
|
||||
"Generate TFLite runtime verifiers")));
|
||||
|
||||
// Returns the associated option name for the given op definition.
|
||||
static inline std::string GetOperatorOptionName(const Record &def) {
|
||||
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
|
||||
@ -103,6 +119,12 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper,
|
||||
// conversion generation and so the simplicity was chosen over the
|
||||
// flexibility.
|
||||
StringRef arg_name = arg_values->getArgNameStr(i);
|
||||
// Skip any "intermiadiateXXX" attribute as they are specially handled
|
||||
// in the exporter. They are special because though they are attributes
|
||||
// in the MLIR they are expressed as tensors in the flatbuffer instead
|
||||
// of option.
|
||||
if (op_name == "LSTMOp" && arg_name.take_back(12) == "intermediate")
|
||||
continue;
|
||||
os << formatv(
|
||||
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
|
||||
arg_name, mlir::tblgen::Attribute(arg_def).getAttrDefName());
|
||||
@ -148,17 +170,24 @@ static void EmitOperatorBuilders(const std::vector<Record *> &defs,
|
||||
for (const auto *def : defs) {
|
||||
StringRef op_name = def->getName().drop_front(4);
|
||||
|
||||
const bool has_intermediates = op_name == "LSTMOp";
|
||||
// Signature
|
||||
os << "static flatbuffers::Offset<tflite::Operator> "
|
||||
<< GetOperatorBuilderName(def->getName()) << "(mlir::TFL::" << op_name
|
||||
<< " tflOp, uint32_t opcode_index, "
|
||||
<< "const std::vector<int32_t>& operands,"
|
||||
<< "const std::vector<int32_t>& results,"
|
||||
<< (has_intermediates ? "const std::vector<int32_t>& intermediate_index,"
|
||||
: "")
|
||||
<< "flatbuffers::FlatBufferBuilder *fbb) {\n";
|
||||
|
||||
// Inputs & outputs
|
||||
os << " auto inputs = fbb->CreateVector(operands);\n"
|
||||
" auto outputs = fbb->CreateVector(results);\n\n";
|
||||
// Intermediates for LSTM.
|
||||
if (has_intermediates) {
|
||||
os << " auto intermediates = fbb->CreateVector(intermediate_index);\n";
|
||||
}
|
||||
|
||||
// Build the FlatBuffer operator
|
||||
os << " return tflite::CreateOperator(\n"
|
||||
@ -175,9 +204,9 @@ static void EmitOperatorBuilders(const std::vector<Record *> &defs,
|
||||
// Only builtin ops' builders are auto-generated. custom_options are only
|
||||
// used by custom or flex ops and those ops are handled manually.
|
||||
os << " /*custom_options=*/0, "
|
||||
"tflite::CustomOptionsFormat_FLEXBUFFERS,\n"
|
||||
" /*mutating_variable_inputs=*/0);\n"
|
||||
"}\n\n";
|
||||
<< "tflite::CustomOptionsFormat_FLEXBUFFERS,\n"
|
||||
<< " /*mutating_variable_inputs=*/0"
|
||||
<< (has_intermediates ? ", intermediates" : "") << ");\n}\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -228,6 +257,7 @@ static void EmitGetBuiltinOpCode(const std::vector<Record *> &defs,
|
||||
// uint32_t opcode_index,
|
||||
// const std::vector<int32_t>& operands,
|
||||
// const std::vector<int32_t>& results,
|
||||
// const std::vector<int32_t>& intermediates,
|
||||
// flatbuffers::FlatBufferBuilder *fbb);
|
||||
static void EmitBuildOperator(const std::vector<Record *> &defs,
|
||||
raw_ostream *ostream) {
|
||||
@ -239,6 +269,7 @@ static void EmitBuildOperator(const std::vector<Record *> &defs,
|
||||
"uint32_t opcode_index, "
|
||||
"const std::vector<int32_t>& operands,"
|
||||
"const std::vector<int32_t>& results,"
|
||||
"const std::vector<int32_t>& intermediates,"
|
||||
"flatbuffers::FlatBufferBuilder *fbb) {\n";
|
||||
|
||||
for (const auto *def : defs) {
|
||||
@ -248,7 +279,8 @@ static void EmitBuildOperator(const std::vector<Record *> &defs,
|
||||
os << " if (auto tflOp = llvm::dyn_cast<mlir::TFL::" << op_name
|
||||
<< ">(op))\n"
|
||||
<< " return " << GetOperatorBuilderName(def->getName())
|
||||
<< "(tflOp, opcode_index, operands, results, fbb);\n";
|
||||
<< "(tflOp, opcode_index, operands, results, "
|
||||
<< (op_name == "LSTMOp" ? "intermediates, " : "") << "fbb);\n";
|
||||
}
|
||||
|
||||
os << " return llvm::None;\n"
|
||||
@ -291,6 +323,10 @@ static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper,
|
||||
if (!arg_def) continue;
|
||||
if (arg_def->getDef()->isSubClassOf(attr_type)) {
|
||||
StringRef arg_name = arg_values->getArgNameStr(i);
|
||||
// Already handle this case in flatbuffer_import.cc.
|
||||
if (option_name == "LSTMOptions" &&
|
||||
arg_name.take_back(12) == "intermediate")
|
||||
continue;
|
||||
StringRef attr_type = mlir::tblgen::Attribute(arg_def).getAttrDefName();
|
||||
os << formatv(
|
||||
" attributes.emplace_back(builder.getNamedAttr(\"{0}\","
|
||||
@ -342,8 +378,101 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static void GenOperandResultVerifier(raw_ostream &os,
|
||||
llvm::ArrayRef<llvm::Init *> values,
|
||||
StringRef valueKind) {
|
||||
mlir::tblgen::FmtContext fctx;
|
||||
|
||||
bool first = true;
|
||||
for (auto static_value : llvm::enumerate(values)) {
|
||||
auto *definit = llvm::cast<llvm::DefInit>(static_value.value());
|
||||
auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate");
|
||||
if (!val) continue;
|
||||
|
||||
// Create code block on first type to verify.
|
||||
if (first) {
|
||||
os << " {\n";
|
||||
os << " unsigned index = " << static_value.index() << ";\n";
|
||||
first = false;
|
||||
}
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
auto desc =
|
||||
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
|
||||
|
||||
// Emit a loop to check all the dynamic values in the pack.
|
||||
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
|
||||
// Capitalize the first letter to match the function name
|
||||
valueKind.substr(0, 1).upper(), valueKind.substr(1),
|
||||
static_value.index());
|
||||
|
||||
os << " (void)v;\n"
|
||||
<< " if (!("
|
||||
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||
<< formatv(
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||
valueKind, desc)
|
||||
<< " }\n" // if
|
||||
<< " ++index;\n"
|
||||
<< " }\n"; // for
|
||||
}
|
||||
|
||||
// Emit closing brace if needed.
|
||||
if (!first) os << " }\n";
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os);
|
||||
|
||||
// Retrieve all the definitions derived from TFL_Op and sort by record name.
|
||||
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
|
||||
llvm::sort(defs, LessRecord());
|
||||
|
||||
// Iterate through all the ops defined.
|
||||
for (const auto *def : defs) {
|
||||
mlir::tblgen::Operator op(*def);
|
||||
if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue;
|
||||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
auto &value = op.getOperand(i);
|
||||
// Skip from from first variadic operands for now. Else getOperand index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||
}
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
auto &value = op.getResult(i);
|
||||
// Skip from from first variadic results for now. Else getResult index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||
}
|
||||
}
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
|
||||
"operand");
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
||||
"result");
|
||||
os << " return mlir::success();\n}\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
return TableGenMain(argv[0], &OperatorWritersMain);
|
||||
if (action == ActionType::OpConv)
|
||||
return TableGenMain(argv[0], &OperatorWritersMain);
|
||||
return TableGenMain(argv[0], &RuntimeVerifierWriterMain);
|
||||
}
|
@ -46,7 +46,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
@ -76,6 +76,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@ -124,6 +125,20 @@ static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
|
||||
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string> input_arrays_flag(
|
||||
"input-arrays",
|
||||
llvm::cl::desc(
|
||||
"List of input tensors, if different from the default inputs"),
|
||||
llvm::cl::init(""));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string> output_arrays_flag(
|
||||
"output-arrays",
|
||||
llvm::cl::desc(
|
||||
"List of output tensors, if different from the default outputs"),
|
||||
llvm::cl::init(""));
|
||||
|
||||
namespace {
|
||||
bool IsScalar(const TensorT& tensor) {
|
||||
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
||||
@ -532,6 +547,7 @@ bool IsCustomOp(const std::string& op_name) {
|
||||
// TODO(krzysd) Handle function calls
|
||||
StatusOr<Operation*> ConvertOp(
|
||||
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
|
||||
const std::vector<mlir::TensorType>& intermediate_types,
|
||||
Value optional_arg_marker, const std::vector<std::string>& op_names,
|
||||
const std::vector<std::string>& func_names,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
|
||||
@ -590,6 +606,33 @@ StatusOr<Operation*> ConvertOp(
|
||||
op_state.addTypes({type});
|
||||
}
|
||||
|
||||
if (op_name == "tfl.lstm") {
|
||||
// TODO(b/147587779): add the right region if region is empty.
|
||||
op_state.addRegion();
|
||||
if (!op.intermediates.empty()) {
|
||||
if (op.intermediates.size() != 5) {
|
||||
auto err = errors::InvalidArgument(
|
||||
"operator has intermediate tensors but the number of them is not "
|
||||
"five.");
|
||||
return emitError(loc, err.ToString()), err;
|
||||
}
|
||||
// Create intermediate value
|
||||
|
||||
const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
|
||||
"input_to_input_intermediate", "input_to_forget_intermediate",
|
||||
"input_to_cell_intermediate", "input_to_output_intermediate",
|
||||
"effective_hidden_scale_intermediate"};
|
||||
for (auto type_and_name :
|
||||
llvm::zip(intermediate_types, kIntermediateNames)) {
|
||||
mlir::TypeAttr type_attr =
|
||||
mlir::TypeAttr::get(std::get<0>(type_and_name));
|
||||
auto named_attr =
|
||||
builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
|
||||
op_state.addAttribute(named_attr.first, named_attr.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
||||
if (IsCustomOp(op_name)) {
|
||||
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
|
||||
@ -610,43 +653,30 @@ StatusOr<Operation*> ConvertOp(
|
||||
return builder.createOperation(op_state);
|
||||
}
|
||||
|
||||
// Returns the output tensor indices for the given subgraph. If
|
||||
// ordered_output_arrays is provided, then return the tensor indices in
|
||||
// ordered_output_arrays.
|
||||
StatusOr<llvm::SmallVector<int32_t, 4>> GetOutputTensorIndices(
|
||||
const tflite::SubGraphT& subgraph, Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays) {
|
||||
if (ordered_output_arrays.empty()) {
|
||||
return llvm::SmallVector<int32_t, 4>(subgraph.outputs.begin(),
|
||||
subgraph.outputs.end());
|
||||
// Returns indices of the given tensors in the subgraph. Returns error if a
|
||||
// tensor name cannot be found in the subgraph.
|
||||
StatusOr<std::vector<int>> GetTensorIndices(
|
||||
const tflite::SubGraphT& subgraph,
|
||||
const std::vector<std::string>& tensor_names) {
|
||||
absl::flat_hash_map<std::string, int> name_to_index;
|
||||
for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
|
||||
name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
|
||||
}
|
||||
|
||||
llvm::SmallVector<int32_t, 4> outputs;
|
||||
outputs.resize(ordered_output_arrays.size());
|
||||
absl::flat_hash_map<std::string, int> output_order_map;
|
||||
for (auto output : llvm::enumerate(ordered_output_arrays)) {
|
||||
output_order_map[output.value()] = output.index();
|
||||
}
|
||||
std::vector<int> indices;
|
||||
indices.reserve(tensor_names.size());
|
||||
|
||||
int tensor_index = 0;
|
||||
int found_output_tensors = 0;
|
||||
for (const auto& tensor : subgraph.tensors) {
|
||||
auto found = output_order_map.find(tensor->name);
|
||||
if (found != output_order_map.end()) {
|
||||
const int output_index = found->second;
|
||||
outputs[output_index] = tensor_index;
|
||||
++found_output_tensors;
|
||||
for (const auto& name : tensor_names) {
|
||||
auto found = name_to_index.find(name);
|
||||
if (found != name_to_index.end()) {
|
||||
indices.push_back(found->second);
|
||||
} else {
|
||||
return errors::InvalidArgument("could not find tensor in subgraph: ",
|
||||
name);
|
||||
}
|
||||
++tensor_index;
|
||||
}
|
||||
|
||||
if (found_output_tensors != ordered_output_arrays.size()) {
|
||||
auto err = errors::InvalidArgument(
|
||||
"cannot find all nodes in ordered_output_arrays");
|
||||
return emitError(base_loc, err.ToString()), err;
|
||||
}
|
||||
|
||||
return outputs;
|
||||
return indices;
|
||||
}
|
||||
|
||||
// Given a list of tensor indices, returns a string of concatenated tensor names
|
||||
@ -661,15 +691,18 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
|
||||
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
||||
}
|
||||
|
||||
// Given a list of output indices, traverses the subgraph and returns the set of
|
||||
// ops that are ancestors of the output tensors.
|
||||
// Traverses the subgraph from output_indices to input_indices and returns the
|
||||
// set of ops that are visited.
|
||||
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
|
||||
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
|
||||
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
|
||||
ArrayRef<int32_t> output_indices) {
|
||||
// Create a map from tensor index to defining op.
|
||||
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
|
||||
for (const auto& op : subgraph.operators) {
|
||||
for (int32_t output : op->outputs) {
|
||||
defining_op[output] = op.get();
|
||||
if (!llvm::is_contained(input_indices, output)) {
|
||||
defining_op[output] = op.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -718,18 +751,40 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
const std::vector<std::string>& op_names,
|
||||
const std::vector<std::string>& func_names,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
||||
Location base_loc, Builder builder,
|
||||
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
|
||||
Location base_loc, Builder builder, bool is_entry_point,
|
||||
bool use_external_constant,
|
||||
const std::vector<std::string>& ordered_input_arrays,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
llvm::SmallVector<mlir::Type, 2> ret_types;
|
||||
llvm::SmallVector<mlir::Type, 4> input_types;
|
||||
|
||||
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
|
||||
|
||||
// Construct function type
|
||||
for (auto input : subgraph.inputs) {
|
||||
auto& tensor = *subgraph.tensors.at(input);
|
||||
std::vector<int> func_inputs = subgraph.inputs;
|
||||
if (is_entry_point && !ordered_input_arrays.empty()) {
|
||||
if (!experimental_prune_unreachable_nodes_unconditionally) {
|
||||
// TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
|
||||
return errors::InvalidArgument(
|
||||
"input-arrays should be used with experimental pruning flag");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(func_inputs,
|
||||
GetTensorIndices(subgraph, ordered_input_arrays));
|
||||
}
|
||||
|
||||
// Add state variables to inputs.
|
||||
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
|
||||
func_inputs.end());
|
||||
for (int i = 0; i < subgraph.tensors.size(); i++) {
|
||||
auto& tensor = *subgraph.tensors.at(i);
|
||||
if (tensor.is_variable && !input_index_set.contains(i)) {
|
||||
func_inputs.emplace_back(i);
|
||||
input_index_set.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto input_or_variable : func_inputs) {
|
||||
auto& tensor = *subgraph.tensors.at(input_or_variable);
|
||||
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
|
||||
// but we cannot differentiate scalars from unranked tensors.
|
||||
// Here we reverse the default assumption that shape = [] means unranked.
|
||||
@ -753,9 +808,11 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto func_outputs,
|
||||
GetOutputTensorIndices(subgraph, base_loc, ordered_output_arrays));
|
||||
std::vector<int> func_outputs = subgraph.outputs;
|
||||
if (is_entry_point && !ordered_output_arrays.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(func_outputs,
|
||||
GetTensorIndices(subgraph, ordered_output_arrays));
|
||||
}
|
||||
|
||||
for (auto output : func_outputs) {
|
||||
bool is_constant = !is_op_output[output];
|
||||
@ -782,8 +839,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
Value maybe_optional_arg_marker = nullptr;
|
||||
|
||||
// Get or construct MLIR values for each input
|
||||
for (int i = 0, e = subgraph.inputs.size(); i < e; i++) {
|
||||
auto input_tensor = subgraph.inputs[i];
|
||||
for (int i = 0, e = func_inputs.size(); i < e; i++) {
|
||||
auto input_tensor = func_inputs[i];
|
||||
const auto& tensor = *subgraph.tensors.at(input_tensor);
|
||||
auto loc = TensorLoc(tensor, builder, base_loc);
|
||||
if (vals_map[input_tensor]) {
|
||||
@ -806,9 +863,9 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
// Set tf.entry_function attribute
|
||||
if (is_entry_point) {
|
||||
llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
|
||||
if (!subgraph.inputs.empty()) {
|
||||
if (!func_inputs.empty()) {
|
||||
attributes.push_back(BuildTFEntryFunctionAttribute(
|
||||
subgraph, &builder, "inputs", subgraph.inputs));
|
||||
subgraph, &builder, "inputs", func_inputs));
|
||||
}
|
||||
if (!func_outputs.empty()) {
|
||||
attributes.push_back(BuildTFEntryFunctionAttribute(
|
||||
@ -820,7 +877,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||
if (experimental_prune_unreachable_nodes_unconditionally) {
|
||||
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
|
||||
PruneSubgraph(subgraph, func_outputs));
|
||||
PruneSubgraph(subgraph, func_inputs, func_outputs));
|
||||
}
|
||||
|
||||
// Construct MLIR operators from TFLite operators
|
||||
@ -859,6 +916,18 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
}
|
||||
|
||||
// Intermediate tensors for tfl.lstm are used to carry quantization range
|
||||
// in their types, so we only need and extract their types.
|
||||
std::vector<mlir::TensorType> intermediate_types;
|
||||
intermediate_types.reserve(5);
|
||||
for (auto intermediate : op->intermediates) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
|
||||
/*shapeless_are_scalars=*/true,
|
||||
/*is_constant=*/true));
|
||||
intermediate_types.emplace_back(type);
|
||||
}
|
||||
|
||||
// The NameLoc corresponding to the name of the first output tensor
|
||||
auto op_loc =
|
||||
op->outputs.empty()
|
||||
@ -868,8 +937,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
// to a valid Value
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto* mlir_op,
|
||||
ConvertOp(*op, vals_map, maybe_optional_arg_marker, op_names,
|
||||
func_names, subgraph.tensors, op_loc, op_builder));
|
||||
ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
|
||||
op_names, func_names, subgraph.tensors, op_loc, op_builder));
|
||||
|
||||
// Add the results to the value maps. There are two cases: 1. the result
|
||||
// tensor does not have min/max values, the original op result is used
|
||||
@ -931,8 +1000,9 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
||||
|
||||
OwningModuleRef tflite::FlatBufferToMlir(
|
||||
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool use_external_constant,
|
||||
const std::vector<std::string>& ordered_input_arrays,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
auto model_ptr =
|
||||
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
||||
@ -971,33 +1041,25 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
builder.getStringAttr(model->description));
|
||||
}
|
||||
|
||||
if (!ordered_output_arrays.empty() && model->subgraphs.size() > 1) {
|
||||
// TODO(b/141485522): support more than one subgraph.
|
||||
return emitError(base_loc,
|
||||
"ordered_output_arrays does not support more than one "
|
||||
"subgraph yet"),
|
||||
nullptr;
|
||||
}
|
||||
|
||||
for (auto e : llvm::enumerate(model->subgraphs)) {
|
||||
auto& subgraph = e.value();
|
||||
std::string name = SubgraphName(e.index(), *subgraph);
|
||||
auto func_or_error = ConvertSubgraph(
|
||||
*subgraph, name, operator_names, func_names, model->buffers, base_loc,
|
||||
// Only the entry point needs pseudo_input_ops
|
||||
builder,
|
||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||
builder, ordered_output_arrays,
|
||||
/*is_entry_point=*/e.index() == 0,
|
||||
/*use_external_constant=*/use_external_constant,
|
||||
/*use_external_constant=*/use_external_constant, ordered_input_arrays,
|
||||
ordered_output_arrays,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
if (!func_or_error.ok()) {
|
||||
return emitError(base_loc, "could not translate function ")
|
||||
<< subgraph->name,
|
||||
<< subgraph->name << ": "
|
||||
<< func_or_error.status().error_message(),
|
||||
nullptr;
|
||||
}
|
||||
module.push_back(func_or_error.ConsumeValueOrDie());
|
||||
}
|
||||
// TFLite subgraphs do not necessarily have names,
|
||||
|
||||
return OwningModuleRef(module);
|
||||
}
|
||||
@ -1012,17 +1074,24 @@ static OwningModuleRef FlatBufferFileToMlirTrans(
|
||||
auto loc =
|
||||
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
||||
|
||||
// Parses output_arrays_order from command line option.
|
||||
// Parses input/output names from command line options.
|
||||
std::vector<std::string> inputs;
|
||||
std::vector<std::string> outputs;
|
||||
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) {
|
||||
// Use output parser since we only have tensor names.
|
||||
if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) {
|
||||
return emitError(loc, "parsing input array info failed ")
|
||||
<< input_arrays_flag,
|
||||
nullptr;
|
||||
}
|
||||
if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) {
|
||||
return emitError(loc, "parsing output array info failed ")
|
||||
<< output_arrays_string,
|
||||
<< output_arrays_flag,
|
||||
nullptr;
|
||||
}
|
||||
|
||||
return tflite::FlatBufferToMlir(
|
||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||
context, loc, outputs, use_external_constant,
|
||||
context, loc, use_external_constant, inputs, outputs,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
}
|
||||
|
||||
|
@ -35,9 +35,9 @@ namespace tflite {
|
||||
// are not ancestors of the output nodes will be pruned.
|
||||
mlir::OwningModuleRef FlatBufferToMlir(
|
||||
absl::string_view buffer, mlir::MLIRContext* context,
|
||||
mlir::Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool use_external_constant = false,
|
||||
mlir::Location base_loc, bool use_external_constant = false,
|
||||
const std::vector<std::string>& ordered_input_arrays = {},
|
||||
const std::vector<std::string>& ordered_output_arrays = {},
|
||||
bool experimental_prune_unreachable_nodes_unconditionally = false);
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -44,6 +44,7 @@ llvm::Optional<tflite::BuiltinOperator> GetBuiltinOpCode(Operation *mlir_op);
|
||||
llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
|
||||
Operation *mlir_op, uint32_t opcode_index,
|
||||
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
|
||||
const std::vector<int32_t> &intermediates,
|
||||
flatbuffers::FlatBufferBuilder *fbb);
|
||||
|
||||
// Populates the array of mlir::NamedAttributes corresponding to the given
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user