Merge branch 'master' into Fix_FileWriter
This commit is contained in:
commit
56cafc2c6c
55
.bazelrc
55
.bazelrc
@ -37,7 +37,6 @@
|
|||||||
# v2: Build TF v2
|
# v2: Build TF v2
|
||||||
#
|
#
|
||||||
# Feature and Third party library support options:
|
# Feature and Third party library support options:
|
||||||
# xla: Build TF with XLA
|
|
||||||
# using_cuda: CUDA is available to build system.
|
# using_cuda: CUDA is available to build system.
|
||||||
# cuda: Build with full cuda support.
|
# cuda: Build with full cuda support.
|
||||||
# rocm: Build with AMD GPU support (rocm).
|
# rocm: Build with AMD GPU support (rocm).
|
||||||
@ -222,6 +221,19 @@ build --define=grpc_no_ares=true
|
|||||||
# archives in -whole_archive -no_whole_archive.
|
# archives in -whole_archive -no_whole_archive.
|
||||||
build --noincompatible_remove_legacy_whole_archive
|
build --noincompatible_remove_legacy_whole_archive
|
||||||
|
|
||||||
|
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
|
||||||
|
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
|
||||||
|
# https://github.com/tensorflow/community/pull/179
|
||||||
|
build --noincompatible_prohibit_aapt1
|
||||||
|
|
||||||
|
# Enable XLA
|
||||||
|
build --action_env=TF_ENABLE_XLA=1
|
||||||
|
build --define=with_xla_support=true
|
||||||
|
|
||||||
|
# Keep config XLA until all build scripts are cleaned up.
|
||||||
|
build:xla --action_env=TF_ENABLE_XLA=1
|
||||||
|
build:xla --define=with_xla_support=true
|
||||||
|
|
||||||
# Modular TF build options
|
# Modular TF build options
|
||||||
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
||||||
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
||||||
@ -307,29 +319,29 @@ build:v2 --action_env=TF2_BEHAVIOR=1
|
|||||||
build --config=v2
|
build --config=v2
|
||||||
test --config=v2
|
test --config=v2
|
||||||
|
|
||||||
# Enable XLA
|
|
||||||
build:xla --action_env=TF_ENABLE_XLA=1
|
|
||||||
build:xla --define=with_xla_support=true
|
|
||||||
|
|
||||||
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
||||||
# Options when using remote execution
|
# Options when using remote execution
|
||||||
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
||||||
|
|
||||||
|
# Flag to enable remote config
|
||||||
|
common --experimental_repo_remote_exec
|
||||||
|
|
||||||
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
|
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
|
||||||
build:rbe --auth_enabled=true
|
build:rbe --google_default_credentials
|
||||||
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
|
|
||||||
build:rbe --bes_backend=buildeventservice.googleapis.com
|
build:rbe --bes_backend=buildeventservice.googleapis.com
|
||||||
build:rbe --bes_best_effort=false
|
|
||||||
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
|
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
|
||||||
build:rbe --bes_timeout=600s
|
build:rbe --bes_timeout=600s
|
||||||
build:rbe --define=EXECUTOR=remote
|
build:rbe --define=EXECUTOR=remote
|
||||||
|
build:rbe --distinct_host_configuration=false
|
||||||
build:rbe --flaky_test_attempts=3
|
build:rbe --flaky_test_attempts=3
|
||||||
build:rbe --jobs=200
|
build:rbe --jobs=200
|
||||||
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
|
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
|
||||||
build:rbe --remote_timeout=3600
|
build:rbe --remote_timeout=3600
|
||||||
build:rbe --spawn_strategy=remote,worker,standalone,local
|
build:rbe --spawn_strategy=remote,worker,standalone,local
|
||||||
test:rbe --test_env=USER=anon
|
test:rbe --test_env=USER=anon
|
||||||
|
# Attempt to minimize the amount of data transfer between bazel and the remote
|
||||||
build:rbe --distinct_host_configuration=false
|
# workers:
|
||||||
|
build:rbe --remote_download_toplevel
|
||||||
|
|
||||||
build:rbe_linux --config=rbe
|
build:rbe_linux --config=rbe
|
||||||
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
|
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
|
||||||
@ -339,7 +351,6 @@ build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
|||||||
build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
||||||
|
|
||||||
# Non-rbe settings we should include because we do not run configure
|
# Non-rbe settings we should include because we do not run configure
|
||||||
build:rbe_linux --config=xla
|
|
||||||
build:rbe_linux --config=avx_linux
|
build:rbe_linux --config=avx_linux
|
||||||
build:rbe_linux --config=short_logs
|
build:rbe_linux --config=short_logs
|
||||||
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
||||||
@ -354,13 +365,14 @@ build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe
|
|||||||
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||||
|
|
||||||
build:rbe_linux_cuda_nvcc --config=rbe_linux
|
build:rbe_linux_cuda_nvcc --config=rbe_linux
|
||||||
build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
|
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||||
build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
|
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010-gpu"
|
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||||
build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||||
build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
|
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0"
|
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||||
|
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
|
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
|
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
|
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
|
||||||
@ -377,9 +389,8 @@ build:rbe_linux_py2 --python_path="/usr/bin/python2"
|
|||||||
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
|
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
|
||||||
|
|
||||||
build:rbe_linux_py3 --config=rbe_linux
|
build:rbe_linux_py3 --config=rbe_linux
|
||||||
build:rbe_linux_py3 --repo_env=PYTHON_BIN_PATH="/usr/bin/python3"
|
|
||||||
build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
||||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
|
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
|
||||||
|
|
||||||
build:rbe_win --config=rbe
|
build:rbe_win --config=rbe
|
||||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
|
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
|
||||||
@ -396,9 +407,7 @@ build:rbe_win --define=override_eigen_strong_inline=true
|
|||||||
build:rbe_win --jobs=500
|
build:rbe_win --jobs=500
|
||||||
|
|
||||||
build:rbe_win_py37 --config=rbe
|
build:rbe_win_py37 --config=rbe
|
||||||
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
|
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python"
|
||||||
build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
|
|
||||||
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
|
|
||||||
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
||||||
|
|
||||||
build:rbe_win_py38 --config=rbe
|
build:rbe_win_py38 --config=rbe
|
||||||
|
@ -1 +1 @@
|
|||||||
1.2.1
|
2.0.0
|
||||||
|
44
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
Normal file
44
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
name: Bug Issue
|
||||||
|
about: Use this template for reporting a bug
|
||||||
|
labels: 'type:bug'
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<em>Please make sure that this is a bug. As per our
|
||||||
|
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
|
||||||
|
we only address code/doc bugs, performance issues, feature requests and
|
||||||
|
build/installation issues on GitHub. tag:bug_template</em>
|
||||||
|
|
||||||
|
**System information**
|
||||||
|
- Have I written custom code (as opposed to using a stock
|
||||||
|
example script provided in TensorFlow):
|
||||||
|
- OS Platform and Distribution (e.g.,
|
||||||
|
Linux Ubuntu 16.04):
|
||||||
|
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||||
|
the issue happens on mobile device:
|
||||||
|
- TensorFlow installed from (source or
|
||||||
|
binary): - TensorFlow version (use command below):
|
||||||
|
- Python version: - Bazel
|
||||||
|
version (if compiling from source):
|
||||||
|
- GCC/Compiler version (if compiling from
|
||||||
|
source):
|
||||||
|
- CUDA/cuDNN version: - GPU model and memory:
|
||||||
|
|
||||||
|
You can collect some of this information using our environment capture
|
||||||
|
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||||
|
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||||
|
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||||
|
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||||
|
|
||||||
|
**Describe the current behavior**
|
||||||
|
|
||||||
|
**Describe the expected behavior**
|
||||||
|
|
||||||
|
**Standalone code to reproduce the issue**
|
||||||
|
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||||
|
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||||
|
|
||||||
|
**Other info / logs** Include any logs or source code that would be helpful to
|
||||||
|
diagnose the problem. If including tracebacks, please include the full
|
||||||
|
traceback. Large logs and files should be attached.
|
@ -1,35 +0,0 @@
|
|||||||
---
|
|
||||||
name: Bug/Performance Issue
|
|
||||||
about: Use this template for reporting a bug or a performance issue.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
<em>Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template</em>
|
|
||||||
|
|
||||||
**System information**
|
|
||||||
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
|
||||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
|
||||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
|
||||||
- TensorFlow installed from (source or binary):
|
|
||||||
- TensorFlow version (use command below):
|
|
||||||
- Python version:
|
|
||||||
- Bazel version (if compiling from source):
|
|
||||||
- GCC/Compiler version (if compiling from source):
|
|
||||||
- CUDA/cuDNN version:
|
|
||||||
- GPU model and memory:
|
|
||||||
|
|
||||||
You can collect some of this information using our environment capture
|
|
||||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
|
||||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
|
||||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
|
||||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
|
||||||
|
|
||||||
**Describe the current behavior**
|
|
||||||
|
|
||||||
**Describe the expected behavior**
|
|
||||||
|
|
||||||
**Code to reproduce the issue**
|
|
||||||
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
|
|
||||||
|
|
||||||
**Other info / logs**
|
|
||||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
|
@ -1,6 +1,7 @@
|
|||||||
---
|
---
|
||||||
name: Build/Installation Issue
|
name: Build/Installation Issue
|
||||||
about: Use this template for build/installation issues
|
about: Use this template for build/installation issues
|
||||||
|
labels: 'type:build/install'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
---
|
---
|
||||||
name: Documentation Issue
|
name: Documentation Issue
|
||||||
about: Use this template for documentation related
|
about: Use this template for documentation related issues
|
||||||
labels: 'type:docs'
|
labels: 'type:docs'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
||||||
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
|
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
|
||||||
policy, we only address code/doc bugs, performance issues, feature requests, and
|
policy, we only address code/doc bugs, performance issues, feature requests, and
|
||||||
build/installation issues on GitHub.
|
build/installation issues on GitHub.
|
||||||
|
4
.github/ISSUE_TEMPLATE/30-feature-request.md
vendored
4
.github/ISSUE_TEMPLATE/30-feature-request.md
vendored
@ -1,6 +1,6 @@
|
|||||||
---
|
---
|
||||||
name: Feature Request
|
name: Feature Request about: Use this template for raising a feature request
|
||||||
about: Use this template for raising a feature request
|
labels: 'type:feature'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
12
.github/ISSUE_TEMPLATE/40-tflite-op-request.md
vendored
12
.github/ISSUE_TEMPLATE/40-tflite-op-request.md
vendored
@ -1,10 +1,10 @@
|
|||||||
---
|
---
|
||||||
name: TensorFlow Lite Op Request
|
name: TensorFlow Lite Op Request
|
||||||
about: Use this template for reporting ops you are using or missing.
|
about: Use this template for reporting Lite ops you are using or missing
|
||||||
|
labels: 'comp:lite'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
||||||
**System information**
|
**System information**
|
||||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||||
- TensorFlow installed from (source or binary):
|
- TensorFlow installed from (source or binary):
|
||||||
@ -17,8 +17,14 @@ about: Use this template for reporting ops you are using or missing.
|
|||||||
# Copy and paste here
|
# Copy and paste here
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Standalone code to reproduce the issue**
|
||||||
|
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||||
|
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||||
|
|
||||||
Also, please include a link to a GraphDef or the model if possible.
|
Also, please include a link to a GraphDef or the model if possible.
|
||||||
|
|
||||||
**Any other info / logs**
|
**Any other info / logs**
|
||||||
|
|
||||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
Include any logs or source code that would be helpful to diagnose the problem.
|
||||||
|
If including tracebacks, please include the full traceback. Large logs and files
|
||||||
|
should be attached.
|
||||||
|
1
.github/ISSUE_TEMPLATE/50-other-issues.md
vendored
1
.github/ISSUE_TEMPLATE/50-other-issues.md
vendored
@ -1,6 +1,7 @@
|
|||||||
---
|
---
|
||||||
name: Other Issues
|
name: Other Issues
|
||||||
about: Use this template for any other non-support related issues
|
about: Use this template for any other non-support related issues
|
||||||
|
labels: 'type:others'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
---
|
---
|
||||||
name: TensorFlow Lite New Converter Issue
|
name: TensorFlow Lite New Converter Issue
|
||||||
about: Use this template for reporting issues during model conversion to TFLite.
|
about: Use this template for reporting issues during model conversion to TFLite
|
||||||
|
labels: 'TFLiteConverter'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite.
|
|||||||
|
|
||||||
|
|
||||||
**Command used to run the converter or code if you’re using the Python API**
|
**Command used to run the converter or code if you’re using the Python API**
|
||||||
|
If possible, please share a link to Colab/Jupyter/any notebook.
|
||||||
|
|
||||||
```
|
```
|
||||||
# Copy and paste here the exact command
|
# Copy and paste here the exact command
|
||||||
|
45
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
Normal file
45
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
---
|
||||||
|
name: Performance Issue
|
||||||
|
about: Use this template for reporting a performance issue
|
||||||
|
labels: 'type:performance'
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<em>Please make sure that this is an issue related to performance of TensorFlow.
|
||||||
|
As per our
|
||||||
|
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
|
||||||
|
we only address code/doc bugs, performance issues, feature requests and
|
||||||
|
build/installation issues on GitHub. tag:performance_template</em>
|
||||||
|
|
||||||
|
**System information**
|
||||||
|
- Have I written custom code (as opposed to using a stock
|
||||||
|
example script provided in TensorFlow):
|
||||||
|
- OS Platform and Distribution (e.g.,
|
||||||
|
Linux Ubuntu 16.04):
|
||||||
|
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||||
|
the issue happens on mobile device:
|
||||||
|
- TensorFlow installed from (source or
|
||||||
|
binary): - TensorFlow version (use command below):
|
||||||
|
- Python version: - Bazel
|
||||||
|
version (if compiling from source):
|
||||||
|
- GCC/Compiler version (if compiling from
|
||||||
|
source):
|
||||||
|
- CUDA/cuDNN version: - GPU model and memory:
|
||||||
|
|
||||||
|
You can collect some of this information using our environment capture
|
||||||
|
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||||
|
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||||
|
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||||
|
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||||
|
|
||||||
|
**Describe the current behavior**
|
||||||
|
|
||||||
|
**Describe the expected behavior**
|
||||||
|
|
||||||
|
**Standalone code to reproduce the issue**
|
||||||
|
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||||
|
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||||
|
|
||||||
|
**Other info / logs** Include any logs or source code that would be helpful to
|
||||||
|
diagnose the problem. If including tracebacks, please include the full
|
||||||
|
traceback. Large logs and files should be attached.
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -22,6 +22,7 @@ tensorflow/contrib/cmake/_build/
|
|||||||
/tensorflow/python/framework/fast_tensor_util.cpp
|
/tensorflow/python/framework/fast_tensor_util.cpp
|
||||||
/tensorflow/lite/gen/**
|
/tensorflow/lite/gen/**
|
||||||
/tensorflow/lite/tools/make/downloads/**
|
/tensorflow/lite/tools/make/downloads/**
|
||||||
|
/tensorflow/lite/tools/make/gen/**
|
||||||
/api_init_files_list.txt
|
/api_init_files_list.txt
|
||||||
/estimator_api_init_files_list.txt
|
/estimator_api_init_files_list.txt
|
||||||
*.whl
|
*.whl
|
||||||
|
@ -70,7 +70,7 @@ $ python
|
|||||||
3
|
3
|
||||||
>>> hello = tf.constant('Hello, TensorFlow!')
|
>>> hello = tf.constant('Hello, TensorFlow!')
|
||||||
>>> hello.numpy()
|
>>> hello.numpy()
|
||||||
'Hello, TensorFlow!'
|
b'Hello, TensorFlow!'
|
||||||
```
|
```
|
||||||
|
|
||||||
For more examples, see the
|
For more examples, see the
|
||||||
|
29
WORKSPACE
29
WORKSPACE
@ -1,13 +1,11 @@
|
|||||||
workspace(name = "org_tensorflow")
|
workspace(name = "org_tensorflow")
|
||||||
|
|
||||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||||
load("//third_party:repo.bzl", "tf_http_archive")
|
|
||||||
|
|
||||||
tf_http_archive(
|
http_archive(
|
||||||
name = "io_bazel_rules_closure",
|
name = "io_bazel_rules_closure",
|
||||||
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
||||||
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
||||||
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
|
|
||||||
urls = [
|
urls = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
||||||
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
||||||
@ -115,3 +113,28 @@ http_archive(
|
|||||||
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Required for dependency @com_github_grpc_grpc
|
||||||
|
|
||||||
|
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||||
|
|
||||||
|
grpc_deps()
|
||||||
|
|
||||||
|
load(
|
||||||
|
"@build_bazel_rules_apple//apple:repositories.bzl",
|
||||||
|
"apple_rules_dependencies",
|
||||||
|
)
|
||||||
|
|
||||||
|
apple_rules_dependencies()
|
||||||
|
|
||||||
|
load(
|
||||||
|
"@build_bazel_apple_support//lib:repositories.bzl",
|
||||||
|
"apple_support_dependencies",
|
||||||
|
)
|
||||||
|
|
||||||
|
apple_support_dependencies()
|
||||||
|
|
||||||
|
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
|
||||||
|
|
||||||
|
bazel_version_repository(name = "bazel_version")
|
||||||
|
|
||||||
|
@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
|||||||
_TF_WORKSPACE_ROOT = ''
|
_TF_WORKSPACE_ROOT = ''
|
||||||
_TF_BAZELRC = ''
|
_TF_BAZELRC = ''
|
||||||
_TF_CURRENT_BAZEL_VERSION = None
|
_TF_CURRENT_BAZEL_VERSION = None
|
||||||
_TF_MIN_BAZEL_VERSION = '1.2.1'
|
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||||
_TF_MAX_BAZEL_VERSION = '1.2.1'
|
_TF_MAX_BAZEL_VERSION = '2.0.0'
|
||||||
|
|
||||||
NCCL_LIB_PATHS = [
|
NCCL_LIB_PATHS = [
|
||||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||||
@ -1390,10 +1390,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
environ_cp['TF_CONFIGURE_IOS'] = '0'
|
environ_cp['TF_CONFIGURE_IOS'] = '0'
|
||||||
|
|
||||||
xla_enabled_by_default = is_linux() or is_macos()
|
|
||||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
|
||||||
xla_enabled_by_default, 'xla')
|
|
||||||
|
|
||||||
set_action_env_var(
|
set_action_env_var(
|
||||||
environ_cp,
|
environ_cp,
|
||||||
'TF_NEED_OPENCL_SYCL',
|
'TF_NEED_OPENCL_SYCL',
|
||||||
|
@ -187,6 +187,12 @@ config_setting(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "fuchsia",
|
||||||
|
values = {"cpu": "fuchsia"},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "ios_x86_64",
|
name = "ios_x86_64",
|
||||||
values = {
|
values = {
|
||||||
@ -448,19 +454,66 @@ config_setting(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Specifies via a config setting if this is a mobile build or not, makes
|
||||||
|
# it easier to combine settings later.
|
||||||
|
selects.config_setting_group(
|
||||||
|
name = "mobile",
|
||||||
|
match_any = [
|
||||||
|
":android",
|
||||||
|
":chromiumos",
|
||||||
|
":emscripten",
|
||||||
|
":ios",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "lite_protos_legacy",
|
||||||
|
values = {"define": "TENSORFLOW_PROTOS=lite"},
|
||||||
|
visibility = ["//visibility:private"],
|
||||||
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "full_protos",
|
||||||
|
values = {"define": "TENSORFLOW_PROTOS=full"},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
selects.config_setting_group(
|
||||||
|
name = "lite_protos",
|
||||||
|
match_any = [":lite_protos_legacy"],
|
||||||
|
)
|
||||||
|
|
||||||
|
selects.config_setting_group(
|
||||||
|
name = "mobile_lite_protos",
|
||||||
|
match_all = [
|
||||||
|
":lite_protos",
|
||||||
|
":mobile",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
selects.config_setting_group(
|
||||||
|
name = "mobile_full_protos",
|
||||||
|
match_all = [
|
||||||
|
":full_protos",
|
||||||
|
":mobile",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
||||||
# Instead, please use public APIs or public build rules TF provides.
|
# Instead, please use public APIs or public build rules TF provides.
|
||||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||||
package_group(
|
package_group(
|
||||||
name = "internal",
|
name = "internal",
|
||||||
packages = [
|
packages = [
|
||||||
|
# To pass open source testing in the pip Kokoros.
|
||||||
|
"//bazel_pip/tensorflow/...",
|
||||||
"//learning/brain/swift/x10/...",
|
"//learning/brain/swift/x10/...",
|
||||||
"//perftools/accelerators/xprof/api/...",
|
"//perftools/accelerators/xprof/api/...",
|
||||||
|
"//third_party/py/autograph/...",
|
||||||
|
"//third_party/swift/tensorflow/x10/...",
|
||||||
"//tensorflow/...",
|
"//tensorflow/...",
|
||||||
"//tensorflow_estimator/python/estimator/...",
|
"//tensorflow_estimator/python/estimator/...",
|
||||||
"//tensorflow_models/official/...",
|
"//tensorflow_models/official/...",
|
||||||
"//third_party/py/autograph/...",
|
|
||||||
"//third_party/swift/tensorflow/x10/...",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -494,8 +547,8 @@ cc_library(
|
|||||||
name = "grpc",
|
name = "grpc",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = select({
|
deps = select({
|
||||||
":linux_s390x": ["@grpc//:grpc_unsecure"],
|
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
|
||||||
"//conditions:default": ["@grpc"],
|
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -503,8 +556,8 @@ cc_library(
|
|||||||
name = "grpc++",
|
name = "grpc++",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = select({
|
deps = select({
|
||||||
":linux_s390x": ["@grpc//:grpc++_unsecure"],
|
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
|
||||||
"//conditions:default": ["@grpc//:grpc++"],
|
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -589,7 +642,7 @@ tf_cc_shared_object(
|
|||||||
"//tensorflow/core:gpu_runtime_impl",
|
"//tensorflow/core:gpu_runtime_impl",
|
||||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||||
"//tensorflow/core:lib_internal_impl",
|
"//tensorflow/core:lib_internal_impl",
|
||||||
"//tensorflow/core/profiler/lib:profiler_session_impl",
|
"//tensorflow/core/profiler:profiler_impl",
|
||||||
"//tensorflow/stream_executor:stream_executor_impl",
|
"//tensorflow/stream_executor:stream_executor_impl",
|
||||||
"//tensorflow:tf_framework_version_script.lds",
|
"//tensorflow:tf_framework_version_script.lds",
|
||||||
] + tf_additional_binary_deps(),
|
] + tf_additional_binary_deps(),
|
||||||
@ -909,7 +962,6 @@ py_library(
|
|||||||
"//conditions:default": [":tf_python_api_gen_v1"],
|
"//conditions:default": [":tf_python_api_gen_v1"],
|
||||||
}) + [
|
}) + [
|
||||||
":root_init_gen",
|
":root_init_gen",
|
||||||
":virtual_root_init_gen",
|
|
||||||
"//tensorflow/python/keras/api:keras_python_api_gen",
|
"//tensorflow/python/keras/api:keras_python_api_gen",
|
||||||
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v1",
|
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v1",
|
||||||
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v2",
|
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v2",
|
||||||
|
@ -35,9 +35,11 @@ import inspect as _inspect
|
|||||||
import logging as _logging
|
import logging as _logging
|
||||||
import os as _os
|
import os as _os
|
||||||
import site as _site
|
import site as _site
|
||||||
|
import six as _six
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
|
||||||
from tensorflow.python.tools import module_util as _module_util
|
from tensorflow.python.tools import module_util as _module_util
|
||||||
|
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
@ -69,13 +71,13 @@ except ImportError:
|
|||||||
_logging.warning(
|
_logging.warning(
|
||||||
"Limited tf.summary API due to missing TensorBoard installation.")
|
"Limited tf.summary API due to missing TensorBoard installation.")
|
||||||
|
|
||||||
try:
|
# Lazy-load estimator.
|
||||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
|
||||||
_current_module.__path__ = (
|
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||||
|
if _module_dir:
|
||||||
|
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||||
setattr(_current_module, "estimator", estimator)
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .python.keras.api._v2 import keras
|
from .python.keras.api._v2 import keras
|
||||||
@ -85,6 +87,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
if not _six.PY2:
|
||||||
|
import typing as _typing
|
||||||
|
if _typing.TYPE_CHECKING:
|
||||||
|
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||||
|
# pylint: enable=g-import-not-at-top
|
||||||
|
|
||||||
# Enable TF2 behaviors
|
# Enable TF2 behaviors
|
||||||
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
||||||
|
@ -22,12 +22,14 @@ import distutils as _distutils
|
|||||||
import inspect as _inspect
|
import inspect as _inspect
|
||||||
import os as _os
|
import os as _os
|
||||||
import site as _site
|
import site as _site
|
||||||
|
import six as _six
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
|
||||||
# pylint: disable=g-bad-import-order
|
# pylint: disable=g-bad-import-order
|
||||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||||
from tensorflow.python.tools import module_util as _module_util
|
from tensorflow.python.tools import module_util as _module_util
|
||||||
from tensorflow.python.platform import tf_logging as _logging
|
from tensorflow.python.platform import tf_logging as _logging
|
||||||
|
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
@ -64,13 +66,14 @@ elif _tf_api_dir not in __path__:
|
|||||||
# reexport_tf_summary can get compat from sys.modules. Only needed if using
|
# reexport_tf_summary can get compat from sys.modules. Only needed if using
|
||||||
# lazy loading.
|
# lazy loading.
|
||||||
_current_module.compat.v2 # pylint: disable=pointless-statement
|
_current_module.compat.v2 # pylint: disable=pointless-statement
|
||||||
try:
|
|
||||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
# Lazy-load estimator.
|
||||||
_current_module.__path__ = (
|
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||||
|
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||||
|
if _module_dir:
|
||||||
|
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||||
setattr(_current_module, "estimator", estimator)
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .python.keras.api._v1 import keras
|
from .python.keras.api._v1 import keras
|
||||||
@ -80,6 +83,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
if not _six.PY2:
|
||||||
|
import typing as _typing
|
||||||
|
if _typing.TYPE_CHECKING:
|
||||||
|
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||||
|
# pylint: enable=g-import-not-at-top
|
||||||
|
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
|
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
|
||||||
_CONTRIB_WARNING = """
|
_CONTRIB_WARNING = """
|
||||||
|
@ -536,6 +536,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core/kernels:array",
|
"//tensorflow/core/kernels:array",
|
||||||
"//tensorflow/core/kernels:control_flow_ops",
|
"//tensorflow/core/kernels:control_flow_ops",
|
||||||
"//tensorflow/core/kernels:math",
|
"//tensorflow/core/kernels:math",
|
||||||
|
"//tensorflow/core/platform:resource_loader",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -647,12 +648,14 @@ tf_cuda_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":c_api",
|
":c_api",
|
||||||
":kernels",
|
":kernels",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/kernels:ops_testutil",
|
"//tensorflow/core/kernels:ops_testutil",
|
||||||
|
"//third_party/eigen3",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
#include "tensorflow/core/platform/net.h"
|
#include "tensorflow/core/platform/net.h"
|
||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
@ -816,12 +817,15 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
|||||||
|
|
||||||
const int num_inputs = input_shapes->num_items;
|
const int num_inputs = input_shapes->num_items;
|
||||||
NodeDef node_def;
|
NodeDef node_def;
|
||||||
node_def.set_name(tfe_op->operation.Name());
|
node_def.set_name(tfe_op->operation->Name());
|
||||||
node_def.set_op(tfe_op->operation.Name());
|
node_def.set_op(tfe_op->operation->Name());
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
node_def.add_input("dummy_input");
|
node_def.add_input("dummy_input");
|
||||||
}
|
}
|
||||||
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
|
tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
tfe_op->operation.get())
|
||||||
|
->Attrs()
|
||||||
|
.FillAttrValueMap(node_def.mutable_attr());
|
||||||
|
|
||||||
const tensorflow::OpRegistrationData* op_reg_data;
|
const tensorflow::OpRegistrationData* op_reg_data;
|
||||||
status->status =
|
status->status =
|
||||||
|
@ -45,6 +45,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/path.h"
|
||||||
|
#include "tensorflow/core/platform/resource_loader.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||||
@ -193,8 +195,9 @@ TEST(CAPI, LibraryLoadFunctions) {
|
|||||||
{
|
{
|
||||||
// Load the library.
|
// Load the library.
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
TF_Library* lib =
|
string lib_path = tensorflow::GetDataDependencyFilepath(
|
||||||
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
|
tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
|
||||||
|
TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
|
||||||
TF_Code code = TF_GetCode(status);
|
TF_Code code = TF_GetCode(status);
|
||||||
string status_msg(TF_Message(status));
|
string status_msg(TF_Message(status));
|
||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
@ -1350,9 +1353,9 @@ TEST_F(CApiColocationTest, ClearViaProto) {
|
|||||||
|
|
||||||
TEST(CAPI, SavedModel) {
|
TEST(CAPI, SavedModel) {
|
||||||
// Load the saved model.
|
// Load the saved model.
|
||||||
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
|
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
|
||||||
const string saved_model_dir = tensorflow::io::JoinPath(
|
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||||
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
|
"half_plus_two", "00000123"));
|
||||||
TF_SessionOptions* opt = TF_NewSessionOptions();
|
TF_SessionOptions* opt = TF_NewSessionOptions();
|
||||||
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
|
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
|
||||||
TF_Buffer* metagraph = TF_NewBuffer();
|
TF_Buffer* metagraph = TF_NewBuffer();
|
||||||
@ -1426,9 +1429,9 @@ TEST(CAPI, SavedModel) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, SavedModelNullArgsAreValid) {
|
TEST(CAPI, SavedModelNullArgsAreValid) {
|
||||||
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
|
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
|
||||||
const string saved_model_dir = tensorflow::io::JoinPath(
|
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||||
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
|
"half_plus_two", "00000123"));
|
||||||
TF_SessionOptions* opt = TF_NewSessionOptions();
|
TF_SessionOptions* opt = TF_NewSessionOptions();
|
||||||
TF_Status* s = TF_NewStatus();
|
TF_Status* s = TF_NewStatus();
|
||||||
const char* tags[] = {tensorflow::kSavedModelTagServe};
|
const char* tags[] = {tensorflow::kSavedModelTagServe};
|
||||||
|
@ -28,6 +28,8 @@ tf_cuda_library(
|
|||||||
"c_api_debug.cc",
|
"c_api_debug.cc",
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"operation_interface.cc",
|
||||||
|
"operation_interface.h",
|
||||||
"tensor_handle_interface.h",
|
"tensor_handle_interface.h",
|
||||||
],
|
],
|
||||||
hdrs = ["c_api.h"],
|
hdrs = ["c_api.h"],
|
||||||
@ -56,6 +58,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
"//tensorflow/core/platform:errors",
|
"//tensorflow/core/platform:errors",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
@ -92,6 +95,7 @@ filegroup(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"operation_interface.h",
|
||||||
"tensor_handle_interface.h",
|
"tensor_handle_interface.h",
|
||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
@ -104,6 +108,7 @@ tf_cuda_library(
|
|||||||
name = "c_api_internal",
|
name = "c_api_internal",
|
||||||
srcs = [
|
srcs = [
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
|
"operation_interface.h",
|
||||||
"tensor_handle_interface.h",
|
"tensor_handle_interface.h",
|
||||||
],
|
],
|
||||||
hdrs = ["c_api_internal.h"],
|
hdrs = ["c_api_internal.h"],
|
||||||
@ -128,6 +133,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
|
"@com_google_absl//absl/container:fixed_array",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -199,6 +205,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -256,8 +263,6 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/distributed_runtime:remote_device",
|
"//tensorflow/core/distributed_runtime:remote_device",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
"//tensorflow/core/distributed_runtime:worker_env",
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
"//tensorflow/core/profiler/rpc:profiler_server",
|
|
||||||
"//tensorflow/core/profiler/rpc/client:capture_profile",
|
|
||||||
"//tensorflow/core:gpu_runtime",
|
"//tensorflow/core:gpu_runtime",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/container/fixed_array.h"
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
@ -95,14 +94,6 @@ using tensorflow::string;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
|
||||||
const tensorflow::OpDef* op_def = op->operation.OpDef();
|
|
||||||
if (op_def) return op_def;
|
|
||||||
status->status =
|
|
||||||
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
|
|
||||||
return op_def;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsCPU(
|
bool IsCPU(
|
||||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
|
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
|
||||||
if (VariantDeviceIsCustom(variant)) {
|
if (VariantDeviceIsCustom(variant)) {
|
||||||
@ -883,12 +874,12 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
#if defined(IS_MOBILE_PLATFORM)
|
#if defined(IS_MOBILE_PLATFORM)
|
||||||
status->status = tensorflow::Status::OK();
|
status->status = tensorflow::Status::OK();
|
||||||
#else // !defined(IS_MOBILE_PLATFORM)
|
#else // !defined(IS_MOBILE_PLATFORM)
|
||||||
status->status = ctx->context->ClearRemoteExecutors();
|
status->status = ctx->context->SyncExecutors();
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1074,6 +1065,10 @@ AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
|
|||||||
return new TensorHandleInterface(handle_);
|
return new TensorHandleInterface(handle_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void tensorflow::TensorHandleInterface::EnableImplicitMirroring() {
|
||||||
|
handle_->EnableImplicitMirroring();
|
||||||
|
}
|
||||||
|
|
||||||
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
if (h == nullptr) {
|
if (h == nullptr) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
@ -1121,9 +1116,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
|||||||
return retval;
|
return retval;
|
||||||
} else {
|
} else {
|
||||||
tensorflow::Tensor tensor;
|
tensorflow::Tensor tensor;
|
||||||
if (IsCPU(handle_->device())) {
|
if (IsCPU(handle_->device()) || handle_->HasLocalMirror(nullptr)) {
|
||||||
const tensorflow::Tensor* src = nullptr;
|
const tensorflow::Tensor* src = nullptr;
|
||||||
|
if (handle_->HasLocalMirror(nullptr)) {
|
||||||
|
*status = handle_->TensorFromDevice(nullptr, &src);
|
||||||
|
} else {
|
||||||
*status = handle_->Tensor(&src);
|
*status = handle_->Tensor(&src);
|
||||||
|
}
|
||||||
if (!status->ok()) return nullptr;
|
if (!status->ok()) return nullptr;
|
||||||
tensor = *src;
|
tensor = *src;
|
||||||
} else {
|
} else {
|
||||||
@ -1131,6 +1130,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
|||||||
CHECK_NE(ctx, nullptr);
|
CHECK_NE(ctx, nullptr);
|
||||||
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
|
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
|
||||||
if (!status->ok()) return nullptr;
|
if (!status->ok()) return nullptr;
|
||||||
|
if (handle_->ImplicitMirroring()) {
|
||||||
|
*status = handle_->AddEmptyLocalMirror(nullptr);
|
||||||
|
if (!status->ok()) return nullptr;
|
||||||
|
Tensor mirror = tensor;
|
||||||
|
*status = handle_->SetTensor(std::move(mirror), nullptr);
|
||||||
|
if (!status->ok()) return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||||
}
|
}
|
||||||
@ -1195,31 +1201,23 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
|
|
||||||
!tensorflow::DataTypeCanUseMemcpy(
|
|
||||||
static_cast<tensorflow::DataType>(dtype))) {
|
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
|
||||||
"Trying to create a tensor with a pointer to non-pod memory.");
|
|
||||||
deallocator(data, len, deallocator_arg);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
||||||
// the device?
|
// the device?
|
||||||
TF_ManagedBuffer* buf =
|
TF_ManagedBuffer* buf =
|
||||||
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
||||||
|
/*owns_memory=*/false);
|
||||||
|
|
||||||
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
||||||
tensorflow::TensorShape(dimvec), buf);
|
tensorflow::TensorShape(dimvec), buf);
|
||||||
buf->Unref();
|
buf->Unref();
|
||||||
tensorflow::TensorHandle* ret_handle;
|
tensorflow::TensorHandle* ret_handle;
|
||||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant_device;
|
|
||||||
if (custom_device == nullptr) {
|
if (custom_device == nullptr) {
|
||||||
variant_device = device;
|
|
||||||
} else {
|
|
||||||
variant_device = custom_device;
|
|
||||||
}
|
|
||||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||||
t, variant_device, context, &ret_handle);
|
t, device, context, &ret_handle);
|
||||||
|
} else {
|
||||||
|
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||||
|
t, custom_device, context, &ret_handle);
|
||||||
|
}
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -1258,9 +1256,8 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
|||||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
std::unique_ptr<TFE_Op> new_op(
|
std::unique_ptr<TFE_Op> new_op(
|
||||||
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
|
new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)});
|
||||||
status->status =
|
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
|
||||||
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
|
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
new_op.reset();
|
new_op.reset();
|
||||||
}
|
}
|
||||||
@ -1270,49 +1267,51 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
|||||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||||
|
|
||||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||||
status->status = op->operation.SetDeviceName(device_name);
|
status->status = op->operation->SetDeviceName(device_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||||
tensorflow::Device* device = (op->operation.Device() == nullptr)
|
return op->operation->DeviceName().c_str();
|
||||||
? op->operation.EagerContext().HostCPU()
|
|
||||||
: op->operation.Device();
|
|
||||||
return device->name().c_str();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||||
op->operation.SetUseXla(enable);
|
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||||
#ifndef TENSORFLOW_EAGER_USE_XLA
|
tensorflow::Status s = op->operation->SetUseXla(enable);
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
|
||||||
|
}
|
||||||
|
#else
|
||||||
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
|
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
|
||||||
"built with XLA support.";
|
"built with XLA support.";
|
||||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||||
tensorflow::TensorHandle* h =
|
status->status = op->operation->AddInput(input->handle);
|
||||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
|
||||||
input->handle.get())
|
|
||||||
->Handle();
|
|
||||||
op->operation.AddInput(h);
|
|
||||||
status->status = op->operation.MaybeInferSingleInputAttrs(h);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
|
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||||
|
num_inputs);
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
op->operation.AddInput(
|
handles[i].reset(inputs[i]->handle->Copy());
|
||||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
|
||||||
inputs[i]->handle.get())
|
|
||||||
->Handle());
|
|
||||||
}
|
}
|
||||||
status->status = op->operation.InferInputListAttrs(num_inputs);
|
status->status = op->operation->AddInputList(handles);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||||
unsigned char* is_list, TF_Status* status) {
|
unsigned char* is_list, TF_Status* status) {
|
||||||
TF_AttrType ret = TF_ATTR_INT;
|
TF_AttrType ret = TF_ATTR_INT;
|
||||||
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
|
const tensorflow::AttrTypeMap* attr_types_;
|
||||||
attr_name, &ret, is_list);
|
bool is_function;
|
||||||
|
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
|
||||||
|
&attr_types_, &is_function);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
status->status =
|
||||||
|
tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1333,221 +1332,169 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
|
|||||||
|
|
||||||
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
||||||
size_t length) {
|
size_t length) {
|
||||||
op->operation.MutableAttrs()->Set(
|
auto s = op->operation->SetAttrString(
|
||||||
attr_name,
|
attr_name, static_cast<const char*>(value), length);
|
||||||
tensorflow::StringPiece(static_cast<const char*>(value), length));
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
||||||
op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
|
auto s = op->operation->SetAttrInt(attr_name, value);
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
||||||
op->operation.MutableAttrs()->Set(attr_name, value);
|
auto s = op->operation->SetAttrFloat(attr_name, value);
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
||||||
op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
|
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
||||||
op->operation.MutableAttrs()->Set(attr_name,
|
auto s = op->operation->SetAttrType(attr_name, value);
|
||||||
static_cast<tensorflow::DataType>(value));
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
||||||
const int num_dims, TF_Status* out_status) {
|
const int num_dims, TF_Status* out_status) {
|
||||||
if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
|
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
|
||||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
|
||||||
tensorflow::strings::StrCat(
|
|
||||||
"Value specified for `", attr_name, "` has ", num_dims,
|
|
||||||
" dimensions which is over the limit of ",
|
|
||||||
tensorflow::TensorShape::MaxDimensions(), ".")
|
|
||||||
.c_str());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
tensorflow::TensorShapeProto proto;
|
|
||||||
if (num_dims < 0) {
|
|
||||||
proto.set_unknown_rank(true);
|
|
||||||
} else {
|
|
||||||
for (int d = 0; d < num_dims; ++d) {
|
|
||||||
proto.add_dim()->set_size(dims[d]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
op->operation.MutableAttrs()->Set(attr_name, proto);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
||||||
const TFE_Op* value) {
|
const TFE_Op* value) {
|
||||||
tensorflow::AttrValue attr_value;
|
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
|
||||||
tensorflow::NameAttrList* func = attr_value.mutable_func();
|
if (!s.ok()) {
|
||||||
func->set_name(value->operation.Name());
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
|
}
|
||||||
op->operation.MutableAttrs()->Set(attr_name, attr_value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
||||||
const char* data, size_t length) {
|
const char* data, size_t length) {
|
||||||
tensorflow::AttrValue attr_value;
|
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
|
||||||
tensorflow::NameAttrList* func = attr_value.mutable_func();
|
if (!s.ok()) {
|
||||||
func->set_name(data, length);
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
op->operation.MutableAttrs()->Set(attr_name, attr_value);
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::Tensor t;
|
status->status = op->operation->SetAttrTensor(attr_name, tensor);
|
||||||
status->status = TF_TensorToTensor(tensor, &t);
|
|
||||||
if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||||
const void* const* values, const size_t* lengths,
|
const void* const* values, const size_t* lengths,
|
||||||
int num_values) {
|
int num_values) {
|
||||||
std::vector<tensorflow::StringPiece> v(num_values);
|
auto s =
|
||||||
for (int i = 0; i < num_values; ++i) {
|
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||||
v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
|
if (!s.ok()) {
|
||||||
lengths[i]);
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
}
|
}
|
||||||
op->operation.MutableAttrs()->Set(attr_name, v);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||||
const float* values, int num_values) {
|
const float* values, int num_values) {
|
||||||
op->operation.MutableAttrs()->Set(
|
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
|
||||||
attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||||
const int64_t* values, int num_values) {
|
const int64_t* values, int num_values) {
|
||||||
op->operation.MutableAttrs()->Set(
|
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
|
||||||
attr_name, tensorflow::gtl::ArraySlice<const int64>(
|
if (!s.ok()) {
|
||||||
reinterpret_cast<const int64*>(values), num_values));
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||||
const TF_DataType* values, int num_values) {
|
const TF_DataType* values, int num_values) {
|
||||||
op->operation.MutableAttrs()->Set(
|
auto s = op->operation->SetAttrTypeList(attr_name, values, num_values);
|
||||||
attr_name,
|
if (!s.ok()) {
|
||||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||||
const unsigned char* values, int num_values) {
|
const unsigned char* values, int num_values) {
|
||||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
|
||||||
for (int i = 0; i < num_values; ++i) {
|
if (!s.ok()) {
|
||||||
b[i] = values[i];
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
}
|
}
|
||||||
op->operation.MutableAttrs()->Set(
|
|
||||||
attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||||
const int64_t** dims, const int* num_dims,
|
const int64_t** dims, const int* num_dims,
|
||||||
int num_values, TF_Status* out_status) {
|
int num_values, TF_Status* out_status) {
|
||||||
std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
|
out_status->status =
|
||||||
new tensorflow::TensorShapeProto[num_values]);
|
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||||
for (int i = 0; i < num_values; ++i) {
|
|
||||||
const auto num_dims_i = num_dims[i];
|
|
||||||
|
|
||||||
if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
|
|
||||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
|
||||||
tensorflow::strings::StrCat(
|
|
||||||
"Value specified for `", attr_name, "` has ", num_dims_i,
|
|
||||||
" dimensions which is over the limit of ",
|
|
||||||
tensorflow::TensorShape::MaxDimensions(), ".")
|
|
||||||
.c_str());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (num_dims_i < 0) {
|
|
||||||
proto[i].set_unknown_rank(true);
|
|
||||||
} else {
|
|
||||||
const int64_t* dims_i = dims[i];
|
|
||||||
auto proto_i = &proto[i];
|
|
||||||
for (int d = 0; d < num_dims_i; ++d) {
|
|
||||||
proto_i->add_dim()->set_size(dims_i[d]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
op->operation.MutableAttrs()->Set(
|
|
||||||
attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
|
|
||||||
proto.get(), num_values));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||||
const TFE_Op** value, int num_values) {
|
const TFE_Op** value, int num_values) {
|
||||||
std::unique_ptr<tensorflow::NameAttrList[]> funcs(
|
auto s = op->operation->SetAttrFunctionList(attr_name, value, num_values);
|
||||||
new tensorflow::NameAttrList[num_values]);
|
if (!s.ok()) {
|
||||||
for (int i = 0; i < num_values; i++) {
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||||
funcs[i].set_name(value[i]->operation.Name());
|
|
||||||
value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
|
|
||||||
}
|
}
|
||||||
op->operation.MutableAttrs()->Set(
|
}
|
||||||
attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
|
|
||||||
funcs.get(), num_values));
|
void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
|
||||||
|
const void* proto, size_t proto_len,
|
||||||
|
TF_Status* status) {
|
||||||
|
tensorflow::AttrValue attr_value;
|
||||||
|
if (!attr_value.ParseFromArray(proto, proto_len)) {
|
||||||
|
status->status =
|
||||||
|
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (op == nullptr || op->operation == nullptr) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Got a null or uninitialized `op` argument");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
op->operation.get());
|
||||||
|
operation->MutableAttrs()->Set(attr_name, attr_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
||||||
const char* input_name,
|
const char* input_name,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
const tensorflow::OpDef* op_def = GetOpDef(op, status);
|
int ret = -1;
|
||||||
if (!status->status.ok()) {
|
status->status = op->operation->InputLength(input_name, &ret);
|
||||||
return -1;
|
return ret;
|
||||||
}
|
|
||||||
tensorflow::AttrValueMap attrs;
|
|
||||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
|
||||||
tensorflow::NameRangeMap name_ranges;
|
|
||||||
status->status = tensorflow::NameRangesForNode(
|
|
||||||
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
auto iter = name_ranges.find(input_name);
|
|
||||||
if (iter == name_ranges.end()) {
|
|
||||||
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
|
|
||||||
"' not found");
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
return iter->second.second - iter->second.first;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||||
const char* output_name,
|
const char* output_name,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
const tensorflow::OpDef* op_def = GetOpDef(op, status);
|
int ret = -1;
|
||||||
if (!status->status.ok()) {
|
status->status = op->operation->OutputLength(output_name, &ret);
|
||||||
return -1;
|
return ret;
|
||||||
}
|
|
||||||
tensorflow::AttrValueMap attrs;
|
|
||||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
|
||||||
tensorflow::NameRangeMap name_ranges;
|
|
||||||
status->status = tensorflow::NameRangesForNode(
|
|
||||||
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
auto iter = name_ranges.find(output_name);
|
|
||||||
if (iter == name_ranges.end()) {
|
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
|
||||||
"Output '", output_name, "' not found");
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
return iter->second.second - iter->second.first;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
*num_retvals);
|
||||||
status->status = tensorflow::EagerExecute(&op->operation,
|
status->status = op->operation->Execute(&handles, num_retvals);
|
||||||
handle_retvals.data(), num_retvals);
|
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < *num_retvals; ++i) {
|
for (int i = 0; i < *num_retvals; ++i) {
|
||||||
retvals[i] = new TFE_TensorHandle{
|
retvals[i] = new TFE_TensorHandle{std::move(handles[i])};
|
||||||
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1675,6 +1622,31 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
|
|||||||
|
|
||||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
||||||
|
|
||||||
|
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||||
|
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
op->operation.get());
|
||||||
|
*attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||||
|
tensorflow::AttrValueMap m;
|
||||||
|
attrs->attributes->FillAttrValueMap(&m);
|
||||||
|
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
op->operation.get());
|
||||||
|
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||||
|
for (auto attribute : m) {
|
||||||
|
destination->Set(attribute.first, attribute.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||||
|
TF_Status* status) {
|
||||||
|
tensorflow::NameAttrList name_and_attrs;
|
||||||
|
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||||
|
name_and_attrs.set_name(attrs->name);
|
||||||
|
status->status = MessageToBuffer(name_and_attrs, buf);
|
||||||
|
}
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||||
const tensorflow::AttrValue& default_value,
|
const tensorflow::AttrValue& default_value,
|
||||||
@ -1794,10 +1766,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
|||||||
op->Inputs()[i])});
|
op->Inputs()[i])});
|
||||||
}
|
}
|
||||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||||
// TODO(allenl): figure out how to get attrs from EagerOperation
|
|
||||||
TF_Status status;
|
TF_Status status;
|
||||||
|
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
|
||||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||||
num_retvals, outputs.data(), &status, info_);
|
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||||
if (status.status.ok()) {
|
if (status.status.ok()) {
|
||||||
for (int i = 0; i < *num_retvals; ++i) {
|
for (int i = 0; i < *num_retvals; ++i) {
|
||||||
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
|
@ -25,34 +25,20 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
|
||||||
#include "tensorflow/core/profiler/rpc/profiler_server.h"
|
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
|
|
||||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||||
const char* raw_device_name, TF_Status* status) {
|
const char* raw_device_name, TF_Status* status) {
|
||||||
if (op_to_reset) {
|
if (op_to_reset) {
|
||||||
status->status = op_to_reset->operation.Reset(
|
status->status =
|
||||||
op_or_function_name, raw_device_name, false, nullptr);
|
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||||
} else {
|
} else {
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||||
"op_to_reset should not be nullptr");
|
"op_to_reset should not be nullptr");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
|
||||||
op->operation.ConsumeInput(
|
|
||||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
|
||||||
->Handle());
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_StartProfilerServer(int port) {
|
|
||||||
// Release child thread intentionally. The child thread can be terminated by
|
|
||||||
// terminating the main thread.
|
|
||||||
tensorflow::StartProfilerServer(port).release();
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||||
ctx->context->SetShouldStoreGraphs(true);
|
ctx->context->SetShouldStoreGraphs(true);
|
||||||
}
|
}
|
||||||
@ -61,46 +47,6 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
|||||||
ctx->context->SetShouldStoreGraphs(false);
|
ctx->context->SetShouldStoreGraphs(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
|
||||||
const char* logdir, const char* worker_list,
|
|
||||||
bool include_dataset_ops, int duration_ms,
|
|
||||||
int num_tracing_attempts,
|
|
||||||
TF_Status* status) {
|
|
||||||
tensorflow::Status s =
|
|
||||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
|
||||||
if (!s.ok()) {
|
|
||||||
Set_TF_Status_from_Status(status, s);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
|
|
||||||
include_dataset_ops, duration_ms,
|
|
||||||
num_tracing_attempts);
|
|
||||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
|
||||||
return s.ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
|
|
||||||
int monitoring_level, bool display_timestamp,
|
|
||||||
TF_Buffer* result, TF_Status* status) {
|
|
||||||
tensorflow::Status s =
|
|
||||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
|
||||||
if (!s.ok()) {
|
|
||||||
Set_TF_Status_from_Status(status, s);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
string content;
|
|
||||||
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
|
|
||||||
display_timestamp, &content);
|
|
||||||
void* data = tensorflow::port::Malloc(content.length());
|
|
||||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
|
||||||
result->data = data;
|
|
||||||
result->length = content.length();
|
|
||||||
result->data_deallocator = [](void* data, size_t length) {
|
|
||||||
tensorflow::port::Free(data);
|
|
||||||
};
|
|
||||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||||
int64_t value) {
|
int64_t value) {
|
||||||
cell->cell.IncrementBy(value);
|
cell->cell.IncrementBy(value);
|
||||||
@ -568,8 +514,7 @@ void TFE_DeleteCancellationManager(
|
|||||||
void TFE_OpSetCancellationManager(TFE_Op* op,
|
void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||||
TFE_CancellationManager* cancellation_manager,
|
TFE_CancellationManager* cancellation_manager,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
op->operation.SetCancellationManager(
|
status->status = op->operation->SetCancellationManager(cancellation_manager);
|
||||||
&cancellation_manager->cancellation_manager);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_Executor* TFE_NewExecutor(bool is_async) {
|
TFE_Executor* TFE_NewExecutor(bool is_async) {
|
||||||
@ -611,3 +556,28 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
|||||||
tensorflow::port::Free(data);
|
tensorflow::port::Free(data);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
|
||||||
|
TF_Status* status) {
|
||||||
|
h->handle->EnableImplicitMirroring();
|
||||||
|
status->status = tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||||
|
TF_Buffer* buf, TF_Status* status) {
|
||||||
|
auto* function_def = ctx->context->FindFunctionDef(function_name);
|
||||||
|
if (function_def == nullptr) {
|
||||||
|
status->status = tensorflow::errors::NotFound(
|
||||||
|
"Unable to find FunctionDef with name: ", function_name);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
string str = function_def->SerializeAsString();
|
||||||
|
void* data = tensorflow::port::Malloc(str.length());
|
||||||
|
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||||
|
buf->data = data;
|
||||||
|
buf->length = str.length();
|
||||||
|
buf->data_deallocator = [](void* data, size_t length) {
|
||||||
|
tensorflow::port::Free(data);
|
||||||
|
};
|
||||||
|
status->status = tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
@ -34,18 +34,6 @@ TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
|
|||||||
const char* raw_device_name,
|
const char* raw_device_name,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// Start a profiler grpc server which listens to specified port. It will start
|
|
||||||
// the server on its own thread. It can be shutdown by terminating tensorflow.
|
|
||||||
// It can be used in both Eager mode and graph mode. Creating multiple profiler
|
|
||||||
// server is allowed. The service defined in
|
|
||||||
// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
|
|
||||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
|
|
||||||
// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
|
|
||||||
|
|
||||||
// Enables only graph collection in RunMetadata on the functions executed from
|
// Enables only graph collection in RunMetadata on the functions executed from
|
||||||
// this context.
|
// this context.
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
|
TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
|
||||||
@ -54,29 +42,6 @@ TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
|
|||||||
// this context.
|
// this context.
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
|
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
|
||||||
|
|
||||||
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
|
||||||
// profiling and save the result into logdir which can be visualized by
|
|
||||||
// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set
|
|
||||||
// include_dataset_opts to false to profile longer traces. It will block the
|
|
||||||
// caller thread until receives tracing result.
|
|
||||||
// This API is designed for TensorBoard, for end user, please use
|
|
||||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
|
||||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
|
||||||
TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
|
|
||||||
const char* service_addr, const char* logdir, const char* worker_list,
|
|
||||||
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
|
||||||
// monitoring and return the result in a string. It will block the
|
|
||||||
// caller thread until receiving the monitoring result.
|
|
||||||
// This API is designed for TensorBoard, for end user, please use
|
|
||||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
|
||||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
|
|
||||||
const char* service_addr, int duration_ms, int monitoring_level,
|
|
||||||
bool display_timestamp, TF_Buffer* result, TF_Status* status);
|
|
||||||
|
|
||||||
// TODO(fishx): Move these monitoring APIs into a separate file.
|
// TODO(fishx): Move these monitoring APIs into a separate file.
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Monitoring Counter APIs.
|
// Monitoring Counter APIs.
|
||||||
@ -417,10 +382,18 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
const char* worker_name,
|
const char* worker_name,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
// Clear pending streaming requests and error statuses on remote executors.
|
// Sync pending nodes in local executors (including the context default executor
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
// and thread executors) and streaming requests to remote executors, and get the
|
||||||
|
// combined status.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
// If the TensorHandle is copied to another device as part of an op execution,
|
||||||
|
// the copy is destroyed after the op has executed. Enabling implicit mirroring
|
||||||
|
// causes the copy to be held as a mirror for the lifetime of the TensorHandle.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring(
|
||||||
|
TFE_TensorHandle*, TF_Status*);
|
||||||
|
|
||||||
// This function will block till the operation that produces `h` has
|
// This function will block till the operation that produces `h` has
|
||||||
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
||||||
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
||||||
@ -450,7 +423,42 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||||
TF_Buffer* buf);
|
TF_Buffer* buf);
|
||||||
|
|
||||||
#define TFE_CUSTOM_DEVICE_VERSION 0
|
// APIs for generically dealing with op attributes (e.g. when forwarding them
|
||||||
|
// through custom device implementations).
|
||||||
|
//
|
||||||
|
// TODO(allenl): Currently these are black boxes, but we should have some way to
|
||||||
|
// inspect values. This would let people e.g. copy over most attributes and then
|
||||||
|
// modify some based on their values.
|
||||||
|
|
||||||
|
// A reference to an op's name -> attribute mapping
|
||||||
|
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||||
|
|
||||||
|
// Fetch a struct with a reference to information about attributes of `op`.
|
||||||
|
//
|
||||||
|
// The `attrs` struct does not own any memory, and `op` must outlive it.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||||
|
|
||||||
|
// Add attributes in `attrs` to `op`.
|
||||||
|
//
|
||||||
|
// Does not overwrite or update existing attributes, but adds new ones.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
|
||||||
|
|
||||||
|
// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
|
||||||
|
// containing the op name and a map of its attributes.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
|
||||||
|
TF_Buffer* buf,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Set an op's attribute from a serialized AttrValue protocol buffer.
|
||||||
|
//
|
||||||
|
// Analogous to TF_SetAttrValueProto for building graph operations.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
|
||||||
|
const char* attr_name,
|
||||||
|
const void* proto,
|
||||||
|
size_t proto_len,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
#define TFE_CUSTOM_DEVICE_VERSION 1
|
||||||
|
|
||||||
// Struct to be filled in
|
// Struct to be filled in
|
||||||
typedef struct TFE_CustomDevice {
|
typedef struct TFE_CustomDevice {
|
||||||
@ -467,10 +475,10 @@ typedef struct TFE_CustomDevice {
|
|||||||
void* device_info);
|
void* device_info);
|
||||||
|
|
||||||
// Method to execute an operation.
|
// Method to execute an operation.
|
||||||
// TODO(allenl) figure out a generic way of passing attrs here
|
|
||||||
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
||||||
const char* operation_name, int* num_outputs,
|
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s,
|
||||||
|
void* device_info);
|
||||||
|
|
||||||
// Method to delete a device.
|
// Method to delete a device.
|
||||||
void (*delete_device)(void* device_info);
|
void (*delete_device)(void* device_info);
|
||||||
@ -501,6 +509,11 @@ typedef struct TFE_CustomDevice {
|
|||||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||||
const char* device_name, void* device_info);
|
const char* device_name, void* device_info);
|
||||||
|
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||||
|
const char* function_name,
|
||||||
|
TF_Buffer* buf,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -27,12 +27,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/operation_interface.h"
|
||||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
|
||||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
@ -89,7 +89,7 @@ struct TFE_TensorDebugInfo {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_Op {
|
struct TFE_Op {
|
||||||
tensorflow::EagerOperation operation;
|
std::unique_ptr<AbstractOperationInterface> operation;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_MonitoringCounterCell {
|
struct TFE_MonitoringCounterCell {
|
||||||
@ -236,4 +236,17 @@ struct TFE_Executor {
|
|||||||
tensorflow::EagerExecutor* unowned_executor;
|
tensorflow::EagerExecutor* unowned_executor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||||
|
// that sometimes do not require serialization.
|
||||||
|
struct TFE_OpAttrs {
|
||||||
|
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
|
||||||
|
|
||||||
|
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
|
||||||
|
const char* op_name)
|
||||||
|
: name(op_name), attributes(value) {}
|
||||||
|
|
||||||
|
const char* name;
|
||||||
|
const tensorflow::AttrBuilder* attributes;
|
||||||
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||||
@ -127,7 +128,7 @@ void TestRemoteExecute(bool async) {
|
|||||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||||
|
|
||||||
void TestRemoteExecuteSilentCopies(bool async) {
|
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||||
|
|
||||||
// This server def has the task index set to 0.
|
// This server def has the task index set to 0.
|
||||||
@ -166,10 +167,14 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
|||||||
auto* h1_task2 =
|
auto* h1_task2 =
|
||||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
// Handles are on task0 (local), and task2, but op is on task1.
|
// Handles are on task0 (local), and task2, but op is on task1.
|
||||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||||
|
if (remote) {
|
||||||
TFE_OpSetDevice(matmul, task1_name, status);
|
TFE_OpSetDevice(matmul, task1_name, status);
|
||||||
|
}
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
TFE_TensorHandle* retvals[1];
|
TFE_TensorHandle* retvals[1];
|
||||||
@ -177,6 +182,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
|||||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
// TODO(gjn): Add support for waiting on async local mirrors
|
||||||
|
if (!async) {
|
||||||
|
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
|
h1_task2->handle.get())
|
||||||
|
->Handle();
|
||||||
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
matmul->operation.get());
|
||||||
|
// The input handles should never change since they have been mirrored.
|
||||||
|
ASSERT_EQ(op->GetInput(1), remote_arg);
|
||||||
|
}
|
||||||
|
|
||||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
@ -213,9 +229,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
|||||||
worker_server2.release();
|
worker_server2.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
|
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||||
|
TestRemoteExecuteSilentCopies(false, true);
|
||||||
|
}
|
||||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||||
TestRemoteExecuteSilentCopies(true);
|
TestRemoteExecuteSilentCopies(true, true);
|
||||||
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||||
|
TestRemoteExecuteSilentCopies(false, false);
|
||||||
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||||
|
TestRemoteExecuteSilentCopies(true, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||||
|
@ -17,12 +17,15 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
@ -365,7 +368,8 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
|
|||||||
|
|
||||||
void TensorHandleSilentCopy(bool async,
|
void TensorHandleSilentCopy(bool async,
|
||||||
TFE_ContextDevicePlacementPolicy global_policy,
|
TFE_ContextDevicePlacementPolicy global_policy,
|
||||||
TFE_ContextDevicePlacementPolicy thread_policy) {
|
TFE_ContextDevicePlacementPolicy thread_policy,
|
||||||
|
bool cpu_op) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
@ -376,26 +380,49 @@ void TensorHandleSilentCopy(bool async,
|
|||||||
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
|
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
|
||||||
}
|
}
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
// Disable the test if no GPU is present.
|
// Disable the test if no GPU is present.
|
||||||
string gpu_device_name;
|
string gpu_device_name;
|
||||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||||
|
if (cpu_op) {
|
||||||
|
string cpu_device_name;
|
||||||
|
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||||
|
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status.get());
|
||||||
|
} else {
|
||||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
}
|
||||||
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
TFE_TensorHandle* retvals[1];
|
TFE_TensorHandle* retvals[1];
|
||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Validate if the input was replaced with a different TensorHandle
|
||||||
|
auto arg0 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
|
hcpu->handle.get())
|
||||||
|
->Handle();
|
||||||
|
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
|
hgpu->handle.get())
|
||||||
|
->Handle();
|
||||||
|
|
||||||
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
matmul->operation.get());
|
||||||
|
|
||||||
|
// The input handles should never change since they have been mirrored.
|
||||||
|
EXPECT_EQ(op->GetInput(0), arg0);
|
||||||
|
EXPECT_EQ(op->GetInput(1), arg1);
|
||||||
|
|
||||||
TFE_DeleteOp(matmul);
|
TFE_DeleteOp(matmul);
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
TFE_DeleteTensorHandle(hgpu);
|
TFE_DeleteTensorHandle(hgpu);
|
||||||
@ -411,19 +438,19 @@ void TensorHandleSilentCopy(bool async,
|
|||||||
}
|
}
|
||||||
TEST(CAPI, TensorHandleSilentCopy) {
|
TEST(CAPI, TensorHandleSilentCopy) {
|
||||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||||
TFE_DEVICE_PLACEMENT_SILENT);
|
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||||
}
|
}
|
||||||
TEST(CAPI, TensorHandleSilentCopyAsync) {
|
TEST(CAPI, TensorHandleSilentCopyAsync) {
|
||||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
|
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
|
||||||
TFE_DEVICE_PLACEMENT_SILENT);
|
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||||
}
|
}
|
||||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
|
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
|
||||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||||
TFE_DEVICE_PLACEMENT_SILENT);
|
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||||
}
|
}
|
||||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
|
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
|
||||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||||
TFE_DEVICE_PLACEMENT_SILENT);
|
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetAndGetOpDevices(bool async) {
|
void SetAndGetOpDevices(bool async) {
|
||||||
@ -559,6 +586,91 @@ TEST(CAPI, TensorHandleDevices) {
|
|||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExecuteAdd(bool async, bool forward_input) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
|
||||||
|
// If a GPU exists, copy the handle to GPU so that we can exercise
|
||||||
|
// unprotecting a mirror.
|
||||||
|
std::string gpu_device_name;
|
||||||
|
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||||
|
TFE_TensorHandle* n_gpu =
|
||||||
|
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
|
||||||
|
TFE_DeleteTensorHandle(n);
|
||||||
|
n = n_gpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
|
||||||
|
|
||||||
|
// Store pointer to raw buffer for validation of forwarding behaviour.
|
||||||
|
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
|
||||||
|
void* orig_ptr = TF_TensorData(orig);
|
||||||
|
TF_DeleteTensor(orig);
|
||||||
|
|
||||||
|
TFE_Op* add_op = AddOp(ctx, n, m);
|
||||||
|
std::string cpu_device_name;
|
||||||
|
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||||
|
TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
if (forward_input) {
|
||||||
|
TFE_DeleteTensorHandle(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_retvals = 1;
|
||||||
|
|
||||||
|
if (async) {
|
||||||
|
// Enqueue dummy ops so we backlog async execution & actually test async.
|
||||||
|
for (int i = 0; i < 10000; ++i) {
|
||||||
|
TFE_TensorHandle* dummy = nullptr;
|
||||||
|
TFE_Execute(add_op, &dummy, &num_retvals, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteTensorHandle(dummy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* retval = nullptr;
|
||||||
|
TFE_Execute(add_op, &retval, &num_retvals, status);
|
||||||
|
EXPECT_EQ(1, num_retvals);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
if (!forward_input) {
|
||||||
|
TFE_DeleteTensorHandle(n);
|
||||||
|
}
|
||||||
|
TFE_DeleteOp(add_op);
|
||||||
|
|
||||||
|
TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
|
||||||
|
if (forward_input || async) {
|
||||||
|
EXPECT_EQ(orig_ptr, TF_TensorData(t));
|
||||||
|
} else {
|
||||||
|
EXPECT_NE(orig_ptr, TF_TensorData(t));
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteTensorHandle(m);
|
||||||
|
TFE_DeleteTensorHandle(retval);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
float result[100 * 100] = {0};
|
||||||
|
EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
|
||||||
|
memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
for (int i = 0; i < 100 * 100; ++i) {
|
||||||
|
EXPECT_EQ(2.0f, result[i]);
|
||||||
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
|
||||||
|
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
|
||||||
|
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
|
||||||
|
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
|
||||||
|
|
||||||
void Execute_MatMul_CPU(bool async) {
|
void Execute_MatMul_CPU(bool async) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
@ -1197,6 +1309,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
|||||||
TFE_DeleteTensorHandle(h_shares_tensor);
|
TFE_DeleteTensorHandle(h_shares_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
|
||||||
|
tensorflow::AttrValueMap attr_values;
|
||||||
|
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
|
||||||
|
->Attrs()
|
||||||
|
.FillAttrValueMap(&attr_values);
|
||||||
|
return attr_values;
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
@ -1213,8 +1333,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
|||||||
TFE_OpAddInput(minOp, axis, status);
|
TFE_OpAddInput(minOp, axis, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
tensorflow::AttrValueMap attr_values;
|
tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
|
||||||
minOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
|
||||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||||
EXPECT_NE(attr_found, attr_values.cend());
|
EXPECT_NE(attr_found, attr_values.cend());
|
||||||
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
||||||
@ -1253,8 +1372,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
|
|||||||
TFE_OpAddInputList(concatOp, inputs, 2, status);
|
TFE_OpAddInputList(concatOp, inputs, 2, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
tensorflow::AttrValueMap attr_values;
|
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
|
||||||
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
|
||||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||||
EXPECT_NE(attr_found, attr_values.cend());
|
EXPECT_NE(attr_found, attr_values.cend());
|
||||||
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
||||||
@ -1294,8 +1412,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
|
|||||||
TFE_OpAddInputList(assertOp, data, 3, status);
|
TFE_OpAddInputList(assertOp, data, 3, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
tensorflow::AttrValueMap attr_values;
|
tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
|
||||||
assertOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
|
||||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||||
EXPECT_NE(attr_found, attr_values.cend());
|
EXPECT_NE(attr_found, attr_values.cend());
|
||||||
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
|
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
|
||||||
@ -1331,16 +1448,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
|||||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||||
TFE_OpAddInput(concatOp, dim, status);
|
TFE_OpAddInput(concatOp, dim, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
CHECK(concatOp->operation.OpDef());
|
CHECK(concatOp->operation->OpDef());
|
||||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
EXPECT_FALSE(concatOp->operation.OpDef())
|
EXPECT_FALSE(concatOp->operation->OpDef())
|
||||||
<< "Inference context is still present";
|
<< "Inference context is still present";
|
||||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
tensorflow::AttrValueMap attr_values;
|
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
|
||||||
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
|
||||||
EXPECT_EQ(attr_values.find("T"), attr_values.end());
|
EXPECT_EQ(attr_values.find("T"), attr_values.end());
|
||||||
EXPECT_EQ(attr_values.find("N"), attr_values.end());
|
EXPECT_EQ(attr_values.find("N"), attr_values.end());
|
||||||
|
|
||||||
@ -1427,4 +1543,88 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
|||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||||
|
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||||
|
TFE_OpAttrs attributes;
|
||||||
|
TFE_OpGetAttrs(var_op, &attributes);
|
||||||
|
|
||||||
|
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
||||||
|
TFE_OpAddAttrs(copy_op, &attributes);
|
||||||
|
unsigned char is_list = 0;
|
||||||
|
ASSERT_EQ(TF_ATTR_TYPE,
|
||||||
|
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
ASSERT_EQ(TF_ATTR_SHAPE,
|
||||||
|
TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
tensorflow::AttrValueMap attr_values;
|
||||||
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
copy_op->operation.get());
|
||||||
|
op->Attrs().FillAttrValueMap(&attr_values);
|
||||||
|
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_DeleteOp(var_op);
|
||||||
|
TFE_DeleteOp(copy_op);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||||
|
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||||
|
TFE_OpAttrs attributes;
|
||||||
|
TFE_OpGetAttrs(var_op, &attributes);
|
||||||
|
|
||||||
|
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||||
|
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
tensorflow::NameAttrList name_and_attrs;
|
||||||
|
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
|
||||||
|
serialized_attr_values->length));
|
||||||
|
ASSERT_EQ("VarHandleOp", name_and_attrs.name());
|
||||||
|
ASSERT_EQ(tensorflow::DT_INT64,
|
||||||
|
name_and_attrs.attr().find("dtype")->second.type());
|
||||||
|
TF_DeleteBuffer(serialized_attr_values);
|
||||||
|
|
||||||
|
TFE_Op* second_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
|
||||||
|
string serialized_dtype;
|
||||||
|
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
|
||||||
|
&serialized_dtype));
|
||||||
|
TFE_OpSetAttrValueProto(
|
||||||
|
second_var_op, "dtype",
|
||||||
|
reinterpret_cast<const void*>(serialized_dtype.c_str()),
|
||||||
|
serialized_dtype.length(), status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
tensorflow::AttrValueMap attr_values;
|
||||||
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
second_var_op->operation.get());
|
||||||
|
op->Attrs().FillAttrValueMap(&attr_values);
|
||||||
|
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_DeleteOp(var_op);
|
||||||
|
TFE_DeleteOp(second_var_op);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() {
|
|||||||
return th;
|
return th;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
|
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(op, a, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(op, b, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
|
||||||
|
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
|
@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
|
|||||||
// Return a tensor handle containing a 3x2 matrix of floats
|
// Return a tensor handle containing a 3x2 matrix of floats
|
||||||
TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
||||||
|
|
||||||
|
// Return an add op multiplying `a` by `b`.
|
||||||
|
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||||
|
|
||||||
// Return a matmul op multiplying `a` by `b`.
|
// Return a matmul op multiplying `a` by `b`.
|
||||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -31,6 +32,8 @@ struct LoggingDevice {
|
|||||||
tensorflow::string underlying_device;
|
tensorflow::string underlying_device;
|
||||||
// Set to true whenever a TensorHandle is copied onto the device
|
// Set to true whenever a TensorHandle is copied onto the device
|
||||||
bool* arrived_flag;
|
bool* arrived_flag;
|
||||||
|
// Set to true whenever an operation is executed
|
||||||
|
bool* executed_flag;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LoggedTensor {
|
struct LoggedTensor {
|
||||||
@ -81,12 +84,14 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||||
const char* operation_name, int* num_outputs,
|
const char* operation_name,
|
||||||
|
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||||
TFE_TensorHandle** outputs, TF_Status* s,
|
TFE_TensorHandle** outputs, TF_Status* s,
|
||||||
void* device_info) {
|
void* device_info) {
|
||||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||||
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
TFE_OpAddAttrs(op, attributes);
|
||||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||||
for (int j = 0; j < num_inputs; ++j) {
|
for (int j = 0; j < num_inputs; ++j) {
|
||||||
TFE_TensorHandle* input = inputs[j];
|
TFE_TensorHandle* input = inputs[j];
|
||||||
@ -115,6 +120,7 @@ void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
|||||||
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
|
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
|
||||||
std::move(logged_tensor), s);
|
std::move(logged_tensor), s);
|
||||||
}
|
}
|
||||||
|
*(dev->executed_flag) = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeleteLoggingDevice(void* device_info) {
|
void DeleteLoggingDevice(void* device_info) {
|
||||||
@ -122,7 +128,7 @@ void DeleteLoggingDevice(void* device_info) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||||
bool* arrived_flag) {
|
bool* arrived_flag, bool* executed_flag) {
|
||||||
TFE_CustomDevice custom_device;
|
TFE_CustomDevice custom_device;
|
||||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||||
@ -131,6 +137,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
|||||||
LoggingDevice* device = new LoggingDevice;
|
LoggingDevice* device = new LoggingDevice;
|
||||||
device->ctx = context;
|
device->ctx = context;
|
||||||
device->arrived_flag = arrived_flag;
|
device->arrived_flag = arrived_flag;
|
||||||
|
device->executed_flag = executed_flag;
|
||||||
device->device_name = name;
|
device->device_name = name;
|
||||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
TFE_RegisterCustomDevice(context, custom_device, name, device);
|
TFE_RegisterCustomDevice(context, custom_device, name, device);
|
||||||
@ -144,13 +151,15 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
|||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
bool arrived = false;
|
bool arrived = false;
|
||||||
|
bool executed = false;
|
||||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
RegisterLoggingDevice(context, name, &arrived);
|
RegisterLoggingDevice(context, name, &arrived, &executed);
|
||||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||||
ASSERT_FALSE(arrived);
|
ASSERT_FALSE(arrived);
|
||||||
TFE_TensorHandle* hdevice =
|
TFE_TensorHandle* hdevice =
|
||||||
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
|
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
|
||||||
ASSERT_TRUE(arrived);
|
ASSERT_TRUE(arrived);
|
||||||
|
ASSERT_FALSE(executed);
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
||||||
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
|
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
|
||||||
@ -160,6 +169,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
|||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
|
||||||
TFE_DeleteTensorHandle(retval);
|
TFE_DeleteTensorHandle(retval);
|
||||||
TFE_DeleteTensorHandle(hcpu);
|
TFE_DeleteTensorHandle(hcpu);
|
||||||
@ -167,4 +177,118 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
|||||||
TFE_DeleteContext(context);
|
TFE_DeleteContext(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CUSTOM_DEVICE, ResetOperation) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
bool arrived = false;
|
||||||
|
bool executed = false;
|
||||||
|
const char* custom_device_name =
|
||||||
|
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed);
|
||||||
|
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
|
||||||
|
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
|
||||||
|
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
|
||||||
|
tensorflow::string(custom_device_name));
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpReset(reused_op.get(), "Identity",
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
|
||||||
|
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
bool arrived = false;
|
||||||
|
bool executed = false;
|
||||||
|
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
RegisterLoggingDevice(context.get(), name, &arrived, &executed);
|
||||||
|
|
||||||
|
// Create a variable handle placed on the custom device.
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||||
|
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||||
|
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_TensorHandle* var_handle = nullptr;
|
||||||
|
int num_retvals = 1;
|
||||||
|
executed = false;
|
||||||
|
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||||
|
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||||
|
|
||||||
|
// Assign to the variable, copying to the custom device.
|
||||||
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||||
|
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||||
|
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
executed = false;
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
|
||||||
|
// Read the variable's value.
|
||||||
|
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
executed = false;
|
||||||
|
num_retvals = 1;
|
||||||
|
TFE_TensorHandle* var_value = nullptr;
|
||||||
|
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
auto value_cleaner = tensorflow::gtl::MakeCleanup(
|
||||||
|
[var_value]() { TFE_DeleteTensorHandle(var_value); });
|
||||||
|
ASSERT_EQ(tensorflow::string(name),
|
||||||
|
tensorflow::string(
|
||||||
|
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||||
|
TFE_TensorHandle* var_value_unpacked =
|
||||||
|
reinterpret_cast<LoggedTensor*>(
|
||||||
|
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||||
|
->tensor;
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||||
|
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||||
|
TF_DeleteTensor);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
|
||||||
|
|
||||||
|
// Free the backing buffer for the variable.
|
||||||
|
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
312
tensorflow/c/eager/operation_interface.cc
Normal file
312
tensorflow/c/eager/operation_interface.cc
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/operation_interface.h"
|
||||||
|
|
||||||
|
#include "absl/container/fixed_array.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
OperationInterface::OperationInterface(TFE_Context* ctx)
|
||||||
|
: operation_(ctx->context) {}
|
||||||
|
|
||||||
|
const string& OperationInterface::DeviceName() const {
|
||||||
|
absl::variant<Device*, CustomDevice*> variant_device =
|
||||||
|
(operation_.Device() == kVariantDeviceNull)
|
||||||
|
? operation_.EagerContext().HostCPU()
|
||||||
|
: operation_.Device();
|
||||||
|
return absl::visit([](auto* d) -> const string& { return d->name(); },
|
||||||
|
variant_device);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetDeviceName(const char* name) {
|
||||||
|
return operation_.SetDeviceName(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrString(const char* attr_name,
|
||||||
|
const char* data, size_t length) {
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrType(const char* attr_name,
|
||||||
|
TF_DataType value) {
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrShape(const char* attr_name,
|
||||||
|
const int64_t* dims,
|
||||||
|
const int num_dims) {
|
||||||
|
if (num_dims > TensorShape::MaxDimensions()) {
|
||||||
|
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||||
|
num_dims,
|
||||||
|
" dimensions which is over the limit of ",
|
||||||
|
TensorShape::MaxDimensions(), ".");
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShapeProto proto;
|
||||||
|
if (num_dims < 0) {
|
||||||
|
proto.set_unknown_rank(true);
|
||||||
|
} else {
|
||||||
|
for (int d = 0; d < num_dims; ++d) {
|
||||||
|
proto.add_dim()->set_size(dims[d]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, proto);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrFunction(
|
||||||
|
const char* attr_name,
|
||||||
|
const std::unique_ptr<AbstractOperationInterface>& value) {
|
||||||
|
AttrValue attr_value;
|
||||||
|
NameAttrList* func = attr_value.mutable_func();
|
||||||
|
func->set_name(value->Name());
|
||||||
|
OperationInterface* value_operation =
|
||||||
|
tensorflow::down_cast<OperationInterface*>(value.get());
|
||||||
|
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
|
||||||
|
const char* data,
|
||||||
|
size_t length) {
|
||||||
|
AttrValue attr_value;
|
||||||
|
NameAttrList* func = attr_value.mutable_func();
|
||||||
|
func->set_name(data, length);
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrTensor(const char* attr_name,
|
||||||
|
TF_Tensor* tensor) {
|
||||||
|
Tensor t;
|
||||||
|
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, t);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrStringList(const char* attr_name,
|
||||||
|
const void* const* values,
|
||||||
|
const size_t* lengths,
|
||||||
|
int num_values) {
|
||||||
|
std::vector<StringPiece> v(num_values);
|
||||||
|
for (int i = 0; i < num_values; ++i) {
|
||||||
|
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||||
|
}
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, v);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrFloatList(const char* attr_name,
|
||||||
|
const float* values,
|
||||||
|
int num_values) {
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<const float>(values, num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrIntList(const char* attr_name,
|
||||||
|
const int64_t* values,
|
||||||
|
int num_values) {
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<const int64>(
|
||||||
|
reinterpret_cast<const int64*>(values), num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrTypeList(const char* attr_name,
|
||||||
|
const TF_DataType* values,
|
||||||
|
int num_values) {
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<const DataType>(
|
||||||
|
reinterpret_cast<const DataType*>(values), num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrBoolList(const char* attr_name,
|
||||||
|
const unsigned char* values,
|
||||||
|
int num_values) {
|
||||||
|
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||||
|
for (int i = 0; i < num_values; ++i) {
|
||||||
|
b[i] = values[i];
|
||||||
|
}
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrShapeList(const char* attr_name,
|
||||||
|
const int64_t** dims,
|
||||||
|
const int* num_dims,
|
||||||
|
int num_values) {
|
||||||
|
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||||
|
for (int i = 0; i < num_values; ++i) {
|
||||||
|
const auto num_dims_i = num_dims[i];
|
||||||
|
|
||||||
|
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||||
|
num_dims_i, " dimensions which is over the limit of ",
|
||||||
|
TensorShape::MaxDimensions(), "."));
|
||||||
|
}
|
||||||
|
if (num_dims_i < 0) {
|
||||||
|
proto[i].set_unknown_rank(true);
|
||||||
|
} else {
|
||||||
|
const int64_t* dims_i = dims[i];
|
||||||
|
auto proto_i = &proto[i];
|
||||||
|
for (int d = 0; d < num_dims_i; ++d) {
|
||||||
|
proto_i->add_dim()->set_size(dims_i[d]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
|
||||||
|
const TFE_Op** value,
|
||||||
|
int num_values) {
|
||||||
|
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
|
||||||
|
for (int i = 0; i < num_values; i++) {
|
||||||
|
auto value_operation =
|
||||||
|
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
|
||||||
|
funcs[i].set_name(value_operation->operation_.Name());
|
||||||
|
value_operation->operation_.Attrs().FillAttrValueMap(
|
||||||
|
funcs[i].mutable_attr());
|
||||||
|
}
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
const OpDef* OperationInterface::GetOpDef(Status* status) {
|
||||||
|
const tensorflow::OpDef* op_def = operation_.OpDef();
|
||||||
|
if (op_def) return op_def;
|
||||||
|
*status = OpDefForOp(Name(), &op_def);
|
||||||
|
return op_def;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::InputLength(const char* input_name, int* length) {
|
||||||
|
Status status;
|
||||||
|
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
AttrValueMap attrs;
|
||||||
|
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||||
|
NameRangeMap name_ranges;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
|
||||||
|
auto iter = name_ranges.find(input_name);
|
||||||
|
if (iter == name_ranges.end()) {
|
||||||
|
return errors::InvalidArgument("Input '", input_name, "' not found");
|
||||||
|
}
|
||||||
|
*length = iter->second.second - iter->second.first;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::OutputLength(const char* output_name, int* length) {
|
||||||
|
Status status;
|
||||||
|
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
AttrValueMap attrs;
|
||||||
|
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||||
|
NameRangeMap name_ranges;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
|
||||||
|
auto iter = name_ranges.find(output_name);
|
||||||
|
if (iter == name_ranges.end()) {
|
||||||
|
return errors::InvalidArgument("Output '", output_name, "' not found");
|
||||||
|
}
|
||||||
|
*length = iter->second.second - iter->second.first;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::AddInput(
|
||||||
|
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
|
||||||
|
TensorHandle* h =
|
||||||
|
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||||
|
operation_.AddInput(h);
|
||||||
|
return operation_.MaybeInferSingleInputAttrs(h);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::AddInputList(
|
||||||
|
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||||
|
inputs) {
|
||||||
|
for (auto& input : inputs) {
|
||||||
|
TensorHandle* h =
|
||||||
|
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||||
|
operation_.AddInput(h);
|
||||||
|
}
|
||||||
|
return operation_.InferInputListAttrs(inputs.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::Execute(
|
||||||
|
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||||
|
int* num_retvals) {
|
||||||
|
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
|
||||||
|
for (int i = 0; i < *num_retvals; ++i) {
|
||||||
|
retvals->at(i).reset(
|
||||||
|
new tensorflow::TensorHandleInterface(handle_retvals[i]));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetCancellationManager(
|
||||||
|
TFE_CancellationManager* cancellation_manager) {
|
||||||
|
operation_.SetCancellationManager(
|
||||||
|
&cancellation_manager->cancellation_manager);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetUseXla(bool enable) {
|
||||||
|
operation_.SetUseXla(enable);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
188
tensorflow/c/eager/operation_interface.h
Normal file
188
tensorflow/c/eager/operation_interface.h
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/container/fixed_array.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
|
|
||||||
|
// Abstract interface to an operation.
|
||||||
|
class AbstractOperationInterface {
|
||||||
|
public:
|
||||||
|
virtual ~AbstractOperationInterface() {}
|
||||||
|
|
||||||
|
virtual void Clear() = 0;
|
||||||
|
virtual tensorflow::Status Reset(const char* op,
|
||||||
|
const char* raw_device_name) = 0;
|
||||||
|
|
||||||
|
virtual const tensorflow::string& Name() const = 0;
|
||||||
|
virtual const tensorflow::string& DeviceName() const = 0;
|
||||||
|
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
|
||||||
|
|
||||||
|
virtual tensorflow::Status AddInput(
|
||||||
|
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
|
||||||
|
virtual tensorflow::Status AddInputList(
|
||||||
|
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||||
|
inputs) = 0;
|
||||||
|
virtual tensorflow::Status Execute(
|
||||||
|
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||||
|
int* num_retvals) = 0;
|
||||||
|
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||||
|
|
||||||
|
virtual tensorflow::Status SetAttrString(const char* attr_name,
|
||||||
|
const char* data, size_t length) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrInt(const char* attr_name,
|
||||||
|
int64_t value) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
|
||||||
|
float value) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrType(const char* attr_name,
|
||||||
|
TF_DataType value) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrShape(const char* attr_name,
|
||||||
|
const int64_t* dims,
|
||||||
|
const int num_dims) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrFunction(
|
||||||
|
const char* attr_name,
|
||||||
|
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
|
||||||
|
const char* value,
|
||||||
|
size_t length) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
|
||||||
|
TF_Tensor* tensor) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
|
||||||
|
const void* const* values,
|
||||||
|
const size_t* lengths,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
|
||||||
|
const float* values,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
|
||||||
|
const int64_t* values,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
|
||||||
|
const TF_DataType* values,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
|
||||||
|
const unsigned char* values,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
|
||||||
|
const int64_t** dims,
|
||||||
|
const int* num_dims,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
|
||||||
|
const TFE_Op** value,
|
||||||
|
int num_values) = 0;
|
||||||
|
|
||||||
|
virtual tensorflow::Status InputLength(const char* input_name,
|
||||||
|
int* length) = 0;
|
||||||
|
virtual tensorflow::Status OutputLength(const char* output_name,
|
||||||
|
int* length) = 0;
|
||||||
|
|
||||||
|
// Experimental
|
||||||
|
virtual tensorflow::Status SetUseXla(bool enable) {
|
||||||
|
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
|
||||||
|
}
|
||||||
|
virtual tensorflow::Status SetCancellationManager(
|
||||||
|
TFE_CancellationManager* cancellation_manager) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetCancellationManager not implemented");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class OpDef;
|
||||||
|
|
||||||
|
class OperationInterface : public AbstractOperationInterface {
|
||||||
|
public:
|
||||||
|
explicit OperationInterface(TFE_Context* ctx);
|
||||||
|
~OperationInterface() override{};
|
||||||
|
|
||||||
|
void Clear() override { operation_.Clear(); }
|
||||||
|
Status Reset(const char* op, const char* raw_device_name) override {
|
||||||
|
return operation_.Reset(op, raw_device_name, false, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
const string& Name() const override { return operation_.Name(); }
|
||||||
|
const string& DeviceName() const override;
|
||||||
|
Status SetDeviceName(const char* name) override;
|
||||||
|
|
||||||
|
Status AddInput(
|
||||||
|
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
|
||||||
|
Status AddInputList(
|
||||||
|
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||||
|
inputs) override;
|
||||||
|
Status Execute(
|
||||||
|
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||||
|
int* num_retvals) override;
|
||||||
|
const tensorflow::OpDef* OpDef() const override {
|
||||||
|
return operation_.OpDef();
|
||||||
|
};
|
||||||
|
|
||||||
|
Status SetAttrString(const char* attr_name, const char* data,
|
||||||
|
size_t length) override;
|
||||||
|
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||||
|
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||||
|
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||||
|
Status SetAttrType(const char* attr_name, TF_DataType value) override;
|
||||||
|
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||||
|
const int num_dims) override;
|
||||||
|
Status SetAttrFunction(
|
||||||
|
const char* attr_name,
|
||||||
|
const std::unique_ptr<AbstractOperationInterface>& value) override;
|
||||||
|
Status SetAttrFunctionName(const char* attr_name, const char* data,
|
||||||
|
size_t length) override;
|
||||||
|
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
|
||||||
|
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||||
|
const size_t* lengths, int num_values) override;
|
||||||
|
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||||
|
const int* num_dims, int num_values) override;
|
||||||
|
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
|
||||||
|
int num_values) override;
|
||||||
|
|
||||||
|
Status InputLength(const char* input_name, int* length) override;
|
||||||
|
Status OutputLength(const char* output_name, int* length) override;
|
||||||
|
|
||||||
|
Status SetUseXla(bool enable) override;
|
||||||
|
Status SetCancellationManager(
|
||||||
|
TFE_CancellationManager* cancellation_manager) override;
|
||||||
|
|
||||||
|
// TODO(gjn): Remove once TFE_InferShapes is removed
|
||||||
|
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
|
||||||
|
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
|
||||||
|
|
||||||
|
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const tensorflow::OpDef* GetOpDef(Status* status);
|
||||||
|
EagerOperation operation_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
@ -55,6 +55,14 @@ class AbstractTensorHandleInterface {
|
|||||||
|
|
||||||
// Return a copy of the handle.
|
// Return a copy of the handle.
|
||||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||||
|
|
||||||
|
// Maintain mirror tensors for any implicit copies to local devices. This
|
||||||
|
// setting is offered on a per tensor handle basis to avoid potential memory
|
||||||
|
// over utilization due to holding on to mirrors as well as the original
|
||||||
|
// tensor. Note this setting overrides the context mirroring policy whereby if
|
||||||
|
// the mirroring policy is MIRRORING_NONE, we will still continue to mirror
|
||||||
|
// this tensor.
|
||||||
|
virtual void EnableImplicitMirroring() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -77,6 +85,8 @@ class TensorHandleInterface : public AbstractTensorHandleInterface {
|
|||||||
|
|
||||||
AbstractTensorHandleInterface* Copy() override;
|
AbstractTensorHandleInterface* Copy() override;
|
||||||
|
|
||||||
|
void EnableImplicitMirroring() override;
|
||||||
|
|
||||||
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
||||||
// use cases.
|
// use cases.
|
||||||
TensorHandle* Handle() { return handle_; }
|
TensorHandle* Handle() { return handle_; }
|
||||||
|
@ -24,12 +24,16 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/device_base.h"
|
#include "tensorflow/core/framework/device_base.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
@ -64,25 +65,41 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
|
|||||||
}
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
|
||||||
|
const int64_t* dims, int num_dims, size_t len) {
|
||||||
|
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||||
|
for (int i = 0; i < num_dims; ++i) {
|
||||||
|
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||||
|
tensorflow::TensorInterface ret(
|
||||||
|
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||||
|
tensorflow::TensorShape(dimvec), buf));
|
||||||
|
buf->Unref();
|
||||||
|
size_t elem_size = TF_DataTypeSize(dtype);
|
||||||
|
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
|
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
|
||||||
int num_dims, size_t len) {
|
int num_dims, size_t len) {
|
||||||
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
|
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
|
||||||
tensorflow::cpu_allocator());
|
tensorflow::cpu_allocator());
|
||||||
return TF_NewTensor(dtype, dims, num_dims, data, len,
|
TF_ManagedBuffer* buf =
|
||||||
tensorflow::deallocate_buffer,
|
new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer,
|
||||||
tensorflow::cpu_allocator());
|
tensorflow::cpu_allocator(), /*owns_memory=*/true);
|
||||||
|
return CreateTensor(buf, dtype, dims, num_dims, len);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||||
void* data, size_t len,
|
void* data, size_t len,
|
||||||
void (*deallocator)(void* data, size_t len, void* arg),
|
void (*deallocator)(void* data, size_t len, void* arg),
|
||||||
void* deallocator_arg) {
|
void* deallocator_arg) {
|
||||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
|
||||||
for (int i = 0; i < num_dims; ++i) {
|
|
||||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ManagedBuffer* buf = nullptr;
|
TF_ManagedBuffer* buf = nullptr;
|
||||||
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
|
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
|
||||||
tensorflow::DataTypeCanUseMemcpy(
|
tensorflow::DataTypeCanUseMemcpy(
|
||||||
@ -97,24 +114,17 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
|||||||
// Other types have the same representation, so copy only if it is safe to
|
// Other types have the same representation, so copy only if it is safe to
|
||||||
// do so.
|
// do so.
|
||||||
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
|
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
|
||||||
len, tensorflow::deallocate_buffer, nullptr);
|
len, tensorflow::deallocate_buffer, nullptr,
|
||||||
|
/*owns_memory=*/true);
|
||||||
std::memcpy(buf->data(), data, len);
|
std::memcpy(buf->data(), data, len);
|
||||||
// Free the original buffer.
|
// Free the original buffer.
|
||||||
deallocator(data, len, deallocator_arg);
|
deallocator(data, len, deallocator_arg);
|
||||||
} else {
|
} else {
|
||||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
||||||
|
/*owns_memory=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
return CreateTensor(buf, dtype, dims, num_dims, len);
|
||||||
tensorflow::TensorInterface ret(
|
|
||||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
|
||||||
tensorflow::TensorShape(dimvec), buf));
|
|
||||||
buf->Unref();
|
|
||||||
size_t elem_size = TF_DataTypeSize(dtype);
|
|
||||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
||||||
|
@ -38,11 +38,12 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
|||||||
public:
|
public:
|
||||||
TF_ManagedBuffer(void* data, size_t len,
|
TF_ManagedBuffer(void* data, size_t len,
|
||||||
void (*deallocator)(void* data, size_t len, void* arg),
|
void (*deallocator)(void* data, size_t len, void* arg),
|
||||||
void* deallocator_arg)
|
void* deallocator_arg, bool owns_memory)
|
||||||
: TensorBuffer(data),
|
: TensorBuffer(data),
|
||||||
len_(len),
|
len_(len),
|
||||||
deallocator_(deallocator),
|
deallocator_(deallocator),
|
||||||
deallocator_arg_(deallocator_arg) {}
|
deallocator_arg_(deallocator_arg),
|
||||||
|
owns_memory_(owns_memory) {}
|
||||||
|
|
||||||
~TF_ManagedBuffer() override {
|
~TF_ManagedBuffer() override {
|
||||||
(*deallocator_)(data(), len_, deallocator_arg_);
|
(*deallocator_)(data(), len_, deallocator_arg_);
|
||||||
@ -57,13 +58,13 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
|||||||
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
|
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prevents input forwarding from mutating this buffer.
|
bool OwnsMemory() const override { return owns_memory_; }
|
||||||
bool OwnsMemory() const override { return false; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const size_t len_;
|
const size_t len_;
|
||||||
void (*const deallocator_)(void* data, size_t len, void* arg);
|
void (*const deallocator_)(void* data, size_t len, void* arg);
|
||||||
void* const deallocator_arg_;
|
void* const deallocator_arg_;
|
||||||
|
bool owns_memory_;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -68,6 +68,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/platform:resource_loader",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -224,3 +225,15 @@ filegroup(
|
|||||||
"testdata/VarsAndArithmeticObjectGraph/**",
|
"testdata/VarsAndArithmeticObjectGraph/**",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
exports_files(
|
||||||
|
glob([
|
||||||
|
"testdata/half_plus_two_pbtxt/**",
|
||||||
|
"testdata/half_plus_two_main_op/**",
|
||||||
|
"testdata/half_plus_two/**",
|
||||||
|
"testdata/half_plus_two_v2/**",
|
||||||
|
"testdata/x_plus_y_v2_debuginfo/**",
|
||||||
|
"testdata/CyclicModule/**",
|
||||||
|
"testdata/VarsAndArithmeticObjectGraph/**",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
@ -21,15 +21,22 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/path.h"
|
||||||
|
#include "tensorflow/core/platform/resource_loader.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kTestDataPbTxt[] =
|
string TestDataPbTxt() {
|
||||||
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
|
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||||
constexpr char kTestDataSharded[] =
|
"half_plus_two_pbtxt", "00000123");
|
||||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
}
|
||||||
|
|
||||||
|
string TestDataSharded() {
|
||||||
|
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||||
|
"half_plus_two", "00000123");
|
||||||
|
}
|
||||||
|
|
||||||
class ReaderTest : public ::testing::Test {
|
class ReaderTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
|
|||||||
TEST_F(ReaderTest, TagMatch) {
|
TEST_F(ReaderTest, TagMatch) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
|
||||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
&meta_graph_def));
|
&meta_graph_def));
|
||||||
CheckMetaGraphDef(meta_graph_def);
|
CheckMetaGraphDef(meta_graph_def);
|
||||||
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
|
|||||||
TEST_F(ReaderTest, NoTagMatch) {
|
TEST_F(ReaderTest, NoTagMatch) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
|
||||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||||
&meta_graph_def);
|
&meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
|
|||||||
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
|
||||||
Status st = ReadMetaGraphDefFromSavedModel(
|
Status st = ReadMetaGraphDefFromSavedModel(
|
||||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
|||||||
TEST_F(ReaderTest, PbtxtFormat) {
|
TEST_F(ReaderTest, PbtxtFormat) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
|
|
||||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
&meta_graph_def));
|
&meta_graph_def));
|
||||||
CheckMetaGraphDef(meta_graph_def);
|
CheckMetaGraphDef(meta_graph_def);
|
||||||
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
|
|||||||
TEST_F(ReaderTest, InvalidExportPath) {
|
TEST_F(ReaderTest, InvalidExportPath) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath("missing-path");
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
|
||||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
&meta_graph_def);
|
&meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
|
@ -20,9 +20,11 @@ from __future__ import print_function as _print_function
|
|||||||
|
|
||||||
import logging as _logging
|
import logging as _logging
|
||||||
import os as _os
|
import os as _os
|
||||||
|
import six as _six
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
|
||||||
from tensorflow.python.tools import module_util as _module_util
|
from tensorflow.python.tools import module_util as _module_util
|
||||||
|
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||||
|
|
||||||
# pylint: disable=g-bad-import-order
|
# pylint: disable=g-bad-import-order
|
||||||
|
|
||||||
@ -36,20 +38,19 @@ try:
|
|||||||
from tensorboard.summary._tf import summary
|
from tensorboard.summary._tf import summary
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||||
# Make sure we get the correct summary module with lazy loading
|
|
||||||
setattr(_current_module, "summary", summary)
|
setattr(_current_module, "summary", summary)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_logging.warning(
|
_logging.warning(
|
||||||
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
||||||
"installation.")
|
"installation.")
|
||||||
|
|
||||||
try:
|
# Lazy-load estimator.
|
||||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
|
||||||
_current_module.__path__ = (
|
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||||
|
if _module_dir:
|
||||||
|
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||||
setattr(_current_module, "estimator", estimator)
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tensorflow.python.keras.api._v2 import keras
|
from tensorflow.python.keras.api._v2 import keras
|
||||||
@ -59,6 +60,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
if not _six.PY2:
|
||||||
|
import typing as _typing
|
||||||
|
if _typing.TYPE_CHECKING:
|
||||||
|
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||||
|
# pylint: enable=g-import-not-at-top
|
||||||
|
|
||||||
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
|
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
|
||||||
#
|
#
|
||||||
|
@ -20,8 +20,10 @@ from __future__ import print_function as _print_function
|
|||||||
|
|
||||||
import os as _os
|
import os as _os
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
import six as _six
|
||||||
|
|
||||||
from tensorflow.python.tools import module_util as _module_util
|
from tensorflow.python.tools import module_util as _module_util
|
||||||
|
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||||
|
|
||||||
# pylint: disable=g-bad-import-order
|
# pylint: disable=g-bad-import-order
|
||||||
|
|
||||||
@ -31,13 +33,14 @@ from tensorflow.python.tools import module_util as _module_util
|
|||||||
|
|
||||||
# Hook external TensorFlow modules.
|
# Hook external TensorFlow modules.
|
||||||
_current_module = _sys.modules[__name__]
|
_current_module = _sys.modules[__name__]
|
||||||
try:
|
|
||||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
# Lazy-load estimator.
|
||||||
_current_module.__path__ = (
|
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||||
|
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||||
|
if _module_dir:
|
||||||
|
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||||
setattr(_current_module, "estimator", estimator)
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tensorflow.python.keras.api._v1 import keras
|
from tensorflow.python.keras.api._v1 import keras
|
||||||
@ -47,6 +50,14 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
if not _six.PY2:
|
||||||
|
import typing as _typing
|
||||||
|
if _typing.TYPE_CHECKING:
|
||||||
|
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||||
|
# pylint: enable=g-import-not-at-top
|
||||||
|
|
||||||
|
|
||||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||||
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||||
|
@ -84,6 +84,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:resource_loader",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:support", # fixdeps: keep
|
"@llvm-project//llvm:support", # fixdeps: keep
|
||||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/aot/codegen.h"
|
#include "tensorflow/compiler/aot/codegen.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/resource_loader.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -139,23 +141,40 @@ TEST_F(ParseCppClassTest, ParseFail) {
|
|||||||
|
|
||||||
static void CompareWithGoldenFile(
|
static void CompareWithGoldenFile(
|
||||||
const string& tensorflow_relative_golden_file_name,
|
const string& tensorflow_relative_golden_file_name,
|
||||||
const string& expected_contents) {
|
const string& expected_contents, bool ignore_cr) {
|
||||||
|
// Get rid of all CR characters, we may be running under windows.
|
||||||
|
string sanitized_expected_contents(expected_contents);
|
||||||
|
if (ignore_cr) {
|
||||||
|
sanitized_expected_contents.erase(
|
||||||
|
std::remove(sanitized_expected_contents.begin(),
|
||||||
|
sanitized_expected_contents.end(), '\r'),
|
||||||
|
sanitized_expected_contents.end());
|
||||||
|
}
|
||||||
|
|
||||||
// To update the golden file, flip update_golden to true and run the
|
// To update the golden file, flip update_golden to true and run the
|
||||||
// following:
|
// following:
|
||||||
// bazel test --test_strategy=local \
|
// bazel test --test_strategy=local \
|
||||||
// third_party/tensorflow/compiler/aot:codegen_test
|
// third_party/tensorflow/compiler/aot:codegen_test
|
||||||
const bool update_golden = false;
|
const bool update_golden = false;
|
||||||
const string golden_file_name = io::JoinPath(
|
string golden_file_name;
|
||||||
testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name);
|
|
||||||
|
|
||||||
if (update_golden) {
|
if (update_golden) {
|
||||||
|
golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||||
|
tensorflow_relative_golden_file_name);
|
||||||
TF_EXPECT_OK(
|
TF_EXPECT_OK(
|
||||||
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
|
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
golden_file_name =
|
||||||
|
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
||||||
string golden_file_contents;
|
string golden_file_contents;
|
||||||
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
||||||
&golden_file_contents));
|
&golden_file_contents));
|
||||||
|
if (ignore_cr) {
|
||||||
|
golden_file_contents.erase(std::remove(golden_file_contents.begin(),
|
||||||
|
golden_file_contents.end(), '\r'),
|
||||||
|
golden_file_contents.end());
|
||||||
|
}
|
||||||
EXPECT_EQ(golden_file_contents, expected_contents);
|
EXPECT_EQ(golden_file_contents, expected_contents);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -229,14 +248,18 @@ TEST(CodegenTest, Golden) {
|
|||||||
// The other fields in metadata_result are tested as part of the generated
|
// The other fields in metadata_result are tested as part of the generated
|
||||||
// header test.
|
// header test.
|
||||||
|
|
||||||
CompareWithGoldenFile("compiler/aot/codegen_test_o.golden",
|
// This specific golden test checks a binary file. It can potentially run into
|
||||||
metadata_result.object_file_data);
|
// issues due to ABIs not being stable, but has not so far.
|
||||||
|
// If we see any ABI issues, we should reconsider this specific test case.
|
||||||
|
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
|
||||||
|
metadata_result.object_file_data, false);
|
||||||
|
|
||||||
string header;
|
string header;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
GenerateHeader(opts, config, compile_result, metadata_result, &header));
|
GenerateHeader(opts, config, compile_result, metadata_result, &header));
|
||||||
|
|
||||||
CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header);
|
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
|
||||||
|
true);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tfcompile
|
} // namespace tfcompile
|
||||||
|
@ -14,6 +14,10 @@ package_group(
|
|||||||
includes = [
|
includes = [
|
||||||
"//tensorflow/compiler/tf2xla:internal",
|
"//tensorflow/compiler/tf2xla:internal",
|
||||||
],
|
],
|
||||||
|
packages = [
|
||||||
|
"//tensorflow/compiler/tests/...",
|
||||||
|
"//tensorflow/python/...",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
package_group(
|
package_group(
|
||||||
|
@ -676,12 +676,10 @@ Status Encapsulator::Subgraph::AddFunctionCallNode(
|
|||||||
Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
|
Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
|
||||||
AttrSlice attrs = node->attrs();
|
AttrSlice attrs = node->attrs();
|
||||||
attr->clear();
|
attr->clear();
|
||||||
bool found_group_attribute = false;
|
|
||||||
for (const auto& node_attr : attrs) {
|
for (const auto& node_attr : attrs) {
|
||||||
if (node_attr.first == group_attribute_) {
|
if (node_attr.first == group_attribute_) {
|
||||||
TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
|
TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
|
||||||
*attr = node_attr.second.s();
|
*attr = node_attr.second.s();
|
||||||
found_group_attribute = true;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -790,7 +788,6 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
|
|||||||
|
|
||||||
TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
|
TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
|
||||||
TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
|
TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
|
||||||
|
|
||||||
MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
|
MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
|
||||||
|
|
||||||
for (auto& entry : subgraphs_) {
|
for (auto& entry : subgraphs_) {
|
||||||
|
@ -108,7 +108,7 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
|
|||||||
"(LRN, LRNGrad)."
|
"(LRN, LRNGrad)."
|
||||||
" BN: TF FusedBatchNorm* operations."
|
" BN: TF FusedBatchNorm* operations."
|
||||||
" FUSIBLE: All TF operations that XLA can fuse (All the above). "
|
" FUSIBLE: All TF operations that XLA can fuse (All the above). "
|
||||||
"You can also put any TF operation name, e.g. 'FUSIBLE,Matmul'."),
|
"You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
|
||||||
Flag("tf_xla_clustering_debug",
|
Flag("tf_xla_clustering_debug",
|
||||||
&mark_for_compilation_flags->tf_xla_clustering_debug,
|
&mark_for_compilation_flags->tf_xla_clustering_debug,
|
||||||
"Dump graphs during XLA compilation."),
|
"Dump graphs during XLA compilation."),
|
||||||
|
@ -20,6 +20,7 @@ XLA_OPS_DEPS = [
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
|
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
@ -41,6 +42,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
@ -206,12 +208,14 @@ se::DeviceMemoryAllocator* GetAllocator(
|
|||||||
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||||
const std::vector<int>& constants,
|
const std::vector<int>& constants,
|
||||||
const std::vector<int>& resources,
|
const std::vector<int>& resources,
|
||||||
const NameAttrList& function)
|
const NameAttrList& function,
|
||||||
|
bool has_ref_vars)
|
||||||
: OpKernel(ctx),
|
: OpKernel(ctx),
|
||||||
constants_(constants),
|
constants_(constants),
|
||||||
resources_(resources),
|
resources_(resources),
|
||||||
function_(function),
|
function_(function),
|
||||||
platform_info_(PlatformInfoFromContext(ctx)) {}
|
platform_info_(PlatformInfoFromContext(ctx)),
|
||||||
|
has_ref_vars_(has_ref_vars) {}
|
||||||
|
|
||||||
static Status BuildCompilationCache(OpKernelContext* ctx,
|
static Status BuildCompilationCache(OpKernelContext* ctx,
|
||||||
const XlaPlatformInfo& platform_info,
|
const XlaPlatformInfo& platform_info,
|
||||||
@ -350,8 +354,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
Status s = CompileToLocalExecutable(
|
Status s = CompileToLocalExecutable(
|
||||||
ctx, function_, /*has_ref_vars=*/true, platform_info_, resources_,
|
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||||
constants_, /*lazy=*/false, &client, &variables, &kernel, &executable);
|
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
|
||||||
|
&executable);
|
||||||
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
|
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
|
||||||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
|
platform_info_.device_type().type_string() == DEVICE_GPU)) {
|
||||||
// Suggest auto jit if the failure was with GPU or CPU.
|
// Suggest auto jit if the failure was with GPU or CPU.
|
||||||
@ -384,6 +389,18 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
run_options.set_allocator(allocator);
|
run_options.set_allocator(allocator);
|
||||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||||
run_options.set_rng_seed(GetXLARandomSeed());
|
run_options.set_rng_seed(GetXLARandomSeed());
|
||||||
|
xla::ThenExecuteFunction then_execute;
|
||||||
|
if (ctx->op_device_context()) {
|
||||||
|
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
|
||||||
|
Status status = ctx->op_device_context()->ThenExecute(
|
||||||
|
down_cast<Device*>(ctx->device()), stream, std::move(fn));
|
||||||
|
if (!status.ok()) {
|
||||||
|
// This should never happen.
|
||||||
|
LOG(ERROR) << "ThenExecute failed " << status;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
run_options.set_then_execute_function(&then_execute);
|
||||||
|
}
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
auto start_time = env->NowMicros();
|
auto start_time = env->NowMicros();
|
||||||
|
|
||||||
@ -462,7 +479,7 @@ bool HasRefVars(OpKernelConstruction* ctx) {
|
|||||||
|
|
||||||
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
|
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
|
||||||
: XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
|
: XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
|
||||||
FunctionAttr(ctx)) {}
|
FunctionAttr(ctx), /*has_ref_vars=*/true) {}
|
||||||
|
|
||||||
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
|
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
|
||||||
VLOG(1) << "XlaLocalLaunchOp destroyed";
|
VLOG(1) << "XlaLocalLaunchOp destroyed";
|
||||||
@ -592,6 +609,18 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
|||||||
run_options.set_allocator(allocator);
|
run_options.set_allocator(allocator);
|
||||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||||
run_options.set_rng_seed(GetXLARandomSeed());
|
run_options.set_rng_seed(GetXLARandomSeed());
|
||||||
|
xla::ThenExecuteFunction then_execute;
|
||||||
|
if (ctx->op_device_context()) {
|
||||||
|
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
|
||||||
|
Status status = ctx->op_device_context()->ThenExecute(
|
||||||
|
down_cast<Device*>(ctx->device()), stream, std::move(fn));
|
||||||
|
if (!status.ok()) {
|
||||||
|
// This should never happen.
|
||||||
|
LOG(ERROR) << "ThenExecute failed " << status;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
run_options.set_then_execute_function(&then_execute);
|
||||||
|
}
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
auto start_time = env->NowMicros();
|
auto start_time = env->NowMicros();
|
||||||
|
|
||||||
|
@ -95,12 +95,15 @@ class XlaPlatformInfo {
|
|||||||
// in the GraphDef.
|
// in the GraphDef.
|
||||||
// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
|
// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
|
||||||
// this kernel when asked to create a kernel for an XLA-compiled function.
|
// this kernel when asked to create a kernel for an XLA-compiled function.
|
||||||
|
//
|
||||||
|
// `has_ref_vars`: whether the input computation can have reference variables.
|
||||||
|
// TODO(cheshire): instead derive this information from the input graph.
|
||||||
class XlaLocalLaunchBase : public OpKernel {
|
class XlaLocalLaunchBase : public OpKernel {
|
||||||
public:
|
public:
|
||||||
XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||||
const std::vector<int>& constants,
|
const std::vector<int>& constants,
|
||||||
const std::vector<int>& resources,
|
const std::vector<int>& resources,
|
||||||
const NameAttrList& function);
|
const NameAttrList& function, bool has_ref_vars);
|
||||||
XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
|
XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
|
||||||
XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
|
XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
|
||||||
~XlaLocalLaunchBase() override = default;
|
~XlaLocalLaunchBase() override = default;
|
||||||
@ -115,6 +118,8 @@ class XlaLocalLaunchBase : public OpKernel {
|
|||||||
|
|
||||||
const NameAttrList function_;
|
const NameAttrList function_;
|
||||||
const XlaPlatformInfo platform_info_;
|
const XlaPlatformInfo platform_info_;
|
||||||
|
|
||||||
|
bool has_ref_vars_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
|
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
|
||||||
|
@ -963,6 +963,22 @@ absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
|
|||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns true iff the attribute `attr_name` is attached to either the node or
|
||||||
|
// to it's callee.
|
||||||
|
static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
|
||||||
|
const char* attr_name) {
|
||||||
|
bool out = false;
|
||||||
|
bool attr_value;
|
||||||
|
if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
|
||||||
|
out |= attr_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
|
||||||
|
out |= attr_value;
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||||
auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
|
auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
|
||||||
return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
|
return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
|
||||||
@ -1016,16 +1032,9 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
|||||||
resource_var_operation_node_id = node->id();
|
resource_var_operation_node_id = node->id();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_xla_compile_attr_true = false;
|
bool is_xla_compile_attr_true =
|
||||||
|
GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
|
||||||
bool xla_compile_attr;
|
GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr);
|
||||||
if (TryGetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr)) {
|
|
||||||
is_xla_compile_attr_true |= xla_compile_attr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (flib_def_->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr).ok()) {
|
|
||||||
is_xla_compile_attr_true |= xla_compile_attr;
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceSet devices;
|
DeviceSet devices;
|
||||||
devices.Insert(device);
|
devices.Insert(device);
|
||||||
@ -1874,6 +1883,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
|||||||
"EmptyTensorList",
|
"EmptyTensorList",
|
||||||
"ExtractImagePatches",
|
"ExtractImagePatches",
|
||||||
"Igamma",
|
"Igamma",
|
||||||
|
"IgammaGradA",
|
||||||
|
"RandomGammaGrad",
|
||||||
"Igammac",
|
"Igammac",
|
||||||
"FFT",
|
"FFT",
|
||||||
"FFT2D",
|
"FFT2D",
|
||||||
@ -1996,6 +2007,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
|||||||
"StatelessRandomNormal",
|
"StatelessRandomNormal",
|
||||||
"StatelessRandomUniform",
|
"StatelessRandomUniform",
|
||||||
"StatelessRandomUniformInt",
|
"StatelessRandomUniformInt",
|
||||||
|
"StatelessRandomUniformFullInt",
|
||||||
"StatelessTruncatedNormal",
|
"StatelessTruncatedNormal",
|
||||||
"StatelessWhile",
|
"StatelessWhile",
|
||||||
"Svd",
|
"Svd",
|
||||||
|
@ -20,15 +20,17 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
bool XlaKernelCreator::CanCreateKernel(
|
||||||
const NodeDef& node_def) const {
|
const FunctionLibraryRuntime& flr,
|
||||||
return CanCreateXlaKernel(node_def);
|
const std::shared_ptr<const NodeProperties>& props) const {
|
||||||
|
return CanCreateXlaKernel(props->node_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
Status XlaKernelCreator::CreateKernel(
|
||||||
const NodeDef& node_def,
|
FunctionLibraryRuntime* flr,
|
||||||
|
const std::shared_ptr<const NodeProperties>& props,
|
||||||
std::unique_ptr<OpKernel>* kernel) const {
|
std::unique_ptr<OpKernel>* kernel) const {
|
||||||
return CreateXlaKernel(flr, node_def, kernel);
|
return CreateXlaKernel(flr, props->node_def, kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
|
|||||||
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
||||||
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
||||||
// with the kXlaCompileAttr set.
|
// with the kXlaCompileAttr set.
|
||||||
bool CanCreateKernel(const FunctionLibraryRuntime& flr,
|
bool CanCreateKernel(
|
||||||
const NodeDef& node_def) const override;
|
const FunctionLibraryRuntime& flr,
|
||||||
|
const std::shared_ptr<const NodeProperties>& props) const override;
|
||||||
|
|
||||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||||
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
Status CreateKernel(FunctionLibraryRuntime* flr,
|
||||||
|
const std::shared_ptr<const NodeProperties>& props,
|
||||||
std::unique_ptr<OpKernel>* kernel) const override;
|
std::unique_ptr<OpKernel>* kernel) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -30,10 +30,12 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
NodeDef ToNodeDef(const string& text) {
|
std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
|
||||||
NodeDef node_def;
|
NodeDef node_def;
|
||||||
|
DataTypeVector dummy;
|
||||||
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
|
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
|
||||||
return node_def;
|
return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
|
||||||
|
dummy);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a FunctionDef that takes one resource and one regular param
|
// Create a FunctionDef that takes one resource and one regular param
|
||||||
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
|
|||||||
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||||
Init({fdef});
|
Init({fdef});
|
||||||
XlaKernelCreator xla_kernel_creator;
|
XlaKernelCreator xla_kernel_creator;
|
||||||
NodeDef callsite =
|
auto callsite =
|
||||||
ToNodeDef(R"pb(
|
ToNodeProperties(R"pb(
|
||||||
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
|
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
|
||||||
)pb");
|
)pb");
|
||||||
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
(*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
|
||||||
|
|
||||||
// Note: need to set attribute on the created node.
|
// Note: need to set attribute on the created node.
|
||||||
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
|
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
|
||||||
@ -127,7 +129,8 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
|
|||||||
Init({fdef});
|
Init({fdef});
|
||||||
XlaKernelCreator xla_kernel_creator;
|
XlaKernelCreator xla_kernel_creator;
|
||||||
|
|
||||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
Status status =
|
||||||
|
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||||
name: 'XTimesY'
|
name: 'XTimesY'
|
||||||
op: 'XTimesY'
|
op: 'XTimesY'
|
||||||
input: 'a'
|
input: 'a'
|
||||||
@ -143,7 +146,8 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
|||||||
Init({fdef});
|
Init({fdef});
|
||||||
XlaKernelCreator xla_kernel_creator;
|
XlaKernelCreator xla_kernel_creator;
|
||||||
|
|
||||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
Status status =
|
||||||
|
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||||
name: 'XTimesY'
|
name: 'XTimesY'
|
||||||
op: 'XTimesY'
|
op: 'XTimesY'
|
||||||
input: 'a'
|
input: 'a'
|
||||||
|
@ -104,7 +104,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
|||||||
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
||||||
/*compile_time_const_nodes=*/nullptr, flr));
|
/*compile_time_const_nodes=*/nullptr, flr));
|
||||||
|
|
||||||
for (int i = 0; i < const_args.size(); ++i) {
|
for (size_t i = 0; i < const_args.size(); ++i) {
|
||||||
if (const_args[i]) {
|
if (const_args[i]) {
|
||||||
constant_arg_indices->push_back(i);
|
constant_arg_indices->push_back(i);
|
||||||
}
|
}
|
||||||
@ -113,7 +113,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
|||||||
// There can be hundreds of resource variables. Reserve the space for them.
|
// There can be hundreds of resource variables. Reserve the space for them.
|
||||||
// We don't reserve for constants above as they are usually few.
|
// We don't reserve for constants above as they are usually few.
|
||||||
resource_arg_indices->reserve(arg_types.size());
|
resource_arg_indices->reserve(arg_types.size());
|
||||||
for (int i = 0; i < arg_types.size(); ++i) {
|
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||||
if (arg_types[i] == DT_RESOURCE) {
|
if (arg_types[i] == DT_RESOURCE) {
|
||||||
resource_arg_indices->push_back(i);
|
resource_arg_indices->push_back(i);
|
||||||
}
|
}
|
||||||
@ -177,7 +177,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
|||||||
// 214 variables and a similar number of activations.
|
// 214 variables and a similar number of activations.
|
||||||
SinglePassSearch constants_search(&constant_arg_indices);
|
SinglePassSearch constants_search(&constant_arg_indices);
|
||||||
SinglePassSearch resources_search(&resource_arg_indices);
|
SinglePassSearch resources_search(&resource_arg_indices);
|
||||||
for (int i = 0; i < fbody->arg_types.size(); ++i) {
|
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
|
||||||
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
||||||
// Compile-time constants and resource handles are expected to be in
|
// Compile-time constants and resource handles are expected to be in
|
||||||
// host memory.
|
// host memory.
|
||||||
@ -207,7 +207,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
|||||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||||
// in device memory except for resources.
|
// in device memory except for resources.
|
||||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
||||||
for (int i = 0; i < fbody->ret_types.size(); ++i) {
|
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
|
||||||
if (fbody->ret_types[i] == DT_RESOURCE) {
|
if (fbody->ret_types[i] == DT_RESOURCE) {
|
||||||
output_memory_types[i] = HOST_MEMORY;
|
output_memory_types[i] = HOST_MEMORY;
|
||||||
}
|
}
|
||||||
@ -218,15 +218,17 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
|||||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||||
Device* dev = flr->device();
|
Device* dev = flr->device();
|
||||||
Status s;
|
Status s;
|
||||||
OpKernelConstruction construction(
|
auto props = std::make_shared<NodeProperties>(
|
||||||
DeviceType(dev->device_type()), dev,
|
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
|
||||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
|
||||||
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
|
dev->GetAllocator(AllocatorAttributes()),
|
||||||
input_memory_types, fbody->ret_types, output_memory_types,
|
flr, dev->resource_manager(), props,
|
||||||
|
input_memory_types, output_memory_types,
|
||||||
flr->graph_def_version(), &s);
|
flr->graph_def_version(), &s);
|
||||||
|
|
||||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||||
&construction, constant_arg_indices, resource_arg_indices, function);
|
&construction, constant_arg_indices, resource_arg_indices, function,
|
||||||
|
/*has_ref_vars=*/false);
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -44,11 +44,9 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:AffineDialectRegistration",
|
"@llvm-project//mlir:AllPassesAndDialects",
|
||||||
"@llvm-project//mlir:LoopDialectRegistration",
|
|
||||||
"@llvm-project//mlir:MlirOptLib",
|
"@llvm-project//mlir:MlirOptLib",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir/test:TestTransforms",
|
"@llvm-project//mlir/test:TestTransforms",
|
||||||
],
|
],
|
||||||
@ -106,7 +104,9 @@ tf_cc_binary(
|
|||||||
name = "tf-opt",
|
name = "tf-opt",
|
||||||
deps = [
|
deps = [
|
||||||
":tf_mlir_opt_main",
|
":tf_mlir_opt_main",
|
||||||
|
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -116,8 +116,10 @@ tf_cc_binary(
|
|||||||
srcs = ["tf_mlir_translate_main.cc"],
|
srcs = ["tf_mlir_translate_main.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":init_mlir",
|
":init_mlir",
|
||||||
|
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
||||||
@ -129,6 +131,7 @@ tf_cc_binary(
|
|||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:AllPassesAndDialects",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:TranslateClParser",
|
"@llvm-project//mlir:TranslateClParser",
|
||||||
|
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# TensorFlow MLIR
|
||||||
|
|
||||||
|
These are the docs for: https://www.tensorflow.org/mlir
|
26
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
26
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
upper_tabs:
|
||||||
|
# Tabs left of dropdown menu
|
||||||
|
- include: /_upper_tabs_left.yaml
|
||||||
|
- include: /api_docs/_upper_tabs_api.yaml
|
||||||
|
# Dropdown menu
|
||||||
|
- name: Resources
|
||||||
|
path: /resources
|
||||||
|
is_default: true
|
||||||
|
menu:
|
||||||
|
- include: /resources/_menu_toc.yaml
|
||||||
|
lower_tabs:
|
||||||
|
# Subsite tabs
|
||||||
|
other:
|
||||||
|
- name: Guide
|
||||||
|
contents:
|
||||||
|
- title: Overview
|
||||||
|
path: /mlir/overview
|
||||||
|
- heading: Dialects
|
||||||
|
- title: Overview
|
||||||
|
path: /mlir/dialects
|
||||||
|
- title: TensorFlow
|
||||||
|
path: /mlir/tf_ops
|
||||||
|
- title: TensorFlow Lite
|
||||||
|
path: /mlir/tfl_ops
|
||||||
|
|
||||||
|
- include: /_upper_tabs_right.yaml
|
54
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
54
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
book_path: /mlir/_book.yaml
|
||||||
|
project_path: /mlir/_project.yaml
|
||||||
|
description: <!--no description-->
|
||||||
|
landing_page:
|
||||||
|
custom_css_path: /site-assets/css/style.css
|
||||||
|
rows:
|
||||||
|
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
|
||||||
|
items:
|
||||||
|
- description: >
|
||||||
|
The <a href="https://mlir.llvm.org/" class="external">MLIR</a> project defines a common
|
||||||
|
intermediate representation (IR) that unifies the infrastructure required to execute high
|
||||||
|
performance machine learning models in TensorFlow and similar ML frameworks. This project
|
||||||
|
will include the application of HPC techniques, along with integration of
|
||||||
|
search algorithms like reinforcement learning. MLIR aims to reduce the
|
||||||
|
cost to bring up new hardware, and improve usability for existing
|
||||||
|
TensorFlow users.
|
||||||
|
|
||||||
|
- code_block: |
|
||||||
|
<pre class = "prettyprint">
|
||||||
|
// Syntactically similar to LLVM:
|
||||||
|
func @testFunction(%arg0: i32) {
|
||||||
|
%x = call @thingToCall(%arg0) : (i32) -> i32
|
||||||
|
br ^bb1
|
||||||
|
^bb1:
|
||||||
|
%y = addi %x, %x : i32
|
||||||
|
return %y : i32
|
||||||
|
}
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
- classname: devsite-landing-row-cards
|
||||||
|
items:
|
||||||
|
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
|
||||||
|
youtube_id: qzljG6DKgic
|
||||||
|
buttons:
|
||||||
|
- label: Watch the video
|
||||||
|
path: https://www.youtube.com/watch?v=qzljG6DKgic
|
||||||
|
- heading: "A new intermediate representation and compiler framework"
|
||||||
|
image_path: /resources/images/tf-logo-card-16x9.png
|
||||||
|
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
|
||||||
|
buttons:
|
||||||
|
- label: Read on TensorFlow blog
|
||||||
|
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
|
||||||
|
- heading: MLIR on GitHub
|
||||||
|
image_path: /resources/images/github-card-16x9.png
|
||||||
|
path: https://github.com/llvm/llvm-project/tree/master/mlir
|
||||||
|
buttons:
|
||||||
|
- label: View on GitHub
|
||||||
|
path: https://github.com/llvm/llvm-project/tree/master/mlir
|
||||||
|
- heading: TensorFlow MLIR on GitHub
|
||||||
|
image_path: /resources/images/github-card-16x9.png
|
||||||
|
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
|
||||||
|
buttons:
|
||||||
|
- label: View on GitHub
|
||||||
|
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
|
37
tensorflow/compiler/mlir/g3doc/dialects.md
Normal file
37
tensorflow/compiler/mlir/g3doc/dialects.md
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# MLIR dialects
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
|
||||||
|
To separate different hardware and software targets, MLIR has “dialects”,
|
||||||
|
including:
|
||||||
|
|
||||||
|
* TensorFlow IR, which represents all things possible in TensorFlow graphs.
|
||||||
|
* XLA HLO IR, which is designed to take advantage of XLA’s compilation
|
||||||
|
abilities (with output to, among other things, TPUs).
|
||||||
|
* An experimental affine dialect, which focuses on
|
||||||
|
[polyhedral representations](https://en.wikipedia.org/wiki/Polytope_model)
|
||||||
|
and optimizations.
|
||||||
|
* LLVM IR, which has a 1:1 mapping between it and LLVM’s own representation,
|
||||||
|
allowing MLIR to emit GPU and CPU code through LLVM.
|
||||||
|
* TensorFlow Lite, which will translate to running code on mobile platforms.
|
||||||
|
|
||||||
|
Each dialect consists of a set of defined operations which have invariants
|
||||||
|
placed on them, like: “This is a binary operator, and the inputs and outputs
|
||||||
|
have the same types.”
|
||||||
|
|
||||||
|
## Adding to MLIR
|
||||||
|
|
||||||
|
MLIR has no fixed/built-in list of globally known operations (no “intrinsics”).
|
||||||
|
Dialects can define entirely custom types, which is how MLIR can model things
|
||||||
|
like the LLVM IR type system (which has first class aggregates), domain
|
||||||
|
abstractions important for ML-optimized accelerators like quantized types, and
|
||||||
|
even the Swift or Clang type systems (which are built around Swift/Clang
|
||||||
|
declaration nodes) in the future.
|
||||||
|
|
||||||
|
If you want to connect a new low-level compiler, you would create a new dialect
|
||||||
|
and the lowerings between the TensorFlow Graph dialect and your dialect.
|
||||||
|
This smooths the path for hardware and compiler makers. You can even target
|
||||||
|
dialects at different levels in the same model; the higher-level optimizers
|
||||||
|
will respect the unfamiliar parts of the IR and wait for a lower level to handle
|
||||||
|
it.
|
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 148 KiB |
36
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
36
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# MLIR
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
MLIR, or Multi-Level Intermediate Representation, is a representation format
|
||||||
|
and library of compiler utilities that sits between the model representation
|
||||||
|
and low-level compilers/executors that generate hardware-specific code.
|
||||||
|
|
||||||
|
MLIR is, at its heart, a flexible infrastructure for modern optimizing
|
||||||
|
compilers. This means it consists of a specification for intermediate
|
||||||
|
representations (IR) and a code toolkit to perform transformations on that
|
||||||
|
representation. (In compiler parlance, as you move from higher-level
|
||||||
|
representations to lower-level representations, these transformations can be
|
||||||
|
called “lowerings”)
|
||||||
|
|
||||||
|
MLIR is highly influenced by [LLVM](https://llvm.org/) and unabashedly reuses
|
||||||
|
many great ideas from it. It has a flexible type system, and allows
|
||||||
|
representing, analyzing and transforming graphs combining multiple levels of
|
||||||
|
abstraction in the same compilation unit. These abstractions include TensorFlow
|
||||||
|
operations, nested polyhedral loop regions, and even LLVM instructions and fixed
|
||||||
|
hardware operations and types.
|
||||||
|
|
||||||
|
We expect MLIR to be of interest to many groups, including:
|
||||||
|
|
||||||
|
* Compiler researchers and implementers looking to optimize performance and
|
||||||
|
memory consumption of machine learning models
|
||||||
|
* Hardware makers looking for a way to connect their hardware to TensorFlow,
|
||||||
|
such as TPUs, portable neural hardware in phones, and other custom ASICs
|
||||||
|
* People writing language bindings that want to take advantage of optimizing
|
||||||
|
compilers and hardware acceleration.
|
||||||
|
|
||||||
|
The TensorFlow ecosystem contains a number of compilers and optimizers that
|
||||||
|
operate at multiple levels of the software and hardware stack. We expect the
|
||||||
|
gradual adoption of MLIR to simplify every aspect of this stack.
|
||||||
|
|
||||||
|
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>
|
@ -48,10 +48,11 @@ def _run_lit_test(name, data, size, tags, driver, features):
|
|||||||
" the driver parameter when running this test. If you require" +
|
" the driver parameter when running this test. If you require" +
|
||||||
" custom driver support, please file an issue to request it.")
|
" custom driver support, please file an issue to request it.")
|
||||||
|
|
||||||
|
# Disable tests on windows for now, to enable testing rest of all xla and mlir.
|
||||||
native.py_test(
|
native.py_test(
|
||||||
name = name,
|
name = name,
|
||||||
srcs = ["@llvm-project//llvm:lit"],
|
srcs = ["@llvm-project//llvm:lit"],
|
||||||
tags = tags,
|
tags = tags + ["no_windows"],
|
||||||
args = [
|
args = [
|
||||||
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
|
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
|
||||||
] + features,
|
] + features,
|
||||||
|
@ -208,6 +208,7 @@ cc_library(
|
|||||||
"ir/tfl_ops.h.inc",
|
"ir/tfl_ops.h.inc",
|
||||||
"ir/tfl_ops_interface.cc.inc",
|
"ir/tfl_ops_interface.cc.inc",
|
||||||
"ir/tfl_ops_interface.h.inc",
|
"ir/tfl_ops_interface.h.inc",
|
||||||
|
"runtime_verifiers.inc",
|
||||||
"utils/attribute_utils.cc",
|
"utils/attribute_utils.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
@ -231,6 +232,7 @@ cc_library(
|
|||||||
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
|
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
@ -302,12 +304,14 @@ cc_library(
|
|||||||
"transforms/optimize_functional_ops.cc",
|
"transforms/optimize_functional_ops.cc",
|
||||||
"transforms/prepare_composite_functions_tf.cc",
|
"transforms/prepare_composite_functions_tf.cc",
|
||||||
"transforms/prepare_tf.cc",
|
"transforms/prepare_tf.cc",
|
||||||
|
"transforms/runtime_type_verify.cc",
|
||||||
"transforms/split_merged_operands.cc",
|
"transforms/split_merged_operands.cc",
|
||||||
"transforms/trim_functions_tf.cc",
|
"transforms/trim_functions_tf.cc",
|
||||||
"transforms/unroll_batch_matmul.cc",
|
"transforms/unroll_batch_matmul.cc",
|
||||||
"transforms/while_loop_outline.cc",
|
"transforms/while_loop_outline.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
"ir/tfl_ops_interface.h.inc",
|
||||||
"transforms/dilated_conv.h",
|
"transforms/dilated_conv.h",
|
||||||
"transforms/passes.h",
|
"transforms/passes.h",
|
||||||
"transforms/unroll_batch_matmul.h",
|
"transforms/unroll_batch_matmul.h",
|
||||||
@ -323,6 +327,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -459,9 +464,9 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_native_cc_binary(
|
tf_native_cc_binary(
|
||||||
name = "operator-converter-gen",
|
name = "converter-gen",
|
||||||
srcs = [
|
srcs = [
|
||||||
"operator_converter_gen.cc",
|
"converter_gen.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
@ -471,14 +476,18 @@ tf_native_cc_binary(
|
|||||||
)
|
)
|
||||||
|
|
||||||
gentbl(
|
gentbl(
|
||||||
name = "operator_converter_inc",
|
name = "converter_inc",
|
||||||
tbl_outs = [
|
tbl_outs = [
|
||||||
(
|
(
|
||||||
"", # This driver has no options.
|
"--gen-operator-converters",
|
||||||
"operator_converters.inc",
|
"operator_converters.inc",
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"--gen-runtime-verifiers",
|
||||||
|
"runtime_verifiers.inc",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
tblgen = ":operator-converter-gen",
|
tblgen = ":converter-gen",
|
||||||
td_file = "ir/tfl_ops.td",
|
td_file = "ir/tfl_ops.td",
|
||||||
td_srcs = [
|
td_srcs = [
|
||||||
":tensorflow_lite_ops_td_files",
|
":tensorflow_lite_ops_td_files",
|
||||||
@ -561,6 +570,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -581,8 +591,6 @@ cc_library(
|
|||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:QuantOps",
|
"@llvm-project//mlir:QuantOps",
|
||||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
|
||||||
"@llvm-project//mlir:StandardDialectRegistration",
|
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:Translation",
|
"@llvm-project//mlir:Translation",
|
||||||
@ -594,6 +602,7 @@ tf_cc_binary(
|
|||||||
name = "flatbuffer_translate",
|
name = "flatbuffer_translate",
|
||||||
deps = [
|
deps = [
|
||||||
":flatbuffer_translate_lib",
|
":flatbuffer_translate_lib",
|
||||||
|
"@llvm-project//mlir:LoopOpsTransforms",
|
||||||
"@llvm-project//mlir:MlirTranslateMain",
|
"@llvm-project//mlir:MlirTranslateMain",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -643,12 +652,14 @@ tf_cc_binary(
|
|||||||
"//tensorflow/compiler/mlir:init_mlir",
|
"//tensorflow/compiler/mlir:init_mlir",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:errors",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -696,7 +707,6 @@ cc_library(
|
|||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:QuantOps",
|
"@llvm-project//mlir:QuantOps",
|
||||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -730,7 +740,6 @@ cc_library(
|
|||||||
"@llvm-project//mlir:Parser",
|
"@llvm-project//mlir:Parser",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:QuantOps",
|
"@llvm-project//mlir:QuantOps",
|
||||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
],
|
],
|
||||||
|
@ -35,7 +35,8 @@ struct PassConfig {
|
|||||||
skip_control_dialect(false),
|
skip_control_dialect(false),
|
||||||
form_clusters(false),
|
form_clusters(false),
|
||||||
inline_functions(true),
|
inline_functions(true),
|
||||||
unfold_batch_matmul(true) {}
|
unfold_batch_matmul(true),
|
||||||
|
legalize_tf_while(true) {}
|
||||||
|
|
||||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||||
// added, which produces TF Lite ops.
|
// added, which produces TF Lite ops.
|
||||||
@ -61,6 +62,10 @@ struct PassConfig {
|
|||||||
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||||
// of tfl.fully_connected ops.
|
// of tfl.fully_connected ops.
|
||||||
bool unfold_batch_matmul;
|
bool unfold_batch_matmul;
|
||||||
|
// Whether to legalize TF While to TFL While.
|
||||||
|
// Note: This is staging step and will be removed.
|
||||||
|
// TODO(b/137395003): Remove post switching legalization.
|
||||||
|
bool legalize_tf_while;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
|
@ -28,6 +28,9 @@ limitations under the License.
|
|||||||
#include "llvm/TableGen/Record.h"
|
#include "llvm/TableGen/Record.h"
|
||||||
#include "llvm/TableGen/TableGenBackend.h"
|
#include "llvm/TableGen/TableGenBackend.h"
|
||||||
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
|
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
|
||||||
|
#include "mlir/TableGen/Format.h" // TF:llvm-project
|
||||||
|
#include "mlir/TableGen/Operator.h" // TF:llvm-project
|
||||||
|
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
|
||||||
|
|
||||||
using llvm::DefInit;
|
using llvm::DefInit;
|
||||||
using llvm::dyn_cast;
|
using llvm::dyn_cast;
|
||||||
@ -41,6 +44,19 @@ using llvm::SmallVector;
|
|||||||
using llvm::StringInit;
|
using llvm::StringInit;
|
||||||
using llvm::StringRef;
|
using llvm::StringRef;
|
||||||
|
|
||||||
|
enum ActionType {
|
||||||
|
OpConv,
|
||||||
|
RuntimeVerify,
|
||||||
|
};
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
llvm::cl::opt<ActionType> action(
|
||||||
|
llvm::cl::desc("Action to perform:"),
|
||||||
|
llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters",
|
||||||
|
"Generate operator converters"),
|
||||||
|
clEnumValN(RuntimeVerify, "gen-runtime-verifiers",
|
||||||
|
"Generate TFLite runtime verifiers")));
|
||||||
|
|
||||||
// Returns the associated option name for the given op definition.
|
// Returns the associated option name for the given op definition.
|
||||||
static inline std::string GetOperatorOptionName(const Record &def) {
|
static inline std::string GetOperatorOptionName(const Record &def) {
|
||||||
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
|
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
|
||||||
@ -342,8 +358,101 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void GenOperandResultVerifier(raw_ostream &os,
|
||||||
|
llvm::ArrayRef<llvm::Init *> values,
|
||||||
|
StringRef valueKind) {
|
||||||
|
mlir::tblgen::FmtContext fctx;
|
||||||
|
|
||||||
|
bool first = true;
|
||||||
|
for (auto static_value : llvm::enumerate(values)) {
|
||||||
|
auto *definit = llvm::cast<llvm::DefInit>(static_value.value());
|
||||||
|
auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate");
|
||||||
|
if (!val) continue;
|
||||||
|
|
||||||
|
// Create code block on first type to verify.
|
||||||
|
if (first) {
|
||||||
|
os << " {\n";
|
||||||
|
os << " unsigned index = " << static_value.index() << ";\n";
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||||
|
auto desc =
|
||||||
|
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
|
||||||
|
|
||||||
|
// Emit a loop to check all the dynamic values in the pack.
|
||||||
|
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
|
||||||
|
// Capitalize the first letter to match the function name
|
||||||
|
valueKind.substr(0, 1).upper(), valueKind.substr(1),
|
||||||
|
static_value.index());
|
||||||
|
|
||||||
|
os << " (void)v;\n"
|
||||||
|
<< " if (!("
|
||||||
|
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||||
|
<< formatv(
|
||||||
|
" return op->emitOpError(\"{0} #\") << index "
|
||||||
|
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||||
|
valueKind, desc)
|
||||||
|
<< " }\n" // if
|
||||||
|
<< " ++index;\n"
|
||||||
|
<< " }\n"; // for
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit closing brace if needed.
|
||||||
|
if (!first) os << " }\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||||
|
emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os);
|
||||||
|
|
||||||
|
// Retrieve all the definitions derived from TFL_Op and sort by record name.
|
||||||
|
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
|
||||||
|
llvm::sort(defs, LessRecord());
|
||||||
|
|
||||||
|
// Iterate through all the ops defined.
|
||||||
|
for (const auto *def : defs) {
|
||||||
|
mlir::tblgen::Operator op(*def);
|
||||||
|
if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue;
|
||||||
|
|
||||||
|
mlir::tblgen::FmtContext verify_ctx;
|
||||||
|
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||||
|
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
|
||||||
|
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||||
|
verify_ctx.withOp("top");
|
||||||
|
|
||||||
|
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||||
|
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||||
|
auto &value = op.getOperand(i);
|
||||||
|
// Skip from from first variadic operands for now. Else getOperand index
|
||||||
|
// used below doesn't match.
|
||||||
|
if (value.isVariadic()) break;
|
||||||
|
if (!value.name.empty())
|
||||||
|
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||||
|
}
|
||||||
|
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||||
|
auto &value = op.getResult(i);
|
||||||
|
// Skip from from first variadic results for now. Else getResult index
|
||||||
|
// used below doesn't match.
|
||||||
|
if (value.isVariadic()) break;
|
||||||
|
if (!value.name.empty())
|
||||||
|
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
|
||||||
|
"operand");
|
||||||
|
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
||||||
|
"result");
|
||||||
|
os << " return mlir::success();\n}\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
llvm::InitLLVM y(argc, argv);
|
llvm::InitLLVM y(argc, argv);
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||||
|
if (action == ActionType::OpConv)
|
||||||
return TableGenMain(argv[0], &OperatorWritersMain);
|
return TableGenMain(argv[0], &OperatorWritersMain);
|
||||||
|
return TableGenMain(argv[0], &RuntimeVerifierWriterMain);
|
||||||
}
|
}
|
@ -46,7 +46,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||||
@ -76,6 +76,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
@ -124,6 +125,20 @@ static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
|
|||||||
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
|
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
|
||||||
llvm::cl::init(false));
|
llvm::cl::init(false));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static opt<std::string> input_arrays_flag(
|
||||||
|
"input-arrays",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"List of input tensors, if different from the default inputs"),
|
||||||
|
llvm::cl::init(""));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static opt<std::string> output_arrays_flag(
|
||||||
|
"output-arrays",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"List of output tensors, if different from the default outputs"),
|
||||||
|
llvm::cl::init(""));
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
bool IsScalar(const TensorT& tensor) {
|
bool IsScalar(const TensorT& tensor) {
|
||||||
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
||||||
@ -590,6 +605,11 @@ StatusOr<Operation*> ConvertOp(
|
|||||||
op_state.addTypes({type});
|
op_state.addTypes({type});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (op_name == "tfl.lstm") {
|
||||||
|
// TODO(b/147587779): add the right region if region is empty.
|
||||||
|
op_state.addRegion();
|
||||||
|
}
|
||||||
|
|
||||||
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
||||||
if (IsCustomOp(op_name)) {
|
if (IsCustomOp(op_name)) {
|
||||||
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
|
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
|
||||||
@ -610,43 +630,30 @@ StatusOr<Operation*> ConvertOp(
|
|||||||
return builder.createOperation(op_state);
|
return builder.createOperation(op_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the output tensor indices for the given subgraph. If
|
// Returns indices of the given tensors in the subgraph. Returns error if a
|
||||||
// ordered_output_arrays is provided, then return the tensor indices in
|
// tensor name cannot be found in the subgraph.
|
||||||
// ordered_output_arrays.
|
StatusOr<std::vector<int>> GetTensorIndices(
|
||||||
StatusOr<llvm::SmallVector<int32_t, 4>> GetOutputTensorIndices(
|
const tflite::SubGraphT& subgraph,
|
||||||
const tflite::SubGraphT& subgraph, Location base_loc,
|
const std::vector<std::string>& tensor_names) {
|
||||||
const std::vector<std::string>& ordered_output_arrays) {
|
absl::flat_hash_map<std::string, int> name_to_index;
|
||||||
if (ordered_output_arrays.empty()) {
|
for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
|
||||||
return llvm::SmallVector<int32_t, 4>(subgraph.outputs.begin(),
|
name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
|
||||||
subgraph.outputs.end());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<int32_t, 4> outputs;
|
std::vector<int> indices;
|
||||||
outputs.resize(ordered_output_arrays.size());
|
indices.reserve(tensor_names.size());
|
||||||
absl::flat_hash_map<std::string, int> output_order_map;
|
|
||||||
for (auto output : llvm::enumerate(ordered_output_arrays)) {
|
for (const auto& name : tensor_names) {
|
||||||
output_order_map[output.value()] = output.index();
|
auto found = name_to_index.find(name);
|
||||||
|
if (found != name_to_index.end()) {
|
||||||
|
indices.push_back(found->second);
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("could not find tensor in subgraph: ",
|
||||||
|
name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int tensor_index = 0;
|
return indices;
|
||||||
int found_output_tensors = 0;
|
|
||||||
for (const auto& tensor : subgraph.tensors) {
|
|
||||||
auto found = output_order_map.find(tensor->name);
|
|
||||||
if (found != output_order_map.end()) {
|
|
||||||
const int output_index = found->second;
|
|
||||||
outputs[output_index] = tensor_index;
|
|
||||||
++found_output_tensors;
|
|
||||||
}
|
|
||||||
++tensor_index;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (found_output_tensors != ordered_output_arrays.size()) {
|
|
||||||
auto err = errors::InvalidArgument(
|
|
||||||
"cannot find all nodes in ordered_output_arrays");
|
|
||||||
return emitError(base_loc, err.ToString()), err;
|
|
||||||
}
|
|
||||||
|
|
||||||
return outputs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a list of tensor indices, returns a string of concatenated tensor names
|
// Given a list of tensor indices, returns a string of concatenated tensor names
|
||||||
@ -661,17 +668,20 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
|
|||||||
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a list of output indices, traverses the subgraph and returns the set of
|
// Traverses the subgraph from output_indices to input_indices and returns the
|
||||||
// ops that are ancestors of the output tensors.
|
// set of ops that are visited.
|
||||||
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
|
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
|
||||||
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
|
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
|
||||||
|
ArrayRef<int32_t> output_indices) {
|
||||||
// Create a map from tensor index to defining op.
|
// Create a map from tensor index to defining op.
|
||||||
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
|
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
|
||||||
for (const auto& op : subgraph.operators) {
|
for (const auto& op : subgraph.operators) {
|
||||||
for (int32_t output : op->outputs) {
|
for (int32_t output : op->outputs) {
|
||||||
|
if (!llvm::is_contained(input_indices, output)) {
|
||||||
defining_op[output] = op.get();
|
defining_op[output] = op.get();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<const tflite::OperatorT*> queue;
|
std::vector<const tflite::OperatorT*> queue;
|
||||||
for (int32_t output : output_indices) {
|
for (int32_t output : output_indices) {
|
||||||
@ -718,18 +728,40 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
const std::vector<std::string>& op_names,
|
const std::vector<std::string>& op_names,
|
||||||
const std::vector<std::string>& func_names,
|
const std::vector<std::string>& func_names,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
||||||
Location base_loc, Builder builder,
|
Location base_loc, Builder builder, bool is_entry_point,
|
||||||
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
|
|
||||||
bool use_external_constant,
|
bool use_external_constant,
|
||||||
|
const std::vector<std::string>& ordered_input_arrays,
|
||||||
|
const std::vector<std::string>& ordered_output_arrays,
|
||||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
llvm::SmallVector<mlir::Type, 2> ret_types;
|
llvm::SmallVector<mlir::Type, 2> ret_types;
|
||||||
llvm::SmallVector<mlir::Type, 4> input_types;
|
llvm::SmallVector<mlir::Type, 4> input_types;
|
||||||
|
|
||||||
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
|
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
|
||||||
|
|
||||||
// Construct function type
|
std::vector<int> func_inputs = subgraph.inputs;
|
||||||
for (auto input : subgraph.inputs) {
|
if (is_entry_point && !ordered_input_arrays.empty()) {
|
||||||
auto& tensor = *subgraph.tensors.at(input);
|
if (!experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
|
// TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"input-arrays should be used with experimental pruning flag");
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(func_inputs,
|
||||||
|
GetTensorIndices(subgraph, ordered_input_arrays));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add state variables to inputs.
|
||||||
|
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
|
||||||
|
func_inputs.end());
|
||||||
|
for (int i = 0; i < subgraph.tensors.size(); i++) {
|
||||||
|
auto& tensor = *subgraph.tensors.at(i);
|
||||||
|
if (tensor.is_variable && !input_index_set.contains(i)) {
|
||||||
|
func_inputs.emplace_back(i);
|
||||||
|
input_index_set.insert(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto input_or_variable : func_inputs) {
|
||||||
|
auto& tensor = *subgraph.tensors.at(input_or_variable);
|
||||||
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
|
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
|
||||||
// but we cannot differentiate scalars from unranked tensors.
|
// but we cannot differentiate scalars from unranked tensors.
|
||||||
// Here we reverse the default assumption that shape = [] means unranked.
|
// Here we reverse the default assumption that shape = [] means unranked.
|
||||||
@ -753,9 +785,11 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
std::vector<int> func_outputs = subgraph.outputs;
|
||||||
auto func_outputs,
|
if (is_entry_point && !ordered_output_arrays.empty()) {
|
||||||
GetOutputTensorIndices(subgraph, base_loc, ordered_output_arrays));
|
TF_ASSIGN_OR_RETURN(func_outputs,
|
||||||
|
GetTensorIndices(subgraph, ordered_output_arrays));
|
||||||
|
}
|
||||||
|
|
||||||
for (auto output : func_outputs) {
|
for (auto output : func_outputs) {
|
||||||
bool is_constant = !is_op_output[output];
|
bool is_constant = !is_op_output[output];
|
||||||
@ -782,8 +816,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
Value maybe_optional_arg_marker = nullptr;
|
Value maybe_optional_arg_marker = nullptr;
|
||||||
|
|
||||||
// Get or construct MLIR values for each input
|
// Get or construct MLIR values for each input
|
||||||
for (int i = 0, e = subgraph.inputs.size(); i < e; i++) {
|
for (int i = 0, e = func_inputs.size(); i < e; i++) {
|
||||||
auto input_tensor = subgraph.inputs[i];
|
auto input_tensor = func_inputs[i];
|
||||||
const auto& tensor = *subgraph.tensors.at(input_tensor);
|
const auto& tensor = *subgraph.tensors.at(input_tensor);
|
||||||
auto loc = TensorLoc(tensor, builder, base_loc);
|
auto loc = TensorLoc(tensor, builder, base_loc);
|
||||||
if (vals_map[input_tensor]) {
|
if (vals_map[input_tensor]) {
|
||||||
@ -806,9 +840,9 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
// Set tf.entry_function attribute
|
// Set tf.entry_function attribute
|
||||||
if (is_entry_point) {
|
if (is_entry_point) {
|
||||||
llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
|
llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
|
||||||
if (!subgraph.inputs.empty()) {
|
if (!func_inputs.empty()) {
|
||||||
attributes.push_back(BuildTFEntryFunctionAttribute(
|
attributes.push_back(BuildTFEntryFunctionAttribute(
|
||||||
subgraph, &builder, "inputs", subgraph.inputs));
|
subgraph, &builder, "inputs", func_inputs));
|
||||||
}
|
}
|
||||||
if (!func_outputs.empty()) {
|
if (!func_outputs.empty()) {
|
||||||
attributes.push_back(BuildTFEntryFunctionAttribute(
|
attributes.push_back(BuildTFEntryFunctionAttribute(
|
||||||
@ -820,7 +854,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||||
if (experimental_prune_unreachable_nodes_unconditionally) {
|
if (experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
|
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
|
||||||
PruneSubgraph(subgraph, func_outputs));
|
PruneSubgraph(subgraph, func_inputs, func_outputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct MLIR operators from TFLite operators
|
// Construct MLIR operators from TFLite operators
|
||||||
@ -931,8 +965,9 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
|||||||
|
|
||||||
OwningModuleRef tflite::FlatBufferToMlir(
|
OwningModuleRef tflite::FlatBufferToMlir(
|
||||||
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
||||||
const std::vector<std::string>& ordered_output_arrays,
|
|
||||||
bool use_external_constant,
|
bool use_external_constant,
|
||||||
|
const std::vector<std::string>& ordered_input_arrays,
|
||||||
|
const std::vector<std::string>& ordered_output_arrays,
|
||||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
auto model_ptr =
|
auto model_ptr =
|
||||||
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
||||||
@ -971,33 +1006,25 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
|||||||
builder.getStringAttr(model->description));
|
builder.getStringAttr(model->description));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ordered_output_arrays.empty() && model->subgraphs.size() > 1) {
|
|
||||||
// TODO(b/141485522): support more than one subgraph.
|
|
||||||
return emitError(base_loc,
|
|
||||||
"ordered_output_arrays does not support more than one "
|
|
||||||
"subgraph yet"),
|
|
||||||
nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto e : llvm::enumerate(model->subgraphs)) {
|
for (auto e : llvm::enumerate(model->subgraphs)) {
|
||||||
auto& subgraph = e.value();
|
auto& subgraph = e.value();
|
||||||
std::string name = SubgraphName(e.index(), *subgraph);
|
std::string name = SubgraphName(e.index(), *subgraph);
|
||||||
auto func_or_error = ConvertSubgraph(
|
auto func_or_error = ConvertSubgraph(
|
||||||
*subgraph, name, operator_names, func_names, model->buffers, base_loc,
|
*subgraph, name, operator_names, func_names, model->buffers, base_loc,
|
||||||
// Only the entry point needs pseudo_input_ops
|
builder,
|
||||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||||
builder, ordered_output_arrays,
|
|
||||||
/*is_entry_point=*/e.index() == 0,
|
/*is_entry_point=*/e.index() == 0,
|
||||||
/*use_external_constant=*/use_external_constant,
|
/*use_external_constant=*/use_external_constant, ordered_input_arrays,
|
||||||
|
ordered_output_arrays,
|
||||||
experimental_prune_unreachable_nodes_unconditionally);
|
experimental_prune_unreachable_nodes_unconditionally);
|
||||||
if (!func_or_error.ok()) {
|
if (!func_or_error.ok()) {
|
||||||
return emitError(base_loc, "could not translate function ")
|
return emitError(base_loc, "could not translate function ")
|
||||||
<< subgraph->name,
|
<< subgraph->name << ": "
|
||||||
|
<< func_or_error.status().error_message(),
|
||||||
nullptr;
|
nullptr;
|
||||||
}
|
}
|
||||||
module.push_back(func_or_error.ConsumeValueOrDie());
|
module.push_back(func_or_error.ConsumeValueOrDie());
|
||||||
}
|
}
|
||||||
// TFLite subgraphs do not necessarily have names,
|
|
||||||
|
|
||||||
return OwningModuleRef(module);
|
return OwningModuleRef(module);
|
||||||
}
|
}
|
||||||
@ -1012,17 +1039,24 @@ static OwningModuleRef FlatBufferFileToMlirTrans(
|
|||||||
auto loc =
|
auto loc =
|
||||||
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
||||||
|
|
||||||
// Parses output_arrays_order from command line option.
|
// Parses input/output names from command line options.
|
||||||
|
std::vector<std::string> inputs;
|
||||||
std::vector<std::string> outputs;
|
std::vector<std::string> outputs;
|
||||||
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) {
|
// Use output parser since we only have tensor names.
|
||||||
|
if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) {
|
||||||
|
return emitError(loc, "parsing input array info failed ")
|
||||||
|
<< input_arrays_flag,
|
||||||
|
nullptr;
|
||||||
|
}
|
||||||
|
if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) {
|
||||||
return emitError(loc, "parsing output array info failed ")
|
return emitError(loc, "parsing output array info failed ")
|
||||||
<< output_arrays_string,
|
<< output_arrays_flag,
|
||||||
nullptr;
|
nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return tflite::FlatBufferToMlir(
|
return tflite::FlatBufferToMlir(
|
||||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||||
context, loc, outputs, use_external_constant,
|
context, loc, use_external_constant, inputs, outputs,
|
||||||
experimental_prune_unreachable_nodes_unconditionally);
|
experimental_prune_unreachable_nodes_unconditionally);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,9 +35,9 @@ namespace tflite {
|
|||||||
// are not ancestors of the output nodes will be pruned.
|
// are not ancestors of the output nodes will be pruned.
|
||||||
mlir::OwningModuleRef FlatBufferToMlir(
|
mlir::OwningModuleRef FlatBufferToMlir(
|
||||||
absl::string_view buffer, mlir::MLIRContext* context,
|
absl::string_view buffer, mlir::MLIRContext* context,
|
||||||
mlir::Location base_loc,
|
mlir::Location base_loc, bool use_external_constant = false,
|
||||||
const std::vector<std::string>& ordered_output_arrays,
|
const std::vector<std::string>& ordered_input_arrays = {},
|
||||||
bool use_external_constant = false,
|
const std::vector<std::string>& ordered_output_arrays = {},
|
||||||
bool experimental_prune_unreachable_nodes_unconditionally = false);
|
bool experimental_prune_unreachable_nodes_unconditionally = false);
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "llvm/Support/ToolOutputFile.h"
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||||
@ -122,8 +122,6 @@ bool emit_custom_ops;
|
|||||||
bool emit_select_tf_ops;
|
bool emit_select_tf_ops;
|
||||||
bool lower_tensor_list_ops;
|
bool lower_tensor_list_ops;
|
||||||
bool strip_debug_info;
|
bool strip_debug_info;
|
||||||
// NOLINTNEXTLINE
|
|
||||||
std::string output_arrays_string;
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static opt<bool, true> emit_builtin_tflite_ops_flag(
|
static opt<bool, true> emit_builtin_tflite_ops_flag(
|
||||||
@ -156,11 +154,6 @@ static opt<bool, true> strip_debug_info_flag(
|
|||||||
"strip-debug-info", llvm::cl::desc("Strip debug info during export"),
|
"strip-debug-info", llvm::cl::desc("Strip debug info during export"),
|
||||||
llvm::cl::location(strip_debug_info), llvm::cl::init(false));
|
llvm::cl::location(strip_debug_info), llvm::cl::init(false));
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
|
||||||
static opt<std::string, true> output_arrays_flag(
|
|
||||||
"output-arrays", llvm::cl::desc("List of output tensors"),
|
|
||||||
llvm::cl::location(output_arrays_string), llvm::cl::init(""));
|
|
||||||
|
|
||||||
ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
|
ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
|
||||||
|
|
||||||
// Use initial buffer size in flatbuffer builder to be same as the initial size
|
// Use initial buffer size in flatbuffer builder to be same as the initial size
|
||||||
@ -172,7 +165,7 @@ constexpr size_t kInitialBufferSize = 10240;
|
|||||||
// `isSigned` is set to false for other types.
|
// `isSigned` is set to false for other types.
|
||||||
static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||||
bool is_signed = true) {
|
bool is_signed = true) {
|
||||||
if (!is_signed && type.isInteger(8)) {
|
if (!is_signed && type.isSignlessInteger(8)) {
|
||||||
return tflite::TensorType_UINT8;
|
return tflite::TensorType_UINT8;
|
||||||
}
|
}
|
||||||
if (!is_signed) {
|
if (!is_signed) {
|
||||||
|
@ -27,7 +27,5 @@ extern bool emit_custom_ops;
|
|||||||
extern bool lower_tensor_list_ops;
|
extern bool lower_tensor_list_ops;
|
||||||
// The flag to control whether debug info gets stripped on export.
|
// The flag to control whether debug info gets stripped on export.
|
||||||
extern bool strip_debug_info;
|
extern bool strip_debug_info;
|
||||||
// The flag to control the output array info of tflite graph.
|
|
||||||
extern std::string output_arrays_string;
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_
|
||||||
|
@ -71,4 +71,23 @@ def TFL_SparseOp : OpInterface<"SparseOpInterface"> {
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TFL runtime type verification of operand/result types.
|
||||||
|
|
||||||
|
def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||||
|
let description = [{
|
||||||
|
Interface to verify TFLite runtime op verification.
|
||||||
|
|
||||||
|
This verifies that the converted TFLite ops has operand/result type
|
||||||
|
supported by the TFLite runtime.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
StaticInterfaceMethod<
|
||||||
|
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||||
|
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
|
||||||
|
>,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TFL_OP_INTERFACES
|
#endif // TFL_OP_INTERFACES
|
||||||
|
@ -23,9 +23,10 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/APFloat.h"
|
#include "llvm/ADT/APFloat.h"
|
||||||
#include "llvm/ADT/APInt.h"
|
#include "llvm/ADT/APInt.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||||
@ -36,6 +37,7 @@ limitations under the License.
|
|||||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||||
|
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
@ -273,7 +275,7 @@ Attribute ConstFoldBinaryOp(
|
|||||||
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
|
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
|
||||||
float_calculate, is_commutative);
|
float_calculate, is_commutative);
|
||||||
|
|
||||||
if (elemType.isa<IntegerType>())
|
if (elemType.isSignlessInteger())
|
||||||
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
|
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
|
||||||
int_calculate, is_commutative);
|
int_calculate, is_commutative);
|
||||||
|
|
||||||
@ -721,12 +723,11 @@ static LogicalResult Verify(PackOp op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Make sure all inputs have the same shape and element type.
|
// Make sure all inputs have the same shape and element type.
|
||||||
// TODO(rahulsp): Simplify once b/135032064 is fixed.
|
// TODO(b/135032063): Simplify once fixed.
|
||||||
for (Value operand : op.getOperands()) {
|
for (Type operand_type : op.getOperandTypes()) {
|
||||||
auto other_type = operand.getType().cast<ShapedType>();
|
if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
|
||||||
if (input_type != other_type)
|
|
||||||
return op.emitOpError("operands should be of the same type. got ")
|
return op.emitOpError("operands should be of the same type. got ")
|
||||||
<< input_type << ", " << other_type;
|
<< input_type << ", " << operand_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@ -1106,10 +1107,10 @@ static LogicalResult VerifySplitOpOutputTypes(
|
|||||||
for (int64_t i = 0; i < num_splits; ++i) {
|
for (int64_t i = 0; i < num_splits; ++i) {
|
||||||
auto expected_output_type = get_expected_output_type(i);
|
auto expected_output_type = get_expected_output_type(i);
|
||||||
Value output = op->getResult(i);
|
Value output = op->getResult(i);
|
||||||
auto output_type = output.getType().dyn_cast<RankedTensorType>();
|
if (failed(verifyCompatibleShape(output.getType(), expected_output_type)))
|
||||||
if (!output_type || output_type != expected_output_type)
|
|
||||||
return op->emitOpError()
|
return op->emitOpError()
|
||||||
<< "output #" << i << " should be " << expected_output_type;
|
<< "output #" << i << " should be " << expected_output_type
|
||||||
|
<< " instead got " << output.getType();
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -1559,7 +1560,7 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
limit_tensor.getType().getRank() == 0 &&
|
limit_tensor.getType().getRank() == 0 &&
|
||||||
delta_tensor.getType().getRank() == 0);
|
delta_tensor.getType().getRank() == 0);
|
||||||
Type elem_type = getType().cast<ShapedType>().getElementType();
|
Type elem_type = getType().cast<ShapedType>().getElementType();
|
||||||
if (elem_type.isa<IntegerType>()) {
|
if (elem_type.isSignlessInteger()) {
|
||||||
auto start_attr = start_tensor.getValue<IntegerAttr>({});
|
auto start_attr = start_tensor.getValue<IntegerAttr>({});
|
||||||
auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
|
auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
|
||||||
auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
|
auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
|
||||||
@ -1661,7 +1662,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
|
|
||||||
// Do not try to fold elements attr of a quant type because
|
// Do not try to fold elements attr of a quant type because
|
||||||
// DenseElementsAttr does not support it.
|
// DenseElementsAttr does not support it.
|
||||||
if (!getType().cast<ShapedType>().getElementType().isIntOrFloat())
|
if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
assert(perm_tensor.getType().getRank() == 1);
|
assert(perm_tensor.getType().getRank() == 1);
|
||||||
@ -1741,47 +1742,108 @@ static LogicalResult Verify(TransposeOp op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult Verify(WhileOp op) {
|
||||||
|
if (op.getNumOperands() != op.getNumResults())
|
||||||
|
return op.emitOpError(llvm::formatv(
|
||||||
|
"number of operands does not match number of results ({0} != {1})",
|
||||||
|
op.getNumOperands(), op.getNumResults()));
|
||||||
|
// TODO(jpienaar): Verify operand, result & block arguments types
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct WhileResultOperandsMatch : public OpRewritePattern<WhileOp> {
|
// Canonicalize While op so that results and operands match and external values
|
||||||
|
// are via implicit capture rather than via block args.
|
||||||
|
struct WhileResultOperandsMatchAndImplicitCapture
|
||||||
|
: public OpRewritePattern<WhileOp> {
|
||||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(WhileOp while_op,
|
PatternMatchResult matchAndRewrite(WhileOp while_op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto size = while_op.body().front().getArguments().size();
|
// Replace values simply passed through the body with extern values. The
|
||||||
Operation *op = while_op.getOperation();
|
// block arguments of body and while match and so the corresponding cond
|
||||||
auto old_size = op->getNumResults();
|
// argument can be easily found.
|
||||||
// No change needed as the number of operands match the number of results.
|
bool unchanged = true;
|
||||||
if (size == old_size) return matchFailure();
|
auto &body_block = while_op.body().front();
|
||||||
|
auto &cond_block = while_op.cond().front();
|
||||||
|
auto &yield = *body_block.getTerminator();
|
||||||
|
for (auto ba : body_block.getArguments()) {
|
||||||
|
if (ba == yield.getOperand(ba.getArgNumber())) {
|
||||||
|
unchanged = false;
|
||||||
|
auto value = while_op.getOperand(ba.getArgNumber());
|
||||||
|
ba.replaceAllUsesWith(value);
|
||||||
|
cond_block.getArgument(ba.getArgNumber()).replaceAllUsesWith(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Collect the new types by combining results of old op with additional
|
// The While ops operands and result types need to match
|
||||||
// operand results.
|
SmallVector<Value, 4> new_operands;
|
||||||
|
SmallVector<Value, 4> new_body_yield;
|
||||||
|
SmallVector<bool, 4> const_operand(while_op.getNumOperands(), false);
|
||||||
llvm::SmallVector<Type, 4> types;
|
llvm::SmallVector<Type, 4> types;
|
||||||
types.reserve(size);
|
new_operands.reserve(while_op.getNumOperands());
|
||||||
for (auto type : while_op.getResultTypes()) types.push_back(type);
|
new_body_yield.reserve(while_op.getNumOperands());
|
||||||
for (auto arg : while_op.body().front().getArguments().drop_front(old_size))
|
types.reserve(while_op.getNumOperands());
|
||||||
types.push_back(arg.getType());
|
|
||||||
// Collect operands.
|
// Remove block arguments not used in either cond or body. This leaves the
|
||||||
llvm::SmallVector<Value, 8> operands;
|
// block arguments of body and cond matching still.
|
||||||
operands.reserve(while_op.getNumOperands());
|
int arg_index = 0;
|
||||||
for (auto operand : while_op.getOperands()) operands.push_back(operand);
|
for (int while_index = 0, e = while_op.getNumOperands(); while_index < e;
|
||||||
|
++while_index) {
|
||||||
|
auto value = while_op.getOperand(while_index);
|
||||||
|
if (body_block.getArgument(arg_index).use_empty() &&
|
||||||
|
cond_block.getArgument(arg_index).use_empty() &&
|
||||||
|
// This could be relaxed and casts inserted.
|
||||||
|
while_op.getResult(while_index).getType() == value.getType()) {
|
||||||
|
unchanged = false;
|
||||||
|
body_block.eraseArgument(arg_index);
|
||||||
|
cond_block.eraseArgument(arg_index);
|
||||||
|
|
||||||
|
// Mark operand as constant and replace all uses with input to while.
|
||||||
|
while_op.getResult(while_index).replaceAllUsesWith(value);
|
||||||
|
const_operand[while_index] = true;
|
||||||
|
} else {
|
||||||
|
new_operands.push_back(value);
|
||||||
|
new_body_yield.push_back(yield.getOperand(while_index));
|
||||||
|
auto type = while_op.getResult(while_index).getType();
|
||||||
|
types.push_back(type);
|
||||||
|
++arg_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done if no values removed from blocks and operands & results match.
|
||||||
|
if (unchanged) return matchFailure();
|
||||||
|
|
||||||
// Replace with new While with matching operands and results.
|
// Replace with new While with matching operands and results.
|
||||||
|
Operation *op = while_op.getOperation();
|
||||||
Operation *new_op = rewriter.insert(
|
Operation *new_op = rewriter.insert(
|
||||||
Operation::create(op->getLoc(), op->getName(), types, operands,
|
Operation::create(op->getLoc(), op->getName(), types, new_operands,
|
||||||
op->getAttrs(), {}, /*numRegions=*/2,
|
op->getAttrs(), {}, /*numRegions=*/2,
|
||||||
/*resizableOperandList=*/true));
|
/*resizableOperandList=*/true));
|
||||||
|
|
||||||
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
|
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
|
||||||
rewriter.replaceOp(op,
|
int new_index = 0;
|
||||||
new_op->getResults().take_front(op->getNumResults()));
|
for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
|
||||||
|
if (const_operand[op_index]) continue;
|
||||||
|
op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
|
||||||
|
++new_index;
|
||||||
|
}
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
Block &new_body_block = cast<WhileOp>(new_op).body().front();
|
||||||
|
rewriter.setInsertionPointToEnd(&new_body_block);
|
||||||
|
rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
|
||||||
|
new_body_yield);
|
||||||
|
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.insert<WhileResultOperandsMatch>(context);
|
results.insert<WhileResultOperandsMatchAndImplicitCapture>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
Region &WhileOp::getLoopBody() { return body(); }
|
Region &WhileOp::getLoopBody() { return body(); }
|
||||||
@ -1809,6 +1871,7 @@ LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
|
|||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
|
||||||
|
|
||||||
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
|
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
|
||||||
Attribute value,
|
Attribute value,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -17,6 +17,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/mlir/lite:common",
|
"//tensorflow/compiler/mlir/lite:common",
|
||||||
|
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||||
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
|
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
|
||||||
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
|
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
|
||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
@ -62,6 +63,41 @@ const char kDetectionPostProcessOp[] =
|
|||||||
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
|
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
|
||||||
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
|
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
|
||||||
|
|
||||||
|
const char kUnidirectionalSequenceLstmOp[] =
|
||||||
|
"name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
|
||||||
|
"DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
|
||||||
|
"input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
|
||||||
|
"name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
|
||||||
|
"'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
|
||||||
|
"'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
|
||||||
|
"'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
|
||||||
|
"'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
|
||||||
|
"'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
|
||||||
|
"'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
|
||||||
|
"'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
|
||||||
|
"'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
|
||||||
|
"type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
|
||||||
|
"input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
|
||||||
|
"'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
|
||||||
|
"type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
|
||||||
|
"input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
|
||||||
|
"name: 'InputCellStateTensor' type: DT_FLOAT } "
|
||||||
|
"output_arg: { name: 'Concat' type: DT_FLOAT} "
|
||||||
|
"output_arg: { name: "
|
||||||
|
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
|
||||||
|
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||||
|
|
||||||
|
const char kUnidirectionalSequenceRnnOp[] =
|
||||||
|
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
|
||||||
|
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
|
||||||
|
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
|
||||||
|
"name: 'Bias' type: DT_FLOAT} "
|
||||||
|
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
|
||||||
|
"output_arg: { name: "
|
||||||
|
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
|
||||||
|
"DT_FLOAT} "
|
||||||
|
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||||
|
|
||||||
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
||||||
// conversion mapping for constants defined in TFLite Python API.
|
// conversion mapping for constants defined in TFLite Python API.
|
||||||
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||||
@ -259,6 +295,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
||||||
toco_flags.custom_opdefs().end());
|
toco_flags.custom_opdefs().end());
|
||||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||||
|
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
|
||||||
|
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
|
||||||
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
|
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -277,6 +315,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
pass_config.lower_tensor_list_ops = true;
|
pass_config.lower_tensor_list_ops = true;
|
||||||
|
|
||||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||||
|
// Convert back to outlined while format for export back to flatbuffer.
|
||||||
|
if (pass_config.legalize_tf_while) {
|
||||||
|
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||||
|
}
|
||||||
|
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||||
|
|
||||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
|
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
|
||||||
#include "mlir/IR/AffineMap.h" // TF:llvm-project
|
#include "mlir/IR/AffineMap.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
|
@ -61,11 +61,9 @@ TfLiteStatus QuantizeModel(
|
|||||||
std::string serialized_model(
|
std::string serialized_model(
|
||||||
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
||||||
input_builder.GetSize());
|
input_builder.GetSize());
|
||||||
std::vector<std::string> output_arrays_order;
|
|
||||||
|
|
||||||
OwningModuleRef module =
|
OwningModuleRef module = tflite::FlatBufferToMlir(serialized_model, &context,
|
||||||
tflite::FlatBufferToMlir(serialized_model, &context,
|
UnknownLoc::get(&context));
|
||||||
UnknownLoc::get(&context), output_arrays_order);
|
|
||||||
if (!module) {
|
if (!module) {
|
||||||
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
@ -150,7 +150,8 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
|
|
||||||
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
|
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
|
||||||
float error_tolerance, bool single_layer_verify)
|
float error_tolerance, bool single_layer_verify)
|
||||||
: RewritePattern(DQ::getOperationName(), 1, context),
|
// Set the score to a large number so it is always preferred.
|
||||||
|
: RewritePattern(DQ::getOperationName(), 300, context),
|
||||||
enable_verify(enable_verify),
|
enable_verify(enable_verify),
|
||||||
error_tolerance(error_tolerance),
|
error_tolerance(error_tolerance),
|
||||||
single_layer_verify(single_layer_verify) {}
|
single_layer_verify(single_layer_verify) {}
|
||||||
@ -190,7 +191,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
auto ele_type = operand.getType().cast<TensorType>().getElementType();
|
auto ele_type = operand.getType().cast<TensorType>().getElementType();
|
||||||
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
|
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
|
||||||
inputs.push_back(op_inst.input());
|
inputs.push_back(op_inst.input());
|
||||||
} else if (ele_type.isa<IntegerType>()) {
|
} else if (ele_type.isSignlessInteger()) {
|
||||||
// If the operand is an integer tensor, then it doesn't require the
|
// If the operand is an integer tensor, then it doesn't require the
|
||||||
// DQ op in the pattern.
|
// DQ op in the pattern.
|
||||||
inputs.push_back(operand);
|
inputs.push_back(operand);
|
||||||
@ -224,7 +225,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
auto user = llvm::cast<Q>(*result.user_begin());
|
auto user = llvm::cast<Q>(*result.user_begin());
|
||||||
outputs_replaced.insert({user.output(), enumerated_result.index()});
|
outputs_replaced.insert({user.output(), enumerated_result.index()});
|
||||||
output_types.push_back(user.getType());
|
output_types.push_back(user.getType());
|
||||||
} else if (result_ele_type.template isa<IntegerType>()) {
|
} else if (result_ele_type.isSignlessInteger()) {
|
||||||
// If the result is an integer tensor, then it doesn't require the
|
// If the result is an integer tensor, then it doesn't require the
|
||||||
// D op in the pattern.
|
// D op in the pattern.
|
||||||
outputs_replaced.insert({result, enumerated_result.index()});
|
outputs_replaced.insert({result, enumerated_result.index()});
|
||||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||||
|
@ -48,11 +48,9 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
|||||||
std::string serialized_model(
|
std::string serialized_model(
|
||||||
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
||||||
input_builder.GetSize());
|
input_builder.GetSize());
|
||||||
std::vector<std::string> output_arrays_order;
|
|
||||||
|
|
||||||
OwningModuleRef module =
|
OwningModuleRef module = tflite::FlatBufferToMlir(serialized_model, &context,
|
||||||
tflite::FlatBufferToMlir(serialized_model, &context,
|
UnknownLoc::get(&context));
|
||||||
UnknownLoc::get(&context), output_arrays_order);
|
|
||||||
if (!module) {
|
if (!module) {
|
||||||
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
|
@ -15,5 +15,6 @@ filegroup(
|
|||||||
data = [
|
data = [
|
||||||
"//tensorflow/compiler/mlir:tf-opt",
|
"//tensorflow/compiler/mlir:tf-opt",
|
||||||
"@llvm-project//llvm:FileCheck",
|
"@llvm-project//llvm:FileCheck",
|
||||||
|
"@llvm-project//llvm:not",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -test-constant-fold | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
// CHECK-LABEL: @add_float
|
// CHECK-LABEL: @add_float
|
||||||
func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
|
func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
|
||||||
|
@ -10,9 +10,7 @@ glob_lit_tests(
|
|||||||
driver = "@llvm-project//mlir:run_lit.sh",
|
driver = "@llvm-project//mlir:run_lit.sh",
|
||||||
test_file_exts = [
|
test_file_exts = [
|
||||||
"pbtxt",
|
"pbtxt",
|
||||||
# TODO(fengliuai): reenable these tests after the fused loc is
|
# "py", TODO(b/150304798)
|
||||||
# supported in the diagnostic handler.
|
|
||||||
# "py",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,6 +27,20 @@ func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1
|
|||||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||||
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||||
|
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||||
|
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||||
|
return %2 : tensor<1x128x128x8xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDilatedConvWithNonTrivialDilations
|
||||||
|
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
|
||||||
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
|
||||||
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
||||||
|
// CHECK-NEXT: return [[RESULT]]
|
||||||
|
}
|
||||||
|
|
||||||
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||||
@ -104,7 +118,7 @@ func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1:
|
|||||||
|
|
||||||
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
@ -115,7 +129,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
|
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
@ -125,7 +139,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
@ -136,7 +150,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
|
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
@ -146,7 +160,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
|
|||||||
|
|
||||||
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||||
@ -157,7 +171,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
|
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
@ -167,7 +181,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||||
@ -178,7 +192,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
|
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
@ -188,7 +202,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
|
|||||||
|
|
||||||
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
@ -200,7 +214,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
|
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
@ -210,7 +224,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
@ -222,10 +236,29 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
|
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> {
|
||||||
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||||
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||||
|
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
|
%3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
|
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32>
|
||||||
|
return %4 : tensor<1x128x128x1xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis
|
||||||
|
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
|
||||||
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"
|
||||||
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
|
||||||
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"
|
||||||
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
||||||
|
// CHECK-NEXT: return [[RESULT]]
|
||||||
|
}
|
||||||
|
@ -178,15 +178,20 @@ func @inputsAfterOutputs() {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// expected-error@+1 {{Found malformed ophint regions: missing inputs or outputs.}}
|
|
||||||
module {
|
module {
|
||||||
func @extractOphintFailure() {
|
func @extractOphintSame() {
|
||||||
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
|
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
|
||||||
%1 = call @AnotherFunc(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
%1 = call @AnotherFunc(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||||
%2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
%2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||||
%3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
%3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||||
%4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
%4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||||
return
|
return
|
||||||
|
|
||||||
|
// CHECK: [[VAL_0:%.*]] = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
|
||||||
|
// CHECK: [[VAL_1:%.*]] = call @AnotherFunc([[VAL_0]]) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||||
|
// CHECK: [[VAL_2:%.*]] = "tf.Sigmoid"([[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||||
|
// CHECK: [[VAL_3:%.*]] = "tf.Mul"([[VAL_2]], [[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||||
|
// CHECK: [[VAL_4:%.*]] = "tf.Identity"([[VAL_3]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @AnotherFunc(%arg0: tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> {
|
func @AnotherFunc(%arg0: tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> {
|
||||||
|
@ -0,0 +1,13 @@
|
|||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -input-arrays=squared_difference --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||||
|
// Tests -input-arrays flag.
|
||||||
|
|
||||||
|
func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
|
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||||
|
%1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
|
||||||
|
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
|
||||||
|
return %2 : tensor<4xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: main
|
||||||
|
// CHECK-NOT: tfl.squared_difference
|
||||||
|
// CHECK: tfl.mul %[[CONST:.*]], %arg0
|
||||||
|
}
|
@ -0,0 +1,15 @@
|
|||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||||
|
// Ensure lstm roundtrip exactly
|
||||||
|
|
||||||
|
func @main(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> {
|
||||||
|
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||||
|
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||||
|
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
return %24 : tensor<4xf32>
|
||||||
|
// CHECK-LABEL: main
|
||||||
|
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
|
||||||
|
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
|
||||||
|
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
// CHECK: return %[[RES0]]
|
||||||
|
|
||||||
|
}
|
@ -1,25 +1,31 @@
|
|||||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||||
|
|
||||||
// Check to see if function references in while loops are preserved
|
// Check to see if function references in while loops are preserved
|
||||||
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
|
||||||
// TODO(b/138222071) Expect first output to be a scalar
|
// TODO(b/138222071) Expect first output to be a scalar
|
||||||
// CHECK: %{{.*}}:2 = "tf.While"(%{{.*}}, %{{.*}}) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>)
|
// CHECK: %{{.*}}:2 = "tf.While"(%{{.*}}, %{{.*}}) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>)
|
||||||
|
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||||
// While %arg0 is greater than zero, element wise add %arg1 with itself.
|
// While %arg0 is greater than zero, element wise add %arg1 with itself.
|
||||||
%0:2 = "tf.While"(%arg0, %arg1) {
|
%0:2 = "tfl.while"(%arg0, %arg1) ( {
|
||||||
cond = @cond, body = @body, is_stateless = false
|
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
|
||||||
} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
|
%1 = call @cond(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> tensor<i1>
|
||||||
|
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
|
||||||
|
%1:2 = call @body(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
|
||||||
|
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
|
||||||
return %0#1 : tensor<1xf32>
|
return %0#1 : tensor<1xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||||
%0 = "std.constant" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
%cst = constant dense<0> : tensor<i32> loc("Const")
|
||||||
%1 = "tfl.greater"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||||
return %1 : tensor<i1>
|
return %0 : tensor<i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
|
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
|
||||||
%0 = "std.constant" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
%cst = constant dense<1> : tensor<i32> loc("Const")
|
||||||
%1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
%0 = "tfl.sub"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
%2 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||||
return %1, %2 : tensor<*xi32>, tensor<*xf32>
|
return %0, %1 : tensor<*xi32>, tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// RUN: tf-opt --tfl-legalize-tf-while %s -o - | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt --tfl-legalize-tf-while %s -o - | FileCheck %s --dump-input-on-failure
|
||||||
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline --mlir-disable-inline-simplify | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
|
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline --mlir-disable-inline-simplify | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
|
||||||
|
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline | FileCheck %s --dump-input-on-failure --check-prefix=CANON
|
||||||
|
|
||||||
func @while_main(%arg0: tensor<?x256x256xf32>) -> (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>) attributes {tf.entry_function = {inputs = "input", outputs = "Identity,Identity_1,Identity_2"}} {
|
func @while_main(%arg0: tensor<?x256x256xf32>) -> (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>) attributes {tf.entry_function = {inputs = "input", outputs = "Identity,Identity_1,Identity_2"}} {
|
||||||
%cst = constant dense<1.000000e+00> : tensor<256x256xf32>
|
%cst = constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||||
@ -51,3 +52,25 @@ func @while_cond_10_frozen0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: t
|
|||||||
// INLINE: yield
|
// INLINE: yield
|
||||||
// INLINE: while_body
|
// INLINE: while_body
|
||||||
// INLINE: while_cond
|
// INLINE: while_cond
|
||||||
|
|
||||||
|
// CANON-LABEL: func @while_main
|
||||||
|
// CANON-SAME: ([[VAL_0:%.*]]: tensor<?x256x256xf32>)
|
||||||
|
// CANON-SAME: (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>)
|
||||||
|
// CANON: [[VAL_1:%.*]] = constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||||
|
// CANON: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
|
||||||
|
// CANON: [[VAL_3:%.*]] = constant dense<10> : tensor<i32>
|
||||||
|
// CANON: [[VAL_4:%.*]] = constant dense<1> : tensor<i32>
|
||||||
|
// CANON: [[VAL_5:%.*]] = "tf.Const"() {value = dense<2.560000e+02> : tensor<256x256xf32>} : () -> tensor<?x?xf32>
|
||||||
|
// CANON: [[VAL_6:%.*]]:3 = "tfl.while"([[VAL_2]], [[VAL_2]], [[VAL_0]]) ( {
|
||||||
|
// CANON: ^bb0([[VAL_7:%.*]]: tensor<*xi32>, [[VAL_8:%.*]]: tensor<*xi32>, [[VAL_9:%.*]]: tensor<*xf32>):
|
||||||
|
// CANON: [[VAL_10:%.*]] = "tf.Less"([[VAL_8]], [[VAL_3]])
|
||||||
|
// CANON: "tfl.yield"([[VAL_10]]) : (tensor<*xi1>) -> ()
|
||||||
|
// CANON: }, {
|
||||||
|
// CANON: ^bb0([[VAL_11:%.*]]: tensor<*xi32>, [[VAL_12:%.*]]: tensor<*xi32>, [[VAL_13:%.*]]: tensor<*xf32>):
|
||||||
|
// CANON: [[VAL_14:%.*]] = "tf.AddV2"([[VAL_12]], [[VAL_4]])
|
||||||
|
// CANON: [[VAL_15:%.*]] = "tf.AddV2"([[VAL_13]], [[VAL_5]])
|
||||||
|
// CANON: [[VAL_16:%.*]] = "tf.AddV2"([[VAL_11]], [[VAL_4]])
|
||||||
|
// CANON: "tfl.yield"([[VAL_16]], [[VAL_14]], [[VAL_15]]) : (tensor<*xi32>, tensor<*xi32>, tensor<*xf32>) -> ()
|
||||||
|
// CANON: }) {is_stateless = true} : (tensor<i32>, tensor<i32>, tensor<?x256x256xf32>) -> (tensor<i32>, tensor<i32>, tensor<?x256x256xf32>)
|
||||||
|
// CANON: return [[VAL_17:%.*]]#1, [[VAL_1]], [[VAL_17]]#2 : tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>
|
||||||
|
// CANON: }
|
||||||
|
@ -123,6 +123,17 @@ func @softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
|||||||
// CHECK: "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
// CHECK: "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||||
|
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
return %0 : tensor<8x16xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: softplus
|
||||||
|
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
// CHECK-NEXT: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
// CHECK-NEXT: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
|
||||||
|
// CHECK-NEXT: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
||||||
return %0 : tensor<8x8x8x8xf32>
|
return %0 : tensor<8x8x8x8xf32>
|
||||||
@ -739,6 +750,15 @@ func @matrix_diag_v3(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
|||||||
// CHECK: return [[VAL_6]] : tensor<8x16x16xf32>
|
// CHECK: return [[VAL_6]] : tensor<8x16x16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @matrix_set_diag(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||||
|
%0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||||
|
return %0 : tensor<3x3xi32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @matrix_set_diag(
|
||||||
|
// CHECK: [[VAL_0:%.*]] = "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||||
|
// CHECK: return [[VAL_0]]
|
||||||
|
}
|
||||||
|
|
||||||
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
return %0 : tensor<8x16xf32>
|
return %0 : tensor<8x16xf32>
|
||||||
@ -1364,3 +1384,99 @@ func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
|
|||||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64>
|
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @random_uniform() -> tensor<2x5xf32> {
|
||||||
|
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||||
|
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
|
||||||
|
return %1 : tensor<2x5xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: random_uniform
|
||||||
|
// CHECK: %[[CST:.*]] = constant dense
|
||||||
|
// CHECK: return %[[CST:.*]] : tensor<2x5xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @random_uniform_no_fold(%arg0: tensor<2xi32>) -> tensor<2x5xf32> {
|
||||||
|
%1 = "tf.RandomUniform"(%arg0) { seed = 0, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
|
||||||
|
return %1 : tensor<2x5xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: random_uniform_no_fold
|
||||||
|
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
|
||||||
|
}
|
||||||
|
|
||||||
|
func @random_uniform_no_fold2(%arg0: tensor<2xi32>) -> tensor<*xf32> {
|
||||||
|
%1 = "tf.RandomUniform"(%arg0) { seed = 1, seed2 = 2} : (tensor<2xi32>) -> tensor<*xf32>
|
||||||
|
return %1 : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: random_uniform_no_fold2
|
||||||
|
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
|
||||||
|
}
|
||||||
|
|
||||||
|
func @random_uniform_no_fold3() -> tensor<2x5xf64> {
|
||||||
|
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||||
|
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf64>
|
||||||
|
return %1 : tensor<2x5xf64>
|
||||||
|
|
||||||
|
// CHECK-LABEL: random_uniform_no_fold3
|
||||||
|
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
|
||||||
|
}
|
||||||
|
|
||||||
|
func @LstmWithoutProjection(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x16xf32>) {
|
||||||
|
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x28xf32>} : () -> tensor<16x28xf32>
|
||||||
|
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x16xf32>} : () -> tensor<16x16xf32>
|
||||||
|
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16xf32>} : () -> tensor<16xf32>
|
||||||
|
%4 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x16xf32>} : () -> tensor<1x16xf32>
|
||||||
|
%5 = "tf.Const"() {device = "", dtype = f32, value = dense<-1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||||
|
%6:3 = "tf.UnidirectionalSequenceLstm"(%arg, %1, %1, %1, %1, %2, %2, %2, %2, %3, %3, %3, %3, %3, %3, %3, %5, %5, %4, %4) {_tflite_input_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19], device = ""} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1x16xf32>, tensor<1x16xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<28x1x16xf32>)
|
||||||
|
return %6#2 : tensor<28x1x16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @LstmWithoutProjection([[VAL_0:%.*]]: tensor<28x1x28xf32>) -> tensor<28x1x16xf32> {
|
||||||
|
// CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<16x28xf32>
|
||||||
|
// CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<16x16xf32>
|
||||||
|
// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<16xf32>
|
||||||
|
// CHECK: [[VAL_4:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32>
|
||||||
|
// CHECK: [[VAL_5:%.*]] = constant unit
|
||||||
|
// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32>
|
||||||
|
// CHECK: return [[VAL_6]] : tensor<28x1x16xf32>
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) {
|
||||||
|
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x16xf32>} : () -> tensor<16x16xf32>
|
||||||
|
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x8xf32>} : () -> tensor<16x8xf32>
|
||||||
|
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16xf32>} : () -> tensor<16xf32>
|
||||||
|
%4 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x16xf32>} : () -> tensor<1x16xf32>
|
||||||
|
%5 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<8x16xf32>} : () -> tensor<8x16xf32>
|
||||||
|
%6 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x8xf32>} : () -> tensor<1x8xf32>
|
||||||
|
%7 = "tf.Const"() {device = "", dtype = f32, value = dense<-1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||||
|
%8:3 = "tf.UnidirectionalSequenceLstm"(%arg, %1, %1, %1, %1, %2, %2, %2, %2, %7, %7, %7, %3, %3, %3, %3, %5, %7, %6, %4) {_tflite_input_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 18, 19], device = ""} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, tensor<1xf32>, tensor<1x8xf32>, tensor<1x16xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<28x1x8xf32>)
|
||||||
|
return %8#2 : tensor<28x1x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @LstmWithProjection(
|
||||||
|
// CHECK-SAME: [[VAL_7:%.*]]: tensor<28x1x16xf32>) -> tensor<28x1x8xf32> {
|
||||||
|
// CHECK: [[VAL_8:%.*]] = constant dense<0.000000e+00> : tensor<16x16xf32>
|
||||||
|
// CHECK: [[VAL_9:%.*]] = constant dense<0.000000e+00> : tensor<16x8xf32>
|
||||||
|
// CHECK: [[VAL_10:%.*]] = constant dense<0.000000e+00> : tensor<16xf32>
|
||||||
|
// CHECK: [[VAL_11:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32>
|
||||||
|
// CHECK: [[VAL_12:%.*]] = constant dense<0.000000e+00> : tensor<8x16xf32>
|
||||||
|
// CHECK: [[VAL_13:%.*]] = constant dense<0.000000e+00> : tensor<1x8xf32>
|
||||||
|
// CHECK: [[VAL_14:%.*]] = constant unit
|
||||||
|
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
|
||||||
|
// CHECK: return [[VAL_15]] : tensor<28x1x8xf32>
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
|
||||||
|
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<28x28xf32>} : () -> tensor<28x28xf32>
|
||||||
|
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<28xf32>} : () -> tensor<28xf32>
|
||||||
|
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x28xf32>} : () -> tensor<1x28xf32>
|
||||||
|
%4:2 = "tf.UnidirectionalSequenceRnn"(%arg, %1, %1, %2, %3) {_tflite_input_indices = [0, 1, 2, 3, 4], device = ""} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> (tensor<*xf32>, tensor<28x1x28xf32>)
|
||||||
|
return %4#1 : tensor<28x1x28xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @UnidirectionalRnn([[VAL_0:%.*]]: tensor<28x1x28xf32>) -> tensor<28x1x28xf32> {
|
||||||
|
// CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<28x28xf32>
|
||||||
|
// CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<28xf32>
|
||||||
|
// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<1x28xf32>
|
||||||
|
// CHECK: [[VAL_4:%.*]] = "tfl.unidirectional_sequence_rnn"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_3]]) {fused_activation_function = "TANH", time_major = true} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> tensor<28x1x28xf32>
|
||||||
|
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
|
||||||
|
// CHECK: }
|
||||||
|
@ -1,17 +1,15 @@
|
|||||||
// Test to verify translation & export work as intended with runtime.
|
// Test to verify translation & export work as intended with runtime.
|
||||||
|
|
||||||
// RUN: not mlir-tflite-runner --dump-interpreter-state %s 2>&1 | FileCheck %s --check-prefix ERROR --dump-input-on-failure
|
|
||||||
// RUN: tf-opt --mlir-print-debuginfo --canonicalize --tfl-while-loop-outline %s | mlir-tflite-runner --dump-interpreter-state 2>&1 | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt --mlir-print-debuginfo --canonicalize --tfl-while-loop-outline %s | mlir-tflite-runner --dump-interpreter-state 2>&1 | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
// ERROR: number of operands and results don't match
|
|
||||||
|
|
||||||
// Verify value computed:
|
// Verify value computed:
|
||||||
// ----------------------
|
// ----------------------
|
||||||
// CHECK: result: Tensor<type: FLOAT32, shape: 1, values: 96>
|
// CHECK: result: Tensor<type: FLOAT32, shape: 1, values: 96>
|
||||||
|
// CHECK: pconst: Tensor<type: INT32, shape: , values: 1>
|
||||||
|
|
||||||
// Verify tensors in interpreter state:
|
// Verify tensors in interpreter state:
|
||||||
// ------------------------------------
|
// ------------------------------------
|
||||||
// CHECK: Tensor 0 dec kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
// CHECK: Tensor 0 pconst kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||||
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||||
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
|
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
|
||||||
// CHECK-NEXT: Tensor 3 std.constant kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
// CHECK-NEXT: Tensor 3 std.constant kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||||
@ -24,12 +22,12 @@
|
|||||||
// ------------------------------------
|
// ------------------------------------
|
||||||
// CHECK: Operator Builtin Code {{[0-9]*}} WHILE
|
// CHECK: Operator Builtin Code {{[0-9]*}} WHILE
|
||||||
|
|
||||||
func @main() -> tensor<1xf32>
|
func @main() -> (tensor<1xf32>, tensor<i32>)
|
||||||
attributes {tf.entry_function = {outputs = "result"}} {
|
attributes {tf.entry_function = {outputs = "result,pconst"}} {
|
||||||
%cst = constant dense<1> : tensor<i32> loc("dec")
|
%cst = constant dense<1> : tensor<i32> loc("dec")
|
||||||
%arg0 = constant dense<5> : tensor<i32> loc("N")
|
%arg0 = constant dense<5> : tensor<i32> loc("N")
|
||||||
%arg1 = constant dense<3.0> : tensor<1xf32> loc("val")
|
%arg1 = constant dense<3.0> : tensor<1xf32> loc("val")
|
||||||
%0:2 = "tfl.while"(%arg0, %arg1, %cst) ( {
|
%0:3 = "tfl.while"(%arg0, %arg1, %cst) ( {
|
||||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>, %arg4: tensor<i32>):
|
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>, %arg4: tensor<i32>):
|
||||||
%cst_0 = constant dense<0> : tensor<i32>
|
%cst_0 = constant dense<0> : tensor<i32>
|
||||||
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||||
@ -40,7 +38,7 @@ func @main() -> tensor<1xf32>
|
|||||||
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
|
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||||
"tfl.yield"(%1, %2, %arg4) : (tensor<*xi32>, tensor<*xf32>, tensor<i32>) -> ()
|
"tfl.yield"(%1, %2, %arg4) : (tensor<*xi32>, tensor<*xf32>, tensor<i32>) -> ()
|
||||||
}) : (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<1xf32>)
|
}) : (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<1xf32>, tensor<i32>)
|
||||||
return %0#1 : tensor<1xf32>
|
return %0#1, %0#2 : tensor<1xf32>, tensor<i32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,14 +34,14 @@
|
|||||||
// CHECK-NEXT: shape: [ ],
|
// CHECK-NEXT: shape: [ ],
|
||||||
// CHECK-NEXT: type: INT32,
|
// CHECK-NEXT: type: INT32,
|
||||||
// CHECK-NEXT: buffer: 3,
|
// CHECK-NEXT: buffer: 3,
|
||||||
// CHECK-NEXT: name: "tf.While",
|
// CHECK-NEXT: name: "tfl.while",
|
||||||
// CHECK-NEXT: quantization: {
|
// CHECK-NEXT: quantization: {
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: shape: [ 1 ],
|
// CHECK-NEXT: shape: [ 1 ],
|
||||||
// CHECK-NEXT: buffer: 4,
|
// CHECK-NEXT: buffer: 4,
|
||||||
// CHECK-NEXT: name: "tf.While:1",
|
// CHECK-NEXT: name: "tfl.while:1",
|
||||||
// CHECK-NEXT: quantization: {
|
// CHECK-NEXT: quantization: {
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
@ -193,22 +193,27 @@
|
|||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||||
// While %arg0 is greater than zero, element wise add %arg1 with itself.
|
%0:2 = "tfl.while"(%arg0, %arg1) ( {
|
||||||
%0:2 = "tf.While"(%arg0, %arg1) {
|
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
|
||||||
cond = @cond, body = @body, is_stateless = false
|
%1 = call @cond(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> tensor<i1>
|
||||||
} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
|
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
|
||||||
|
%1:2 = call @body(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
|
||||||
|
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
|
||||||
return %0#1 : tensor<1xf32>
|
return %0#1 : tensor<1xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||||
%0 = "std.constant" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
%cst = constant dense<0> : tensor<i32> loc("Const")
|
||||||
%1 = "tfl.greater"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||||
return %1 : tensor<i1>
|
return %0 : tensor<i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
|
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
|
||||||
%0 = "std.constant" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
%cst = constant dense<1> : tensor<i32> loc("Const")
|
||||||
%1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
%0 = "tfl.sub"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
%2 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||||
return %1, %2 : tensor<*xi32>, tensor<*xf32>
|
return %0, %1 : tensor<*xi32>, tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
// Unary math ops
|
// Unary math ops
|
||||||
// -----
|
// -----
|
||||||
@ -593,6 +593,21 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>,
|
|||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: testLstmQuantizedType
|
||||||
|
func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
|
||||||
|
%cst = constant unit
|
||||||
|
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( {
|
||||||
|
}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||||
|
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||||
|
// CHECK: %[[RES0:.*]] = constant unit
|
||||||
|
// CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( {
|
||||||
|
// CHECK-NEXT: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||||
|
// CHECK: return %[[RES1]]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: testLstm
|
// CHECK-LABEL: testLstm
|
||||||
@ -878,6 +893,14 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @packUnranked(%arg0: tensor<2xi32>, %arg1: tensor<*xi32>) -> tensor<2x2xi32> {
|
||||||
|
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||||
|
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<*xi32>) -> tensor<2x2xi32>
|
||||||
|
return %0 : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
|
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
|
||||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
|
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
|
||||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
|
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
|
||||||
@ -1632,6 +1655,7 @@ func @testSplitOpWithMismatchTensorTypeNonSplitDim(%arg0 : tensor<16x4xf32>) ->
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL:testSplitOpWithValidTensorType
|
||||||
func @testSplitOpWithValidTensorType(%arg0 : tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>, tensor<16x2xf32>) {
|
func @testSplitOpWithValidTensorType(%arg0 : tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>, tensor<16x2xf32>) {
|
||||||
%split_dim_0 = constant dense<0> : tensor<i32>
|
%split_dim_0 = constant dense<0> : tensor<i32>
|
||||||
%0, %1 = "tfl.split"(%split_dim_0, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>)
|
%0, %1 = "tfl.split"(%split_dim_0, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>)
|
||||||
@ -1639,6 +1663,9 @@ func @testSplitOpWithValidTensorType(%arg0 : tensor<16x4xf32>) -> (tensor<8x4xf3
|
|||||||
%2, %3 = "tfl.split"(%split_dim_1, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
|
%2, %3 = "tfl.split"(%split_dim_1, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
|
||||||
%split_dim_2 = constant dense<1> : tensor<1xi32>
|
%split_dim_2 = constant dense<1> : tensor<1xi32>
|
||||||
%4, %5 = "tfl.split"(%split_dim_2, %arg0) {num_splits = 2 : i32} : (tensor<1xi32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
|
%4, %5 = "tfl.split"(%split_dim_2, %arg0) {num_splits = 2 : i32} : (tensor<1xi32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
|
||||||
|
%6:2 = "tfl.split"(%split_dim_2, %arg0) {num_splits = 2 : i32} : (tensor<1xi32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<16x?xf32>)
|
||||||
|
%7:2 = "tfl.split"(%split_dim_2, %arg0) {num_splits = 2 : i32} : (tensor<1xi32>, tensor<16x4xf32>) -> (tensor<?x2xf32>, tensor<16x?xf32>)
|
||||||
|
%8:2 = "tfl.split"(%split_dim_2, %arg0) {num_splits = 2 : i32} : (tensor<1xi32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<*xf32>)
|
||||||
return %0, %1, %2, %3, %4 : tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>, tensor<16x2xf32>
|
return %0, %1, %2, %3, %4 : tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>, tensor<16x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1984,3 +2011,32 @@ func @testDensify(%arg0: tensor<? x f32>) -> tensor<? x f32> {
|
|||||||
%0 = "tfl.densify"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
%0 = "tfl.densify"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||||
return %0 : tensor<? x f32>
|
return %0 : tensor<? x f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||||
|
%cst = constant dense<0> : tensor<i32> loc("Const")
|
||||||
|
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
|
||||||
|
%cst = constant dense<1> : tensor<i32> loc("Const1")
|
||||||
|
%0 = "tfl.sub"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
|
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||||
|
return %0, %1 : tensor<*xi32>, tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<i32> {
|
||||||
|
// expected-error @+1 {{number of operands does not match number of results}}
|
||||||
|
%0:1 = "tfl.while"(%arg0, %arg1) ( {
|
||||||
|
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
|
||||||
|
%1 = call @WhileOp_cond(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> tensor<i1>
|
||||||
|
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
|
||||||
|
%1:2 = call @WhileOp_body(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
|
||||||
|
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||||
|
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>)
|
||||||
|
return %0#0 : tensor<i32>
|
||||||
|
}
|
||||||
|
@ -717,6 +717,31 @@ func @expandDimsToReshape(%arg0: tensor<6x6x256xf32>) -> tensor<6x6x256x1xf32> {
|
|||||||
// CHECK: return %[[RESULT]]
|
// CHECK: return %[[RESULT]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: convertTrivialTransposeToReshape
|
||||||
|
func @convertTrivialTransposeToReshape(%arg0: tensor<6x6x256x1xf32>) -> tensor<1x6x6x256xf32> {
|
||||||
|
%cst = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
|
||||||
|
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<6x6x256x1xf32>, tensor<4xi32>) -> tensor<1x6x6x256xf32>
|
||||||
|
return %0 : tensor<1x6x6x256xf32>
|
||||||
|
|
||||||
|
// CHECK: [[CONST:.*]] = constant dense<[1, 6, 6, 256]> : tensor<4xi32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor<6x6x256x1xf32>, tensor<4xi32>) -> tensor<1x6x6x256xf32>
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: doNotConvertNonTrivialTransposeToReshape
|
||||||
|
func @doNotConvertNonTrivialTransposeToReshape(%arg0: tensor<6x6x256x1xf32>) -> tensor<1x6x6x256xf32> {
|
||||||
|
// Note: The dimension 0 and 1 are swapped, so it's not trivial
|
||||||
|
// (elements are not in the same order).
|
||||||
|
%cst = constant dense<[3, 1, 0, 2]> : tensor<4xi32>
|
||||||
|
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<6x6x256x1xf32>, tensor<4xi32>) -> tensor<1x6x6x256xf32>
|
||||||
|
return %0 : tensor<1x6x6x256xf32>
|
||||||
|
|
||||||
|
// CHECK: [[CONST:.*]] = constant dense<[3, 1, 0, 2]> : tensor<4xi32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "tfl.transpose"(%arg0, %[[CONST:.*]])
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: Relu1
|
// CHECK-LABEL: Relu1
|
||||||
func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||||
%cst = constant dense<-1.0> : tensor<f32>
|
%cst = constant dense<-1.0> : tensor<f32>
|
||||||
|
@ -96,3 +96,40 @@ func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
|||||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Verify unused if with functions without side-effects are removed.
|
||||||
|
|
||||||
|
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||||
|
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
|
||||||
|
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||||
|
%cst_0 = constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
%cst_1 = constant dense<0.000000e+00> : tensor<8xf32>
|
||||||
|
%cst_2 = constant dense<0.000000e+00> : tensor<8x3x3x3xf32>
|
||||||
|
%0 = "tfl.sub"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3x15x14x3xf32>, tensor<f32>) -> tensor<3x15x14x3xf32>
|
||||||
|
%1 = "tfl.greater_equal"(%arg0, %0) : (tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<3x15x14x3xi1>
|
||||||
|
%2 = "tf.All"(%1, %cst) {Tidx = i32, device = "/device:CPU:0", keep_dims = false} : (tensor<3x15x14x3xi1>, tensor<4xi32>) -> tensor<i1>
|
||||||
|
%3 = "tf.If"(%2, %2, %arg0, %0) {Tcond = i1,
|
||||||
|
else_branch = @_functionalize_if_else_branch_00, is_stateless = false,
|
||||||
|
then_branch = @_functionalize_if_then_branch_00} :
|
||||||
|
(tensor<i1>, tensor<i1>, tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<i1>
|
||||||
|
%4 = "tfl.conv_2d"(%arg0, %cst_2, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<3x15x14x3xf32>, tensor<8x3x3x3xf32>, tensor<8xf32>) -> tensor<3x15x14x8xf32>
|
||||||
|
return %4 : tensor<3x15x14x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||||
|
%cst = constant dense<false> : tensor<i1>
|
||||||
|
return %cst : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||||
|
%cst = constant dense<true> : tensor<i1>
|
||||||
|
return %cst : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @main
|
||||||
|
// CHECK-NOT: tf.If
|
||||||
|
// CHECK: return
|
||||||
|
// CHECK-NOT: func else_branch
|
||||||
|
// CHECK-NOT: func then_branch
|
||||||
|
@ -154,7 +154,7 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
module {
|
module {
|
||||||
func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
|
||||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||||
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||||
@ -165,7 +165,7 @@ func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10x
|
|||||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: func @inference_standard_lstm_7410([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<?x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
|
||||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
@ -181,7 +181,127 @@ func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10x
|
|||||||
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||||
// CHECK: [[VAL_19:%.*]] = constant unit
|
// CHECK: [[VAL_19:%.*]] = constant unit
|
||||||
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
|
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
|
||||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32>
|
||||||
// CHECK: return [[VAL_21:%.*]] : tensor<?x8x10xf32>
|
// CHECK: return [[VAL_21:%.*]] : tensor<8x?x10xf32>
|
||||||
|
// CHECK: }
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
|
||||||
|
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
|
||||||
|
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
|
||||||
|
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
|
||||||
|
%3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||||
|
%4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||||
|
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||||
|
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
|
||||||
|
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||||
|
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
|
||||||
|
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||||
|
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||||
|
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||||
|
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||||
|
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||||
|
// CHECK: [[VAL_21:%.*]] = constant unit
|
||||||
|
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
|
||||||
|
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
|
||||||
|
// CHECK: [[VAL_23:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||||
|
// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_25:%.*]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
|
||||||
|
// CHECK: return [[VAL_24]] : tensor<8x8x10xf32>
|
||||||
|
// CHECK: }
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||||
|
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||||
|
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||||
|
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||||
|
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
|
||||||
|
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x?x10xf32>
|
||||||
|
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||||
|
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||||
|
// CHECK: [[VAL_6:%.*]] = constant dense<0> : tensor<1xi32>
|
||||||
|
// CHECK: [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor<?x8x8xf32>, tensor<1xi32>) -> tensor<?x8x8xf32>
|
||||||
|
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||||
|
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||||
|
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||||
|
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||||
|
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||||
|
// CHECK: [[VAL_21:%.*]] = constant unit
|
||||||
|
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
|
||||||
|
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32>
|
||||||
|
// CHECK: return [[VAL_23:%.*]] : tensor<8x?x10xf32>
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||||
|
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
|
||||||
|
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
|
||||||
|
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
|
||||||
|
%3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||||
|
%4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||||
|
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||||
|
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||||
|
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||||
|
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
|
||||||
|
// CHECK: [[VAL_8:%.*]] = constant dense<0> : tensor<1xi32>
|
||||||
|
// CHECK: [[VAL_9:%.*]] = "tf.ReverseV2"([[VAL_7]], [[VAL_8]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32>
|
||||||
|
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_10]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||||
|
// CHECK: [[VAL_12:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: [[VAL_13:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_12]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||||
|
// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_16:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_14]], [[VAL_15]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||||
|
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_19:%.*]]:4 = "tf.SplitV"([[VAL_13]], [[VAL_17]], [[VAL_18]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||||
|
// CHECK: [[VAL_20:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_22:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_20]], [[VAL_21]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||||
|
// CHECK: [[VAL_23:%.*]] = constant unit
|
||||||
|
// CHECK: [[VAL_24:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) ( {
|
||||||
|
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
|
||||||
|
// CHECK: [[VAL_25:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||||
|
// CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_27:%.*]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
|
||||||
|
// CHECK: return [[VAL_26]] : tensor<8x8x10xf32>
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -622,3 +622,16 @@ func @QuantizeSharedBiases2(
|
|||||||
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
|
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
|
||||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ReturnQuantizedResult
|
||||||
|
func @ReturnQuantizedResult(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3xf32>, %arg2: tensor<32xf32>) -> (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) {
|
||||||
|
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||||
|
%1 = "tfl.quantize"(%0) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||||
|
%2 = "tfl.dequantize"(%1) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> (tensor<1x112x112x32xf32>)
|
||||||
|
return %0, %2 : tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[dw:.*]] = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2)
|
||||||
|
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[dw]])
|
||||||
|
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||||
|
// CHECK: return %[[dq]], %[[dq]]
|
||||||
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s
|
// RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<256x3x32x32xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x16x30x30xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<256x3x32x32xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x16x30x30xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
||||||
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<256x3x32x32xf32>) :
|
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<256x3x32x32xf32>) :
|
||||||
@ -117,6 +117,37 @@ func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor
|
|||||||
// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
|
// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @batchNormWithGlobalNormalization(
|
||||||
|
%t:tensor<1x10x10x3xf32>, %m:tensor<3xf32>, %v:tensor<3xf32>, %beta:tensor<3xf32>, %gamma:tensor<3xf32>) -> (tensor<1x10x10x3xf32>) {
|
||||||
|
%0 = "tf.BatchNormWithGlobalNormalization"(%t, %m, %v, %beta, %gamma) {T = "tfdtype$DT_FLOAT", variance_epsilon = 0.001 : f32, scale_after_normalization = false} : (tensor<1x10x10x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<1x10x10x3xf32>)
|
||||||
|
return %0 : tensor<1x10x10x3xf32>
|
||||||
|
// CHECK-LABEL: batchNormWithGlobalNormalization
|
||||||
|
// CHECK: %[[EPSILON:.*]] = constant dense<1.000000e-03>
|
||||||
|
// CHECK: %[[VARIANCE:.*]] = "tf.Add"(%[[ARG_V:.*]], %[[EPSILON]])
|
||||||
|
// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[VARIANCE]])
|
||||||
|
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG_T:.*]], %[[RSQRT]])
|
||||||
|
// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG_M:.*]], %[[RSQRT]])
|
||||||
|
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG_BETA:.*]], %[[MUL2]])
|
||||||
|
// CHECK: %[[RESULT:.*]] = "tf.Add"(%[[MUL1]], %[[SUB]])
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
}
|
||||||
|
|
||||||
|
func @batchNormWithGlobalNormalizationWithScaleAfterNormalization(
|
||||||
|
%t:tensor<1x10x10x3xf32>, %m:tensor<3xf32>, %v:tensor<3xf32>, %beta:tensor<3xf32>, %gamma:tensor<3xf32>) -> (tensor<1x10x10x3xf32>) {
|
||||||
|
%0 = "tf.BatchNormWithGlobalNormalization"(%t, %m, %v, %beta, %gamma) {T = "tfdtype$DT_FLOAT", variance_epsilon = 0.001 : f32, scale_after_normalization = true} : (tensor<1x10x10x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<1x10x10x3xf32>)
|
||||||
|
return %0 : tensor<1x10x10x3xf32>
|
||||||
|
// CHECK-LABEL: batchNormWithGlobalNormalizationWithScaleAfterNormalization
|
||||||
|
// CHECK: %[[EPSILON:.*]] = constant dense<1.000000e-03>
|
||||||
|
// CHECK: %[[VARIANCE:.*]] = "tf.Add"(%[[ARG_V:.*]], %[[EPSILON]])
|
||||||
|
// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[VARIANCE]])
|
||||||
|
// CHECK: %[[MUL0:.*]] = "tf.Mul"(%[[RSQRT]], %[[ARG_GAMMA:.*]])
|
||||||
|
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG_T:.*]], %[[MUL0]])
|
||||||
|
// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG_M:.*]], %[[MUL0]])
|
||||||
|
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG_BETA:.*]], %[[MUL2]])
|
||||||
|
// CHECK: %[[RESULT:.*]] = "tf.Add"(%[[MUL1]], %[[SUB]])
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: fakeQuantPerChannelForActivation
|
// CHECK-LABEL: fakeQuantPerChannelForActivation
|
||||||
func @fakeQuantPerChannelForActivation(%arg0: tensor<8x3xf32>) -> (tensor<8x3xf32>) {
|
func @fakeQuantPerChannelForActivation(%arg0: tensor<8x3xf32>) -> (tensor<8x3xf32>) {
|
||||||
%arg1 = constant dense<[0.0, -1.0, 1.0]> : tensor<3xf32>
|
%arg1 = constant dense<[0.0, -1.0, 1.0]> : tensor<3xf32>
|
||||||
@ -422,6 +453,30 @@ func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|||||||
// CHECK: return %arg0 : tensor<3xf32>
|
// CHECK: return %arg0 : tensor<3xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @StridedSliceEllipsisMaskBefore
|
||||||
|
func @StridedSliceEllipsisMaskBefore(%arg0: tensor<21x15x7xf32>) -> tensor<21x15x2xf32> {
|
||||||
|
%cst = constant dense<0> : tensor<2xi32>
|
||||||
|
%cst_0 = constant dense<1> : tensor<2xi32>
|
||||||
|
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<21x15x2xf32>
|
||||||
|
return %0 : tensor<21x15x2xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32>
|
||||||
|
// CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32>
|
||||||
|
// CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 3 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<21x15x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @StridedSliceEllipsisMaskAfter
|
||||||
|
func @StridedSliceEllipsisMaskAfter(%arg0: tensor<21x15x7xf32>) -> tensor<5x15x7xf32> {
|
||||||
|
%cst = constant dense<0> : tensor<2xi32>
|
||||||
|
%cst_0 = constant dense<1> : tensor<2xi32>
|
||||||
|
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<5x15x7xf32>
|
||||||
|
return %0 : tensor<5x15x7xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32>
|
||||||
|
// CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32>
|
||||||
|
// CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<5x15x7xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
|
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
|
||||||
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
|
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
|
||||||
%cst = constant dense<0> : tensor<4xi32>
|
%cst = constant dense<0> : tensor<4xi32>
|
||||||
@ -456,3 +511,34 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
|
|||||||
%1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32>
|
%1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32>
|
||||||
return %1 : tensor<1x4x64x64xf32>
|
return %1 : tensor<1x4x64x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @MatrixSetDiagV2Conversion
|
||||||
|
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||||
|
%cst = constant dense<0> : tensor<i32>
|
||||||
|
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||||
|
return %0 : tensor<3x3xi32>
|
||||||
|
|
||||||
|
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @MatrixSetDiagV2NonZeroK
|
||||||
|
func @MatrixSetDiagV2NonZeroK(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||||
|
%cst = constant dense<1> : tensor<i32>
|
||||||
|
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||||
|
return %0 : tensor<3x3xi32>
|
||||||
|
|
||||||
|
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
|
||||||
|
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV2"(%arg0, %arg1, %[[CST]]) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @MatrixSetDiagV3Conversion
|
||||||
|
func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||||
|
%cst = constant dense<0> : tensor<i32>
|
||||||
|
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||||
|
return %0 : tensor<3x3xi32>
|
||||||
|
|
||||||
|
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
}
|
||||||
|
@ -2,39 +2,44 @@
|
|||||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s
|
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s
|
||||||
|
|
||||||
// CHECK-LABEL: QuantizeFloatConst
|
// CHECK-LABEL: QuantizeFloatConst
|
||||||
func @QuantizeFloatConst() -> tensor<f32> {
|
func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||||
%0 = constant dense<-0.1> : tensor<2x2xf32>
|
%0 = constant dense<-0.1> : tensor<2x2xf32>
|
||||||
%1 = "tfl.quantize"(%0) {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||||
%2 = "tfl.dequantize"(%1) : (tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<f32>
|
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||||
return %2 : tensor<f32>
|
|
||||||
|
|
||||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
|
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
|
||||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
|
// CHECK: return %[[cst]]
|
||||||
// CHECK: return %[[dq]] : tensor<f32>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: QuantizeDenseFloatConst
|
// CHECK-LABEL: QuantizeDenseFloatConst
|
||||||
func @QuantizeDenseFloatConst() -> tensor<2x2xf32> {
|
func @QuantizeDenseFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||||
%0 = constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32>
|
%0 = constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32>
|
||||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||||
return %2 : tensor<2x2xf32>
|
|
||||||
|
|
||||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>}
|
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>}
|
||||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
|
// CHECK: return %[[cst]]
|
||||||
// CHECK: return %[[dq]] : tensor<2x2xf32>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: QuantizeSplatFloatConst
|
// CHECK-LABEL: QuantizeSplatFloatConst
|
||||||
func @QuantizeSplatFloatConst() -> tensor<2x2xf32> {
|
func @QuantizeSplatFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||||
%0 = constant dense<3.0> : tensor<2x2xf32>
|
%0 = constant dense<3.0> : tensor<2x2xf32>
|
||||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||||
|
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||||
|
|
||||||
|
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
|
||||||
|
// CHECK: return %[[cst]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: NotQuantizeFloatConst
|
||||||
|
func @NotQuantizeFloatConst() -> tensor<2x2xf32> {
|
||||||
|
%0 = constant dense<-0.1> : tensor<2x2xf32>
|
||||||
|
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
||||||
return %2 : tensor<2x2xf32>
|
return %2 : tensor<2x2xf32>
|
||||||
|
|
||||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
|
// CHECK: %[[cst:.*]] = constant dense<-1.000000e-01> : tensor<2x2xf32>
|
||||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
|
// CHECK: return %[[cst]] : tensor<2x2xf32>
|
||||||
// CHECK: return %[[dq]] : tensor<2x2xf32>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: DequantizeAndQuantize
|
// CHECK-LABEL: DequantizeAndQuantize
|
||||||
@ -71,7 +76,7 @@ func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>
|
|||||||
// DEBUG: %[[act:.*]] = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
// DEBUG: %[[act:.*]] = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
||||||
// DEBUG: %[[f_conv:.*]] = "tfl.conv_2d"(%[[act]], %[[wt]], %[[bias]])
|
// DEBUG: %[[f_conv:.*]] = "tfl.conv_2d"(%[[act]], %[[wt]], %[[bias]])
|
||||||
// DEBUG: %[[q_conv:.*]] = "tfl.conv_2d"
|
// DEBUG: %[[q_conv:.*]] = "tfl.conv_2d"
|
||||||
// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {tolerance = 1.000000e-01 : f32}
|
// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {tolerance = 5.000000e+00 : f32}
|
||||||
// DEBUG: return %[[q_conv]] : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
// DEBUG: return %[[q_conv]] : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,8 +241,8 @@ func @QuantizeSplit(%arg: tensor<4x!quant.uniform<u8:f32, 1.0>>, %cst: tensor<i3
|
|||||||
|
|
||||||
// DEUBG: %[[f_split:.*]]:2 = "tfl.split"
|
// DEUBG: %[[f_split:.*]]:2 = "tfl.split"
|
||||||
// DEUBG: %[[q_split:.*]]:2 = "tfl.split"
|
// DEUBG: %[[q_split:.*]]:2 = "tfl.split"
|
||||||
// DEUBG: "tfl.NumericVerify"(%[[q_split]]#1, %[[f_split]]#1) {tolerance = 1.000000e-01 : f32}
|
// DEUBG: "tfl.NumericVerify"(%[[q_split]]#1, %[[f_split]]#1) {tolerance = 5.000000e+00 : f32}
|
||||||
// DEUBG: "tfl.NumericVerify"(%[[q_split]]#0, %[[f_split]]#0) {tolerance = 1.000000e-01 : f32}
|
// DEUBG: "tfl.NumericVerify"(%[[q_split]]#0, %[[f_split]]#0) {tolerance = 5.000000e+00 : f32}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: QuantizeSplitUnusedResults
|
// CHECK-LABEL: QuantizeSplitUnusedResults
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user