Merge master
This commit is contained in:
commit
0c2afe176a
.bazelrc.bazelversion
.github/ISSUE_TEMPLATE
00-bug-issue.md00-bug-performance-issue.md10-build-installation-issue.md20-documentation-issue.md30-feature-request.md40-tflite-op-request.md50-other-issues.md60-tflite-converter-issue.md80-performance-issue.md
.pylintrcWORKSPACEconfigure.pytensorflow
BUILDapi_template.__init__.pyapi_template_v1.__init__.py
c
BUILDc_api_experimental.ccc_api_experimental.hc_api_experimental_test.ccc_api_test.cc
eager
BUILDc_api.ccc_api_debug.ccc_api_experimental.ccc_api_experimental.hc_api_experimental_test.ccc_api_internal.hc_api_test.ccc_api_test_util.ccc_api_test_util.hcustom_device_test.ccoperation_interface.ccoperation_interface.htensor_handle_interface.h
experimental/filesystem
kernels_test.cctf_tensor.cctf_tensor_internal.hcc
compat_template.__init__.pycompat_template_v1.__init__.pycompiler
aot
jit
BUILDcompilability_check_util.ccencapsulate_subgraphs_pass.ccflags.cc
kernels
mark_for_compilation_pass.ccxla_device.ccxla_kernel_creator.ccxla_kernel_creator.hxla_kernel_creator_test.ccxla_kernel_creator_util.ccmlir
BUILDglob_lit_test.bzl
lite
BUILD
common
converter_gen.ccflatbuffer_translate.ccir
mlir_tflite_runner.ccpython
quantization
tests
29
.bazelrc
29
.bazelrc
@ -69,6 +69,7 @@
|
||||
# rbe_linux_py3: Linux Python 3 RBE config
|
||||
#
|
||||
# rbe_win_py37: Windows Python 3.7 RBE config
|
||||
# rbe_win_py38: Windows Python 3.8 RBE config
|
||||
#
|
||||
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
|
||||
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
|
||||
@ -221,6 +222,11 @@ build --define=grpc_no_ares=true
|
||||
# archives in -whole_archive -no_whole_archive.
|
||||
build --noincompatible_remove_legacy_whole_archive
|
||||
|
||||
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
|
||||
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
|
||||
# https://github.com/tensorflow/community/pull/179
|
||||
build --noincompatible_prohibit_aapt1
|
||||
|
||||
# Modular TF build options
|
||||
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
||||
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
||||
@ -313,22 +319,26 @@ build:xla --define=with_xla_support=true
|
||||
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
||||
# Options when using remote execution
|
||||
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
||||
|
||||
# Flag to enable remote config
|
||||
common --experimental_repo_remote_exec
|
||||
|
||||
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
|
||||
build:rbe --auth_enabled=true
|
||||
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
|
||||
build:rbe --google_default_credentials
|
||||
build:rbe --bes_backend=buildeventservice.googleapis.com
|
||||
build:rbe --bes_best_effort=false
|
||||
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
|
||||
build:rbe --bes_timeout=600s
|
||||
build:rbe --define=EXECUTOR=remote
|
||||
build:rbe --distinct_host_configuration=false
|
||||
build:rbe --flaky_test_attempts=3
|
||||
build:rbe --jobs=200
|
||||
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
|
||||
build:rbe --remote_timeout=3600
|
||||
build:rbe --spawn_strategy=remote,worker,standalone,local
|
||||
test:rbe --test_env=USER=anon
|
||||
|
||||
build:rbe --distinct_host_configuration=false
|
||||
# Attempt to minimize the amount of data transfer between bazel and the remote
|
||||
# workers:
|
||||
build:rbe --remote_download_toplevel
|
||||
|
||||
build:rbe_linux --config=rbe
|
||||
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
|
||||
@ -355,7 +365,7 @@ build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubu
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux
|
||||
build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
|
||||
build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010-gpu"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
||||
build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
||||
build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
|
||||
@ -392,6 +402,7 @@ build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
|
||||
|
||||
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
|
||||
build:rbe_win --define=override_eigen_strong_inline=true
|
||||
build:rbe_win --jobs=500
|
||||
|
||||
build:rbe_win_py37 --config=rbe
|
||||
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
|
||||
@ -399,6 +410,12 @@ build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
|
||||
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
|
||||
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
||||
|
||||
build:rbe_win_py38 --config=rbe
|
||||
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
|
||||
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
|
||||
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
|
||||
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
|
||||
|
||||
# These you may need to change for your own GCP project.
|
||||
build:tensorflow_testing_rbe --project_id=tensorflow-testing
|
||||
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance
|
||||
|
@ -1 +1 @@
|
||||
1.2.1
|
||||
2.0.0
|
||||
|
44
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
Normal file
44
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
Normal file
@ -0,0 +1,44 @@
|
||||
---
|
||||
name: Bug Issue
|
||||
about: Use this template for reporting a bug
|
||||
labels: 'type:bug'
|
||||
|
||||
---
|
||||
|
||||
<em>Please make sure that this is a bug. As per our
|
||||
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
|
||||
we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:bug_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
**Other info / logs** Include any logs or source code that would be helpful to
|
||||
diagnose the problem. If including tracebacks, please include the full
|
||||
traceback. Large logs and files should be attached.
|
@ -1,35 +0,0 @@
|
||||
---
|
||||
name: Bug/Performance Issue
|
||||
about: Use this template for reporting a bug or a performance issue.
|
||||
|
||||
---
|
||||
|
||||
<em>Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or binary):
|
||||
- TensorFlow version (use command below):
|
||||
- Python version:
|
||||
- Bazel version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from source):
|
||||
- CUDA/cuDNN version:
|
||||
- GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
|
||||
|
||||
**Other info / logs**
|
||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
@ -1,6 +1,7 @@
|
||||
---
|
||||
name: Build/Installation Issue
|
||||
about: Use this template for build/installation issues
|
||||
labels: 'type:build/install'
|
||||
|
||||
---
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
---
|
||||
name: Documentation Issue
|
||||
about: Use this template for documentation related
|
||||
about: Use this template for documentation related issues
|
||||
labels: 'type:docs'
|
||||
|
||||
---
|
||||
|
||||
|
||||
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
|
||||
policy, we only address code/doc bugs, performance issues, feature requests, and
|
||||
build/installation issues on GitHub.
|
||||
|
4
.github/ISSUE_TEMPLATE/30-feature-request.md
vendored
4
.github/ISSUE_TEMPLATE/30-feature-request.md
vendored
@ -1,6 +1,6 @@
|
||||
---
|
||||
name: Feature Request
|
||||
about: Use this template for raising a feature request
|
||||
name: Feature Request about: Use this template for raising a feature request
|
||||
labels: 'type:feature'
|
||||
|
||||
---
|
||||
|
||||
|
12
.github/ISSUE_TEMPLATE/40-tflite-op-request.md
vendored
12
.github/ISSUE_TEMPLATE/40-tflite-op-request.md
vendored
@ -1,10 +1,10 @@
|
||||
---
|
||||
name: TensorFlow Lite Op Request
|
||||
about: Use this template for reporting ops you are using or missing.
|
||||
about: Use this template for reporting Lite ops you are using or missing
|
||||
labels: 'comp:lite'
|
||||
|
||||
---
|
||||
|
||||
|
||||
**System information**
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- TensorFlow installed from (source or binary):
|
||||
@ -17,8 +17,14 @@ about: Use this template for reporting ops you are using or missing.
|
||||
# Copy and paste here
|
||||
```
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
Also, please include a link to a GraphDef or the model if possible.
|
||||
|
||||
**Any other info / logs**
|
||||
|
||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
||||
Include any logs or source code that would be helpful to diagnose the problem.
|
||||
If including tracebacks, please include the full traceback. Large logs and files
|
||||
should be attached.
|
||||
|
1
.github/ISSUE_TEMPLATE/50-other-issues.md
vendored
1
.github/ISSUE_TEMPLATE/50-other-issues.md
vendored
@ -1,6 +1,7 @@
|
||||
---
|
||||
name: Other Issues
|
||||
about: Use this template for any other non-support related issues
|
||||
labels: 'type:others'
|
||||
|
||||
---
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
---
|
||||
name: TensorFlow Lite New Converter Issue
|
||||
about: Use this template for reporting issues during model conversion to TFLite.
|
||||
about: Use this template for reporting issues during model conversion to TFLite
|
||||
labels: 'TFLiteConverter'
|
||||
|
||||
---
|
||||
|
||||
@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite.
|
||||
|
||||
|
||||
**Command used to run the converter or code if you’re using the Python API**
|
||||
If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
```
|
||||
# Copy and paste here the exact command
|
||||
|
45
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
Normal file
45
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
---
|
||||
name: Performance Issue
|
||||
about: Use this template for reporting a performance issue
|
||||
labels: 'type:performance'
|
||||
|
||||
---
|
||||
|
||||
<em>Please make sure that this is an issue related to performance of TensorFlow.
|
||||
As per our
|
||||
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
|
||||
we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:performance_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
**Other info / logs** Include any logs or source code that would be helpful to
|
||||
diagnose the problem. If including tracebacks, please include the full
|
||||
traceback. Large logs and files should be attached.
|
29
WORKSPACE
29
WORKSPACE
@ -1,13 +1,11 @@
|
||||
workspace(name = "org_tensorflow")
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
load("//third_party:repo.bzl", "tf_http_archive")
|
||||
|
||||
tf_http_archive(
|
||||
http_archive(
|
||||
name = "io_bazel_rules_closure",
|
||||
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
||||
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
||||
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
||||
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
||||
@ -115,3 +113,28 @@ http_archive(
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||
],
|
||||
)
|
||||
|
||||
# Required for dependency @com_github_grpc_grpc
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
||||
grpc_deps()
|
||||
|
||||
load(
|
||||
"@build_bazel_rules_apple//apple:repositories.bzl",
|
||||
"apple_rules_dependencies",
|
||||
)
|
||||
|
||||
apple_rules_dependencies()
|
||||
|
||||
load(
|
||||
"@build_bazel_apple_support//lib:repositories.bzl",
|
||||
"apple_support_dependencies",
|
||||
)
|
||||
|
||||
apple_support_dependencies()
|
||||
|
||||
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
|
||||
|
||||
bazel_version_repository(name = "bazel_version")
|
||||
|
||||
|
@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '1.2.1'
|
||||
_TF_MAX_BAZEL_VERSION = '1.2.1'
|
||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '2.0.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||
|
@ -187,6 +187,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "fuchsia",
|
||||
values = {"cpu": "fuchsia"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "ios_x86_64",
|
||||
values = {
|
||||
@ -448,19 +454,66 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Specifies via a config setting if this is a mobile build or not, makes
|
||||
# it easier to combine settings later.
|
||||
selects.config_setting_group(
|
||||
name = "mobile",
|
||||
match_any = [
|
||||
":android",
|
||||
":chromiumos",
|
||||
":emscripten",
|
||||
":ios",
|
||||
],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "lite_protos_legacy",
|
||||
values = {"define": "TENSORFLOW_PROTOS=lite"},
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "full_protos",
|
||||
values = {"define": "TENSORFLOW_PROTOS=full"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "lite_protos",
|
||||
match_any = [":lite_protos_legacy"],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "mobile_lite_protos",
|
||||
match_all = [
|
||||
":lite_protos",
|
||||
":mobile",
|
||||
],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "mobile_full_protos",
|
||||
match_all = [
|
||||
":full_protos",
|
||||
":mobile",
|
||||
],
|
||||
)
|
||||
|
||||
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
||||
# Instead, please use public APIs or public build rules TF provides.
|
||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/...",
|
||||
"//learning/brain/swift/x10/...",
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
],
|
||||
)
|
||||
|
||||
@ -494,8 +547,8 @@ cc_library(
|
||||
name = "grpc",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":linux_s390x": ["@grpc//:grpc_unsecure"],
|
||||
"//conditions:default": ["@grpc"],
|
||||
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
|
||||
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
|
||||
}),
|
||||
)
|
||||
|
||||
@ -503,8 +556,8 @@ cc_library(
|
||||
name = "grpc++",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":linux_s390x": ["@grpc//:grpc++_unsecure"],
|
||||
"//conditions:default": ["@grpc//:grpc++"],
|
||||
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
|
||||
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
|
||||
}),
|
||||
)
|
||||
|
||||
@ -589,6 +642,7 @@ tf_cc_shared_object(
|
||||
"//tensorflow/core:gpu_runtime_impl",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||
"//tensorflow/core:lib_internal_impl",
|
||||
"//tensorflow/core/profiler:profiler_impl",
|
||||
"//tensorflow/stream_executor:stream_executor_impl",
|
||||
"//tensorflow:tf_framework_version_script.lds",
|
||||
] + tf_additional_binary_deps(),
|
||||
@ -908,7 +962,6 @@ py_library(
|
||||
"//conditions:default": [":tf_python_api_gen_v1"],
|
||||
}) + [
|
||||
":root_init_gen",
|
||||
":virtual_root_init_gen",
|
||||
"//tensorflow/python/keras/api:keras_python_api_gen",
|
||||
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v1",
|
||||
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v2",
|
||||
|
@ -35,9 +35,11 @@ import inspect as _inspect
|
||||
import logging as _logging
|
||||
import os as _os
|
||||
import site as _site
|
||||
import six as _six
|
||||
import sys as _sys
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
@ -69,13 +71,13 @@ except ImportError:
|
||||
_logging.warning(
|
||||
"Limited tf.summary API due to missing TensorBoard installation.")
|
||||
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
try:
|
||||
from .python.keras.api._v2 import keras
|
||||
@ -85,6 +87,13 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if not _six.PY2:
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
# Enable TF2 behaviors
|
||||
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
||||
|
@ -22,12 +22,14 @@ import distutils as _distutils
|
||||
import inspect as _inspect
|
||||
import os as _os
|
||||
import site as _site
|
||||
import six as _six
|
||||
import sys as _sys
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.platform import tf_logging as _logging
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
@ -64,13 +66,14 @@ elif _tf_api_dir not in __path__:
|
||||
# reexport_tf_summary can get compat from sys.modules. Only needed if using
|
||||
# lazy loading.
|
||||
_current_module.compat.v2 # pylint: disable=pointless-statement
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
try:
|
||||
from .python.keras.api._v1 import keras
|
||||
@ -80,6 +83,13 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if not _six.PY2:
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
|
||||
_CONTRIB_WARNING = """
|
||||
|
@ -57,6 +57,7 @@ filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"c_api_internal.h",
|
||||
"python_api.h",
|
||||
"tf_status_helper.h",
|
||||
"tf_status_internal.h",
|
||||
"tf_tensor_internal.h",
|
||||
@ -98,6 +99,17 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_tf_session_hdrs",
|
||||
srcs = [
|
||||
"python_api.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_attrtype",
|
||||
hdrs = ["tf_attrtype.h"],
|
||||
@ -524,6 +536,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/kernels:array",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:math",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
],
|
||||
)
|
||||
|
||||
@ -536,6 +549,7 @@ tf_cc_test(
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = ["notsan"], # b/149031034
|
||||
# We must ensure that the dependencies can be dynamically linked since
|
||||
# the shared library must be able to use core:framework.
|
||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||
@ -634,12 +648,14 @@ tf_cuda_cc_test(
|
||||
deps = [
|
||||
":c_api",
|
||||
":kernels",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
@ -519,72 +520,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
||||
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
||||
}
|
||||
|
||||
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
|
||||
auto* status = TF_NewStatus();
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::Tensor dst;
|
||||
TF_CHECK_OK(TF_TensorToTensor(t, &dst));
|
||||
LOG(INFO) << dst.DebugString();
|
||||
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TFE_OpPrintDebugString(TFE_Op* op) {
|
||||
VLOG(1) << "TFE_OpPrintDebugString() over " << op;
|
||||
LOG(INFO) << op->operation.DebugString();
|
||||
}
|
||||
|
||||
struct TFE_ExecuteOpNotification {
|
||||
TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
|
||||
tensorflow::Notification n;
|
||||
std::unique_ptr<tensorflow::Thread> thread;
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
|
||||
};
|
||||
|
||||
TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
|
||||
TFE_TensorHandle** retvals,
|
||||
int* num_retvals,
|
||||
TF_Status* status) {
|
||||
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
|
||||
|
||||
n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
|
||||
tensorflow::ThreadOptions(), "ExecuteOpThread",
|
||||
[op, retvals, num_retvals, n]() {
|
||||
TFE_Execute(op, retvals, num_retvals, n->status.get());
|
||||
n->n.Notify();
|
||||
}));
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
void TFE_ExecuteOpNotificationWaitAndDelete(
|
||||
TFE_ExecuteOpNotification* notification, TF_Status* status) {
|
||||
if (notification == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Passed in notification is a nullptr.");
|
||||
|
||||
return;
|
||||
}
|
||||
if (notification->thread == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Passed in notification didn't start a thread correctly. Cleaning up "
|
||||
"this notification. Please re-execute the operation to get a new "
|
||||
"notification.");
|
||||
|
||||
delete notification;
|
||||
return;
|
||||
}
|
||||
|
||||
notification->n.WaitForNotification();
|
||||
|
||||
status->status = notification->status->status;
|
||||
|
||||
delete notification;
|
||||
}
|
||||
|
||||
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
||||
status->status = tensorflow::errors::Internal(errMsg);
|
||||
}
|
||||
@ -882,12 +817,15 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
node_def.set_name(tfe_op->operation.Name());
|
||||
node_def.set_op(tfe_op->operation.Name());
|
||||
node_def.set_name(tfe_op->operation->Name());
|
||||
node_def.set_op(tfe_op->operation->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
node_def.add_input("dummy_input");
|
||||
}
|
||||
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
|
||||
tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
tfe_op->operation.get())
|
||||
->Attrs()
|
||||
.FillAttrValueMap(node_def.mutable_attr());
|
||||
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
status->status =
|
||||
|
@ -188,31 +188,6 @@ TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
|
||||
TF_Session* session, int tensor_id, TF_Status* status);
|
||||
|
||||
// Prints `handle` in a human readable format to standard output for debugging.
|
||||
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
|
||||
TFE_TensorHandle* handle);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpPrintDebugString(TFE_Op* op);
|
||||
|
||||
typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
|
||||
|
||||
// Allows invoking a kernel asynchronously, and explicitly returns a
|
||||
// notification that can be waited upon. This always executes the kernel in a
|
||||
// new thread.
|
||||
// 1. `retvals` and `num_retvals` can only be consumed after
|
||||
// `TFE_ExecuteOp` returns successfully. They shouldn't be used
|
||||
// if the return is unsuccessful
|
||||
// 2. These new APIs cannot be used together with the TFE context level async
|
||||
// support.
|
||||
TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
|
||||
TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status);
|
||||
|
||||
// Waits to complete the op execution, and cleans up the notification.
|
||||
// Errors reported by op execution are set in `status`.
|
||||
TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
|
||||
TFE_ExecuteOpNotification* notification, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
||||
const char* errMsg);
|
||||
|
||||
|
@ -84,127 +84,6 @@ TEST(CAPI_EXPERIMENTAL, IsStateful) {
|
||||
EXPECT_EQ(id, 0);
|
||||
}
|
||||
|
||||
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
|
||||
TFE_Op* matmul_op = MatMulOp(ctx, m, m);
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
auto* r =
|
||||
TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
|
||||
|
||||
TFE_ExecuteOpNotificationWaitAndDelete(r, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
|
||||
TFE_DeleteOp(matmul_op);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Perform a send/recv test. Recv blocks, so they need to be executed
|
||||
// asynchronously.
|
||||
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
|
||||
// Build a send op.
|
||||
TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(send_op, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
string tensor_name = "Tensor";
|
||||
TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
|
||||
tensor_name.size());
|
||||
string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
|
||||
send_device.size());
|
||||
TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
|
||||
string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
|
||||
recv_device.size());
|
||||
TFE_OpSetAttrBool(send_op, "client_terminated", true);
|
||||
|
||||
// Build a recv op.
|
||||
TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
|
||||
TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
|
||||
tensor_name.size());
|
||||
TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
|
||||
send_device.size());
|
||||
TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
|
||||
TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
|
||||
recv_device.size());
|
||||
TFE_OpSetAttrBool(recv_op, "client_terminated", true);
|
||||
|
||||
TFE_TensorHandle* send_retvals;
|
||||
int send_num_retvals = 0;
|
||||
auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
|
||||
&send_num_retvals, status);
|
||||
|
||||
TFE_TensorHandle* recv_retvals[1] = {nullptr};
|
||||
int recv_num_retvals = 1;
|
||||
auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
|
||||
&recv_num_retvals, status);
|
||||
|
||||
TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(1, product[0]);
|
||||
EXPECT_EQ(2, product[1]);
|
||||
EXPECT_EQ(3, product[2]);
|
||||
EXPECT_EQ(4, product[3]);
|
||||
|
||||
TFE_DeleteOp(send_op);
|
||||
TFE_DeleteOp(recv_op);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
|
||||
TFE_DeleteTensorHandle(recv_retvals[0]);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
class ShapeInferenceTest : public ::testing::Test {
|
||||
protected:
|
||||
ShapeInferenceTest()
|
||||
|
@ -45,6 +45,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
@ -193,8 +195,9 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
{
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Library* lib =
|
||||
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
|
||||
string lib_path = tensorflow::GetDataDependencyFilepath(
|
||||
tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
|
||||
TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
|
||||
TF_Code code = TF_GetCode(status);
|
||||
string status_msg(TF_Message(status));
|
||||
TF_DeleteStatus(status);
|
||||
@ -1350,9 +1353,9 @@ TEST_F(CApiColocationTest, ClearViaProto) {
|
||||
|
||||
TEST(CAPI, SavedModel) {
|
||||
// Load the saved model.
|
||||
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
|
||||
const string saved_model_dir = tensorflow::io::JoinPath(
|
||||
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
|
||||
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
|
||||
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two", "00000123"));
|
||||
TF_SessionOptions* opt = TF_NewSessionOptions();
|
||||
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
|
||||
TF_Buffer* metagraph = TF_NewBuffer();
|
||||
@ -1426,9 +1429,9 @@ TEST(CAPI, SavedModel) {
|
||||
}
|
||||
|
||||
TEST(CAPI, SavedModelNullArgsAreValid) {
|
||||
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
|
||||
const string saved_model_dir = tensorflow::io::JoinPath(
|
||||
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
|
||||
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
|
||||
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two", "00000123"));
|
||||
TF_SessionOptions* opt = TF_NewSessionOptions();
|
||||
TF_Status* s = TF_NewStatus();
|
||||
const char* tags[] = {tensorflow::kSavedModelTagServe};
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_cuda_cc_test",
|
||||
"tf_cuda_library",
|
||||
@ -27,6 +28,8 @@ tf_cuda_library(
|
||||
"c_api_debug.cc",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"operation_interface.cc",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
@ -55,6 +58,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
@ -81,8 +85,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/profiler/lib:profiler_lib",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -93,6 +95,7 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
visibility = [
|
||||
@ -105,6 +108,7 @@ tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
@ -129,7 +133,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
],
|
||||
)
|
||||
|
||||
@ -258,8 +262,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/profiler/rpc:profiler_server",
|
||||
"//tensorflow/core/profiler/rpc/client:capture_profile",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -289,6 +291,27 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"custom_device_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape",
|
||||
hdrs = ["tape.h"],
|
||||
@ -301,7 +324,10 @@ cc_library(
|
||||
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = ["c_api.h"],
|
||||
srcs = [
|
||||
"c_api.h",
|
||||
"c_api_experimental.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
@ -44,6 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -94,15 +94,12 @@ using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = op->operation.OpDef();
|
||||
if (op_def) return op_def;
|
||||
status->status =
|
||||
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
|
||||
return op_def;
|
||||
}
|
||||
|
||||
bool IsCPU(const tensorflow::Device* d) {
|
||||
bool IsCPU(
|
||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
|
||||
if (VariantDeviceIsCustom(variant)) {
|
||||
return false;
|
||||
}
|
||||
tensorflow::Device* d = absl::get<tensorflow::Device*>(variant);
|
||||
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
|
||||
}
|
||||
|
||||
@ -265,9 +262,9 @@ tensorflow::Status GetReplacedFromExistingWorkers(
|
||||
}
|
||||
|
||||
tensorflow::Status CreateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, int keep_alive_secs,
|
||||
const tensorflow::ServerDef& server_def,
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const bool lazy_copy_remote_function_inputs,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
@ -296,7 +293,7 @@ tensorflow::Status CreateRemoteContexts(
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::CreateContextRequest request(base_request);
|
||||
tensorflow::eager::CreateContextRequest request;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
@ -304,6 +301,21 @@ tensorflow::Status CreateRemoteContexts(
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
ctx->context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(),
|
||||
base_request.cluster_device_attributes_size());
|
||||
for (int i = 0; i < filtered_device_mask.size(); i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
request.set_lazy_copy_remote_function_inputs(
|
||||
@ -325,13 +337,34 @@ tensorflow::Status CreateRemoteContexts(
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
const std::vector<string>& added_workers,
|
||||
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
|
||||
int cluster_device_count = base_request.cluster_device_attributes_size();
|
||||
std::unordered_set<string> added_or_removed(added_workers.begin(),
|
||||
added_workers.end());
|
||||
std::copy(removed_workers.begin(), removed_workers.end(),
|
||||
std::inserter(added_or_removed, added_or_removed.end()));
|
||||
// Whether each device is in the updated (added or removed) workers
|
||||
std::vector<bool> device_added_or_removed(cluster_device_count);
|
||||
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
|
||||
const auto& da = base_request.cluster_device_attributes().at(i);
|
||||
tensorflow::DeviceNameUtils::ParsedName pn;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
|
||||
string task_name;
|
||||
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
|
||||
if (added_or_removed.find(task_name) != added_or_removed.end()) {
|
||||
device_added_or_removed[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
@ -354,17 +387,42 @@ tensorflow::Status UpdateRemoteContexts(
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
ctx->context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
||||
|
||||
// If any of the devices that match the device filters are in the set of
|
||||
// added or removed workers, we must send a complete UpdateContextRequest.
|
||||
// Otherwise, only send a simple request to increment context view ID.
|
||||
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
|
||||
std::transform(device_added_or_removed.begin(),
|
||||
device_added_or_removed.end(), filtered_device_mask.begin(),
|
||||
added_or_removed_filtered_devices.begin(),
|
||||
std::logical_and<bool>());
|
||||
const bool full_update_request =
|
||||
std::accumulate(added_or_removed_filtered_devices.begin(),
|
||||
added_or_removed_filtered_devices.end(), false,
|
||||
std::logical_or<bool>());
|
||||
|
||||
tensorflow::eager::UpdateContextRequest request;
|
||||
auto* response = new tensorflow::eager::UpdateContextResponse();
|
||||
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
for (const auto& da : base_request.cluster_device_attributes()) {
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
if (full_update_request) {
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
for (int i = 0; i < cluster_device_count; i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eager_client->UpdateContextAsync(
|
||||
&request, response,
|
||||
@ -525,15 +583,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
base_request.mutable_server_def()
|
||||
->mutable_default_session_config()
|
||||
->MergeFrom(server_def.default_session_config());
|
||||
|
||||
// Initialize remote eager workers.
|
||||
// TODO(b/138847548) Create remote eager contexts in async mode by default.
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
} else {
|
||||
@ -543,7 +598,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// we must set their context_view_id to the existing master's
|
||||
// context_view_id + 1.
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
if (!existing_workers.empty()) {
|
||||
@ -553,8 +608,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
}
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
|
||||
existing_workers, context_id, context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
ctx, existing_workers, added_workers, removed_workers, context_id,
|
||||
context_view_id + 1, server_def, remote_eager_workers.get(),
|
||||
base_request));
|
||||
}
|
||||
}
|
||||
|
||||
@ -709,6 +765,22 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
|
||||
device_filters[i] = tdf.second.device_filters(i);
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
status->status =
|
||||
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
}
|
||||
}
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/true);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
@ -733,6 +805,11 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to update a context with invalid context id.");
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
LOG(WARNING) << "Device filters can only be specified when initializing "
|
||||
"the cluster. Any changes in device filters are ignored "
|
||||
"when updating the server def.";
|
||||
}
|
||||
// TODO(haoyuzhang): Check server_def compatibility before the update
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/false);
|
||||
@ -797,6 +874,15 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
status->status = ctx->context->ClearRemoteExecutors();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
ctx->context->SetThreadLocalDevicePlacementPolicy(
|
||||
@ -928,6 +1014,9 @@ const char* tensorflow::TensorHandleInterface::DeviceName(
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VariantDeviceIsCustom(handle_->device())) {
|
||||
return absl::get<CustomDevice*>(handle_->device())->name().c_str();
|
||||
}
|
||||
tensorflow::Device* d = handle_->op_device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
@ -948,9 +1037,15 @@ const char* tensorflow::TensorHandleInterface::BackingDeviceName(
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = handle_->device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
if (VariantDeviceIsCustom(handle_->device())) {
|
||||
return absl::get<tensorflow::CustomDevice*>(handle_->device())
|
||||
->name()
|
||||
.c_str();
|
||||
} else {
|
||||
tensorflow::Device* d = absl::get<tensorflow::Device*>(handle_->device());
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
@ -970,6 +1065,10 @@ AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
|
||||
return new TensorHandleInterface(handle_);
|
||||
}
|
||||
|
||||
void tensorflow::TensorHandleInterface::EnableImplicitMirroring() {
|
||||
handle_->EnableImplicitMirroring();
|
||||
}
|
||||
|
||||
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -984,6 +1083,18 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VariantDeviceIsCustom(handle_->device())) {
|
||||
tensorflow::CustomDevice* custom_device =
|
||||
absl::get<tensorflow::CustomDevice*>(handle_->device());
|
||||
tensorflow::TensorHandle* copy;
|
||||
*status = custom_device->CopyTensorFromDevice(
|
||||
handle_, "/job:localhost/task:0/replica:0/device:CPU:0", ©);
|
||||
if (status->ok()) {
|
||||
return TensorHandleInterface(copy).Resolve(status);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
|
||||
if (handle_->IsRemote()) {
|
||||
@ -1005,9 +1116,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
return retval;
|
||||
} else {
|
||||
tensorflow::Tensor tensor;
|
||||
if (IsCPU(handle_->device())) {
|
||||
if (IsCPU(handle_->device()) || handle_->HasLocalMirror(nullptr)) {
|
||||
const tensorflow::Tensor* src = nullptr;
|
||||
*status = handle_->Tensor(&src);
|
||||
if (handle_->HasLocalMirror(nullptr)) {
|
||||
*status = handle_->TensorFromDevice(nullptr, &src);
|
||||
} else {
|
||||
*status = handle_->Tensor(&src);
|
||||
}
|
||||
if (!status->ok()) return nullptr;
|
||||
tensor = *src;
|
||||
} else {
|
||||
@ -1015,6 +1130,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
CHECK_NE(ctx, nullptr);
|
||||
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->ok()) return nullptr;
|
||||
if (handle_->ImplicitMirroring()) {
|
||||
*status = handle_->AddEmptyLocalMirror(nullptr);
|
||||
if (!status->ok()) return nullptr;
|
||||
Tensor mirror = tensor;
|
||||
*status = handle_->SetTensor(std::move(mirror), nullptr);
|
||||
if (!status->ok()) return nullptr;
|
||||
}
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||
}
|
||||
@ -1029,6 +1151,11 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
const tensorflow::Tensor* t;
|
||||
status->status = handle->Tensor(&t);
|
||||
return t->data();
|
||||
}
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -1036,8 +1163,9 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
"handle.");
|
||||
return nullptr;
|
||||
}
|
||||
if (handle->device() != nullptr) {
|
||||
status->status = handle->device()->Sync();
|
||||
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
|
||||
if (device != nullptr) {
|
||||
status->status = device->Sync();
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1056,37 +1184,40 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
const int64_t* dims, int num_dims, void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device;
|
||||
tensorflow::Device* device = nullptr;
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::CustomDevice* custom_device = nullptr;
|
||||
if (!status->status.ok()) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
status->status =
|
||||
context->FindCustomDeviceFromName(device_name, &custom_device);
|
||||
if (!status->status.ok()) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
|
||||
!tensorflow::DataTypeCanUseMemcpy(
|
||||
static_cast<tensorflow::DataType>(dtype))) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to create a tensor with a pointer to non-pod memory.");
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
}
|
||||
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
||||
// the device?
|
||||
TF_ManagedBuffer* buf =
|
||||
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
||||
/*owns_memory=*/false);
|
||||
|
||||
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
buf->Unref();
|
||||
tensorflow::TensorHandle* ret_handle;
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, device, context, &ret_handle);
|
||||
if (custom_device == nullptr) {
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, device, context, &ret_handle);
|
||||
} else {
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, custom_device, context, &ret_handle);
|
||||
}
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1125,9 +1256,8 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op> new_op(
|
||||
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
|
||||
status->status =
|
||||
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
|
||||
new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)});
|
||||
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
new_op.reset();
|
||||
}
|
||||
@ -1137,49 +1267,51 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
status->status = op->operation.SetDeviceName(device_name);
|
||||
status->status = op->operation->SetDeviceName(device_name);
|
||||
}
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
tensorflow::Device* device = (op->operation.Device() == nullptr)
|
||||
? op->operation.EagerContext().HostCPU()
|
||||
: op->operation.Device();
|
||||
return device->name().c_str();
|
||||
return op->operation->DeviceName().c_str();
|
||||
}
|
||||
|
||||
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
op->operation.SetUseXla(enable);
|
||||
#ifndef TENSORFLOW_EAGER_USE_XLA
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Status s = op->operation->SetUseXla(enable);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
|
||||
}
|
||||
#else
|
||||
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
|
||||
"built with XLA support.";
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
tensorflow::TensorHandle* h =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
input->handle.get())
|
||||
->Handle();
|
||||
op->operation.AddInput(h);
|
||||
status->status = op->operation.MaybeInferSingleInputAttrs(h);
|
||||
status->status = op->operation->AddInput(input->handle);
|
||||
}
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||
num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
op->operation.AddInput(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
inputs[i]->handle.get())
|
||||
->Handle());
|
||||
handles[i].reset(inputs[i]->handle->Copy());
|
||||
}
|
||||
status->status = op->operation.InferInputListAttrs(num_inputs);
|
||||
status->status = op->operation->AddInputList(handles);
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
unsigned char* is_list, TF_Status* status) {
|
||||
TF_AttrType ret = TF_ATTR_INT;
|
||||
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
|
||||
attr_name, &ret, is_list);
|
||||
const tensorflow::AttrTypeMap* attr_types_;
|
||||
bool is_function;
|
||||
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
|
||||
&attr_types_, &is_function);
|
||||
if (!status->status.ok()) {
|
||||
return ret;
|
||||
}
|
||||
status->status =
|
||||
tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -1200,221 +1332,150 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
|
||||
|
||||
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
||||
size_t length) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name,
|
||||
tensorflow::StringPiece(static_cast<const char*>(value), length));
|
||||
auto s = op->operation->SetAttrString(
|
||||
attr_name, static_cast<const char*>(value), length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
|
||||
auto s = op->operation->SetAttrInt(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name, value);
|
||||
auto s = op->operation->SetAttrFloat(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
|
||||
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name,
|
||||
static_cast<tensorflow::DataType>(value));
|
||||
auto s = op->operation->SetAttrType(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
||||
const int num_dims, TF_Status* out_status) {
|
||||
if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"Value specified for `", attr_name, "` has ", num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
tensorflow::TensorShape::MaxDimensions(), ".")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(attr_name, proto);
|
||||
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op* value) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
tensorflow::NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(value->operation.Name());
|
||||
value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
|
||||
op->operation.MutableAttrs()->Set(attr_name, attr_value);
|
||||
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
||||
const char* data, size_t length) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
tensorflow::NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(data, length);
|
||||
op->operation.MutableAttrs()->Set(attr_name, attr_value);
|
||||
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
||||
TF_Status* status) {
|
||||
tensorflow::Tensor t;
|
||||
status->status = TF_TensorToTensor(tensor, &t);
|
||||
if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
|
||||
status->status = op->operation->SetAttrTensor(attr_name, tensor);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
const void* const* values, const size_t* lengths,
|
||||
int num_values) {
|
||||
std::vector<tensorflow::StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
|
||||
lengths[i]);
|
||||
auto s =
|
||||
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(attr_name, v);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
|
||||
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name,
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
||||
auto s = op->operation->SetAttrTypeList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
const unsigned char* values, int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, TF_Status* out_status) {
|
||||
std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
|
||||
new tensorflow::TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"Value specified for `", attr_name, "` has ", num_dims_i,
|
||||
" dimensions which is over the limit of ",
|
||||
tensorflow::TensorShape::MaxDimensions(), ".")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
|
||||
proto.get(), num_values));
|
||||
out_status->status =
|
||||
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op** value, int num_values) {
|
||||
std::unique_ptr<tensorflow::NameAttrList[]> funcs(
|
||||
new tensorflow::NameAttrList[num_values]);
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
funcs[i].set_name(value[i]->operation.Name());
|
||||
value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
|
||||
auto s = op->operation->SetAttrFunctionList(attr_name, value, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
|
||||
funcs.get(), num_values));
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
||||
const char* input_name,
|
||||
TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = GetOpDef(op, status);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
tensorflow::AttrValueMap attrs;
|
||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
||||
tensorflow::NameRangeMap name_ranges;
|
||||
status->status = tensorflow::NameRangesForNode(
|
||||
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
auto iter = name_ranges.find(input_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
|
||||
"' not found");
|
||||
return -1;
|
||||
}
|
||||
return iter->second.second - iter->second.first;
|
||||
int ret = -1;
|
||||
status->status = op->operation->InputLength(input_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
const char* output_name,
|
||||
TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = GetOpDef(op, status);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
tensorflow::AttrValueMap attrs;
|
||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
||||
tensorflow::NameRangeMap name_ranges;
|
||||
status->status = tensorflow::NameRangesForNode(
|
||||
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
auto iter = name_ranges.find(output_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Output '", output_name, "' not found");
|
||||
return -1;
|
||||
}
|
||||
return iter->second.second - iter->second.first;
|
||||
int ret = -1;
|
||||
status->status = op->operation->OutputLength(output_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||
status->status = tensorflow::EagerExecute(&op->operation,
|
||||
handle_retvals.data(), num_retvals);
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||
*num_retvals);
|
||||
status->status = op->operation->Execute(&handles, num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
|
||||
retvals[i] = new TFE_TensorHandle{std::move(handles[i])};
|
||||
}
|
||||
}
|
||||
|
||||
@ -1427,8 +1488,42 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h->handle.get())
|
||||
->Handle(),
|
||||
&handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
// Handle tensor handles currently in custom devices
|
||||
const char* handle_device_name = h->handle->DeviceName(&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorFromDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h->handle.get())
|
||||
->Handle(),
|
||||
device_name, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Handle regular case.
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle(),
|
||||
@ -1508,6 +1603,23 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
|
||||
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
||||
|
||||
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
*attrs = TFE_OpAttrs(&operation->Attrs());
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
attrs->attributes->FillAttrValueMap(&m);
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
for (auto attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
}
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
@ -1567,3 +1679,96 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
}
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
public:
|
||||
CustomDeviceAPI(TFE_CustomDevice device, void* info, string name)
|
||||
: device_(device), info_(info), name_(name) {}
|
||||
|
||||
~CustomDeviceAPI() override { device_.delete_device(info_); }
|
||||
|
||||
const string& name() override { return name_; }
|
||||
|
||||
tensorflow::Status CopyTensorToDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
|
||||
TF_Status status;
|
||||
TFE_TensorHandle* result_handle =
|
||||
device_.copy_tensor_to_device(&tensor_handle, &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
->Handle();
|
||||
(*result)->Ref();
|
||||
delete result_handle;
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status CopyTensorFromDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
const tensorflow::string& target_device_name,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
TF_Status status;
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
&tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
->Handle();
|
||||
(*result)->Ref();
|
||||
delete result_handle;
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status Execute(tensorflow::EagerOperation* op,
|
||||
tensorflow::TensorHandle** retvals,
|
||||
int* num_retvals) override {
|
||||
std::vector<TFE_TensorHandle*> inputs;
|
||||
inputs.reserve(op->Inputs().size());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
op->Inputs()[i]->Ref();
|
||||
inputs.push_back(new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
op->Inputs()[i])});
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
TFE_OpAttrs attributes(&op->Attrs());
|
||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
outputs[i]->handle.get())
|
||||
->Handle();
|
||||
retvals[i]->Ref();
|
||||
delete outputs[i];
|
||||
}
|
||||
}
|
||||
|
||||
for (auto inp : inputs) {
|
||||
delete inp;
|
||||
}
|
||||
return status.status;
|
||||
}
|
||||
|
||||
private:
|
||||
TFE_CustomDevice device_;
|
||||
void* info_;
|
||||
string name_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info) {
|
||||
auto custom_device =
|
||||
std::make_unique<CustomDeviceAPI>(device, device_info, device_name);
|
||||
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Device* device = handle_->device();
|
||||
tensorflow::Device* device = absl::get<Device*>(handle_->device());
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
tensorflow::XlaDevice* xla_device =
|
||||
|
@ -25,55 +25,20 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
||||
#include "tensorflow/core/profiler/rpc/profiler_server.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
status->status = op_to_reset->operation.Reset(
|
||||
op_or_function_name, raw_device_name, false, nullptr);
|
||||
status->status =
|
||||
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
op->operation.ConsumeInput(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle());
|
||||
}
|
||||
|
||||
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
|
||||
|
||||
bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
|
||||
return profiler->profiler->Status().ok();
|
||||
}
|
||||
|
||||
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
|
||||
|
||||
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
string content;
|
||||
status->status = profiler->profiler->SerializeToString(&content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = content.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
}
|
||||
|
||||
void TFE_StartProfilerServer(int port) {
|
||||
// Release child thread intentionally. The child thread can be terminated by
|
||||
// terminating the main thread.
|
||||
tensorflow::StartProfilerServer(port).release();
|
||||
}
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
@ -82,46 +47,6 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
||||
const char* logdir, const char* worker_list,
|
||||
bool include_dataset_ops, int duration_ms,
|
||||
int num_tracing_attempts,
|
||||
TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return false;
|
||||
}
|
||||
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
|
||||
include_dataset_ops, duration_ms,
|
||||
num_tracing_attempts);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return s.ok();
|
||||
}
|
||||
|
||||
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
|
||||
int monitoring_level, bool display_timestamp,
|
||||
TF_Buffer* result, TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return;
|
||||
}
|
||||
string content;
|
||||
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
|
||||
display_timestamp, &content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
||||
result->data = data;
|
||||
result->length = content.length();
|
||||
result->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
|
||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||
int64_t value) {
|
||||
cell->cell.IncrementBy(value);
|
||||
@ -589,8 +514,7 @@ void TFE_DeleteCancellationManager(
|
||||
void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||
TFE_CancellationManager* cancellation_manager,
|
||||
TF_Status* status) {
|
||||
op->operation.SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
status->status = op->operation->SetCancellationManager(cancellation_manager);
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_NewExecutor(bool is_async) {
|
||||
@ -632,3 +556,28 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
}
|
||||
|
||||
void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
h->handle->EnableImplicitMirroring();
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
auto* function_def = ctx->context->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"Unable to find FunctionDef with name: ", function_name);
|
||||
return;
|
||||
}
|
||||
string str = function_def->SerializeAsString();
|
||||
void* data = tensorflow::port::Malloc(str.length());
|
||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = str.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
@ -27,42 +27,13 @@ extern "C" {
|
||||
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
|
||||
// does not set the device name. If it's not `NULL`, then it attempts to parse
|
||||
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
|
||||
// than seperately calling it because if the existing op has the same
|
||||
// than separately calling it because if the existing op has the same
|
||||
// `raw_device_name`, it skips parsing and just leave as it is.
|
||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
|
||||
const char* op_or_function_name,
|
||||
const char* raw_device_name,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
// A profiler which will start profiling when creating the object and will stop
|
||||
// when the object is destroyed. It will profile all operations run under the
|
||||
// given TFE_Context. Multiple instance of it can be created, but at most one
|
||||
// of them will profile for each TFE_Context.
|
||||
// Thread-safety: TFE_Profiler is thread-safe.
|
||||
typedef struct TFE_Profiler TFE_Profiler;
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler();
|
||||
TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler);
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
|
||||
|
||||
// The output string is a binary string of tensorflow.tpu.Trace. User can write
|
||||
// the string to file for offline analysis by tensorboard.
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
// Start a profiler grpc server which listens to specified port. It will start
|
||||
// the server on its own thread. It can be shutdown by terminating tensorflow.
|
||||
// It can be used in both Eager mode and graph mode. Creating multiple profiler
|
||||
// server is allowed. The service defined in
|
||||
// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
|
||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
|
||||
// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
|
||||
|
||||
// Enables only graph collection in RunMetadata on the functions executed from
|
||||
// this context.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
|
||||
@ -71,29 +42,6 @@ TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
|
||||
// this context.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
|
||||
|
||||
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
||||
// profiling and save the result into logdir which can be visualized by
|
||||
// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set
|
||||
// include_dataset_opts to false to profile longer traces. It will block the
|
||||
// caller thread until receives tracing result.
|
||||
// This API is designed for TensorBoard, for end user, please use
|
||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||
TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
|
||||
const char* service_addr, const char* logdir, const char* worker_list,
|
||||
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
|
||||
TF_Status* status);
|
||||
|
||||
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
||||
// monitoring and return the result in a string. It will block the
|
||||
// caller thread until receiving the monitoring result.
|
||||
// This API is designed for TensorBoard, for end user, please use
|
||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
|
||||
const char* service_addr, int duration_ms, int monitoring_level,
|
||||
bool display_timestamp, TF_Buffer* result, TF_Status* status);
|
||||
|
||||
// TODO(fishx): Move these monitoring APIs into a separate file.
|
||||
// -----------------------------------------------------------------------------
|
||||
// Monitoring Counter APIs.
|
||||
@ -434,6 +382,16 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
const char* worker_name,
|
||||
TF_Status* status);
|
||||
|
||||
// Clear pending streaming requests and error statuses on remote executors.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// If the TensorHandle is copied to another device as part of an op execution,
|
||||
// the copy is destroyed after the op has executed. Enabling implicit mirroring
|
||||
// causes the copy to be held as a mirror for the lifetime of the TensorHandle.
|
||||
TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring(
|
||||
TFE_TensorHandle*, TF_Status*);
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
||||
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
||||
@ -463,6 +421,82 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
TF_Buffer* buf);
|
||||
|
||||
// APIs for generically dealing with op attributes (e.g. when forwarding them
|
||||
// through custom device implementations).
|
||||
//
|
||||
// TODO(allenl): Currently these are black boxes, but we should have some way to
|
||||
// inspect values. This would let people e.g. copy over most attributes and then
|
||||
// modify some based on their values.
|
||||
|
||||
// A reference to an op's name -> attribute mapping
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
// Fetch a struct with a reference to information about attributes of `op`.
|
||||
//
|
||||
// The `attrs` struct does not own any memory, and `op` must outlive it.
|
||||
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||
|
||||
// Add attributes in `attrs` to `op`.
|
||||
//
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 1
|
||||
|
||||
// Struct to be filled in
|
||||
typedef struct TFE_CustomDevice {
|
||||
int version = TFE_CUSTOM_DEVICE_VERSION;
|
||||
// Method to copy a tensor to the custom device.
|
||||
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor,
|
||||
TF_Status* status,
|
||||
void* device_info) = nullptr;
|
||||
|
||||
// Method to copy a tensor from the custom device to a target device.
|
||||
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info);
|
||||
|
||||
// Method to execute an operation.
|
||||
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info);
|
||||
|
||||
// Method to delete a device.
|
||||
void (*delete_device)(void* device_info);
|
||||
} TFE_CustomDevice;
|
||||
|
||||
// Registers a custom device for use with eager execution.
|
||||
//
|
||||
// Eager operations may be placed on this device, e.g. `with
|
||||
// tf.device("CUSTOM"):` from Python if `device_name` for this call is
|
||||
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
|
||||
//
|
||||
// The custom device defines copy operations for moving TensorHandles on and
|
||||
// off, and an an execution operation for named operations. Often execution will
|
||||
// simply wrap op execution on one or more physical devices.
|
||||
//
|
||||
// device_info is an opaque caller-defined type stored with the custom device
|
||||
// which is passed to the functions referenced in the TFE_CustomDevice struct
|
||||
// `device` (execute, delete_device, etc.). It can for example contain the
|
||||
// names of wrapped devices.
|
||||
//
|
||||
// There are currently no graph semantics implemented for registered custom
|
||||
// devices, so executing tf.functions which contain operations placed on custom
|
||||
// devices will fail.
|
||||
//
|
||||
// This API is highly experimental, and in particular is expected to change when
|
||||
// it starts supporting operations with attributes and when tf.function support
|
||||
// is added.
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||
const char* function_name,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/protobuf/trace_events.pb.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
@ -39,88 +38,6 @@ static bool HasSubstr(absl::string_view base, absl::string_view substr) {
|
||||
return ok;
|
||||
}
|
||||
|
||||
void ExecuteWithProfiling(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
TFE_Profiler* profiler = TFE_NewProfiler();
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
// Run op on GPU if it is present.
|
||||
string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TF_Buffer* profiler_result = TF_NewBuffer();
|
||||
if (async) {
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
|
||||
TFE_DeleteProfiler(profiler);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
profiler::Trace profile_proto;
|
||||
EXPECT_TRUE(profile_proto.ParseFromString(
|
||||
{reinterpret_cast<const char*>(profiler_result->data),
|
||||
profiler_result->length}));
|
||||
string profile_proto_str = profile_proto.DebugString();
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
// TODO(rocm): enable once GPU profiling is supported in ROCm mode
|
||||
if (!gpu_device_name.empty()) {
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
|
||||
}
|
||||
#endif
|
||||
// "/host:CPU" is collected by TraceMe
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
|
||||
TF_DeleteBuffer(profiler_result);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); }
|
||||
TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); }
|
||||
|
||||
TEST(CAPI, MultipleProfilerSession) {
|
||||
TFE_Profiler* profiler1 = TFE_NewProfiler();
|
||||
EXPECT_TRUE(TFE_ProfilerIsOk(profiler1));
|
||||
|
||||
TFE_Profiler* profiler2 = TFE_NewProfiler();
|
||||
EXPECT_FALSE(TFE_ProfilerIsOk(profiler2));
|
||||
|
||||
TFE_DeleteProfiler(profiler1);
|
||||
TFE_DeleteProfiler(profiler2);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringCounter0) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* counter =
|
||||
|
@ -27,12 +27,12 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -48,7 +48,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/profiler/lib/profiler_session.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
struct TFE_ContextOptions {
|
||||
@ -90,13 +89,7 @@ struct TFE_TensorDebugInfo {
|
||||
};
|
||||
|
||||
struct TFE_Op {
|
||||
tensorflow::EagerOperation operation;
|
||||
};
|
||||
|
||||
struct TFE_Profiler {
|
||||
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
||||
|
||||
std::unique_ptr<tensorflow::ProfilerSession> profiler;
|
||||
std::unique_ptr<AbstractOperationInterface> operation;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
@ -243,4 +236,13 @@ struct TFE_Executor {
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : attributes(nullptr) {}
|
||||
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
|
||||
: attributes(value) {}
|
||||
|
||||
const tensorflow::AttrBuilder* attributes;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
@ -17,12 +17,15 @@ limitations under the License.
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -363,34 +366,79 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
|
||||
TensorHandleCopyBetweenTwoGPUDevices(true);
|
||||
}
|
||||
|
||||
void TensorHandleSilentCopy(bool async) {
|
||||
void TensorHandleSilentCopy(bool async,
|
||||
TFE_ContextDevicePlacementPolicy global_policy,
|
||||
TFE_ContextDevicePlacementPolicy thread_policy,
|
||||
bool mirror, bool cpu_op) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, global_policy);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
if (thread_policy != global_policy) {
|
||||
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
|
||||
}
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
if (mirror) {
|
||||
TFE_TensorHandleEnableImplicitMirroring(hcpu, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandleEnableImplicitMirroring(hgpu, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
if (cpu_op) {
|
||||
string cpu_device_name;
|
||||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status.get());
|
||||
} else {
|
||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
||||
}
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Validate if the input was replaced with a different TensorHandle
|
||||
auto arg0 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
hcpu->handle.get())
|
||||
->Handle();
|
||||
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
hgpu->handle.get())
|
||||
->Handle();
|
||||
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
if (mirror) {
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->GetInput(0), arg0);
|
||||
ASSERT_EQ(op->GetInput(1), arg1);
|
||||
} else {
|
||||
if (cpu_op) {
|
||||
ASSERT_EQ(op->GetInput(0), arg0);
|
||||
// The GPU handle should be replaced with a CPU copy
|
||||
ASSERT_NE(op->GetInput(1), arg1);
|
||||
} else {
|
||||
// The CPU handle should be replaced with a GPU copy
|
||||
ASSERT_NE(op->GetInput(0), arg0);
|
||||
ASSERT_EQ(op->GetInput(1), arg1);
|
||||
}
|
||||
}
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
@ -404,57 +452,29 @@ void TensorHandleSilentCopy(bool async) {
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); }
|
||||
TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); }
|
||||
|
||||
void TensorHandleSilentCopyLocal(bool async) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx,
|
||||
TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
}
|
||||
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
TEST(CAPI, TensorHandleSilentCopy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalAsync) {
|
||||
TensorHandleSilentCopyLocal(true);
|
||||
TEST(CAPI, TensorHandleSilentCopyAsync) {
|
||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
|
||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleMirrorCopy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, true, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleMirrorCopyCpu) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, true, true);
|
||||
}
|
||||
|
||||
void SetAndGetOpDevices(bool async) {
|
||||
@ -590,6 +610,91 @@ TEST(CAPI, TensorHandleDevices) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
void ExecuteAdd(bool async, bool forward_input) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
|
||||
// If a GPU exists, copy the handle to GPU so that we can exercise
|
||||
// unprotecting a mirror.
|
||||
std::string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_TensorHandle* n_gpu =
|
||||
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
|
||||
TFE_DeleteTensorHandle(n);
|
||||
n = n_gpu;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
|
||||
|
||||
// Store pointer to raw buffer for validation of forwarding behaviour.
|
||||
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
|
||||
void* orig_ptr = TF_TensorData(orig);
|
||||
TF_DeleteTensor(orig);
|
||||
|
||||
TFE_Op* add_op = AddOp(ctx, n, m);
|
||||
std::string cpu_device_name;
|
||||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
if (forward_input) {
|
||||
TFE_DeleteTensorHandle(n);
|
||||
}
|
||||
|
||||
int num_retvals = 1;
|
||||
|
||||
if (async) {
|
||||
// Enqueue dummy ops so we backlog async execution & actually test async.
|
||||
for (int i = 0; i < 10000; ++i) {
|
||||
TFE_TensorHandle* dummy = nullptr;
|
||||
TFE_Execute(add_op, &dummy, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(dummy);
|
||||
}
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retval = nullptr;
|
||||
TFE_Execute(add_op, &retval, &num_retvals, status);
|
||||
EXPECT_EQ(1, num_retvals);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
if (!forward_input) {
|
||||
TFE_DeleteTensorHandle(n);
|
||||
}
|
||||
TFE_DeleteOp(add_op);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
|
||||
if (forward_input || async) {
|
||||
EXPECT_EQ(orig_ptr, TF_TensorData(t));
|
||||
} else {
|
||||
EXPECT_NE(orig_ptr, TF_TensorData(t));
|
||||
}
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
float result[100 * 100] = {0};
|
||||
EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
|
||||
memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
for (int i = 0; i < 100 * 100; ++i) {
|
||||
EXPECT_EQ(2.0f, result[i]);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
|
||||
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
|
||||
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
|
||||
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
|
||||
|
||||
void Execute_MatMul_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -1228,6 +1333,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||
TFE_DeleteTensorHandle(h_shares_tensor);
|
||||
}
|
||||
|
||||
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
|
||||
->Attrs()
|
||||
.FillAttrValueMap(&attr_values);
|
||||
return attr_values;
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -1244,8 +1357,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
||||
TFE_OpAddInput(minOp, axis, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
minOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
|
||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||
EXPECT_NE(attr_found, attr_values.cend());
|
||||
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
||||
@ -1284,8 +1396,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
|
||||
TFE_OpAddInputList(concatOp, inputs, 2, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
|
||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||
EXPECT_NE(attr_found, attr_values.cend());
|
||||
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
||||
@ -1325,8 +1436,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
|
||||
TFE_OpAddInputList(assertOp, data, 3, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
assertOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
|
||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||
EXPECT_NE(attr_found, attr_values.cend());
|
||||
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
|
||||
@ -1362,16 +1472,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInput(concatOp, dim, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK(concatOp->operation.OpDef());
|
||||
CHECK(concatOp->operation->OpDef());
|
||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FALSE(concatOp->operation.OpDef())
|
||||
EXPECT_FALSE(concatOp->operation->OpDef())
|
||||
<< "Inference context is still present";
|
||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
|
||||
EXPECT_EQ(attr_values.find("T"), attr_values.end());
|
||||
EXPECT_EQ(attr_values.find("N"), attr_values.end());
|
||||
|
||||
@ -1458,4 +1567,40 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
|
||||
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddAttrs(copy_op, &attributes);
|
||||
unsigned char is_list = 0;
|
||||
ASSERT_EQ(TF_ATTR_TYPE,
|
||||
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_ATTR_SHAPE,
|
||||
TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
copy_op->operation.get());
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(var_op);
|
||||
TFE_DeleteOp(copy_op);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, a, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, b, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
|
||||
|
||||
return op;
|
||||
}
|
||||
|
||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
|
@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
|
||||
// Return a tensor handle containing a 3x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
||||
|
||||
// Return an add op multiplying `a` by `b`.
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
// Return a matmul op multiplying `a` by `b`.
|
||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
|
294
tensorflow/c/eager/custom_device_test.cc
Normal file
294
tensorflow/c/eager/custom_device_test.cc
Normal file
@ -0,0 +1,294 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
TFE_Context* ctx;
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* ctx, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, dev->ctx, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->ctx = context;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* context = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived, &executed);
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
ASSERT_FALSE(arrived);
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_FALSE(executed);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
||||
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
|
||||
TFE_OpSetDevice(matmul.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* retval;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_DeleteTensorHandle(hdevice);
|
||||
TFE_DeleteContext(context);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, ResetOperation) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* custom_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed);
|
||||
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
|
||||
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
|
||||
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
|
||||
tensorflow::string(custom_device_name));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpReset(reused_op.get(), "Identity",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
|
||||
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed);
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto value_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_value]() { TFE_DeleteTensorHandle(var_value); });
|
||||
ASSERT_EQ(tensorflow::string(name),
|
||||
tensorflow::string(
|
||||
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||
TFE_TensorHandle* var_value_unpacked =
|
||||
reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||
->tensor;
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||
TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
|
||||
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace
|
312
tensorflow/c/eager/operation_interface.cc
Normal file
312
tensorflow/c/eager/operation_interface.cc
Normal file
@ -0,0 +1,312 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
OperationInterface::OperationInterface(TFE_Context* ctx)
|
||||
: operation_(ctx->context) {}
|
||||
|
||||
const string& OperationInterface::DeviceName() const {
|
||||
absl::variant<Device*, CustomDevice*> variant_device =
|
||||
(operation_.Device() == kVariantDeviceNull)
|
||||
? operation_.EagerContext().HostCPU()
|
||||
: operation_.Device();
|
||||
return absl::visit([](auto* d) -> const string& { return d->name(); },
|
||||
variant_device);
|
||||
}
|
||||
|
||||
Status OperationInterface::SetDeviceName(const char* name) {
|
||||
return operation_.SetDeviceName(name);
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrString(const char* attr_name,
|
||||
const char* data, size_t length) {
|
||||
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrType(const char* attr_name,
|
||||
TF_DataType value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims,
|
||||
const int num_dims) {
|
||||
if (num_dims > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||
num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), ".");
|
||||
}
|
||||
|
||||
TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
|
||||
operation_.MutableAttrs()->Set(attr_name, proto);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) {
|
||||
AttrValue attr_value;
|
||||
NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(value->Name());
|
||||
OperationInterface* value_operation =
|
||||
tensorflow::down_cast<OperationInterface*>(value.get());
|
||||
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
|
||||
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
|
||||
const char* data,
|
||||
size_t length) {
|
||||
AttrValue attr_value;
|
||||
NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(data, length);
|
||||
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrTensor(const char* attr_name,
|
||||
TF_Tensor* tensor) {
|
||||
Tensor t;
|
||||
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
|
||||
operation_.MutableAttrs()->Set(attr_name, t);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) {
|
||||
std::vector<StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
operation_.MutableAttrs()->Set(attr_name, v);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFloatList(const char* attr_name,
|
||||
const float* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const float>(values, num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const DataType>(
|
||||
reinterpret_cast<const DataType*>(values), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims,
|
||||
int num_values) {
|
||||
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||
num_dims_i, " dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), "."));
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value,
|
||||
int num_values) {
|
||||
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
auto value_operation =
|
||||
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
|
||||
funcs[i].set_name(value_operation->operation_.Name());
|
||||
value_operation->operation_.Attrs().FillAttrValueMap(
|
||||
funcs[i].mutable_attr());
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const OpDef* OperationInterface::GetOpDef(Status* status) {
|
||||
const tensorflow::OpDef* op_def = operation_.OpDef();
|
||||
if (op_def) return op_def;
|
||||
*status = OpDefForOp(Name(), &op_def);
|
||||
return op_def;
|
||||
}
|
||||
|
||||
Status OperationInterface::InputLength(const char* input_name, int* length) {
|
||||
Status status;
|
||||
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
AttrValueMap attrs;
|
||||
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||
NameRangeMap name_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
|
||||
auto iter = name_ranges.find(input_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
return errors::InvalidArgument("Input '", input_name, "' not found");
|
||||
}
|
||||
*length = iter->second.second - iter->second.first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::OutputLength(const char* output_name, int* length) {
|
||||
Status status;
|
||||
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
AttrValueMap attrs;
|
||||
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||
NameRangeMap name_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
|
||||
auto iter = name_ranges.find(output_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
return errors::InvalidArgument("Output '", output_name, "' not found");
|
||||
}
|
||||
*length = iter->second.second - iter->second.first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
|
||||
TensorHandle* h =
|
||||
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||
operation_.AddInput(h);
|
||||
return operation_.MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
|
||||
Status OperationInterface::AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) {
|
||||
for (auto& input : inputs) {
|
||||
TensorHandle* h =
|
||||
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||
operation_.AddInput(h);
|
||||
}
|
||||
return operation_.InferInputListAttrs(inputs.size());
|
||||
}
|
||||
|
||||
Status OperationInterface::Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) {
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
TF_RETURN_IF_ERROR(
|
||||
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals->at(i).reset(
|
||||
new tensorflow::TensorHandleInterface(handle_retvals[i]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
operation_.SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetUseXla(bool enable) {
|
||||
operation_.SetUseXla(enable);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
188
tensorflow/c/eager/operation_interface.h
Normal file
188
tensorflow/c/eager/operation_interface.h
Normal file
@ -0,0 +1,188 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class AbstractOperationInterface {
|
||||
public:
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
|
||||
virtual void Clear() = 0;
|
||||
virtual tensorflow::Status Reset(const char* op,
|
||||
const char* raw_device_name) = 0;
|
||||
|
||||
virtual const tensorflow::string& Name() const = 0;
|
||||
virtual const tensorflow::string& DeviceName() const = 0;
|
||||
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual tensorflow::Status AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
|
||||
virtual tensorflow::Status AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) = 0;
|
||||
virtual tensorflow::Status Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) = 0;
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual tensorflow::Status SetAttrString(const char* attr_name,
|
||||
const char* data, size_t length) = 0;
|
||||
virtual tensorflow::Status SetAttrInt(const char* attr_name,
|
||||
int64_t value) = 0;
|
||||
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
|
||||
float value) = 0;
|
||||
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
|
||||
virtual tensorflow::Status SetAttrType(const char* attr_name,
|
||||
TF_DataType value) = 0;
|
||||
virtual tensorflow::Status SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual tensorflow::Status SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
|
||||
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
|
||||
const char* value,
|
||||
size_t length) = 0;
|
||||
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
|
||||
TF_Tensor* tensor) = 0;
|
||||
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
|
||||
const float* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value,
|
||||
int num_values) = 0;
|
||||
|
||||
virtual tensorflow::Status InputLength(const char* input_name,
|
||||
int* length) = 0;
|
||||
virtual tensorflow::Status OutputLength(const char* output_name,
|
||||
int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual tensorflow::Status SetUseXla(bool enable) {
|
||||
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
|
||||
}
|
||||
virtual tensorflow::Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetCancellationManager not implemented");
|
||||
}
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpDef;
|
||||
|
||||
class OperationInterface : public AbstractOperationInterface {
|
||||
public:
|
||||
explicit OperationInterface(TFE_Context* ctx);
|
||||
~OperationInterface() override{};
|
||||
|
||||
void Clear() override { operation_.Clear(); }
|
||||
Status Reset(const char* op, const char* raw_device_name) override {
|
||||
return operation_.Reset(op, raw_device_name, false, nullptr);
|
||||
}
|
||||
|
||||
const string& Name() const override { return operation_.Name(); }
|
||||
const string& DeviceName() const override;
|
||||
Status SetDeviceName(const char* name) override;
|
||||
|
||||
Status AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
|
||||
Status AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) override;
|
||||
Status Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) override;
|
||||
const tensorflow::OpDef* OpDef() const override {
|
||||
return operation_.OpDef();
|
||||
};
|
||||
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||
Status SetAttrType(const char* attr_name, TF_DataType value) override;
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override;
|
||||
Status SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override;
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override;
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override;
|
||||
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
|
||||
int num_values) override;
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override;
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override;
|
||||
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
|
||||
int num_values) override;
|
||||
|
||||
Status InputLength(const char* input_name, int* length) override;
|
||||
Status OutputLength(const char* output_name, int* length) override;
|
||||
|
||||
Status SetUseXla(bool enable) override;
|
||||
Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) override;
|
||||
|
||||
// TODO(gjn): Remove once TFE_InferShapes is removed
|
||||
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
|
||||
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
|
||||
|
||||
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
|
||||
|
||||
private:
|
||||
const tensorflow::OpDef* GetOpDef(Status* status);
|
||||
EagerOperation operation_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
@ -55,6 +55,14 @@ class AbstractTensorHandleInterface {
|
||||
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
|
||||
// Maintain mirror tensors for any implicit copies to local devices. This
|
||||
// setting is offered on a per tensor handle basis to avoid potential memory
|
||||
// over utilization due to holding on to mirrors as well as the original
|
||||
// tensor. Note this setting overrides the context mirroring policy whereby if
|
||||
// the mirroring policy is MIRRORING_NONE, we will still continue to mirror
|
||||
// this tensor.
|
||||
virtual void EnableImplicitMirroring() = 0;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
@ -77,6 +85,8 @@ class TensorHandleInterface : public AbstractTensorHandleInterface {
|
||||
|
||||
AbstractTensorHandleInterface* Copy() override;
|
||||
|
||||
void EnableImplicitMirroring() override;
|
||||
|
||||
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
||||
// use cases.
|
||||
TensorHandle* Handle() { return handle_; }
|
||||
|
@ -1569,7 +1569,7 @@ TEST_P(ModularFileSystemTest, TestRoundTrip) {
|
||||
if (!status.ok())
|
||||
GTEST_SKIP() << "NewRandomAccessFile() not supported: " << status;
|
||||
|
||||
char scratch[64 /* big enough to accomodate test_data */] = {0};
|
||||
char scratch[64 /* big enough to accommodate test_data */] = {0};
|
||||
StringPiece result;
|
||||
status = read_file->Read(0, test_data.size(), &result, scratch);
|
||||
EXPECT_PRED2(UnimplementedOrReturnsCode, status, Code::OK);
|
||||
|
@ -24,12 +24,16 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
@ -64,25 +65,41 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
|
||||
const int64_t* dims, int num_dims, size_t len) {
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||
tensorflow::TensorInterface ret(
|
||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf));
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
|
||||
int num_dims, size_t len) {
|
||||
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
|
||||
tensorflow::cpu_allocator());
|
||||
return TF_NewTensor(dtype, dims, num_dims, data, len,
|
||||
tensorflow::deallocate_buffer,
|
||||
tensorflow::cpu_allocator());
|
||||
TF_ManagedBuffer* buf =
|
||||
new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer,
|
||||
tensorflow::cpu_allocator(), /*owns_memory=*/true);
|
||||
return CreateTensor(buf, dtype, dims, num_dims, len);
|
||||
}
|
||||
|
||||
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg) {
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
TF_ManagedBuffer* buf = nullptr;
|
||||
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
|
||||
tensorflow::DataTypeCanUseMemcpy(
|
||||
@ -97,24 +114,17 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
// Other types have the same representation, so copy only if it is safe to
|
||||
// do so.
|
||||
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
|
||||
len, tensorflow::deallocate_buffer, nullptr);
|
||||
len, tensorflow::deallocate_buffer, nullptr,
|
||||
/*owns_memory=*/true);
|
||||
std::memcpy(buf->data(), data, len);
|
||||
// Free the original buffer.
|
||||
deallocator(data, len, deallocator_arg);
|
||||
} else {
|
||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
||||
/*owns_memory=*/false);
|
||||
}
|
||||
|
||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||
tensorflow::TensorInterface ret(
|
||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf));
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||
return CreateTensor(buf, dtype, dims, num_dims, len);
|
||||
}
|
||||
|
||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
||||
|
@ -38,11 +38,12 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
public:
|
||||
TF_ManagedBuffer(void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg)
|
||||
void* deallocator_arg, bool owns_memory)
|
||||
: TensorBuffer(data),
|
||||
len_(len),
|
||||
deallocator_(deallocator),
|
||||
deallocator_arg_(deallocator_arg) {}
|
||||
deallocator_arg_(deallocator_arg),
|
||||
owns_memory_(owns_memory) {}
|
||||
|
||||
~TF_ManagedBuffer() override {
|
||||
(*deallocator_)(data(), len_, deallocator_arg_);
|
||||
@ -57,13 +58,13 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
|
||||
}
|
||||
|
||||
// Prevents input forwarding from mutating this buffer.
|
||||
bool OwnsMemory() const override { return false; }
|
||||
bool OwnsMemory() const override { return owns_memory_; }
|
||||
|
||||
private:
|
||||
const size_t len_;
|
||||
void (*const deallocator_)(void* data, size_t len, void* arg);
|
||||
void* const deallocator_arg_;
|
||||
bool owns_memory_;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -41,6 +41,16 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"training/coordinator.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
srcs = [
|
||||
|
@ -15,13 +15,12 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
namespace {
|
||||
@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
|
||||
|
||||
Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
grad_outputs->push_back(Identity(scope, grad_inputs[0]));
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(NoGradient());
|
||||
Status QuantizeAndDequantizeV2GradHelper(const Scope& scope,
|
||||
const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
Input input = Shape(scope, op.input(0));
|
||||
Input input_min = op.input(1);
|
||||
Input input_max = op.input(2);
|
||||
int64 axis;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
|
||||
auto qdq_v2_grad = QuantizeAndDequantizeV2Grad(
|
||||
scope, grad_inputs[0], input, input_min, input_max,
|
||||
QuantizeAndDequantizeV2Grad::Axis(axis));
|
||||
grad_outputs->push_back(qdq_v2_grad.input_backprop);
|
||||
grad_outputs->push_back(qdq_v2_grad.input_min_backprop);
|
||||
grad_outputs->push_back(qdq_v2_grad.input_max_backprop);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
|
||||
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2",
|
||||
QuantizeAndDequantizeV2GradHelper);
|
||||
|
||||
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
|
@ -68,6 +68,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
],
|
||||
)
|
||||
|
||||
@ -224,3 +225,15 @@ filegroup(
|
||||
"testdata/VarsAndArithmeticObjectGraph/**",
|
||||
]),
|
||||
)
|
||||
|
||||
exports_files(
|
||||
glob([
|
||||
"testdata/half_plus_two_pbtxt/**",
|
||||
"testdata/half_plus_two_main_op/**",
|
||||
"testdata/half_plus_two/**",
|
||||
"testdata/half_plus_two_v2/**",
|
||||
"testdata/x_plus_y_v2_debuginfo/**",
|
||||
"testdata/CyclicModule/**",
|
||||
"testdata/VarsAndArithmeticObjectGraph/**",
|
||||
]),
|
||||
)
|
||||
|
@ -21,15 +21,22 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestDataPbTxt[] =
|
||||
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
|
||||
constexpr char kTestDataSharded[] =
|
||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
||||
string TestDataPbTxt() {
|
||||
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two_pbtxt", "00000123");
|
||||
}
|
||||
|
||||
string TestDataSharded() {
|
||||
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two", "00000123");
|
||||
}
|
||||
|
||||
class ReaderTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
|
||||
TEST_F(ReaderTest, TagMatch) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
|
||||
TEST_F(ReaderTest, NoTagMatch) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
|
||||
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
Status st = ReadMetaGraphDefFromSavedModel(
|
||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||
TEST_F(ReaderTest, PbtxtFormat) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
|
||||
TEST_F(ReaderTest, InvalidExportPath) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
||||
const string export_dir = GetDataDependencyFilepath("missing-path");
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
|
@ -20,9 +20,11 @@ from __future__ import print_function as _print_function
|
||||
|
||||
import logging as _logging
|
||||
import os as _os
|
||||
import six as _six
|
||||
import sys as _sys
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
@ -36,20 +38,19 @@ try:
|
||||
from tensorboard.summary._tf import summary
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||
# Make sure we get the correct summary module with lazy loading
|
||||
setattr(_current_module, "summary", summary)
|
||||
except ImportError:
|
||||
_logging.warning(
|
||||
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
||||
"installation.")
|
||||
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v2 import keras
|
||||
@ -59,6 +60,13 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if not _six.PY2:
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
|
||||
#
|
||||
|
@ -20,8 +20,10 @@ from __future__ import print_function as _print_function
|
||||
|
||||
import os as _os
|
||||
import sys as _sys
|
||||
import six as _six
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
@ -31,13 +33,14 @@ from tensorflow.python.tools import module_util as _module_util
|
||||
|
||||
# Hook external TensorFlow modules.
|
||||
_current_module = _sys.modules[__name__]
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v1 import keras
|
||||
@ -47,6 +50,14 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if not _six.PY2:
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||
|
@ -84,6 +84,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support", # fixdeps: keep
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
@ -134,54 +135,108 @@ cc_library(
|
||||
# tfcompile.bzl correctly handles usage from outside of the package that it is
|
||||
# defined in.
|
||||
|
||||
# A simple test of tf_library from a text protobuf, mostly to enable the
|
||||
# benchmark_test.
|
||||
# A simple test of tf_library from a text protobuf, to enable benchmark_test.
|
||||
# This test uses an incompleted graph with a node that is not defined. The
|
||||
# compilation works because the undefined node is a feed node.
|
||||
tf_library(
|
||||
name = "test_graph_tfadd",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfadd.config.pbtxt",
|
||||
cpp_class = "AddComp",
|
||||
graph = "test_graph_tfadd.pbtxt",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfadd_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfadd.config.pbtxt",
|
||||
cpp_class = "AddComp",
|
||||
graph = "test_graph_tfadd.pbtxt",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
# A test of tf_library that includes a graph with an unknown op, but where
|
||||
# the compilation works because the unknown op is not needed for the fetches.
|
||||
# the compilation works because the node with the unknown op is not needed
|
||||
# for the fetches.
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
# A test of tf_library that includes a graph with an unknown op, but where
|
||||
# the compilation works because the op between the unknown op and the
|
||||
# fetches is a feed.
|
||||
# the compilation works because the node with the unknown op is only used as
|
||||
# an input of a feed node.
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop2",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop2.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop2_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop2.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
# A test of tf_library that includes a graph with an unknown op, but where
|
||||
# the compilation works because the unknown op is fed.
|
||||
# the compilation works because the node with the unknown op is a feed node.
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop3",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop3.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop3_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop3.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -261,9 +316,13 @@ test_suite(
|
||||
tests = [
|
||||
":benchmark_test",
|
||||
":codegen_test",
|
||||
":test_graph_tfadd_mlir_bridge_test",
|
||||
":test_graph_tfadd_test",
|
||||
":test_graph_tfunknownop2_mlir_bridge_test",
|
||||
":test_graph_tfunknownop2_test",
|
||||
":test_graph_tfunknownop3_mlir_bridge_test",
|
||||
":test_graph_tfunknownop3_test",
|
||||
":test_graph_tfunknownop_mlir_bridge_test",
|
||||
":test_graph_tfunknownop_test",
|
||||
"//tensorflow/compiler/aot/tests:all_tests",
|
||||
],
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -29,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -139,23 +141,40 @@ TEST_F(ParseCppClassTest, ParseFail) {
|
||||
|
||||
static void CompareWithGoldenFile(
|
||||
const string& tensorflow_relative_golden_file_name,
|
||||
const string& expected_contents) {
|
||||
const string& expected_contents, bool ignore_cr) {
|
||||
// Get rid of all CR characters, we may be running under windows.
|
||||
string sanitized_expected_contents(expected_contents);
|
||||
if (ignore_cr) {
|
||||
sanitized_expected_contents.erase(
|
||||
std::remove(sanitized_expected_contents.begin(),
|
||||
sanitized_expected_contents.end(), '\r'),
|
||||
sanitized_expected_contents.end());
|
||||
}
|
||||
|
||||
// To update the golden file, flip update_golden to true and run the
|
||||
// following:
|
||||
// bazel test --test_strategy=local \
|
||||
// third_party/tensorflow/compiler/aot:codegen_test
|
||||
const bool update_golden = false;
|
||||
const string golden_file_name = io::JoinPath(
|
||||
testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name);
|
||||
string golden_file_name;
|
||||
|
||||
if (update_golden) {
|
||||
golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||
tensorflow_relative_golden_file_name);
|
||||
TF_EXPECT_OK(
|
||||
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
|
||||
}
|
||||
|
||||
golden_file_name =
|
||||
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
||||
string golden_file_contents;
|
||||
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
||||
&golden_file_contents));
|
||||
if (ignore_cr) {
|
||||
golden_file_contents.erase(std::remove(golden_file_contents.begin(),
|
||||
golden_file_contents.end(), '\r'),
|
||||
golden_file_contents.end());
|
||||
}
|
||||
EXPECT_EQ(golden_file_contents, expected_contents);
|
||||
}
|
||||
|
||||
@ -229,14 +248,18 @@ TEST(CodegenTest, Golden) {
|
||||
// The other fields in metadata_result are tested as part of the generated
|
||||
// header test.
|
||||
|
||||
CompareWithGoldenFile("compiler/aot/codegen_test_o.golden",
|
||||
metadata_result.object_file_data);
|
||||
// This specific golden test checks a binary file. It can potentially run into
|
||||
// issues due to ABIs not being stable, but has not so far.
|
||||
// If we see any ABI issues, we should reconsider this specific test case.
|
||||
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
|
||||
metadata_result.object_file_data, false);
|
||||
|
||||
string header;
|
||||
TF_ASSERT_OK(
|
||||
GenerateHeader(opts, config, compile_result, metadata_result, &header));
|
||||
|
||||
CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header);
|
||||
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
|
||||
true);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
|
@ -107,12 +107,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
if (flags.mlir_components == "Bridge") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
|
||||
} else {
|
||||
if (!flags.mlir_components.empty()) {
|
||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||
}
|
||||
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||
client, &computation));
|
||||
} else {
|
||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||
}
|
||||
if (!flags.out_session_module.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||
|
@ -98,6 +98,7 @@ tf_library(
|
||||
# compile but the others in this directory succeed, you may need to
|
||||
# expand the "required by all tf_library targets" list in tfcompile.bzl.
|
||||
include_standard_runtime_deps = False,
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -110,6 +111,7 @@ tf_library(
|
||||
cpp_class = "AddWithCkptComp",
|
||||
freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt",
|
||||
graph = "test_graph_tfadd_with_ckpt.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -123,6 +125,7 @@ tf_library(
|
||||
freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt",
|
||||
freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver",
|
||||
graph = "test_graph_tfadd_with_ckpt_saver.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -134,6 +137,7 @@ tf_library(
|
||||
config = "test_graph_tfassert_eq.config.pbtxt",
|
||||
cpp_class = "AssertComp",
|
||||
graph = "test_graph_tfassert_eq.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -145,6 +149,7 @@ tf_library(
|
||||
config = "test_graph_tfcond.config.pbtxt",
|
||||
cpp_class = "CondComp",
|
||||
graph = "test_graph_tfcond.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -156,6 +161,7 @@ tf_library(
|
||||
config = "test_graph_tffunction.config.pbtxt",
|
||||
cpp_class = "FunctionComp",
|
||||
graph = "test_graph_tffunction.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -167,6 +173,7 @@ tf_library(
|
||||
config = "test_graph_tfgather.config.pbtxt",
|
||||
cpp_class = "GatherComp",
|
||||
graph = "test_graph_tfgather.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -178,6 +185,7 @@ tf_library(
|
||||
config = "test_graph_tfmatmul.config.pbtxt",
|
||||
cpp_class = "foo::bar::MatMulComp",
|
||||
graph = "test_graph_tfmatmul.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -189,6 +197,7 @@ tf_library(
|
||||
config = "test_graph_tfmatmulandadd.config.pbtxt",
|
||||
cpp_class = "::foo::bar::MatMulAndAddComp",
|
||||
graph = "test_graph_tfmatmulandadd.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -202,6 +211,7 @@ tf_library(
|
||||
cpp_class = "MatMulAndAddCompWithProfiling",
|
||||
enable_xla_hlo_profiling = True,
|
||||
graph = "test_graph_tfmatmulandadd.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -213,6 +223,7 @@ tf_library(
|
||||
config = "test_graph_tfsplits.config.pbtxt",
|
||||
cpp_class = "SplitsComp",
|
||||
graph = "test_graph_tfsplits.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -224,6 +235,7 @@ tf_library(
|
||||
config = "test_graph_tftop_k.config.pbtxt",
|
||||
cpp_class = "TopKComp",
|
||||
graph = "test_graph_tftop_k.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -235,6 +247,7 @@ tf_library(
|
||||
config = "test_graph_tfvariable.config.pbtxt",
|
||||
cpp_class = "VariableComp",
|
||||
graph = "test_graph_tfvariable.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -246,6 +259,7 @@ tf_library(
|
||||
config = "test_graph_tfvariable_readonly.config.pbtxt",
|
||||
cpp_class = "VariableReadonlyComp",
|
||||
graph = "test_graph_tfvariable_readonly.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -257,6 +271,7 @@ tf_library(
|
||||
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
|
||||
cpp_class = "VariableSequentialUpdatesComp",
|
||||
graph = "test_graph_tfvariable_sequential_updates.pb",
|
||||
mlir_components = "None",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
@ -349,6 +364,18 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tffunction_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tffunction.config.pbtxt",
|
||||
cpp_class = "FunctionComp",
|
||||
graph = "test_graph_tffunction.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfassert_eq_mlir_bridge",
|
||||
testonly = 1,
|
||||
@ -484,6 +511,7 @@ tf_cc_test(
|
||||
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
||||
":test_graph_tfassert_eq_mlir_bridge",
|
||||
":test_graph_tfcond_mlir_bridge",
|
||||
":test_graph_tffunction_mlir_bridge",
|
||||
":test_graph_tfgather_mlir_bridge",
|
||||
":test_graph_tfmatmul_mlir_bridge",
|
||||
":test_graph_tfmatmulandadd_mlir_bridge",
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
||||
@ -429,8 +430,6 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||
TEST(TFCompileTest, Function) {
|
||||
// The function is equivalent to an addition
|
||||
FunctionComp add_fn;
|
||||
@ -445,7 +444,6 @@ TEST(TFCompileTest, Function) {
|
||||
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
||||
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(TFCompileTest, Splits) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
|
@ -37,7 +37,7 @@ def tf_library(
|
||||
tfcompile_tool = "//tensorflow/compiler/aot:tfcompile",
|
||||
include_standard_runtime_deps = True,
|
||||
enable_xla_hlo_profiling = False,
|
||||
mlir_components = None,
|
||||
mlir_components = "None",
|
||||
deps = None,
|
||||
tags = []):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
@ -88,8 +88,8 @@ def tf_library(
|
||||
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated
|
||||
program, and emit metadata that lets us pretty-print the gathered
|
||||
profile counters.
|
||||
mlir_components: When the value is "Bridge", use MLIR to translate
|
||||
GraphDef to HLO.
|
||||
mlir_components: When the value is "None", no components use MLIR. When
|
||||
the value is "Bridge", use MLIR to translate GraphDef to HLO.
|
||||
deps: a list of deps to include on the build rules for the generated
|
||||
library, added to the standard deps if standard_runtime_deps is True.
|
||||
tags: tags to apply to subsidiary build rules.
|
||||
@ -189,10 +189,7 @@ def tf_library(
|
||||
else:
|
||||
profiling_flag = ""
|
||||
|
||||
if mlir_components:
|
||||
mlir_flag = "--mlir_components=" + mlir_components
|
||||
else:
|
||||
mlir_flag = ""
|
||||
mlir_flag = "--mlir_components=" + mlir_components
|
||||
|
||||
native.genrule(
|
||||
name = ("gen_" + name),
|
||||
|
@ -159,7 +159,9 @@ XLA_DEVICE_DEPS = [
|
||||
":common",
|
||||
":xla_launch_util",
|
||||
":xla_tensor",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
|
@ -266,9 +266,9 @@ bool RecursiveCompilabilityChecker::IsCompilableCall(
|
||||
s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
|
||||
&handle);
|
||||
}
|
||||
|
||||
if (!s.ok()) {
|
||||
std::string uncompilable_reason = "could not instantiate call";
|
||||
std::string uncompilable_reason =
|
||||
absl::StrCat("could not instantiate call: '", function.name(), "'");
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
|
||||
|
@ -676,12 +676,10 @@ Status Encapsulator::Subgraph::AddFunctionCallNode(
|
||||
Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
|
||||
AttrSlice attrs = node->attrs();
|
||||
attr->clear();
|
||||
bool found_group_attribute = false;
|
||||
for (const auto& node_attr : attrs) {
|
||||
if (node_attr.first == group_attribute_) {
|
||||
TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
|
||||
*attr = node_attr.second.s();
|
||||
found_group_attribute = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -790,7 +788,6 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
|
||||
|
||||
TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
|
||||
TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
|
||||
|
||||
MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
|
||||
|
||||
for (auto& entry : subgraphs_) {
|
||||
|
@ -108,7 +108,7 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
|
||||
"(LRN, LRNGrad)."
|
||||
" BN: TF FusedBatchNorm* operations."
|
||||
" FUSIBLE: All TF operations that XLA can fuse (All the above). "
|
||||
"You can also put any TF operation name, e.g. 'FUSIBLE,Matmul'."),
|
||||
"You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
|
||||
Flag("tf_xla_clustering_debug",
|
||||
&mark_for_compilation_flags->tf_xla_clustering_debug,
|
||||
"Dump graphs during XLA compilation."),
|
||||
@ -186,6 +186,10 @@ void AllocateAndParseFlags() {
|
||||
&build_ops_flags->tf_xla_check_cluster_output_numerics,
|
||||
"If true then insert CheckNumerics nodes to to check all cluster "
|
||||
"outputs."),
|
||||
Flag("tf_xla_disable_constant_folding",
|
||||
&build_ops_flags->tf_xla_disable_constant_folding,
|
||||
"If true then disables constant folding on TF graph before XLA "
|
||||
"compilation."),
|
||||
|
||||
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
|
||||
"Switch a device into 'on-demand' mode, where instead of "
|
||||
|
@ -20,6 +20,7 @@ XLA_OPS_DEPS = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -41,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
@ -206,12 +208,14 @@ se::DeviceMemoryAllocator* GetAllocator(
|
||||
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
const std::vector<int>& constants,
|
||||
const std::vector<int>& resources,
|
||||
const NameAttrList& function)
|
||||
const NameAttrList& function,
|
||||
bool has_ref_vars)
|
||||
: OpKernel(ctx),
|
||||
constants_(constants),
|
||||
resources_(resources),
|
||||
function_(function),
|
||||
platform_info_(PlatformInfoFromContext(ctx)) {}
|
||||
platform_info_(PlatformInfoFromContext(ctx)),
|
||||
has_ref_vars_(has_ref_vars) {}
|
||||
|
||||
static Status BuildCompilationCache(OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
@ -350,8 +354,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
|
||||
{
|
||||
Status s = CompileToLocalExecutable(
|
||||
ctx, function_, /*has_ref_vars=*/true, platform_info_, resources_,
|
||||
constants_, /*lazy=*/false, &client, &variables, &kernel, &executable);
|
||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
|
||||
&executable);
|
||||
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
|
||||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
|
||||
// Suggest auto jit if the failure was with GPU or CPU.
|
||||
@ -384,6 +389,18 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
xla::ThenExecuteFunction then_execute;
|
||||
if (ctx->op_device_context()) {
|
||||
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
|
||||
Status status = ctx->op_device_context()->ThenExecute(
|
||||
down_cast<Device*>(ctx->device()), stream, std::move(fn));
|
||||
if (!status.ok()) {
|
||||
// This should never happen.
|
||||
LOG(ERROR) << "ThenExecute failed " << status;
|
||||
}
|
||||
};
|
||||
run_options.set_then_execute_function(&then_execute);
|
||||
}
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
@ -462,7 +479,7 @@ bool HasRefVars(OpKernelConstruction* ctx) {
|
||||
|
||||
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
|
||||
: XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
|
||||
FunctionAttr(ctx)) {}
|
||||
FunctionAttr(ctx), /*has_ref_vars=*/true) {}
|
||||
|
||||
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
|
||||
VLOG(1) << "XlaLocalLaunchOp destroyed";
|
||||
@ -592,6 +609,18 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
xla::ThenExecuteFunction then_execute;
|
||||
if (ctx->op_device_context()) {
|
||||
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
|
||||
Status status = ctx->op_device_context()->ThenExecute(
|
||||
down_cast<Device*>(ctx->device()), stream, std::move(fn));
|
||||
if (!status.ok()) {
|
||||
// This should never happen.
|
||||
LOG(ERROR) << "ThenExecute failed " << status;
|
||||
}
|
||||
};
|
||||
run_options.set_then_execute_function(&then_execute);
|
||||
}
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
|
@ -95,12 +95,15 @@ class XlaPlatformInfo {
|
||||
// in the GraphDef.
|
||||
// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
|
||||
// this kernel when asked to create a kernel for an XLA-compiled function.
|
||||
//
|
||||
// `has_ref_vars`: whether the input computation can have reference variables.
|
||||
// TODO(cheshire): instead derive this information from the input graph.
|
||||
class XlaLocalLaunchBase : public OpKernel {
|
||||
public:
|
||||
XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
const std::vector<int>& constants,
|
||||
const std::vector<int>& resources,
|
||||
const NameAttrList& function);
|
||||
const NameAttrList& function, bool has_ref_vars);
|
||||
XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
|
||||
XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
|
||||
~XlaLocalLaunchBase() override = default;
|
||||
@ -115,6 +118,8 @@ class XlaLocalLaunchBase : public OpKernel {
|
||||
|
||||
const NameAttrList function_;
|
||||
const XlaPlatformInfo platform_info_;
|
||||
|
||||
bool has_ref_vars_;
|
||||
};
|
||||
|
||||
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
|
||||
|
@ -963,6 +963,22 @@ absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Returns true iff the attribute `attr_name` is attached to either the node or
|
||||
// to it's callee.
|
||||
static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
|
||||
const char* attr_name) {
|
||||
bool out = false;
|
||||
bool attr_value;
|
||||
if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
|
||||
out |= attr_value;
|
||||
}
|
||||
|
||||
if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
|
||||
out |= attr_value;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
|
||||
return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
|
||||
@ -1016,16 +1032,9 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
resource_var_operation_node_id = node->id();
|
||||
}
|
||||
|
||||
bool is_xla_compile_attr_true = false;
|
||||
|
||||
bool xla_compile_attr;
|
||||
if (TryGetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr)) {
|
||||
is_xla_compile_attr_true |= xla_compile_attr;
|
||||
}
|
||||
|
||||
if (flib_def_->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr).ok()) {
|
||||
is_xla_compile_attr_true |= xla_compile_attr;
|
||||
}
|
||||
bool is_xla_compile_attr_true =
|
||||
GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
|
||||
GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr);
|
||||
|
||||
DeviceSet devices;
|
||||
devices.Insert(device);
|
||||
@ -1874,6 +1883,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"EmptyTensorList",
|
||||
"ExtractImagePatches",
|
||||
"Igamma",
|
||||
"IgammaGradA",
|
||||
"RandomGammaGrad",
|
||||
"Igammac",
|
||||
"FFT",
|
||||
"FFT2D",
|
||||
@ -1996,6 +2007,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"StatelessRandomNormal",
|
||||
"StatelessRandomUniform",
|
||||
"StatelessRandomUniformInt",
|
||||
"StatelessRandomUniformFullInt",
|
||||
"StatelessTruncatedNormal",
|
||||
"StatelessWhile",
|
||||
"Svd",
|
||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_context.h"
|
||||
@ -386,14 +388,33 @@ Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Warn about XLA_CPU/XLA_GPU exactly once.
|
||||
static void ShowXlaDeviceDeprecationWarning(
|
||||
absl::string_view compilation_device_name) {
|
||||
static absl::once_flag once;
|
||||
if (absl::StrContains(compilation_device_name, "CPU") ||
|
||||
absl::StrContains(compilation_device_name, "GPU")) {
|
||||
absl::call_once(once, [] {
|
||||
LOG(WARNING)
|
||||
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
|
||||
"removed in subsequent releases. Instead, use either "
|
||||
"@tf.function(experimental_compile=True) for must-compile "
|
||||
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
|
||||
"for auto-clustering best-effort compilation.";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
|
||||
op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
|
||||
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
op_kernel->ComputeAsync(context, done);
|
||||
|
@ -20,15 +20,17 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const {
|
||||
return CanCreateXlaKernel(node_def);
|
||||
bool XlaKernelCreator::CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const {
|
||||
return CanCreateXlaKernel(props->node_def);
|
||||
}
|
||||
|
||||
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
return CreateXlaKernel(flr, node_def, kernel);
|
||||
Status XlaKernelCreator::CreateKernel(
|
||||
FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
return CreateXlaKernel(flr, props->node_def, kernel);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
|
||||
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
||||
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
||||
// with the kXlaCompileAttr set.
|
||||
bool CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const override;
|
||||
bool CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const override;
|
||||
|
||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
Status CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const override;
|
||||
};
|
||||
|
||||
|
@ -30,10 +30,12 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
NodeDef ToNodeDef(const string& text) {
|
||||
std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
|
||||
NodeDef node_def;
|
||||
DataTypeVector dummy;
|
||||
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
|
||||
return node_def;
|
||||
return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
|
||||
dummy);
|
||||
}
|
||||
|
||||
// Create a FunctionDef that takes one resource and one regular param
|
||||
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
|
||||
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
NodeDef callsite =
|
||||
ToNodeDef(R"pb(
|
||||
auto callsite =
|
||||
ToNodeProperties(R"pb(
|
||||
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
|
||||
)pb");
|
||||
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
(*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
|
||||
|
||||
// Note: need to set attribute on the created node.
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
|
||||
@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
Status status =
|
||||
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
|
||||
}
|
||||
|
||||
@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
Status status =
|
||||
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
|
||||
}
|
||||
|
||||
|
@ -104,7 +104,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
||||
/*compile_time_const_nodes=*/nullptr, flr));
|
||||
|
||||
for (int i = 0; i < const_args.size(); ++i) {
|
||||
for (size_t i = 0; i < const_args.size(); ++i) {
|
||||
if (const_args[i]) {
|
||||
constant_arg_indices->push_back(i);
|
||||
}
|
||||
@ -113,7 +113,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
// There can be hundreds of resource variables. Reserve the space for them.
|
||||
// We don't reserve for constants above as they are usually few.
|
||||
resource_arg_indices->reserve(arg_types.size());
|
||||
for (int i = 0; i < arg_types.size(); ++i) {
|
||||
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||
if (arg_types[i] == DT_RESOURCE) {
|
||||
resource_arg_indices->push_back(i);
|
||||
}
|
||||
@ -143,11 +143,11 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
}
|
||||
string message = absl::StrCat(
|
||||
"Function invoked by the following node is not compilable: ",
|
||||
node_def.ShortDebugString(), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:\n");
|
||||
SummarizeNodeDef(node_def), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:");
|
||||
for (const auto& node_info : uncompilable_node_info) {
|
||||
string node_message =
|
||||
absl::StrCat("\t", node_info.name, ": ",
|
||||
absl::StrCat("\n", node_info.name, ": ",
|
||||
node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
|
||||
for (const auto& stack_frame : node_info.stack_trace) {
|
||||
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
|
||||
@ -156,7 +156,6 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
absl::StrAppend(&message, node_message);
|
||||
}
|
||||
VLOG(1) << message;
|
||||
// node_def is calling a function that XLA can't compile.
|
||||
return errors::InvalidArgument(message);
|
||||
}
|
||||
|
||||
@ -178,7 +177,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
// 214 variables and a similar number of activations.
|
||||
SinglePassSearch constants_search(&constant_arg_indices);
|
||||
SinglePassSearch resources_search(&resource_arg_indices);
|
||||
for (int i = 0; i < fbody->arg_types.size(); ++i) {
|
||||
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
|
||||
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
||||
// Compile-time constants and resource handles are expected to be in
|
||||
// host memory.
|
||||
@ -208,7 +207,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||
// in device memory except for resources.
|
||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
||||
for (int i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
if (fbody->ret_types[i] == DT_RESOURCE) {
|
||||
output_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
@ -219,15 +218,17 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
|
||||
input_memory_types, fbody->ret_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
auto props = std::make_shared<NodeProperties>(
|
||||
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
|
||||
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()),
|
||||
flr, dev->resource_manager(), props,
|
||||
input_memory_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
&construction, constant_arg_indices, resource_arg_indices, function);
|
||||
&construction, constant_arg_indices, resource_arg_indices, function,
|
||||
/*has_ref_vars=*/false);
|
||||
return s;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -44,11 +44,9 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AffineDialectRegistration",
|
||||
"@llvm-project//mlir:LoopDialectRegistration",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir/test:TestTransforms",
|
||||
],
|
||||
@ -66,6 +64,8 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
@ -104,7 +104,9 @@ tf_cc_binary(
|
||||
name = "tf-opt",
|
||||
deps = [
|
||||
":tf_mlir_opt_main",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
],
|
||||
)
|
||||
@ -114,8 +116,10 @@ tf_cc_binary(
|
||||
srcs = ["tf_mlir_translate_main.cc"],
|
||||
deps = [
|
||||
":init_mlir",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
||||
@ -127,6 +131,7 @@ tf_cc_binary(
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TranslateClParser",
|
||||
|
@ -48,10 +48,11 @@ def _run_lit_test(name, data, size, tags, driver, features):
|
||||
" the driver parameter when running this test. If you require" +
|
||||
" custom driver support, please file an issue to request it.")
|
||||
|
||||
# Disable tests on windows for now, to enable testing rest of all xla and mlir.
|
||||
native.py_test(
|
||||
name = name,
|
||||
srcs = ["@llvm-project//llvm:lit"],
|
||||
tags = tags,
|
||||
tags = tags + ["no_windows"],
|
||||
args = [
|
||||
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
|
||||
] + features,
|
||||
|
@ -26,9 +26,11 @@ package_group(
|
||||
filegroup(
|
||||
name = "tensorflow_lite_ops_td_files",
|
||||
srcs = [
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -43,10 +45,29 @@ gentbl(
|
||||
"-gen-op-defs",
|
||||
"ir/tfl_ops.cc.inc",
|
||||
),
|
||||
(
|
||||
"-gen-struct-attr-decls",
|
||||
"ir/tfl_structs.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-struct-attr-defs",
|
||||
"ir/tfl_structs.cc.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-doc",
|
||||
"g3doc/tfl_ops.md",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_lite_op_interfaces_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-interface-decls",
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
@ -57,7 +78,7 @@ gentbl(
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_file = "ir/tfl_op_interfaces.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
],
|
||||
@ -187,6 +208,7 @@ cc_library(
|
||||
"ir/tfl_ops.h.inc",
|
||||
"ir/tfl_ops_interface.cc.inc",
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"runtime_verifiers.inc",
|
||||
"utils/attribute_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
@ -199,8 +221,6 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
@ -209,6 +229,11 @@ cc_library(
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -274,15 +299,19 @@ cc_library(
|
||||
"transforms/generated_prepare_tf.inc",
|
||||
"transforms/legalize_ophint_func_op.cc",
|
||||
"transforms/legalize_tf.cc",
|
||||
"transforms/legalize_tf_while.cc",
|
||||
"transforms/lower_static_tensor_list.cc",
|
||||
"transforms/optimize_functional_ops.cc",
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/runtime_type_verify.cc",
|
||||
"transforms/split_merged_operands.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
"transforms/while_loop_outline.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"transforms/dilated_conv.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
@ -293,10 +322,12 @@ cc_library(
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
@ -376,6 +407,24 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_lite_d2s",
|
||||
srcs = [
|
||||
"transforms/dense_to_sparse.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "generated_op_quant_spec_getters",
|
||||
srcs = [
|
||||
@ -387,6 +436,8 @@ genrule(
|
||||
name = "op_quant_spec_getters_inc",
|
||||
srcs = [
|
||||
"ir/tfl_ops.td",
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
],
|
||||
outs = [
|
||||
@ -413,9 +464,9 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "operator-converter-gen",
|
||||
name = "converter-gen",
|
||||
srcs = [
|
||||
"operator_converter_gen.cc",
|
||||
"converter_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
@ -425,14 +476,18 @@ tf_native_cc_binary(
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "operator_converter_inc",
|
||||
name = "converter_inc",
|
||||
tbl_outs = [
|
||||
(
|
||||
"", # This driver has no options.
|
||||
"--gen-operator-converters",
|
||||
"operator_converters.inc",
|
||||
),
|
||||
(
|
||||
"--gen-runtime-verifiers",
|
||||
"runtime_verifiers.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":operator-converter-gen",
|
||||
tblgen = ":converter-gen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
@ -515,6 +570,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -535,8 +591,6 @@ cc_library(
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:StandardDialectRegistration",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
@ -597,12 +651,14 @@ tf_cc_binary(
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
@ -634,6 +690,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":tensorflow_lite_d2s",
|
||||
":tensorflow_lite_legalize_tf",
|
||||
":tensorflow_lite_optimize",
|
||||
":tensorflow_lite_quantize",
|
||||
@ -649,7 +706,6 @@ cc_library(
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
@ -683,7 +739,6 @@ cc_library(
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
|
@ -35,7 +35,8 @@ struct PassConfig {
|
||||
skip_control_dialect(false),
|
||||
form_clusters(false),
|
||||
inline_functions(true),
|
||||
unfold_batch_matmul(true) {}
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true) {}
|
||||
|
||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||
// added, which produces TF Lite ops.
|
||||
@ -61,6 +62,10 @@ struct PassConfig {
|
||||
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||
// of tfl.fully_connected ops.
|
||||
bool unfold_batch_matmul;
|
||||
// Whether to legalize TF While to TFL While.
|
||||
// Note: This is staging step and will be removed.
|
||||
// TODO(b/137395003): Remove post switching legalization.
|
||||
bool legalize_tf_while;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
|
@ -28,6 +28,9 @@ limitations under the License.
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Format.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Operator.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
|
||||
|
||||
using llvm::DefInit;
|
||||
using llvm::dyn_cast;
|
||||
@ -41,6 +44,19 @@ using llvm::SmallVector;
|
||||
using llvm::StringInit;
|
||||
using llvm::StringRef;
|
||||
|
||||
enum ActionType {
|
||||
OpConv,
|
||||
RuntimeVerify,
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
llvm::cl::opt<ActionType> action(
|
||||
llvm::cl::desc("Action to perform:"),
|
||||
llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters",
|
||||
"Generate operator converters"),
|
||||
clEnumValN(RuntimeVerify, "gen-runtime-verifiers",
|
||||
"Generate TFLite runtime verifiers")));
|
||||
|
||||
// Returns the associated option name for the given op definition.
|
||||
static inline std::string GetOperatorOptionName(const Record &def) {
|
||||
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
|
||||
@ -342,8 +358,101 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static void GenOperandResultVerifier(raw_ostream &os,
|
||||
llvm::ArrayRef<llvm::Init *> values,
|
||||
StringRef valueKind) {
|
||||
mlir::tblgen::FmtContext fctx;
|
||||
|
||||
bool first = true;
|
||||
for (auto static_value : llvm::enumerate(values)) {
|
||||
auto *definit = llvm::cast<llvm::DefInit>(static_value.value());
|
||||
auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate");
|
||||
if (!val) continue;
|
||||
|
||||
// Create code block on first type to verify.
|
||||
if (first) {
|
||||
os << " {\n";
|
||||
os << " unsigned index = " << static_value.index() << ";\n";
|
||||
first = false;
|
||||
}
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
auto desc =
|
||||
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
|
||||
|
||||
// Emit a loop to check all the dynamic values in the pack.
|
||||
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
|
||||
// Capitalize the first letter to match the function name
|
||||
valueKind.substr(0, 1).upper(), valueKind.substr(1),
|
||||
static_value.index());
|
||||
|
||||
os << " (void)v;\n"
|
||||
<< " if (!("
|
||||
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||
<< formatv(
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||
valueKind, desc)
|
||||
<< " }\n" // if
|
||||
<< " ++index;\n"
|
||||
<< " }\n"; // for
|
||||
}
|
||||
|
||||
// Emit closing brace if needed.
|
||||
if (!first) os << " }\n";
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os);
|
||||
|
||||
// Retrieve all the definitions derived from TFL_Op and sort by record name.
|
||||
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
|
||||
llvm::sort(defs, LessRecord());
|
||||
|
||||
// Iterate through all the ops defined.
|
||||
for (const auto *def : defs) {
|
||||
mlir::tblgen::Operator op(*def);
|
||||
if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue;
|
||||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
auto &value = op.getOperand(i);
|
||||
// Skip from from first variadic operands for now. Else getOperand index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||
}
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
auto &value = op.getResult(i);
|
||||
// Skip from from first variadic results for now. Else getResult index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||
}
|
||||
}
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
|
||||
"operand");
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
||||
"result");
|
||||
os << " return mlir::success();\n}\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
return TableGenMain(argv[0], &OperatorWritersMain);
|
||||
if (action == ActionType::OpConv)
|
||||
return TableGenMain(argv[0], &OperatorWritersMain);
|
||||
return TableGenMain(argv[0], &RuntimeVerifierWriterMain);
|
||||
}
|
@ -90,6 +90,7 @@ using mlir::MLIRContext;
|
||||
using mlir::ModuleOp;
|
||||
using mlir::NoneType;
|
||||
using mlir::Operation;
|
||||
using mlir::Region;
|
||||
using mlir::StringAttr;
|
||||
using mlir::TensorType;
|
||||
using mlir::TranslateFromMLIRRegistration;
|
||||
@ -309,7 +310,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
|
||||
static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
|
||||
::mlir::Operation* inst) {
|
||||
// We pass empty string for the original node_def name since Flex runtime
|
||||
// does not care about this being set correctly on node_def. There is no
|
||||
@ -425,6 +426,11 @@ class Translator {
|
||||
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Build while operator where cond & body are regions.
|
||||
Optional<BufferOffset<tflite::Operator>> BuildWhileOperator(
|
||||
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Builds custom operators.
|
||||
// Templated on a) data type of custom_option to be stored into flatbuffer,
|
||||
// and b) TFL custom op type.
|
||||
@ -472,7 +478,10 @@ class Translator {
|
||||
Operation* inst, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
|
||||
// Build a subgraph with a given name out of the region either corresponding
|
||||
// to a function's body or while op.
|
||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
|
||||
const std::string& name, Region* region);
|
||||
|
||||
// Builds Metadata with the given `name` and buffer `content`.
|
||||
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
|
||||
@ -539,9 +548,14 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
||||
attr = cst.value();
|
||||
} else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
} else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
} else {
|
||||
return empty_buffer_;
|
||||
}
|
||||
|
||||
tensorflow::Tensor tensor;
|
||||
auto status = tensorflow::ConvertToTensor(attr, &tensor);
|
||||
if (!status.ok()) {
|
||||
@ -595,6 +609,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
};
|
||||
|
||||
std::vector<int32_t> shape;
|
||||
std::vector<int32_t> shape_signature;
|
||||
if (type.hasStaticShape()) {
|
||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
@ -612,7 +627,25 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
|
||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
} else if (type.hasRank()) {
|
||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
|
||||
shape.reserve(shape_ref.size());
|
||||
for (auto& dim : shape_ref) {
|
||||
shape.push_back(dim == -1 ? 1 : dim);
|
||||
}
|
||||
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
|
||||
if (auto* inst = value.getDefiningOp()) {
|
||||
if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
|
||||
// CreateSparsityParameters(cst.s_param());
|
||||
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
|
||||
// CreateSparsityParameters(cst.s_param());
|
||||
}
|
||||
}
|
||||
|
||||
Type element_type = type.getElementType();
|
||||
tflite::TensorType tflite_element_type =
|
||||
GetTFLiteType(type.getElementType()).ValueOrDie();
|
||||
@ -649,10 +682,19 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
break;
|
||||
}
|
||||
}
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable);
|
||||
|
||||
if (shape_signature.empty()) {
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable);
|
||||
} else {
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable, /*sparsity=*/0,
|
||||
/*shape_signature=*/builder_.CreateVector(shape_signature));
|
||||
}
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
|
||||
@ -687,6 +729,32 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Operator>> Translator::BuildWhileOperator(
|
||||
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
|
||||
auto get_call_index = [&](mlir::Block& b) -> Optional<int> {
|
||||
if (b.getOperations().size() != 2) return llvm::None;
|
||||
if (auto call_op = dyn_cast<mlir::CallOp>(b.front()))
|
||||
return subgraph_index_map_.at(call_op.callee().str());
|
||||
return llvm::None;
|
||||
};
|
||||
auto body_subgraph_index = get_call_index(op.body().front());
|
||||
auto cond_subgraph_index = get_call_index(op.cond().front());
|
||||
if (!body_subgraph_index || !cond_subgraph_index)
|
||||
return op.emitOpError("only single call cond/body while export supported"),
|
||||
llvm::None;
|
||||
auto builtin_options =
|
||||
tflite::CreateWhileOptions(builder_, *cond_subgraph_index,
|
||||
*body_subgraph_index)
|
||||
.Union();
|
||||
auto inputs = builder_.CreateVector(operands);
|
||||
auto outputs = builder_.CreateVector(results);
|
||||
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
|
||||
tflite::BuiltinOptions_WhileOptions,
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
template <typename CustomOptionType, typename TFLOp>
|
||||
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
|
||||
const CustomOptionType& custom_option, const std::string& opcode_name,
|
||||
@ -908,6 +976,16 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
|
||||
results);
|
||||
}
|
||||
if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
|
||||
if (inst->getNumOperands() != inst->getNumResults()) {
|
||||
inst->emitOpError(
|
||||
"number of operands and results don't match, only canonical "
|
||||
"TFL While supported");
|
||||
return llvm::None;
|
||||
}
|
||||
return BuildWhileOperator(whileOp, operands, results);
|
||||
}
|
||||
|
||||
inst->emitOpError("is not a supported TFLite op");
|
||||
return llvm::None;
|
||||
}
|
||||
@ -944,7 +1022,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
// we emit op as flex.
|
||||
// if custom is enabled
|
||||
// we emit the op as custom.
|
||||
auto node_def = getTensorFlowNodeDef(inst);
|
||||
auto node_def = GetTensorFlowNodeDef(inst);
|
||||
if (!node_def) {
|
||||
return llvm::None;
|
||||
}
|
||||
@ -1047,9 +1125,12 @@ bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
|
||||
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
||||
const std::string& name, Region* region) {
|
||||
bool has_input_attr = false;
|
||||
InitializeNamesFromAttribute(fn, &has_input_attr);
|
||||
if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
|
||||
InitializeNamesFromAttribute(fn, &has_input_attr);
|
||||
}
|
||||
std::vector<BufferOffset<tflite::Tensor>> tensors;
|
||||
llvm::DenseMap<Value, int> tensor_index_map;
|
||||
|
||||
@ -1081,7 +1162,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
};
|
||||
|
||||
std::vector<BufferOffset<tflite::Operator>> operators;
|
||||
auto& bb = fn.getBlocks().front();
|
||||
auto& bb = region->front();
|
||||
|
||||
// Main function's arguments are first passed to `input` op so they don't
|
||||
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for
|
||||
@ -1141,7 +1222,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
return tflite::CreateSubGraph(
|
||||
builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
|
||||
builder_.CreateVector(outputs), builder_.CreateVector(operators),
|
||||
/*name=*/builder_.CreateString(fn.getName().str()));
|
||||
/*name=*/builder_.CreateString(name));
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
|
||||
@ -1184,35 +1265,36 @@ Optional<std::string> Translator::Translate(
|
||||
}
|
||||
|
||||
Optional<std::string> Translator::TranslateInternal() {
|
||||
// Create a list of functions in the module with main function being the
|
||||
// first function in the list. This is required as the first subgraph in the
|
||||
// model is entry point for the model.
|
||||
std::vector<FuncOp> functions;
|
||||
functions.reserve(std::distance(module_.begin(), module_.end()));
|
||||
// A list of named regions in the module with main function being the first in
|
||||
// the list. The main function is required as the first subgraph in the model
|
||||
// is entry point for the model.
|
||||
std::vector<std::pair<std::string, Region*>> named_regions;
|
||||
named_regions.reserve(std::distance(module_.begin(), module_.end()));
|
||||
|
||||
int subgraph_idx = 0;
|
||||
FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
|
||||
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
|
||||
functions.push_back(main_fn);
|
||||
for (auto fn : module_.getOps<FuncOp>()) {
|
||||
if (fn == main_fn) continue;
|
||||
named_regions.emplace_back("main", &main_fn.getBody());
|
||||
// Walk over the module collection ops with functions and while ops.
|
||||
module_.walk([&](FuncOp fn) {
|
||||
if (fn != main_fn) {
|
||||
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
|
||||
named_regions.emplace_back(fn.getName().str(), &fn.getBody());
|
||||
}
|
||||
});
|
||||
|
||||
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
|
||||
functions.push_back(fn);
|
||||
}
|
||||
|
||||
// Build subgraph for each of the functions.
|
||||
// Build subgraph for each of the named regions.
|
||||
std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
|
||||
subgraphs.reserve(functions.size());
|
||||
subgraphs.reserve(named_regions.size());
|
||||
int first_failed_func = -1;
|
||||
for (int i = 0; i < functions.size(); ++i) {
|
||||
auto subgraph_or = BuildSubGraph(functions[i]);
|
||||
for (auto it : llvm::enumerate(named_regions)) {
|
||||
auto subgraph_or = BuildSubGraph(it.value().first, it.value().second);
|
||||
if (!subgraph_or) {
|
||||
if (first_failed_func == -1)
|
||||
// Record the index of the first function that cannot be converted.
|
||||
// Record the index of the first region that cannot be converted.
|
||||
// Keep looping through all subgraphs in the module to make sure that
|
||||
// we collect the list of missing ops from the entire module.
|
||||
first_failed_func = i;
|
||||
first_failed_func = it.index();
|
||||
} else {
|
||||
subgraphs.push_back(*subgraph_or);
|
||||
}
|
||||
@ -1233,9 +1315,10 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
"-emit-custom-ops flag): " +
|
||||
failed_custom_ops_list;
|
||||
|
||||
return functions[first_failed_func].emitError("failed while converting: '")
|
||||
<< functions[first_failed_func].getName() << "\'\n"
|
||||
<< err,
|
||||
auto& failed_region = named_regions[first_failed_func];
|
||||
return failed_region.second->getParentOp()->emitError()
|
||||
<< "failed while converting: '" << failed_region.first
|
||||
<< "': " << err,
|
||||
llvm::None;
|
||||
}
|
||||
|
||||
|
93
tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
Normal file
93
tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
Normal file
@ -0,0 +1,93 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the operation interface definition file for TensorFlow Lite.
|
||||
|
||||
#ifndef TFL_OP_INTERFACES
|
||||
#define TFL_OP_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for stateful operands.
|
||||
|
||||
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
|
||||
let description = [{
|
||||
Interface for ops that are stateful and need to identify stateful operands.
|
||||
|
||||
Stateful operands correspond to TF's variables semantics. An op that has 1
|
||||
or more stateful operands is a stateful op.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the indices of stateful operands.}],
|
||||
"std::vector<int>", "GetStatefulOperands", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for output channel index.
|
||||
|
||||
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
|
||||
let description = [{
|
||||
Interface for defining the index of out channel index.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the dimension index of the output channels.}],
|
||||
"int", "GetChannelDimIndex", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for sparse operands.
|
||||
|
||||
def TFL_SparseOp : OpInterface<"SparseOpInterface"> {
|
||||
let description = [{
|
||||
Interface for ops that support sparse computation.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the indices of sparse operands.}],
|
||||
"std::vector<int>", "GetSparseOperands", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL runtime type verification of operand/result types.
|
||||
|
||||
def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||
let description = [{
|
||||
Interface to verify TFLite runtime op verification.
|
||||
|
||||
This verifies that the converted TFLite ops has operand/result type
|
||||
supported by the TFLite runtime.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TFL_OP_INTERFACES
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
@ -36,9 +37,11 @@ limitations under the License.
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
|
||||
namespace TFL {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -52,11 +55,15 @@ struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
bool isLegalToInline(Operation *, Region *,
|
||||
bool isLegalToInline(Operation *op, Region *dest,
|
||||
BlockAndValueMapping &) const final {
|
||||
// No TFLite op restricts inlining today, revise as needed in the future.
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Region *dest, Region *src,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
return isa<WhileOp>(dest->getParentOp());
|
||||
}
|
||||
};
|
||||
|
||||
TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
|
||||
@ -1101,10 +1108,10 @@ static LogicalResult VerifySplitOpOutputTypes(
|
||||
for (int64_t i = 0; i < num_splits; ++i) {
|
||||
auto expected_output_type = get_expected_output_type(i);
|
||||
Value output = op->getResult(i);
|
||||
auto output_type = output.getType().dyn_cast<RankedTensorType>();
|
||||
if (!output_type || output_type != expected_output_type)
|
||||
if (failed(verifyCompatibleShape(output.getType(), expected_output_type)))
|
||||
return op->emitOpError()
|
||||
<< "output #" << i << " should be " << expected_output_type;
|
||||
<< "output #" << i << " should be " << expected_output_type
|
||||
<< " instead got " << output.getType();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@ -1736,6 +1743,128 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Verify(WhileOp op) {
|
||||
if (op.getNumOperands() != op.getNumResults())
|
||||
return op.emitOpError(llvm::formatv(
|
||||
"number of operands does not match number of results ({0} != {1})",
|
||||
op.getNumOperands(), op.getNumResults()));
|
||||
// TODO(jpienaar): Verify operand, result & block arguments types
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Canonicalize While op so that results and operands match and external values
|
||||
// are via implicit capture rather than via block args.
|
||||
struct WhileResultOperandsMatchAndImplicitCapture
|
||||
: public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(WhileOp while_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Replace values simply passed through the body with extern values. The
|
||||
// block arguments of body and while match and so the corresponding cond
|
||||
// argument can be easily found.
|
||||
bool unchanged = true;
|
||||
auto &body_block = while_op.body().front();
|
||||
auto &cond_block = while_op.cond().front();
|
||||
auto &yield = *body_block.getTerminator();
|
||||
for (auto ba : body_block.getArguments()) {
|
||||
if (ba == yield.getOperand(ba.getArgNumber())) {
|
||||
unchanged = false;
|
||||
auto value = while_op.getOperand(ba.getArgNumber());
|
||||
ba.replaceAllUsesWith(value);
|
||||
cond_block.getArgument(ba.getArgNumber()).replaceAllUsesWith(value);
|
||||
}
|
||||
}
|
||||
|
||||
// The While ops operands and result types need to match
|
||||
SmallVector<Value, 4> new_operands;
|
||||
SmallVector<Value, 4> new_body_yield;
|
||||
SmallVector<bool, 4> const_operand(while_op.getNumOperands(), false);
|
||||
llvm::SmallVector<Type, 4> types;
|
||||
new_operands.reserve(while_op.getNumOperands());
|
||||
new_body_yield.reserve(while_op.getNumOperands());
|
||||
types.reserve(while_op.getNumOperands());
|
||||
|
||||
// Remove block arguments not used in either cond or body. This leaves the
|
||||
// block arguments of body and cond matching still.
|
||||
int arg_index = 0;
|
||||
for (int while_index = 0, e = while_op.getNumOperands(); while_index < e;
|
||||
++while_index) {
|
||||
auto value = while_op.getOperand(while_index);
|
||||
if (body_block.getArgument(arg_index).use_empty() &&
|
||||
cond_block.getArgument(arg_index).use_empty() &&
|
||||
// This could be relaxed and casts inserted.
|
||||
while_op.getResult(while_index).getType() == value.getType()) {
|
||||
unchanged = false;
|
||||
body_block.eraseArgument(arg_index);
|
||||
cond_block.eraseArgument(arg_index);
|
||||
|
||||
// Mark operand as constant and replace all uses with input to while.
|
||||
while_op.getResult(while_index).replaceAllUsesWith(value);
|
||||
const_operand[while_index] = true;
|
||||
} else {
|
||||
new_operands.push_back(value);
|
||||
new_body_yield.push_back(yield.getOperand(while_index));
|
||||
auto type = while_op.getResult(while_index).getType();
|
||||
types.push_back(type);
|
||||
++arg_index;
|
||||
}
|
||||
}
|
||||
|
||||
// Done if no values removed from blocks and operands & results match.
|
||||
if (unchanged) return matchFailure();
|
||||
|
||||
// Replace with new While with matching operands and results.
|
||||
Operation *op = while_op.getOperation();
|
||||
Operation *new_op = rewriter.insert(
|
||||
Operation::create(op->getLoc(), op->getName(), types, new_operands,
|
||||
op->getAttrs(), {}, /*numRegions=*/2,
|
||||
/*resizableOperandList=*/true));
|
||||
|
||||
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
|
||||
int new_index = 0;
|
||||
for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
|
||||
if (const_operand[op_index]) continue;
|
||||
op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
|
||||
++new_index;
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
Block &new_body_block = cast<WhileOp>(new_op).body().front();
|
||||
rewriter.setInsertionPointToEnd(&new_body_block);
|
||||
rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
|
||||
new_body_yield);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<WhileResultOperandsMatchAndImplicitCapture>(context);
|
||||
}
|
||||
|
||||
Region &WhileOp::getLoopBody() { return body(); }
|
||||
|
||||
bool WhileOp::isDefinedOutsideOfLoop(Value value) {
|
||||
// TODO(jpienaar): This is to overly conservative and disables anything other
|
||||
// than constant hoisting initially.
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
|
||||
if (ops.empty()) return success();
|
||||
|
||||
// Move the hoisted value to just before the while.
|
||||
Operation *while_op = this->getOperation();
|
||||
for (auto op : ops) op->moveBefore(while_op);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1743,6 +1872,7 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||
#include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
|
||||
|
||||
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value,
|
||||
|
@ -27,10 +27,12 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.h.inc"
|
||||
namespace TFL {
|
||||
|
||||
class TensorFlowLiteDialect : public Dialect {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -41,13 +41,20 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/optional_debug_tools.h"
|
||||
|
||||
using llvm::cl::desc;
|
||||
using llvm::cl::init;
|
||||
using llvm::cl::opt;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string> inputFileName(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
static opt<std::string> input_filename(llvm::cl::Positional,
|
||||
desc("<input file>"), init("-"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool> dump_state("dump-interpreter-state",
|
||||
desc("dump interpreter state post execution"),
|
||||
init(false));
|
||||
|
||||
// TODO(jpienaar): Move these functions to some debug utils.
|
||||
static std::string TfLiteTensorDimString(const TfLiteTensor& tensor) {
|
||||
@ -82,9 +89,9 @@ int main(int argc, char** argv) {
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR TFLite runner\n");
|
||||
|
||||
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
|
||||
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.c_str());
|
||||
if (std::error_code error = file_or_err.getError()) {
|
||||
LOG(ERROR) << argv[0] << ": could not open input file '" << inputFileName
|
||||
LOG(ERROR) << argv[0] << ": could not open input file '" << input_filename
|
||||
<< "': " << error.message() << "\n";
|
||||
return 1;
|
||||
}
|
||||
@ -133,5 +140,7 @@ int main(int argc, char** argv) {
|
||||
TfLiteTensorString(out).c_str());
|
||||
}
|
||||
|
||||
if (dump_state) tflite::PrintInterpreterState(interpreter.get());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:common",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
|
||||
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -277,6 +278,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||
// Convert back to outlined while format for export back to flatbuffer.
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
}
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
@ -39,6 +40,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#define DEBUG_TYPE "quantization-driver"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
namespace {
|
||||
@ -281,6 +284,37 @@ class QuantizationDriver {
|
||||
cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
|
||||
}
|
||||
|
||||
void DumpStates(Operation *current_op) {
|
||||
if (current_op) {
|
||||
llvm::errs() << "\n\n\n" << current_op->getName() << "\n";
|
||||
}
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (llvm::isa<quant::QuantizeCastOp>(op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<ConstantOp>(op))
|
||||
return;
|
||||
if (current_op == op) llvm::errs() << "===>>>";
|
||||
llvm::errs() << op->getName() << " : (";
|
||||
for (auto i = 0; i < op->getNumOperands(); ++i) {
|
||||
if (auto params = GetOperandQuantState(op, i).params)
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
|
||||
llvm::errs());
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ") -> (";
|
||||
for (auto i = 0; i < op->getNumResults(); ++i) {
|
||||
if (auto params = GetResultQuantState(op, i).params)
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
op->getResult(i).getType().cast<ShapedType>().getElementType().print(
|
||||
llvm::errs());
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ")\n";
|
||||
});
|
||||
}
|
||||
|
||||
FuncOp fn_;
|
||||
OpBuilder builder_;
|
||||
bool is_signed_;
|
||||
@ -350,7 +384,7 @@ int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
|
||||
}
|
||||
|
||||
bool QuantizationDriver::SetConstantResultParams(Operation *op) {
|
||||
ElementsAttr attr;
|
||||
DenseFPElementsAttr attr;
|
||||
Value res = op->getResult(0);
|
||||
if (!matchPattern(res, m_Constant(&attr))) {
|
||||
return false;
|
||||
@ -712,6 +746,8 @@ bool QuantizationDriver::PropagateParams() {
|
||||
Operation *op = work_list_.back();
|
||||
work_list_.pop_back();
|
||||
|
||||
LLVM_DEBUG(DumpStates(op));
|
||||
|
||||
// This op has been quantized, so we should not consider it again.
|
||||
if (llvm::is_contained(quantized_, op)) continue;
|
||||
quantized_.insert(op);
|
||||
@ -736,12 +772,23 @@ bool QuantizationDriver::PropagateParams() {
|
||||
}
|
||||
|
||||
// Use the final state to set all the operands' parameters.
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i)
|
||||
changed |= SetOperandParams(op, i, params);
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
|
||||
// Without this check, it will accidently propagate the quantization
|
||||
// information by the shared non-float tensors.
|
||||
if (type.getElementType().isa<FloatType>())
|
||||
changed |= SetOperandParams(op, i, params);
|
||||
}
|
||||
}
|
||||
|
||||
// Use the final state to set all the results' parameters.
|
||||
for (int res = 0, e = op->getNumResults(); res != e; ++res)
|
||||
changed |= SetResultParams(op, res, params);
|
||||
if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
|
||||
// Without this check, it will accidently propagate the quantization
|
||||
// information by the shared non-float-tensors.
|
||||
if (type.getElementType().isa<FloatType>())
|
||||
changed |= SetResultParams(op, res, params);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fengliuai): make the bit width configurable.
|
||||
|
@ -70,7 +70,8 @@ class FixedResultUniformScale {
|
||||
QuantizedType GetResultQuantizedType(int index) {
|
||||
auto op = this->getOperation();
|
||||
auto result_type =
|
||||
op->getResult(index).getType().template cast<TensorType>();
|
||||
op->getResult(index).getType().template cast<ShapedType>();
|
||||
if (!result_type.getElementType().template isa<FloatType>()) return {};
|
||||
Builder builder(op->getContext());
|
||||
IntegerType storage_type = builder.getIntegerType(BitWidth);
|
||||
const double scale = static_cast<double>(ScaleMantissa) *
|
||||
|
@ -399,7 +399,7 @@ static bool PreferResultScale(Operation* op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
|
||||
if (operand_type.getElementType().isa<FloatType>()) {
|
||||
if (float_operands++ > 1) return true;
|
||||
if (++float_operands > 1) return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -459,7 +459,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
}
|
||||
|
||||
// Step 2: backward pass: For the ops skiped in the forward pass, propagate
|
||||
// its results scale backwards.
|
||||
// its results scale backwards as far as possible.
|
||||
func.walk([&](quant::StatisticsOp stats_op) {
|
||||
if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) {
|
||||
all_stats_ops.push_back(stats_op);
|
||||
@ -471,8 +471,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
all_stats_ops.pop_back();
|
||||
|
||||
if (auto def = stats_op.arg().getDefiningOp()) {
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
||||
PreferResultScale(def)) {
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>()) {
|
||||
for (auto input : def->getOperands()) {
|
||||
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||
input.getDefiningOp())) {
|
||||
|
@ -150,7 +150,8 @@ struct QuantizationPattern : public RewritePattern {
|
||||
|
||||
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
|
||||
float error_tolerance, bool single_layer_verify)
|
||||
: RewritePattern(DQ::getOperationName(), 1, context),
|
||||
// Set the score to a large number so it is always preferred.
|
||||
: RewritePattern(DQ::getOperationName(), 300, context),
|
||||
enable_verify(enable_verify),
|
||||
error_tolerance(error_tolerance),
|
||||
single_layer_verify(single_layer_verify) {}
|
||||
@ -167,9 +168,12 @@ struct QuantizationPattern : public RewritePattern {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// If it is terminator or not quantizable, we shouldn't rewrite.
|
||||
// If it is terminator or not quantizable or any ops form the mlir quant
|
||||
// ops dialect, we shouldn't rewrite.
|
||||
if (quantized_op->isKnownTerminator() ||
|
||||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
||||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::QuantizeCastOp>(quantized_op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(quantized_op)) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
|
36
tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD
Normal file
36
tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD
Normal file
@ -0,0 +1,36 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_to_quant",
|
||||
srcs = [
|
||||
"tf_to_quant.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"passes.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
// Legalize the tf ops to the quant ops, so the quantization passes can work.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
|
@ -0,0 +1,19 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
@ -0,0 +1,148 @@
|
||||
// RUN: tf-opt -tf-to-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: fakeQuantPerChannelForActivation
|
||||
func @fakeQuantPerChannelForActivation(%arg0: tensor<8x3xf32>) -> (tensor<8x3xf32>) {
|
||||
%arg1 = constant dense<[0.0, -1.0, 1.0]> : tensor<3xf32>
|
||||
%arg2 = constant dense<[255.0, 254.0, 256.0]> : tensor<3xf32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<8x3xf32>
|
||||
return %0 : tensor<8x3xf32>
|
||||
|
||||
// CHECK: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0)
|
||||
// CHECK: %[[q:.*]] = "quant.qcast"(%[[fq]]) : (tensor<8x3xf32>) -> tensor<8x3x!quant.uniform<i8:f32:1, {1.000000e+00:-128,1.000000e+00:-127,1.000000e+00:-128}>>
|
||||
// CHECK: %[[dq:.*]] = "quant.dcast"(%[[q]])
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantForActivation
|
||||
func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) {
|
||||
^bb0(%arg0: tensor<8xf32>):
|
||||
%arg1 = constant dense<0.0> : tensor<f32>
|
||||
%arg2 = constant dense<255.0> : tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %0 : tensor<8xf32>
|
||||
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0)
|
||||
// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %2 = "quant.dcast"(%1)
|
||||
// CHECK: return %2
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantForActivationNoDuplication
|
||||
func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>) {
|
||||
^bb0(%arg0: tensor<8xf32>):
|
||||
%arg1 = constant dense<0.0> : tensor<f32>
|
||||
%arg2 = constant dense<255.0> : tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
%1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
return %1 : tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
|
||||
// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantFolded
|
||||
func @fakeQuantFolded() -> (tensor<8xf32>) {
|
||||
%in = constant dense<0.0> : tensor<8xf32>
|
||||
%min = constant dense<0.0> : tensor<f32>
|
||||
%max = constant dense<255.0> : tensor<f32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
|
||||
%rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %rst : tensor<8xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT]]) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantNotFolded
|
||||
func @fakeQuantNotFolded(tensor<8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8xf32>) {
|
||||
^bb0(%arg0: tensor<8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
|
||||
%1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %1 : tensor<8xf32>
|
||||
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantWithConv2D
|
||||
func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<f32>
|
||||
%max = constant dense<255.0> : tensor<f32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: perChannelFakeQuantWithConv2D
|
||||
func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<16xf32>
|
||||
%max = constant dense<255.0> : tensor<16xf32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
|
||||
// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
|
||||
// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]] : tensor<256x30x30x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantWithDepthwiseConv2D
|
||||
func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<f32>
|
||||
%max = constant dense<255.0> : tensor<f32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: perChannelFakeQuantWithDepthwiseConv2D
|
||||
func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<16xf32>
|
||||
%max = constant dense<255.0> : tensor<16xf32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
|
||||
// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
|
||||
// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
@ -0,0 +1,162 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The pass to legalize the quantization emulation ops from TF.
|
||||
//
|
||||
namespace {
|
||||
|
||||
// Legalize TF quantization emulation ops to that in Quant ops dialect.
|
||||
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
|
||||
explicit LegalizeTFToQuant() = default;
|
||||
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
|
||||
|
||||
/// Performs the lowering to Quant ops dialect.
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// TODO(fengliuai): move this rule to PreparePatterns.td
|
||||
// TODO(b/140968741): propagate the sign from the command line. Currently all
|
||||
// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
|
||||
// actually INT8.
|
||||
// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
|
||||
// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
|
||||
// folding logic will use a "std.constant" op to replace the
|
||||
// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
|
||||
// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
|
||||
// convert the output type to the next op. Here are the transformations:
|
||||
//
|
||||
// input min cst max cst input min cst max cst
|
||||
// \ | | \ | |
|
||||
// \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity)
|
||||
// \ | | \ | |
|
||||
// tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars
|
||||
// | |
|
||||
// tf.quantize
|
||||
// |
|
||||
// tf.dequantize
|
||||
// |
|
||||
// If the input is a constant, the result pattern will eventually converted to
|
||||
//
|
||||
// quant-emulated input
|
||||
// |
|
||||
// tf.quantize
|
||||
// |
|
||||
// tf.dequantize
|
||||
// |
|
||||
template <typename TFFakeQuantOp, bool PerAxis>
|
||||
struct InsertQuantOpsAfterTFFakeQuantOp
|
||||
: public OpRewritePattern<TFFakeQuantOp> {
|
||||
using BaseType = InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>;
|
||||
|
||||
explicit InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>(
|
||||
MLIRContext *ctx)
|
||||
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We don't want to insert quantize/dequantize if the quantize op exists.
|
||||
auto res = tf_op.outputs();
|
||||
if (!res.hasOneUse() || isa<quant::QuantizeCastOp>(*res.user_begin()))
|
||||
return this->matchFailure();
|
||||
|
||||
// Extract the min/max constant values from the operands. We also consider
|
||||
// a special case that there are tf.Identity ops between the min/max
|
||||
// constants and the tf.FakeQuantWithMinMaxVarsOp.
|
||||
Value min = tf_op.min(), max = tf_op.max();
|
||||
DenseFPElementsAttr min_value, max_value;
|
||||
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp())) {
|
||||
id1.replaceAllUsesWith(id1.input());
|
||||
min = tf_op.min();
|
||||
rewriter.eraseOp(id1);
|
||||
}
|
||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp())) {
|
||||
id2.replaceAllUsesWith(id2.input());
|
||||
max = tf_op.max();
|
||||
rewriter.eraseOp(id2);
|
||||
}
|
||||
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
|
||||
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
|
||||
|
||||
int quant_dim = -1;
|
||||
if (PerAxis) {
|
||||
// This is a special case that the quant_dim is the last dimensions
|
||||
// according to the tf.FakeQuantWithMinMaxPerChannel.
|
||||
quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
|
||||
}
|
||||
// Use the min/max from the operands and the num_bits and narrow_range
|
||||
// attribute to create the quantization parameter for the new quantize op.
|
||||
rewriter.setInsertionPointAfter(tf_op);
|
||||
IntegerAttr num_bits =
|
||||
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
|
||||
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
|
||||
Type res_type = tf_op.getType();
|
||||
TypeAttr qtype = quant::GetQuantizedTypeAttr(
|
||||
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
|
||||
narrow_range, /*is_signed=*/true);
|
||||
if (!qtype) this->matchFailure();
|
||||
|
||||
// Finally, use the quantization parameter to create the quantize and
|
||||
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
|
||||
// and its users.
|
||||
Value value = tf_op.outputs();
|
||||
auto quantize = rewriter.create<quant::QuantizeCastOp>(
|
||||
tf_op.getLoc(), qtype.getValue(), value);
|
||||
auto dequantize = rewriter.create<quant::DequantizeCastOp>(
|
||||
tf_op.getLoc(), res_type, quantize.getResult());
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
using PreparePerTensorFakeQuant =
|
||||
InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsOp, false>;
|
||||
|
||||
using PreparePerChannelFakeQuant =
|
||||
InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsPerChannelOp,
|
||||
true>;
|
||||
|
||||
// TODO(fengliuai): add the support of the tf.QuantizeAndDequantize*
|
||||
// legalization.
|
||||
|
||||
void LegalizeTFToQuant::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
auto *ctx = func.getContext();
|
||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
|
||||
return std::make_unique<LegalizeTFToQuant>();
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTFToQuant> pass(
|
||||
"tf-to-quant", "Legalize TF to quant ops dialect");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
40
tensorflow/compiler/mlir/lite/quantization/xla/BUILD
Normal file
40
tensorflow/compiler/mlir/lite/quantization/xla/BUILD
Normal file
@ -0,0 +1,40 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_xla_quantization_passes",
|
||||
srcs = [
|
||||
"materialize.cc",
|
||||
"op_quant_spec.inc",
|
||||
"propagate.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"passes.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
174
tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc
Normal file
174
tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc
Normal file
@ -0,0 +1,174 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass quantize the constant and rewrite the quantization
|
||||
// ops by xla_hlo primitive ops.
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The pass to materialize the quantization results by xla primitive ops.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pattern matches the "constant->qcast->dcast" pattern and replaces it by
|
||||
// "quantized constant->xla_hlo.dequantize". If it only matches the
|
||||
// "non-constant->qcast->dcast" pattern, it will remove both the "qcast->dcast".
|
||||
// We chain the pattern as a whole to bypass the type checks of the normal
|
||||
// xla_hlo ops.
|
||||
// TODO(fengliuai): make this pass work for bf16 input.
|
||||
class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
|
||||
public:
|
||||
explicit RewriteDequantize(int64_t size, MLIRContext *context)
|
||||
: OpRewritePattern<quant::DequantizeCastOp>(context), size_(size) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(quant::DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// quant.dcast
|
||||
// xla_hlo dequantize only takes min/max, so let's recover them from
|
||||
// the quantization parameters.
|
||||
Value dcast = op.arg();
|
||||
auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType());
|
||||
if (!type || !type.isa<quant::UniformQuantizedType>()) {
|
||||
return matchFailure();
|
||||
}
|
||||
auto qtype = type.cast<quant::UniformQuantizedType>();
|
||||
double scale = qtype.getScale();
|
||||
int64_t zero_point = qtype.getZeroPoint();
|
||||
float min = scale * (qtype.getStorageTypeMin() - zero_point);
|
||||
float max = scale * (qtype.getStorageTypeMax() - zero_point);
|
||||
|
||||
// quant.qcast
|
||||
auto qcast =
|
||||
llvm::dyn_cast_or_null<quant::QuantizeCastOp>(dcast.getDefiningOp());
|
||||
if (!qcast) return matchFailure();
|
||||
|
||||
// constant
|
||||
DenseFPElementsAttr attr;
|
||||
// If it isn't a floating-point constant or the size is too small, let's
|
||||
// remove the quantization. Also the last dimension size should be a
|
||||
// multiplier of 4, so the shape isn't broken during packing and unpacking.
|
||||
if (!matchPattern(qcast.arg(), m_Constant(&attr)) ||
|
||||
attr.getNumElements() <= size_ ||
|
||||
attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) {
|
||||
op.getResult().replaceAllUsesWith(qcast.arg());
|
||||
return matchSuccess();
|
||||
}
|
||||
// TODO(fengliuai): implement transpose if it has high dimension.
|
||||
|
||||
// Create the quantized result
|
||||
auto quantized_result =
|
||||
quant::Quantize(attr, qtype).dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!quantized_result) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Pack the uint8 bits to uint32. The shape is changed from from
|
||||
// [n0, n1, ..., nk] to [n0, n1, ..., nk / 4].
|
||||
std::vector<uint8_t> raw_data;
|
||||
for (auto d : quantized_result.getValues<uint8_t>()) {
|
||||
raw_data.push_back(d);
|
||||
}
|
||||
// The packing might increase the data size by paddings.
|
||||
auto packed_data = xla::PackToUint32<uint8_t>(raw_data);
|
||||
auto packed_shape = attr.getType().getShape().vec();
|
||||
int lower_dims = std::accumulate(
|
||||
packed_shape.begin(),
|
||||
std::next(packed_shape.begin(), packed_shape.size() - 1), 1,
|
||||
std::multiplies<int>());
|
||||
packed_shape[packed_shape.size() - 1] = packed_data.size() / lower_dims;
|
||||
auto packed_type =
|
||||
RankedTensorType::get(packed_shape, rewriter.getIntegerType(32));
|
||||
|
||||
auto packed_quantized_result =
|
||||
DenseElementsAttr::get<uint32_t>(packed_type, packed_data);
|
||||
auto quantized_constant =
|
||||
rewriter.create<ConstantOp>(qcast.getLoc(), packed_quantized_result);
|
||||
|
||||
// Create the xla dequantize op with bf16 output
|
||||
auto dequantized_type = RankedTensorType::get(attr.getType().getShape(),
|
||||
rewriter.getBF16Type());
|
||||
auto dequantize = rewriter.create<DequantizeOp>(
|
||||
qcast.getLoc(), dequantized_type, quantized_constant,
|
||||
rewriter.getF32FloatAttr(min), rewriter.getF32FloatAttr(max),
|
||||
rewriter.getStringAttr("MIN_COMBINED"), rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
||||
// Convert bf16 output back to f32
|
||||
rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getResult().getType(),
|
||||
dequantize);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t size_;
|
||||
};
|
||||
|
||||
// Materialize the quantization results by hlo primitive ops.
|
||||
struct MaterializeToXlaPass : public FunctionPass<MaterializeToXlaPass> {
|
||||
explicit MaterializeToXlaPass() = default;
|
||||
MaterializeToXlaPass(const MaterializeToXlaPass &) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void MaterializeToXlaPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
MLIRContext *ctx = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
// TODO(fengliuai): make the size 6 configurable.
|
||||
patterns.insert<RewriteDequantize>(6, ctx);
|
||||
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass() {
|
||||
return std::make_unique<MaterializeToXlaPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<MaterializeToXlaPass> pass(
|
||||
"xla-hlo-materialize-quant",
|
||||
"Materialize the quantization results by xla primitve ops");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -0,0 +1,7 @@
|
||||
// TODO(fengliuai): automatically generate this file
|
||||
// TODO(fengliuai): add all the xla_hlo ops
|
||||
|
||||
static std::unique_ptr<quant::OpQuantSpec> GetOpQuantSpec(mlir::Operation *op) {
|
||||
auto spec = absl::make_unique<quant::OpQuantSpec>();
|
||||
return spec;
|
||||
}
|
37
tensorflow/compiler/mlir/lite/quantization/xla/passes.h
Normal file
37
tensorflow/compiler/mlir/lite/quantization/xla/passes.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
// Propagate the quantization information to all the tensors according to the
|
||||
// op quant spec.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass();
|
||||
|
||||
// Rewrite the graph and quantize the constant.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass();
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
|
78
tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc
Normal file
78
tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc
Normal file
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass applies quantization propagation on xla_hlo dialect.
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> disable_per_channel(
|
||||
"xla-disable-per-channel", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("Whether disable per-channel quantized weights."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The quantization propagation Pass.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// Applies the quantization propagation on the input function. During the
|
||||
// propagation, two facts are respected:
|
||||
// - The quantization type (params) of the ops in the function
|
||||
// - The quantization spec for the ops
|
||||
// The propagation results should assign quantization types to all the tensors
|
||||
// and the two restrictions are respected.
|
||||
struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
|
||||
explicit PropagateQuantPass() = default;
|
||||
PropagateQuantPass(const PropagateQuantPass &) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc"
|
||||
|
||||
void PropagateQuantPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
// XLA only support uint8/uint16 quantization for now.
|
||||
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
|
||||
disable_per_channel, GetOpQuantSpec);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass() {
|
||||
return std::make_unique<PropagateQuantPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<PropagateQuantPass> pass(
|
||||
"xla-hlo-propagate-quant", "Propagate quantization information");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
19
tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
Normal file
19
tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
Normal file
@ -0,0 +1,19 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
@ -0,0 +1,54 @@
|
||||
// RUN: tf-opt -xla-hlo-materialize-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @quantize_rewrite
|
||||
func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK: %[[qcst:.*]] = constant dense<{{\[\[}}21004416], [-1056997248]]> : tensor<2x1xi32>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32,
|
||||
// CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16>
|
||||
// CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[cast]] : tensor<2x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
|
||||
|
||||
%w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32>
|
||||
return %mul: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_small
|
||||
func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<1x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<1x4xf32>
|
||||
|
||||
%w = constant dense<1.0> : tensor<1x4xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<1x4xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<1x4xf32>
|
||||
return %mul: tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_non_cst
|
||||
func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %arg0 : tensor<2x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
|
||||
|
||||
%q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32>
|
||||
return %mul: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_non_4x
|
||||
func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<2x5xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x5xf32>
|
||||
|
||||
%w = constant dense<1.0> : tensor<2x5xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x5xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<2x5xf32>
|
||||
return %mul: tensor<2x5xf32>
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul
|
||||
func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[dq]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x2xf32>
|
||||
%w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32>
|
||||
%mul = xla_hlo.mul %arg0, %w : tensor<2x2xf32>
|
||||
return %mul: tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add
|
||||
func @add(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[b:.*]] = constant dense<1.000000e+00> : tensor<2xf32>
|
||||
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[b]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>) -> tensor<2xf32>
|
||||
// CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg0, %[[dq]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[add]] : tensor<2x2xf32>
|
||||
%b = constant dense<1.0> : tensor<2xf32>
|
||||
%add = "xla_hlo.add"(%arg0, %b) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
|
||||
return %add: tensor<2x2xf32>
|
||||
}
|
@ -15,5 +15,6 @@ filegroup(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -test-constant-fold | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: @add_float
|
||||
func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
|
||||
|
@ -10,9 +10,7 @@ glob_lit_tests(
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"pbtxt",
|
||||
# TODO(fengliuai): reenable these tests after the fused loc is
|
||||
# supported in the diagnostic handler.
|
||||
# "py",
|
||||
"py",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -178,15 +178,20 @@ func @inputsAfterOutputs() {
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{Found malformed ophint regions: missing inputs or outputs.}}
|
||||
module {
|
||||
func @extractOphintFailure() {
|
||||
func @extractOphintSame() {
|
||||
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
|
||||
%1 = call @AnotherFunc(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
%2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
%3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
%4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
return
|
||||
|
||||
// CHECK: [[VAL_0:%.*]] = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
|
||||
// CHECK: [[VAL_1:%.*]] = call @AnotherFunc([[VAL_0]]) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
// CHECK: [[VAL_2:%.*]] = "tf.Sigmoid"([[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
// CHECK: [[VAL_3:%.*]] = "tf.Mul"([[VAL_2]], [[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tf.Identity"([[VAL_3]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
}
|
||||
|
||||
func @AnotherFunc(%arg0: tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> {
|
||||
|
@ -1,25 +1,31 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
// Check to see if function references in while loops are preserved
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// TODO(b/138222071) Expect first output to be a scalar
|
||||
// CHECK: %{{.*}}:2 = "tf.While"(%{{.*}}, %{{.*}}) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>)
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// While %arg0 is greater than zero, element wise add %arg1 with itself.
|
||||
%0:2 = "tf.While"(%arg0, %arg1) {
|
||||
cond = @cond, body = @body, is_stateless = false
|
||||
} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
|
||||
%0:2 = "tfl.while"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
|
||||
%1 = call @cond(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> tensor<i1>
|
||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
|
||||
%1:2 = call @body(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
|
||||
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||
}) {is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||
%0 = "std.constant" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
||||
%1 = "tfl.greater"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
%cst = constant dense<0> : tensor<i32> loc("Const")
|
||||
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
|
||||
%0 = "std.constant" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
||||
%1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
return %1, %2 : tensor<*xi32>, tensor<*xf32>
|
||||
%cst = constant dense<1> : tensor<i32> loc("Const")
|
||||
%0 = "tfl.sub"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
return %0, %1 : tensor<*xi32>, tensor<*xf32>
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user